Add _repr_latex_ methods for IPython prettyprint
authorVivien Maisonneuve <v.maisonneuve@gmail.com>
Sun, 13 Jul 2014 07:04:48 +0000 (09:04 +0200)
committerVivien Maisonneuve <v.maisonneuve@gmail.com>
Sun, 13 Jul 2014 07:12:33 +0000 (09:12 +0200)
pypol/domains.py
pypol/linexprs.py
pypol/polyhedra.py
pypol/tests/test_linexprs.py
pypol/tests/test_polyhedra.py

index d80fd91..10d12c5 100644 (file)
@@ -438,6 +438,12 @@ class Domain(GeometricObject):
         strings = [repr(polyhedron) for polyhedron in self.polyhedra]
         return 'Or({})'.format(', '.join(strings))
 
         strings = [repr(polyhedron) for polyhedron in self.polyhedra]
         return 'Or({})'.format(', '.join(strings))
 
+    def _repr_latex_(self):
+        strings = []
+        for polyhedron in self.polyhedra:
+            strings.append('({})'.format(polyhedron._repr_latex_().strip('$')))
+        return '${}$'.format(' \\vee '.join(strings))
+
     @classmethod
     def fromsympy(cls, expr):
         import sympy
     @classmethod
     def fromsympy(cls, expr):
         import sympy
index 229e8d9..c8745b5 100644 (file)
@@ -240,18 +240,17 @@ class Expression:
         string = ''
         for i, (symbol, coefficient) in enumerate(self.coefficients()):
             if coefficient == 1:
         string = ''
         for i, (symbol, coefficient) in enumerate(self.coefficients()):
             if coefficient == 1:
-                string += '' if i == 0 else ' + '
-                string += '{!r}'.format(symbol)
+                if i != 0:
+                    string += ' + '
             elif coefficient == -1:
                 string += '-' if i == 0 else ' - '
             elif coefficient == -1:
                 string += '-' if i == 0 else ' - '
-                string += '{!r}'.format(symbol)
+            elif i == 0:
+                string += '{}*'.format(coefficient)
+            elif coefficient > 0:
+                string += ' + {}*'.format(coefficient)
             else:
             else:
-                if i == 0:
-                    string += '{}*{!r}'.format(coefficient, symbol)
-                elif coefficient > 0:
-                    string += ' + {}*{!r}'.format(coefficient, symbol)
-                else:
-                    string += ' - {}*{!r}'.format(-coefficient, symbol)
+                string += ' - {}*'.format(-coefficient)
+            string += '{}'.format(symbol)
         constant = self.constant
         if len(string) == 0:
             string += '{}'.format(constant)
         constant = self.constant
         if len(string) == 0:
             string += '{}'.format(constant)
@@ -261,6 +260,30 @@ class Expression:
             string += ' - {}'.format(-constant)
         return string
 
             string += ' - {}'.format(-constant)
         return string
 
+    def _repr_latex_(self):
+        string = ''
+        for i, (symbol, coefficient) in enumerate(self.coefficients()):
+            if coefficient == 1:
+                if i != 0:
+                    string += ' + '
+            elif coefficient == -1:
+                string += '-' if i == 0 else ' - '
+            elif i == 0:
+                string += '{}'.format(coefficient._repr_latex_().strip('$'))
+            elif coefficient > 0:
+                string += ' + {}'.format(coefficient._repr_latex_().strip('$'))
+            elif coefficient < 0:
+                string += ' - {}'.format((-coefficient)._repr_latex_().strip('$'))
+            string += '{}'.format(symbol._repr_latex_().strip('$'))
+        constant = self.constant
+        if len(string) == 0:
+            string += '{}'.format(constant._repr_latex_().strip('$'))
+        elif constant > 0:
+            string += ' + {}'.format(constant._repr_latex_().strip('$'))
+        elif constant < 0:
+            string += ' - {}'.format((-constant)._repr_latex_().strip('$'))
+        return '${}$'.format(string)
+
     def _parenstr(self, always=False):
         string = str(self)
         if not always and (self.isconstant() or self.issymbol()):
     def _parenstr(self, always=False):
         string = str(self)
         if not always and (self.isconstant() or self.issymbol()):
@@ -340,6 +363,9 @@ class Symbol(Expression):
     def __repr__(self):
         return self.name
 
     def __repr__(self):
         return self.name
 
+    def _repr_latex_(self):
+        return '${}$'.format(self.name)
+
     @classmethod
     def fromsympy(cls, expr):
         import sympy
     @classmethod
     def fromsympy(cls, expr):
         import sympy
