From: Vivien Maisonneuve Date: Sun, 13 Jul 2014 05:56:02 +0000 (+0200) Subject: Cleaner implementation of Rational X-Git-Tag: 1.0~114 X-Git-Url: https://scm.cri.mines-paristech.fr/git/linpy.git/commitdiff_plain/2bad3743bd25bbcfe12db50e2b18ab8d070f2354 Cleaner implementation of Rational --- diff --git a/pypol/linexprs.py b/pypol/linexprs.py index b23eea8..e73449e 100644 --- a/pypol/linexprs.py +++ b/pypol/linexprs.py @@ -31,16 +31,9 @@ class Expression: This class implements linear expressions. """ - __slots__ = ( - '_coefficients', - '_constant', - '_symbols', - '_dimension', - ) - def __new__(cls, coefficients=None, constant=0): if isinstance(coefficients, str): - if constant: + if constant != 0: raise TypeError('too many arguments') return Expression.fromstring(coefficients) if coefficients is None: @@ -50,8 +43,13 @@ class Expression: for symbol, coefficient in coefficients: if not isinstance(symbol, Symbol): raise TypeError('symbols must be Symbol instances') - coefficients = [(symbol, coefficient) + if not isinstance(coefficient, numbers.Rational): + raise TypeError('coefficients must be Rational instances') + coefficients = [(symbol, Fraction(coefficient)) for symbol, coefficient in coefficients if coefficient != 0] + if not isinstance(constant, numbers.Rational): + raise TypeError('constant must be a Rational instance') + constant = Fraction(constant) if len(coefficients) == 0: return Rational(constant) if len(coefficients) == 1 and constant == 0: @@ -59,18 +57,8 @@ class Expression: if coefficient == 1: return symbol self = object().__new__(cls) - self._coefficients = OrderedDict() - for symbol, coefficient in sorted(coefficients, - key=lambda item: item[0].sortkey()): - if isinstance(coefficient, Rational): - coefficient = coefficient.constant - if not isinstance(coefficient, numbers.Rational): - raise TypeError('coefficients must be Rational instances') - self._coefficients[symbol] = coefficient - if isinstance(constant, Rational): - constant = constant.constant - if not isinstance(constant, numbers.Rational): - raise TypeError('constant must be a Rational instance') + self._coefficients = OrderedDict(sorted(coefficients, + key=lambda item: item[0].sortkey())) self._constant = constant self._symbols = tuple(self._coefficients) self._dimension = len(self._symbols) @@ -80,18 +68,19 @@ class Expression: if not isinstance(symbol, Symbol): raise TypeError('symbol must be a Symbol instance') try: - return self._coefficients[symbol] + return Rational(self._coefficients[symbol]) except KeyError: - return 0 + return Rational(0) __getitem__ = coefficient def coefficients(self): - yield from self._coefficients.items() + for symbol, coefficient in self._coefficients.items(): + yield symbol, Rational(coefficient) @property def constant(self): - return self._constant + return Rational(self._constant) @property def symbols(self): @@ -111,8 +100,9 @@ class Expression: return False def values(self): - yield from self._coefficients.values() - yield self.constant + for coefficient in self._coefficients.values(): + yield Rational(coefficient) + yield Rational(self._constant) def __bool__(self): return True @@ -125,20 +115,20 @@ class Expression: @_polymorphic def __add__(self, other): - coefficients = defaultdict(Rational, self.coefficients()) - for symbol, coefficient in other.coefficients(): + coefficients = defaultdict(Fraction, self._coefficients) + for symbol, coefficient in other._coefficients.items(): coefficients[symbol] += coefficient - constant = self.constant + other.constant + constant = self._constant + other._constant return Expression(coefficients, constant) __radd__ = __add__ @_polymorphic def __sub__(self, other): - coefficients = defaultdict(Rational, self.coefficients()) - for symbol, coefficient in other.coefficients(): + coefficients = defaultdict(Fraction, self._coefficients) + for symbol, coefficient in other._coefficients.items(): coefficients[symbol] -= coefficient - constant = self.constant - other.constant + constant = self._constant - other._constant return Expression(coefficients, constant) def __rsub__(self, other): @@ -146,40 +136,19 @@ class Expression: @_polymorphic def __mul__(self, other): - if other.isconstant(): - coefficients = dict(self.coefficients()) - for symbol in coefficients: - coefficients[symbol] *= other.constant - constant = self.constant * other.constant - return Expression(coefficients, constant) - if isinstance(other, Expression) and not self.isconstant(): - raise ValueError('non-linear expression: ' - '{} * {}'.format(self._parenstr(), other._parenstr())) + if isinstance(other, Rational): + return other.__rmul__(self) return NotImplemented __rmul__ = __mul__ @_polymorphic def __truediv__(self, other): - if other.isconstant(): - coefficients = dict(self.coefficients()) - for symbol in coefficients: - coefficients[symbol] = Rational(coefficients[symbol], other.constant) - constant = Rational(self.constant, other.constant) - return Expression(coefficients, constant) - if isinstance(other, Expression): - raise ValueError('non-linear expression: ' - '{} / {}'.format(self._parenstr(), other._parenstr())) + if isinstance(other, Rational): + return other.__rtruediv__(self) return NotImplemented - def __rtruediv__(self, other): - if isinstance(other, self): - if self.isconstant(): - return Rational(other, self.constant) - else: - raise ValueError('non-linear expression: ' - '{} / {}'.format(other._parenstr(), self._parenstr())) - return NotImplemented + __rtruediv__ = __truediv__ @_polymorphic def __eq__(self, other): @@ -187,7 +156,7 @@ class Expression: # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs return isinstance(other, Expression) and \ self._coefficients == other._coefficients and \ - self.constant == other.constant + self._constant == other._constant @_polymorphic def __le__(self, other): @@ -223,11 +192,13 @@ class Expression: substitutions = [(symbol, expression)] result = self for symbol, expression in substitutions: + if not isinstance(symbol, Symbol): + raise TypeError('symbols must be Symbol instances') coefficients = [(othersymbol, coefficient) - for othersymbol, coefficient in result.coefficients() + for othersymbol, coefficient in result._coefficients.items() if othersymbol != symbol] - coefficient = result.coefficient(symbol) - constant = result.constant + coefficient = result._coefficients.get(symbol, 0) + constant = result._constant result = Expression(coefficients, constant) + coefficient*expression return result @@ -325,15 +296,15 @@ class Expression: class Symbol(Expression): - __slots__ = ( - '_name', - ) - def __new__(cls, name): if not isinstance(name, str): raise TypeError('name must be a string') self = object().__new__(cls) self._name = name.strip() + self._coefficients = {self: 1} + self._constant = 0 + self._symbols = (self,) + self._dimension = 1 return self @property @@ -343,38 +314,12 @@ class Symbol(Expression): def __hash__(self): return hash(self.sortkey()) - def coefficient(self, symbol): - if not isinstance(symbol, Symbol): - raise TypeError('symbol must be a Symbol instance') - if symbol == self: - return 1 - else: - return 0 - - def coefficients(self): - yield self, 1 - - @property - def constant(self): - return 0 - - @property - def symbols(self): - return self, - - @property - def dimension(self): - return 1 - def sortkey(self): return self.name, def issymbol(self): return True - def values(self): - yield 1 - def __eq__(self, other): return not isinstance(other, Dummy) and isinstance(other, Symbol) \ and self.name == other.name @@ -406,19 +351,18 @@ class Symbol(Expression): class Dummy(Symbol): - __slots__ = ( - '_name', - '_index', - ) - _count = 0 def __new__(cls, name=None): if name is None: name = 'Dummy_{}'.format(Dummy._count) self = object().__new__(cls) - self._name = name.strip() self._index = Dummy._count + self._name = name.strip() + self._coefficients = {self: 1} + self._constant = 0 + self._symbols = (self,) + self._dimension = 1 Dummy._count += 1 return self @@ -441,51 +385,46 @@ def symbols(names): return tuple(Symbol(name) for name in names) -class Rational(Expression): - - __slots__ = ( - '_constant', - ) +class Rational(Expression, Fraction): def __new__(cls, numerator=0, denominator=None): - self = object().__new__(cls) - if denominator is None and isinstance(numerator, Rational): - self._constant = numerator.constant - else: - self._constant = Fraction(numerator, denominator) + self = Fraction.__new__(cls, numerator, denominator) + self._coefficients = {} + self._constant = Fraction(self) + self._symbols = () + self._dimension = 0 return self def __hash__(self): - return hash(self.constant) - - def coefficient(self, symbol): - if not isinstance(symbol, Symbol): - raise TypeError('symbol must be a Symbol instance') - return 0 - - def coefficients(self): - yield from () + return Fraction.__hash__(self) @property - def symbols(self): - return () - - @property - def dimension(self): - return 0 + def constant(self): + return self def isconstant(self): return True - def values(self): - yield self._constant + def __bool__(self): + return Fraction.__bool__(self) @_polymorphic - def __eq__(self, other): - return isinstance(other, Rational) and self.constant == other.constant + def __mul__(self, other): + coefficients = dict(other._coefficients) + for symbol in coefficients: + coefficients[symbol] *= self._constant + constant = other._constant * self._constant + return Expression(coefficients, constant) - def __bool__(self): - return self.constant != 0 + __rmul__ = __mul__ + + @_polymorphic + def __rtruediv__(self, other): + coefficients = dict(other._coefficients) + for symbol in coefficients: + coefficients[symbol] /= self._constant + constant = other._constant / self._constant + return Expression(coefficients, constant) @classmethod def fromstring(cls, string): diff --git a/pypol/polyhedra.py b/pypol/polyhedra.py index 69ed2b2..e745d7d 100644 --- a/pypol/polyhedra.py +++ b/pypol/polyhedra.py @@ -311,16 +311,18 @@ class Polyhedron(Domain): def _polymorphic(func): @functools.wraps(func) def wrapper(left, right): - if isinstance(left, numbers.Rational): - left = Rational(left) - elif not isinstance(left, Expression): - raise TypeError('left must be a a rational number ' - 'or a linear expression') - if isinstance(right, numbers.Rational): - right = Rational(right) - elif not isinstance(right, Expression): - raise TypeError('right must be a a rational number ' - 'or a linear expression') + if not isinstance(left, Expression): + if isinstance(left, numbers.Rational): + left = Rational(left) + else: + raise TypeError('left must be a a rational number ' + 'or a linear expression') + if not isinstance(right, Expression): + if isinstance(right, numbers.Rational): + right = Rational(right) + else: + raise TypeError('right must be a a rational number ' + 'or a linear expression') return func(left, right) return wrapper diff --git a/pypol/tests/test_linexprs.py b/pypol/tests/test_linexprs.py index c55c842..6ec8993 100644 --- a/pypol/tests/test_linexprs.py +++ b/pypol/tests/test_linexprs.py @@ -117,11 +117,15 @@ class TestExpression(unittest.TestCase): self.assertEqual(self.expr * 0, 0) self.assertEqual(0 * self.expr, 0) self.assertEqual(self.expr * 2, 2*self.x - 4*self.y + 6) + with self.assertRaises(TypeError): + self.x * self.x def test_truediv(self): with self.assertRaises(ZeroDivisionError): self.expr / 0 self.assertEqual(self.expr / 2, self.x / 2 - self.y + Fraction(3, 2)) + with self.assertRaises(TypeError): + self.x / self.x def test_eq(self): self.assertEqual(self.expr, self.expr) @@ -279,6 +283,10 @@ class TestRational(unittest.TestCase): self.assertEqual(Rational(self.pi), self.pi) self.assertEqual(Rational('22/7'), self.pi) + def test_hash(self): + self.assertEqual(hash(self.one), hash(1)) + self.assertEqual(hash(self.pi), hash(Fraction(22, 7))) + def test_isconstant(self): self.assertTrue(self.zero.isconstant())