Simplify Expression.__mul__(), Expression.__truediv__()
[linpy.git] / pypol / linexprs.py
index 229e8d9..b74628b 100644 (file)
@@ -134,22 +134,27 @@ class Expression:
     def __rsub__(self, other):
         return -(self - other)
 
     def __rsub__(self, other):
         return -(self - other)
 
-    @_polymorphic
     def __mul__(self, other):
     def __mul__(self, other):
-        if isinstance(other, Rational):
-            return other.__rmul__(self)
+        if isinstance(other, numbers.Rational):
+            coefficients = dict(self._coefficients)
+            for symbol in coefficients:
+                coefficients[symbol] *= other
+            constant = self._constant * other
+            return Expression(coefficients, constant)
         return NotImplemented
 
     __rmul__ = __mul__
 
         return NotImplemented
 
     __rmul__ = __mul__
 
-    @_polymorphic
     def __truediv__(self, other):
     def __truediv__(self, other):
-        if isinstance(other, Rational):
-            return other.__rtruediv__(self)
+        if isinstance(other, numbers.Rational):
+            coefficients = dict(self._coefficients)
+            for symbol in coefficients:
+                coefficients[symbol] /= other
+            constant = self._constant / other
+            # import pdb; pdb.set_trace()
+            return Expression(coefficients, constant)
         return NotImplemented
 
         return NotImplemented
 
-    __rtruediv__ = __truediv__
-
     @_polymorphic
     def __eq__(self, other):
         # "normal" equality
     @_polymorphic
     def __eq__(self, other):
         # "normal" equality
@@ -240,18 +245,17 @@ class Expression:
         string = ''
         for i, (symbol, coefficient) in enumerate(self.coefficients()):
             if coefficient == 1:
         string = ''
         for i, (symbol, coefficient) in enumerate(self.coefficients()):
             if coefficient == 1:
-                string += '' if i == 0 else ' + '
-                string += '{!r}'.format(symbol)
+                if i != 0:
+                    string += ' + '
             elif coefficient == -1:
                 string += '-' if i == 0 else ' - '
             elif coefficient == -1:
                 string += '-' if i == 0 else ' - '
-                string += '{!r}'.format(symbol)
+            elif i == 0:
+                string += '{}*'.format(coefficient)
+            elif coefficient > 0:
+                string += ' + {}*'.format(coefficient)
             else:
             else:
-                if i == 0:
-                    string += '{}*{!r}'.format(coefficient, symbol)
-                elif coefficient > 0:
-                    string += ' + {}*{!r}'.format(coefficient, symbol)
-                else:
-                    string += ' - {}*{!r}'.format(-coefficient, symbol)
+                string += ' - {}*'.format(-coefficient)
+            string += '{}'.format(symbol)
         constant = self.constant
         if len(string) == 0:
             string += '{}'.format(constant)
         constant = self.constant
         if len(string) == 0:
             string += '{}'.format(constant)
@@ -261,6 +265,30 @@ class Expression:
             string += ' - {}'.format(-constant)
         return string
 
             string += ' - {}'.format(-constant)
         return string
 
+    def _repr_latex_(self):
+        string = ''
+        for i, (symbol, coefficient) in enumerate(self.coefficients()):
+            if coefficient == 1:
+                if i != 0:
+                    string += ' + '
+            elif coefficient == -1:
+                string += '-' if i == 0 else ' - '
+            elif i == 0:
+                string += '{}'.format(coefficient._repr_latex_().strip('$'))
+            elif coefficient > 0:
+                string += ' + {}'.format(coefficient._repr_latex_().strip('$'))
+            elif coefficient < 0:
+                string += ' - {}'.format((-coefficient)._repr_latex_().strip('$'))
+            string += '{}'.format(symbol._repr_latex_().strip('$'))
+        constant = self.constant
+        if len(string) == 0:
+            string += '{}'.format(constant._repr_latex_().strip('$'))
+        elif constant > 0:
+            string += ' + {}'.format(constant._repr_latex_().strip('$'))
+        elif constant < 0:
+            string += ' - {}'.format((-constant)._repr_latex_().strip('$'))
+        return '${}$'.format(string)
+
     def _parenstr(self, always=False):
         string = str(self)
         if not always and (self.isconstant() or self.issymbol()):
     def _parenstr(self, always=False):
         string = str(self)
         if not always and (self.isconstant() or self.issymbol()):
