X-Git-Url: https://scm.cri.mines-paristech.fr/git/linpy.git/blobdiff_plain/ba15f3f33f837b1291f74bc94081e99b860d3228..refs/heads/master:/linpy/linexprs.py?ds=inline diff --git a/linpy/linexprs.py b/linpy/linexprs.py index 82d75d0..ccfbbfa 100644 --- a/linpy/linexprs.py +++ b/linpy/linexprs.py @@ -20,14 +20,16 @@ import functools import numbers import re -from collections import OrderedDict, defaultdict, Mapping +from collections import defaultdict, Mapping, OrderedDict from fractions import Fraction, gcd __all__ = [ + 'Dummy', 'LinExpr', - 'Symbol', 'Dummy', 'symbols', 'Rational', + 'Symbol', + 'symbols', ] @@ -59,10 +61,10 @@ class LinExpr: def __new__(cls, coefficients=None, constant=0): """ Return a linear expression from a dictionary or a sequence, that maps - symbols to their coefficients, and a constant term. The coefficients and - the constant term must be rational numbers. + symbols to their coefficients, and a constant term. The coefficients + and the constant term must be rational numbers. - For example, the linear expression x + 2y + 1 can be constructed using + For example, the linear expression x + 2*y + 1 can be constructed using one of the following instructions: >>> x, y = symbols('x y') @@ -76,12 +78,12 @@ class LinExpr: Alternatively, linear expressions can be constructed from a string: - >>> LinExpr('x + 2*y + 1') + >>> LinExpr('x + 2y + 1') A linear expression with a single symbol of coefficient 1 and no - constant term is automatically subclassed as a Symbol instance. A linear - expression with no symbol, only a constant term, is automatically - subclassed as a Rational instance. + constant term is automatically subclassed as a Symbol instance. A + linear expression with no symbol, only a constant term, is + automatically subclassed as a Rational instance. """ if isinstance(coefficients, str): if constant != 0: @@ -105,8 +107,9 @@ class LinExpr: symbol, coefficient = coefficients[0] if coefficient == 1: return symbol - coefficients = [(symbol, Fraction(coefficient)) - for symbol, coefficient in coefficients if coefficient != 0] + coefficients = [(symbol_, Fraction(coefficient_)) + for symbol_, coefficient_ in coefficients + if coefficient_ != 0] coefficients.sort(key=lambda item: item[0].sortkey()) self = object().__new__(cls) self._coefficients = OrderedDict(coefficients) @@ -122,7 +125,7 @@ class LinExpr: """ if not isinstance(symbol, Symbol): raise TypeError('symbol must be a Symbol instance') - return Rational(self._coefficients.get(symbol, 0)) + return self._coefficients.get(symbol, Fraction(0)) __getitem__ = coefficient @@ -131,15 +134,14 @@ class LinExpr: Iterate over the pairs (symbol, value) of linear terms in the expression. The constant term is ignored. """ - for symbol, coefficient in self._coefficients.items(): - yield symbol, Rational(coefficient) + yield from self._coefficients.items() @property def constant(self): """ The constant term of the expression. """ - return Rational(self._constant) + return self._constant @property def symbols(self): @@ -179,9 +181,8 @@ class LinExpr: Iterate over the coefficient values in the expression, and the constant term. """ - for coefficient in self._coefficients.values(): - yield Rational(coefficient) - yield Rational(self._constant) + yield from self._coefficients.values() + yield self._constant def __bool__(self): return True @@ -225,7 +226,8 @@ class LinExpr: Return the product of the linear expression by a rational. """ if isinstance(other, numbers.Rational): - coefficients = ((symbol, coefficient * other) + coefficients = ( + (symbol, coefficient * other) for symbol, coefficient in self._coefficients.items()) constant = self._constant * other return LinExpr(coefficients, constant) @@ -238,7 +240,8 @@ class LinExpr: Return the quotient of the linear expression by a rational. """ if isinstance(other, numbers.Rational): - coefficients = ((symbol, coefficient / other) + coefficients = ( + (symbol, coefficient / other) for symbol, coefficient in self._coefficients.items()) constant = self._constant / other return LinExpr(coefficients, constant) @@ -247,36 +250,43 @@ class LinExpr: @_polymorphic def __eq__(self, other): """ - Test whether two linear expressions are equal. + Test whether two linear expressions are equal. Unlike methods + LinExpr.__lt__(), LinExpr.__le__(), LinExpr.__ge__(), LinExpr.__gt__(), + the result is a boolean value, not a polyhedron. To express that two + linear expressions are equal or not equal, use functions Eq() and Ne() + instead. """ - return isinstance(other, LinExpr) and \ - self._coefficients == other._coefficients and \ + return self._coefficients == other._coefficients and \ self._constant == other._constant - def __le__(self, other): - from .polyhedra import Le - return Le(self, other) - + @_polymorphic def __lt__(self, other): - from .polyhedra import Lt - return Lt(self, other) + from .polyhedra import Polyhedron + return Polyhedron([], [other - self - 1]) + @_polymorphic + def __le__(self, other): + from .polyhedra import Polyhedron + return Polyhedron([], [other - self]) + + @_polymorphic def __ge__(self, other): - from .polyhedra import Ge - return Ge(self, other) + from .polyhedra import Polyhedron + return Polyhedron([], [self - other]) + @_polymorphic def __gt__(self, other): - from .polyhedra import Gt - return Gt(self, other) + from .polyhedra import Polyhedron + return Polyhedron([], [self - other - 1]) def scaleint(self): """ Return the expression multiplied by its lowest common denominator to make all values integer. """ - lcm = functools.reduce(lambda a, b: a*b // gcd(a, b), - [value.denominator for value in self.values()]) - return self * lcm + lcd = functools.reduce(lambda a, b: a*b // gcd(a, b), + [value.denominator for value in self.values()]) + return self * lcd def subs(self, symbol, expression=None): """ @@ -295,21 +305,16 @@ class LinExpr: 2*x + y + 1 """ if expression is None: - if isinstance(symbol, Mapping): - symbol = symbol.items() - substitutions = symbol + substitutions = dict(symbol) else: - substitutions = [(symbol, expression)] - result = self - for symbol, expression in substitutions: + substitutions = {symbol: expression} + for symbol in substitutions: if not isinstance(symbol, Symbol): raise TypeError('symbols must be Symbol instances') - coefficients = [(othersymbol, coefficient) - for othersymbol, coefficient in result._coefficients.items() - if othersymbol != symbol] - coefficient = result._coefficients.get(symbol, 0) - constant = result._constant - result = LinExpr(coefficients, constant) + coefficient*expression + result = Rational(self._constant) + for symbol, coefficient in self._coefficients.items(): + expression = substitutions.get(symbol, symbol) + result += coefficient * expression return result @classmethod @@ -337,7 +342,7 @@ class LinExpr: return left / right raise SyntaxError('invalid syntax') - _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()') + _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d]\w*|\()') @classmethod def fromstring(cls, string): @@ -345,10 +350,13 @@ class LinExpr: Create an expression from a string. Raise SyntaxError if the string is not properly formatted. """ - # add implicit multiplication operators, e.g. '5x' -> '5*x' + # Add implicit multiplication operators, e.g. '5x' -> '5*x'. string = LinExpr._RE_NUM_VAR.sub(r'\1*\2', string) tree = ast.parse(string, 'eval') - return cls._fromast(tree) + expression = cls._fromast(tree) + if not isinstance(expression, cls): + raise SyntaxError('invalid syntax') + return expression def __repr__(self): string = '' @@ -374,30 +382,6 @@ class LinExpr: string += ' - {}'.format(-constant) return string - def _repr_latex_(self): - string = '' - for i, (symbol, coefficient) in enumerate(self.coefficients()): - if coefficient == 1: - if i != 0: - string += ' + ' - elif coefficient == -1: - string += '-' if i == 0 else ' - ' - elif i == 0: - string += '{}'.format(coefficient._repr_latex_().strip('$')) - elif coefficient > 0: - string += ' + {}'.format(coefficient._repr_latex_().strip('$')) - elif coefficient < 0: - string += ' - {}'.format((-coefficient)._repr_latex_().strip('$')) - string += '{}'.format(symbol._repr_latex_().strip('$')) - constant = self.constant - if len(string) == 0: - string += '{}'.format(constant._repr_latex_().strip('$')) - elif constant > 0: - string += ' + {}'.format(constant._repr_latex_().strip('$')) - elif constant < 0: - string += ' - {}'.format((-constant)._repr_latex_().strip('$')) - return '$${}$$'.format(string) - def _parenstr(self, always=False): string = str(self) if not always and (self.isconstant() or self.issymbol()): @@ -406,36 +390,45 @@ class LinExpr: return '({})'.format(string) @classmethod - def fromsympy(cls, expr): + def fromsympy(cls, expression): """ - Create a linear expression from a sympy expression. Raise ValueError is + Create a linear expression from a SymPy expression. Raise TypeError is the sympy expression is not linear. """ import sympy coefficients = [] constant = 0 - for symbol, coefficient in expr.as_coefficients_dict().items(): + for symbol, coefficient in expression.as_coefficients_dict().items(): coefficient = Fraction(coefficient.p, coefficient.q) if symbol == sympy.S.One: constant = coefficient + elif isinstance(symbol, sympy.Dummy): + # We cannot properly convert dummy symbols with respect to + # symbol equalities. + raise TypeError('cannot convert dummy symbols') elif isinstance(symbol, sympy.Symbol): symbol = Symbol(symbol.name) coefficients.append((symbol, coefficient)) else: - raise ValueError('non-linear expression: {!r}'.format(expr)) - return LinExpr(coefficients, constant) + raise TypeError('non-linear expression: {!r}'.format( + expression)) + expression = LinExpr(coefficients, constant) + if not isinstance(expression, cls): + raise TypeError('cannot convert to a {} instance'.format( + cls.__name__)) + return expression def tosympy(self): """ - Convert the linear expression to a sympy expression. + Convert the linear expression to a SymPy expression. """ import sympy - expr = 0 + expression = 0 for symbol, coefficient in self.coefficients(): term = coefficient * sympy.Symbol(symbol.name) - expr += term - expr += self.constant - return expr + expression += term + expression += self.constant + return expression class Symbol(LinExpr): @@ -447,20 +440,37 @@ class Symbol(LinExpr): Two instances of Symbol are equal if they have the same name. """ + __slots__ = ( + '_name', + '_constant', + '_symbols', + '_dimension', + ) + def __new__(cls, name): """ Return a symbol with the name string given in argument. """ if not isinstance(name, str): raise TypeError('name must be a string') + node = ast.parse(name) + try: + name = node.body[0].value.id + except (AttributeError, SyntaxError): + raise SyntaxError('invalid syntax') self = object().__new__(cls) - self._name = name.strip() - self._coefficients = {self: Fraction(1)} + self._name = name self._constant = Fraction(0) self._symbols = (self,) self._dimension = 1 return self + @property + def _coefficients(self): + # This is not implemented as an attribute, because __hash__ is not + # callable in __new__ in class Dummy. + return {self: Fraction(1)} + @property def name(self): """ @@ -485,7 +495,9 @@ class Symbol(LinExpr): return True def __eq__(self, other): - return self.sortkey() == other.sortkey() + if isinstance(other, Symbol): + return self.sortkey() == other.sortkey() + return NotImplemented def asdummy(self): """ @@ -493,31 +505,23 @@ class Symbol(LinExpr): """ return Dummy(self.name) - @classmethod - def _fromast(cls, node): - if isinstance(node, ast.Module) and len(node.body) == 1: - return cls._fromast(node.body[0]) - elif isinstance(node, ast.Expr): - return cls._fromast(node.value) - elif isinstance(node, ast.Name): - return Symbol(node.id) - raise SyntaxError('invalid syntax') - def __repr__(self): return self.name - def _repr_latex_(self): - return '$${}$$'.format(self.name) - @classmethod - def fromsympy(cls, expr): - import sympy - if isinstance(expr, sympy.Dummy): - return Dummy(expr.name) - elif isinstance(expr, sympy.Symbol): - return Symbol(expr.name) - else: - raise TypeError('expr must be a sympy.Symbol instance') +def symbols(names): + """ + This function returns a tuple of symbols whose names are taken from a comma + or whitespace delimited string, or a sequence of strings. It is useful to + define several symbols at once. + + >>> x, y = symbols('x y') + >>> x, y = symbols('x, y') + >>> x, y = symbols(['x', 'y']) + """ + if isinstance(names, str): + names = names.replace(',', ' ').split() + return tuple(Symbol(name) for name in names) class Dummy(Symbol): @@ -548,15 +552,8 @@ class Dummy(Symbol): """ if name is None: name = 'Dummy_{}'.format(Dummy._count) - elif not isinstance(name, str): - raise TypeError('name must be a string') - self = object().__new__(cls) + self = super().__new__(cls, name) self._index = Dummy._count - self._name = name.strip() - self._coefficients = {self: Fraction(1)} - self._constant = Fraction(0) - self._symbols = (self,) - self._dimension = 1 Dummy._count += 1 return self @@ -569,24 +566,6 @@ class Dummy(Symbol): def __repr__(self): return '_{}'.format(self.name) - def _repr_latex_(self): - return '$${}_{{{}}}$$'.format(self.name, self._index) - - -def symbols(names): - """ - This function returns a tuple of symbols whose names are taken from a comma - or whitespace delimited string, or a sequence of strings. It is useful to - define several symbols at once. - - >>> x, y = symbols('x y') - >>> x, y = symbols('x, y') - >>> x, y = symbols(['x', 'y']) - """ - if isinstance(names, str): - names = names.replace(',', ' ').split() - return tuple(Symbol(name) for name in names) - class Rational(LinExpr, Fraction): """ @@ -596,6 +575,13 @@ class Rational(LinExpr, Fraction): fractions.Fraction classes. """ + __slots__ = ( + '_coefficients', + '_constant', + '_symbols', + '_dimension', + ) + Fraction.__slots__ + def __new__(cls, numerator=0, denominator=None): self = object().__new__(cls) self._coefficients = {} @@ -624,23 +610,3 @@ class Rational(LinExpr, Fraction): return '{!r}'.format(self.numerator) else: return '{!r}/{!r}'.format(self.numerator, self.denominator) - - def _repr_latex_(self): - if self.denominator == 1: - return '$${}$$'.format(self.numerator) - elif self.numerator < 0: - return '$$-\\frac{{{}}}{{{}}}$$'.format(-self.numerator, - self.denominator) - else: - return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator, - self.denominator) - - @classmethod - def fromsympy(cls, expr): - import sympy - if isinstance(expr, sympy.Rational): - return Rational(expr.p, expr.q) - elif isinstance(expr, numbers.Rational): - return Rational(expr) - else: - raise TypeError('expr must be a sympy.Rational instance')