X-Git-Url: https://scm.cri.mines-paristech.fr/git/linpy.git/blobdiff_plain/d06ab92943ec2e10a2bd798ca7c1b5cea395bf34..ce31a1de5082c55b5eed1825ae95d827c55a8b92:/pypol/linexprs.py?ds=sidebyside diff --git a/pypol/linexprs.py b/pypol/linexprs.py index 9a1ed64..5ec5efd 100644 --- a/pypol/linexprs.py +++ b/pypol/linexprs.py @@ -3,13 +3,13 @@ import functools import numbers import re -from collections import OrderedDict, defaultdict +from collections import OrderedDict, defaultdict, Mapping from fractions import Fraction, gcd __all__ = [ 'Expression', - 'Symbol', 'symbols', + 'Symbol', 'Dummy', 'symbols', 'Rational', ] @@ -45,7 +45,7 @@ class Expression: return Expression.fromstring(coefficients) if coefficients is None: return Rational(constant) - if isinstance(coefficients, dict): + if isinstance(coefficients, Mapping): coefficients = coefficients.items() for symbol, coefficient in coefficients: if not isinstance(symbol, Symbol): @@ -61,7 +61,7 @@ class Expression: self = object().__new__(cls) self._coefficients = OrderedDict() for symbol, coefficient in sorted(coefficients, - key=lambda item: item[0].name): + key=lambda item: item[0].sortkey()): if isinstance(coefficient, Rational): coefficient = coefficient.constant if not isinstance(coefficient, numbers.Rational): @@ -218,7 +218,7 @@ class Expression: def subs(self, symbol, expression=None): if expression is None: - if isinstance(symbol, dict): + if isinstance(symbol, Mapping): symbol = symbol.items() substitutions = symbol else: @@ -269,39 +269,27 @@ class Expression: def __repr__(self): string = '' - i = 0 - for symbol in self.symbols: - coefficient = self.coefficient(symbol) + for i, (symbol, coefficient) in enumerate(self.coefficients()): if coefficient == 1: - if i == 0: - string += symbol.name - else: - string += ' + {}'.format(symbol) + string += '' if i == 0 else ' + ' + string += '{!r}'.format(symbol) elif coefficient == -1: - if i == 0: - string += '-{}'.format(symbol) - else: - string += ' - {}'.format(symbol) + string += '-' if i == 0 else ' - ' + string += '{!r}'.format(symbol) else: if i == 0: - string += '{}*{}'.format(coefficient, symbol) + string += '{}*{!r}'.format(coefficient, symbol) elif coefficient > 0: - string += ' + {}*{}'.format(coefficient, symbol) + string += ' + {}*{!r}'.format(coefficient, symbol) else: - assert coefficient < 0 - coefficient *= -1 - string += ' - {}*{}'.format(coefficient, symbol) - i += 1 + string += ' - {}*{!r}'.format(-coefficient, symbol) constant = self.constant - if constant != 0 and i == 0: + if len(string) == 0: string += '{}'.format(constant) elif constant > 0: string += ' + {}'.format(constant) elif constant < 0: - constant *= -1 - string += ' - {}'.format(constant) - if string == '': - string = '0' + string += ' - {}'.format(-constant) return string def _parenstr(self, always=False): @@ -355,7 +343,7 @@ class Symbol(Expression): return self._name def __hash__(self): - return hash(self._name) + return hash(self.sortkey()) def coefficient(self, symbol): if not isinstance(symbol, Symbol): @@ -380,6 +368,9 @@ class Symbol(Expression): def dimension(self): return 1 + def sortkey(self): + return self.name, + def issymbol(self): return True @@ -387,7 +378,11 @@ class Symbol(Expression): yield 1 def __eq__(self, other): - return isinstance(other, Symbol) and self.name == other.name + return not isinstance(other, Dummy) and isinstance(other, Symbol) \ + and self.name == other.name + + def asdummy(self): + return Dummy(self.name) @classmethod def _fromast(cls, node): @@ -399,15 +394,49 @@ class Symbol(Expression): return Symbol(node.id) raise SyntaxError('invalid syntax') + def __repr__(self): + return self.name + @classmethod def fromsympy(cls, expr): import sympy if isinstance(expr, sympy.Symbol): - return Symbol(expr.name) + return cls(expr.name) else: raise TypeError('expr must be a sympy.Symbol instance') +class Dummy(Symbol): + + __slots__ = ( + '_name', + '_index', + ) + + _count = 0 + + def __new__(cls, name=None): + if name is None: + name = 'Dummy_{}'.format(Dummy._count) + self = object().__new__(cls) + self._name = name.strip() + self._index = Dummy._count + Dummy._count += 1 + return self + + def __hash__(self): + return hash(self.sortkey()) + + def sortkey(self): + return self._name, self._index + + def __eq__(self, other): + return isinstance(other, Dummy) and self._index == other._index + + def __repr__(self): + return '_{}'.format(self.name) + + def symbols(names): if isinstance(names, str): names = names.replace(',', ' ').split()