@@ -301,8 +329,8 @@ class Symbol(Expression):
             raise TypeError('name must be a string')
         self = object().__new__(cls)
         self._name = name.strip()
             raise TypeError('name must be a string')
         self = object().__new__(cls)
         self._name = name.strip()
-        self._coefficients = {self: 1}
-        self._constant = 0
+        self._coefficients = {self: Fraction(1)}
+        self._constant = Fraction(0)
         self._symbols = (self,)
         self._dimension = 1
         return self
         self._symbols = (self,)
         self._dimension = 1
         return self
@@ -340,6 +368,9 @@ class Symbol(Expression):
     def __repr__(self):
         return self.name
 
     def __repr__(self):
         return self.name
 
+    def _repr_latex_(self):
+        return '${}$'.format(self.name)
+
     @classmethod
     def fromsympy(cls, expr):
         import sympy
     @classmethod
     def fromsympy(cls, expr):
         import sympy
@@ -359,8 +390,8 @@ class Dummy(Symbol):
         self = object().__new__(cls)
         self._index = Dummy._count
         self._name = name.strip()
         self = object().__new__(cls)
         self._index = Dummy._count
         self._name = name.strip()
-        self._coefficients = {self: 1}
-        self._constant = 0
+        self._coefficients = {self: Fraction(1)}
+        self._constant = Fraction(0)
         self._symbols = (self,)
         self._dimension = 1
         Dummy._count += 1
         self._symbols = (self,)
         self._dimension = 1
         Dummy._count += 1
@@ -378,6 +409,9 @@ class Dummy(Symbol):
     def __repr__(self):
         return '_{}'.format(self.name)
 
     def __repr__(self):
         return '_{}'.format(self.name)
 
+    def _repr_latex_(self):
+        return '${}_{{{}}}$'.format(self.name, self._index)
+
 
 def symbols(names):
     if isinstance(names, str):
 
 def symbols(names):
     if isinstance(names, str):
@@ -408,29 +442,27 @@ class Rational(Expression, Fraction):
     def __bool__(self):
         return Fraction.__bool__(self)
 
     def __bool__(self):
         return Fraction.__bool__(self)
 
-    @_polymorphic
-    def __mul__(self, other):
-        coefficients = dict(other._coefficients)
-        for symbol in coefficients:
-            coefficients[symbol] *= self._constant
-        constant = other._constant * self._constant
-        return Expression(coefficients, constant)
-
-    __rmul__ = __mul__
-
-    @_polymorphic
-    def __rtruediv__(self, other):
-        coefficients = dict(other._coefficients)
-        for symbol in coefficients:
-            coefficients[symbol] /= self._constant
-        constant = other._constant / self._constant
-        return Expression(coefficients, constant)
-
     @classmethod
     def fromstring(cls, string):
         if not isinstance(string, str):
             raise TypeError('string must be a string instance')
     @classmethod
     def fromstring(cls, string):
         if not isinstance(string, str):
             raise TypeError('string must be a string instance')
-        return Rational(Fraction(string))
+        return Rational(string)
+
+    def __repr__(self):
+        if self.denominator == 1:
+            return '{!r}'.format(self.numerator)
+        else:
+            return '{!r}/{!r}'.format(self.numerator, self.denominator)
+
+    def _repr_latex_(self):
+        if self.denominator == 1:
+            return '${}$'.format(self.numerator)
+        elif self.numerator < 0:
+            return '$-\\frac{{{}}}{{{}}}$'.format(-self.numerator,
+                self.denominator)
+        else:
+            return '$\\frac{{{}}}{{{}}}$'.format(self.numerator,
+                self.denominator)
 
     @classmethod
     def fromsympy(cls, expr):
 
     @classmethod
     def fromsympy(cls, expr):