X-Git-Url: https://scm.cri.mines-paristech.fr/git/linpy.git/blobdiff_plain/b046bdcc9044a88743a98bb06951f917edafae59..d4b772a5d2f29c4f54564ea09f5b65289cadcaa1:/pypol/linexprs.py?ds=sidebyside diff --git a/pypol/linexprs.py b/pypol/linexprs.py index 229e8d9..b74628b 100644 --- a/pypol/linexprs.py +++ b/pypol/linexprs.py @@ -134,22 +134,27 @@ class Expression: def __rsub__(self, other): return -(self - other) - @_polymorphic def __mul__(self, other): - if isinstance(other, Rational): - return other.__rmul__(self) + if isinstance(other, numbers.Rational): + coefficients = dict(self._coefficients) + for symbol in coefficients: + coefficients[symbol] *= other + constant = self._constant * other + return Expression(coefficients, constant) return NotImplemented __rmul__ = __mul__ - @_polymorphic def __truediv__(self, other): - if isinstance(other, Rational): - return other.__rtruediv__(self) + if isinstance(other, numbers.Rational): + coefficients = dict(self._coefficients) + for symbol in coefficients: + coefficients[symbol] /= other + constant = self._constant / other + # import pdb; pdb.set_trace() + return Expression(coefficients, constant) return NotImplemented - __rtruediv__ = __truediv__ - @_polymorphic def __eq__(self, other): # "normal" equality @@ -240,18 +245,17 @@ class Expression: string = '' for i, (symbol, coefficient) in enumerate(self.coefficients()): if coefficient == 1: - string += '' if i == 0 else ' + ' - string += '{!r}'.format(symbol) + if i != 0: + string += ' + ' elif coefficient == -1: string += '-' if i == 0 else ' - ' - string += '{!r}'.format(symbol) + elif i == 0: + string += '{}*'.format(coefficient) + elif coefficient > 0: + string += ' + {}*'.format(coefficient) else: - if i == 0: - string += '{}*{!r}'.format(coefficient, symbol) - elif coefficient > 0: - string += ' + {}*{!r}'.format(coefficient, symbol) - else: - string += ' - {}*{!r}'.format(-coefficient, symbol) + string += ' - {}*'.format(-coefficient) + string += '{}'.format(symbol) constant = self.constant if len(string) == 0: string += '{}'.format(constant) @@ -261,6 +265,30 @@ class Expression: string += ' - {}'.format(-constant) return string + def _repr_latex_(self): + string = '' + for i, (symbol, coefficient) in enumerate(self.coefficients()): + if coefficient == 1: + if i != 0: + string += ' + ' + elif coefficient == -1: + string += '-' if i == 0 else ' - ' + elif i == 0: + string += '{}'.format(coefficient._repr_latex_().strip('$')) + elif coefficient > 0: + string += ' + {}'.format(coefficient._repr_latex_().strip('$')) + elif coefficient < 0: + string += ' - {}'.format((-coefficient)._repr_latex_().strip('$')) + string += '{}'.format(symbol._repr_latex_().strip('$')) + constant = self.constant + if len(string) == 0: + string += '{}'.format(constant._repr_latex_().strip('$')) + elif constant > 0: + string += ' + {}'.format(constant._repr_latex_().strip('$')) + elif constant < 0: + string += ' - {}'.format((-constant)._repr_latex_().strip('$')) + return '${}$'.format(string) + def _parenstr(self, always=False): string = str(self) if not always and (self.isconstant() or self.issymbol()): @@ -301,8 +329,8 @@ class Symbol(Expression): raise TypeError('name must be a string') self = object().__new__(cls) self._name = name.strip() - self._coefficients = {self: 1} - self._constant = 0 + self._coefficients = {self: Fraction(1)} + self._constant = Fraction(0) self._symbols = (self,) self._dimension = 1 return self @@ -340,6 +368,9 @@ class Symbol(Expression): def __repr__(self): return self.name + def _repr_latex_(self): + return '${}$'.format(self.name) + @classmethod def fromsympy(cls, expr): import sympy @@ -359,8 +390,8 @@ class Dummy(Symbol): self = object().__new__(cls) self._index = Dummy._count self._name = name.strip() - self._coefficients = {self: 1} - self._constant = 0 + self._coefficients = {self: Fraction(1)} + self._constant = Fraction(0) self._symbols = (self,) self._dimension = 1 Dummy._count += 1 @@ -378,6 +409,9 @@ class Dummy(Symbol): def __repr__(self): return '_{}'.format(self.name) + def _repr_latex_(self): + return '${}_{{{}}}$'.format(self.name, self._index) + def symbols(names): if isinstance(names, str): @@ -408,29 +442,27 @@ class Rational(Expression, Fraction): def __bool__(self): return Fraction.__bool__(self) - @_polymorphic - def __mul__(self, other): - coefficients = dict(other._coefficients) - for symbol in coefficients: - coefficients[symbol] *= self._constant - constant = other._constant * self._constant - return Expression(coefficients, constant) - - __rmul__ = __mul__ - - @_polymorphic - def __rtruediv__(self, other): - coefficients = dict(other._coefficients) - for symbol in coefficients: - coefficients[symbol] /= self._constant - constant = other._constant / self._constant - return Expression(coefficients, constant) - @classmethod def fromstring(cls, string): if not isinstance(string, str): raise TypeError('string must be a string instance') - return Rational(Fraction(string)) + return Rational(string) + + def __repr__(self): + if self.denominator == 1: + return '{!r}'.format(self.numerator) + else: + return '{!r}/{!r}'.format(self.numerator, self.denominator) + + def _repr_latex_(self): + if self.denominator == 1: + return '${}$'.format(self.numerator) + elif self.numerator < 0: + return '$-\\frac{{{}}}{{{}}}$'.format(-self.numerator, + self.denominator) + else: + return '$\\frac{{{}}}{{{}}}$'.format(self.numerator, + self.denominator) @classmethod def fromsympy(cls, expr):