X-Git-Url: https://scm.cri.mines-paristech.fr/git/linpy.git/blobdiff_plain/b02f9551644488e5943f968ac847fe4ed7690d6b..c3149dd6dbf0dd296c85676fcf9f997ead2470f0:/linpy/linexprs.py diff --git a/linpy/linexprs.py b/linpy/linexprs.py index b2cec53..d2554a0 100644 --- a/linpy/linexprs.py +++ b/linpy/linexprs.py @@ -122,7 +122,7 @@ class LinExpr: """ if not isinstance(symbol, Symbol): raise TypeError('symbol must be a Symbol instance') - return self._coefficients.get(symbol, 0) + return self._coefficients.get(symbol, Fraction(0)) __getitem__ = coefficient @@ -131,8 +131,7 @@ class LinExpr: Iterate over the pairs (symbol, value) of linear terms in the expression. The constant term is ignored. """ - for symbol, coefficient in self._coefficients.items(): - yield symbol, coefficient + yield from self._coefficients.items() @property def constant(self): @@ -179,8 +178,7 @@ class LinExpr: Iterate over the coefficient values in the expression, and the constant term. """ - for coefficient in self._coefficients.values(): - yield coefficient + yield from self._coefficients.values() yield self._constant def __bool__(self): @@ -249,9 +247,10 @@ class LinExpr: """ Test whether two linear expressions are equal. """ - return isinstance(other, LinExpr) and \ - self._coefficients == other._coefficients and \ - self._constant == other._constant + if isinstance(other, LinExpr): + return self._coefficients == other._coefficients and \ + self._constant == other._constant + return NotImplemented def __le__(self, other): from .polyhedra import Le @@ -274,9 +273,9 @@ class LinExpr: Return the expression multiplied by its lowest common denominator to make all values integer. """ - lcm = functools.reduce(lambda a, b: a*b // gcd(a, b), + lcd = functools.reduce(lambda a, b: a*b // gcd(a, b), [value.denominator for value in self.values()]) - return self * lcm + return self * lcd def subs(self, symbol, expression=None): """ @@ -295,21 +294,16 @@ class LinExpr: 2*x + y + 1 """ if expression is None: - if isinstance(symbol, Mapping): - symbol = symbol.items() - substitutions = symbol + substitutions = dict(symbol) else: - substitutions = [(symbol, expression)] - result = self - for symbol, expression in substitutions: + substitutions = {symbol: expression} + for symbol in substitutions: if not isinstance(symbol, Symbol): raise TypeError('symbols must be Symbol instances') - coefficients = [(othersymbol, coefficient) - for othersymbol, coefficient in result._coefficients.items() - if othersymbol != symbol] - coefficient = result._coefficients.get(symbol, 0) - constant = result._constant - result = LinExpr(coefficients, constant) + coefficient*expression + result = self._constant + for symbol, coefficient in self._coefficients.items(): + expression = substitutions.get(symbol, symbol) + result += coefficient * expression return result @classmethod @@ -337,7 +331,7 @@ class LinExpr: return left / right raise SyntaxError('invalid syntax') - _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()') + _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d]\w*|\()') @classmethod def fromstring(cls, string): @@ -345,7 +339,7 @@ class LinExpr: Create an expression from a string. Raise SyntaxError if the string is not properly formatted. """ - # add implicit multiplication operators, e.g. '5x' -> '5*x' + # Add implicit multiplication operators, e.g. '5x' -> '5*x'. string = LinExpr._RE_NUM_VAR.sub(r'\1*\2', string) tree = ast.parse(string, 'eval') expr = cls._fromast(tree) @@ -422,7 +416,8 @@ class LinExpr: if symbol == sympy.S.One: constant = coefficient elif isinstance(symbol, sympy.Dummy): - # we cannot properly convert dummy symbols + # We cannot properly convert dummy symbols with respect to + # symbol equalities. raise TypeError('cannot convert dummy symbols') elif isinstance(symbol, sympy.Symbol): symbol = Symbol(symbol.name) @@ -456,6 +451,13 @@ class Symbol(LinExpr): Two instances of Symbol are equal if they have the same name. """ + __slots__ = ( + '_name', + '_constant', + '_symbols', + '_dimension', + ) + def __new__(cls, name): """ Return a symbol with the name string given in argument. @@ -469,12 +471,17 @@ class Symbol(LinExpr): raise SyntaxError('invalid syntax') self = object().__new__(cls) self._name = name - self._coefficients = {self: Fraction(1)} self._constant = Fraction(0) self._symbols = (self,) self._dimension = 1 return self + @property + def _coefficients(self): + # This is not implemented as an attribute, because __hash__ is not + # callable in __new__ in class Dummy. + return {self: Fraction(1)} + @property def name(self): """ @@ -499,7 +506,9 @@ class Symbol(LinExpr): return True def __eq__(self, other): - return self.sortkey() == other.sortkey() + if isinstance(other, Symbol): + return self.sortkey() == other.sortkey() + return NotImplemented def asdummy(self): """ @@ -557,15 +566,8 @@ class Dummy(Symbol): """ if name is None: name = 'Dummy_{}'.format(Dummy._count) - elif not isinstance(name, str): - raise TypeError('name must be a string') - self = object().__new__(cls) + self = super().__new__(cls, name) self._index = Dummy._count - self._name = name.strip() - self._coefficients = {self: Fraction(1)} - self._constant = Fraction(0) - self._symbols = (self,) - self._dimension = 1 Dummy._count += 1 return self @@ -590,6 +592,13 @@ class Rational(LinExpr, Fraction): fractions.Fraction classes. """ + __slots__ = ( + '_coefficients', + '_constant', + '_symbols', + '_dimension', + ) + Fraction.__slots__ + def __new__(cls, numerator=0, denominator=None): self = object().__new__(cls) self._coefficients = {}