@@ -378,6 +404,9 @@ class Dummy(Symbol):
     def __repr__(self):
         return '_{}'.format(self.name)
 
     def __repr__(self):
         return '_{}'.format(self.name)
 
+    def _repr_latex_(self):
+        return '${}_{{{}}}$'.format(self.name, self._index)
+
 
 def symbols(names):
     if isinstance(names, str):
 
 def symbols(names):
     if isinstance(names, str):
@@ -430,7 +459,23 @@ class Rational(Expression, Fraction):
     def fromstring(cls, string):
         if not isinstance(string, str):
             raise TypeError('string must be a string instance')
     def fromstring(cls, string):
         if not isinstance(string, str):
             raise TypeError('string must be a string instance')
-        return Rational(Fraction(string))
+        return Rational(string)
+
+    def __repr__(self):
+        if self.denominator == 1:
+            return '{!r}'.format(self.numerator)
+        else:
+            return '{!r}/{!r}'.format(self.numerator, self.denominator)
+
+    def _repr_latex_(self):
+        if self.denominator == 1:
+            return '${}$'.format(self.numerator)
+        elif self.numerator < 0:
+            return '$-\\frac{{{}}}{{{}}}$'.format(-self.numerator,
+                self.denominator)
+        else:
+            return '$\\frac{{{}}}{{{}}}$'.format(self.numerator,
+                self.denominator)
 
     @classmethod
     def fromsympy(cls, expr):
 
     @classmethod
     def fromsympy(cls, expr):
index e745d7d..6b5f9ab 100644 (file)
@@ -182,14 +182,27 @@ class Polyhedron(Domain):
         else:
             strings = []
             for equality in self.equalities:
         else:
             strings = []
             for equality in self.equalities:
-                strings.append('0 == {}'.format(equality))
+                strings.append('Eq({}, 0)'.format(equality))
             for inequality in self.inequalities:
             for inequality in self.inequalities:
-                strings.append('0 <= {}'.format(inequality))
+                strings.append('Ge({}, 0)'.format(inequality))
             if len(strings) == 1:
                 return strings[0]
             else:
                 return 'And({})'.format(', '.join(strings))
 
             if len(strings) == 1:
                 return strings[0]
             else:
                 return 'And({})'.format(', '.join(strings))
 
+    def _repr_latex_(self):
+        if self.isempty():
+            return '$\\emptyset$'
+        elif self.isuniverse():
+            return '$\\Omega$'
+        else:
+            strings = []
+            for equality in self.equalities:
+                strings.append('{} = 0'.format(equality._repr_latex_().strip('$')))
+            for inequality in self.inequalities:
+                strings.append('{} \\ge 0'.format(inequality._repr_latex_().strip('$')))
+            return '${}$'.format(' \\wedge '.join(strings))
+
     @classmethod
     def fromsympy(cls, expr):
         domain = Domain.fromsympy(expr)
     @classmethod
     def fromsympy(cls, expr):
         domain = Domain.fromsympy(expr)
index 6ec8993..01f844f 100644 (file)
@@ -275,7 +275,7 @@ class TestRational(unittest.TestCase):
     def setUp(self):
         self.zero = Rational(0)
         self.one = Rational(1)
     def setUp(self):
         self.zero = Rational(0)
         self.one = Rational(1)
-        self.pi = Rational(Fraction(22, 7))
+        self.pi = Rational(22, 7)
 
     def test_new(self):
         self.assertEqual(Rational(), self.zero)
 
     def test_new(self):
         self.assertEqual(Rational(), self.zero)
index 689602b..e813609 100644 (file)
@@ -20,11 +20,11 @@ class TestPolyhedron(unittest.TestCase):
 
     def test_str(self):
         self.assertEqual(str(self.square),
 
     def test_str(self):
         self.assertEqual(str(self.square),
-            'And(0 <= x, 0 <= -x + 1, 0 <= y, 0 <= -y + 1)')
+            'And(Ge(x, 0), Ge(-x + 1, 0), Ge(y, 0), Ge(-y + 1, 0))')
 
     def test_repr(self):
         self.assertEqual(repr(self.square),
 
     def test_repr(self):
         self.assertEqual(repr(self.square),
-            "And(0 <= x, 0 <= -x + 1, 0 <= y, 0 <= -y + 1)")
+            "And(Ge(x, 0), Ge(-x + 1, 0), Ge(y, 0), Ge(-y + 1, 0))")
 
     def test_fromstring(self):
         self.assertEqual(Polyhedron.fromstring('{x >= 0, -x + 1 >= 0, '
 
     def test_fromstring(self):
         self.assertEqual(Polyhedron.fromstring('{x >= 0, -x + 1 >= 0, '