Improve hash functions in linexprs
[linpy.git] / pypol / linexprs.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from collections import OrderedDict
7 from fractions import Fraction, gcd
8
9
10 __all__ = [
11 'Expression',
12 'Symbol', 'symbols',
13 'Constant',
14 ]
15
16
17 def _polymorphic(func):
18 @functools.wraps(func)
19 def wrapper(left, right):
20 if isinstance(right, Expression):
21 return func(left, right)
22 elif isinstance(right, numbers.Rational):
23 right = Constant(right)
24 return func(left, right)
25 return NotImplemented
26 return wrapper
27
28
29 class Expression:
30 """
31 This class implements linear expressions.
32 """
33
34 __slots__ = (
35 '_coefficients',
36 '_constant',
37 '_symbols',
38 '_dimension',
39 '_hash',
40 )
41
42 def __new__(cls, coefficients=None, constant=0):
43 if isinstance(coefficients, str):
44 if constant:
45 raise TypeError('too many arguments')
46 return cls.fromstring(coefficients)
47 if isinstance(coefficients, dict):
48 coefficients = coefficients.items()
49 if coefficients is None:
50 return Constant(constant)
51 coefficients = [(symbol, coefficient)
52 for symbol, coefficient in coefficients if coefficient != 0]
53 if len(coefficients) == 0:
54 return Constant(constant)
55 elif len(coefficients) == 1 and constant == 0:
56 symbol, coefficient = coefficients[0]
57 if coefficient == 1:
58 return Symbol(symbol)
59 self = object().__new__(cls)
60 self._coefficients = {}
61 for symbol, coefficient in coefficients:
62 if isinstance(symbol, Symbol):
63 symbol = symbol.name
64 elif not isinstance(symbol, str):
65 raise TypeError('symbols must be strings or Symbol instances')
66 if isinstance(coefficient, Constant):
67 coefficient = coefficient.constant
68 if not isinstance(coefficient, numbers.Rational):
69 raise TypeError('coefficients must be rational numbers '
70 'or Constant instances')
71 self._coefficients[symbol] = coefficient
72 self._coefficients = OrderedDict(sorted(self._coefficients.items()))
73 if isinstance(constant, Constant):
74 constant = constant.constant
75 if not isinstance(constant, numbers.Rational):
76 raise TypeError('constant must be a rational number '
77 'or a Constant instance')
78 self._constant = constant
79 self._symbols = tuple(self._coefficients)
80 self._dimension = len(self._symbols)
81 self._hash = hash((tuple(self._coefficients.items()), self._constant))
82 return self
83
84 def coefficient(self, symbol):
85 if isinstance(symbol, Symbol):
86 symbol = str(symbol)
87 elif not isinstance(symbol, str):
88 raise TypeError('symbol must be a string or a Symbol instance')
89 try:
90 return self._coefficients[symbol]
91 except KeyError:
92 return 0
93
94 __getitem__ = coefficient
95
96 def coefficients(self):
97 yield from self._coefficients.items()
98
99 @property
100 def constant(self):
101 return self._constant
102
103 @property
104 def symbols(self):
105 return self._symbols
106
107 @property
108 def dimension(self):
109 return self._dimension
110
111 def __hash__(self):
112 return self._hash
113
114 def isconstant(self):
115 return False
116
117 def issymbol(self):
118 return False
119
120 def values(self):
121 for symbol in self.symbols:
122 yield self.coefficient(symbol)
123 yield self.constant
124
125 def __bool__(self):
126 return True
127
128 def __pos__(self):
129 return self
130
131 def __neg__(self):
132 return self * -1
133
134 @_polymorphic
135 def __add__(self, other):
136 coefficients = dict(self.coefficients())
137 for symbol, coefficient in other.coefficients():
138 if symbol in coefficients:
139 coefficients[symbol] += coefficient
140 else:
141 coefficients[symbol] = coefficient
142 constant = self.constant + other.constant
143 return Expression(coefficients, constant)
144
145 __radd__ = __add__
146
147 @_polymorphic
148 def __sub__(self, other):
149 coefficients = dict(self.coefficients())
150 for symbol, coefficient in other.coefficients():
151 if symbol in coefficients:
152 coefficients[symbol] -= coefficient
153 else:
154 coefficients[symbol] = -coefficient
155 constant = self.constant - other.constant
156 return Expression(coefficients, constant)
157
158 def __rsub__(self, other):
159 return -(self - other)
160
161 @_polymorphic
162 def __mul__(self, other):
163 if other.isconstant():
164 coefficients = dict(self.coefficients())
165 for symbol in coefficients:
166 coefficients[symbol] *= other.constant
167 constant = self.constant * other.constant
168 return Expression(coefficients, constant)
169 if isinstance(other, Expression) and not self.isconstant():
170 raise ValueError('non-linear expression: '
171 '{} * {}'.format(self._parenstr(), other._parenstr()))
172 return NotImplemented
173
174 __rmul__ = __mul__
175
176 @_polymorphic
177 def __truediv__(self, other):
178 if other.isconstant():
179 coefficients = dict(self.coefficients())
180 for symbol in coefficients:
181 coefficients[symbol] = \
182 Fraction(coefficients[symbol], other.constant)
183 constant = Fraction(self.constant, other.constant)
184 return Expression(coefficients, constant)
185 if isinstance(other, Expression):
186 raise ValueError('non-linear expression: '
187 '{} / {}'.format(self._parenstr(), other._parenstr()))
188 return NotImplemented
189
190 def __rtruediv__(self, other):
191 if isinstance(other, self):
192 if self.isconstant():
193 constant = Fraction(other, self.constant)
194 return Expression(constant=constant)
195 else:
196 raise ValueError('non-linear expression: '
197 '{} / {}'.format(other._parenstr(), self._parenstr()))
198 return NotImplemented
199
200 @_polymorphic
201 def __eq__(self, other):
202 # "normal" equality
203 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
204 return isinstance(other, Expression) and \
205 self._coefficients == other._coefficients and \
206 self.constant == other.constant
207
208 @_polymorphic
209 def __le__(self, other):
210 from .polyhedra import Le
211 return Le(self, other)
212
213 @_polymorphic
214 def __lt__(self, other):
215 from .polyhedra import Lt
216 return Lt(self, other)
217
218 @_polymorphic
219 def __ge__(self, other):
220 from .polyhedra import Ge
221 return Ge(self, other)
222
223 @_polymorphic
224 def __gt__(self, other):
225 from .polyhedra import Gt
226 return Gt(self, other)
227
228 def _toint(self):
229 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
230 [value.denominator for value in self.values()])
231 return self * lcm
232
233 @classmethod
234 def _fromast(cls, node):
235 if isinstance(node, ast.Module) and len(node.body) == 1:
236 return cls._fromast(node.body[0])
237 elif isinstance(node, ast.Expr):
238 return cls._fromast(node.value)
239 elif isinstance(node, ast.Name):
240 return Symbol(node.id)
241 elif isinstance(node, ast.Num):
242 return Constant(node.n)
243 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
244 return -cls._fromast(node.operand)
245 elif isinstance(node, ast.BinOp):
246 left = cls._fromast(node.left)
247 right = cls._fromast(node.right)
248 if isinstance(node.op, ast.Add):
249 return left + right
250 elif isinstance(node.op, ast.Sub):
251 return left - right
252 elif isinstance(node.op, ast.Mult):
253 return left * right
254 elif isinstance(node.op, ast.Div):
255 return left / right
256 raise SyntaxError('invalid syntax')
257
258 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
259
260 @classmethod
261 def fromstring(cls, string):
262 # add implicit multiplication operators, e.g. '5x' -> '5*x'
263 string = cls._RE_NUM_VAR.sub(r'\1*\2', string)
264 tree = ast.parse(string, 'eval')
265 return cls._fromast(tree)
266
267 def __str__(self):
268 string = ''
269 i = 0
270 for symbol in self.symbols:
271 coefficient = self.coefficient(symbol)
272 if coefficient == 1:
273 if i == 0:
274 string += symbol
275 else:
276 string += ' + {}'.format(symbol)
277 elif coefficient == -1:
278 if i == 0:
279 string += '-{}'.format(symbol)
280 else:
281 string += ' - {}'.format(symbol)
282 else:
283 if i == 0:
284 string += '{}*{}'.format(coefficient, symbol)
285 elif coefficient > 0:
286 string += ' + {}*{}'.format(coefficient, symbol)
287 else:
288 assert coefficient < 0
289 coefficient *= -1
290 string += ' - {}*{}'.format(coefficient, symbol)
291 i += 1
292 constant = self.constant
293 if constant != 0 and i == 0:
294 string += '{}'.format(constant)
295 elif constant > 0:
296 string += ' + {}'.format(constant)
297 elif constant < 0:
298 constant *= -1
299 string += ' - {}'.format(constant)
300 if string == '':
301 string = '0'
302 return string
303
304 def _parenstr(self, always=False):
305 string = str(self)
306 if not always and (self.isconstant() or self.issymbol()):
307 return string
308 else:
309 return '({})'.format(string)
310
311 def __repr__(self):
312 return '{}({!r})'.format(self.__class__.__name__, str(self))
313
314 @classmethod
315 def fromsympy(cls, expr):
316 import sympy
317 coefficients = {}
318 constant = 0
319 for symbol, coefficient in expr.as_coefficients_dict().items():
320 coefficient = Fraction(coefficient.p, coefficient.q)
321 if symbol == sympy.S.One:
322 constant = coefficient
323 elif isinstance(symbol, sympy.Symbol):
324 symbol = symbol.name
325 coefficients[symbol] = coefficient
326 else:
327 raise ValueError('non-linear expression: {!r}'.format(expr))
328 return cls(coefficients, constant)
329
330 def tosympy(self):
331 import sympy
332 expr = 0
333 for symbol, coefficient in self.coefficients():
334 term = coefficient * sympy.Symbol(symbol)
335 expr += term
336 expr += self.constant
337 return expr
338
339
340 class Symbol(Expression):
341
342 __slots__ = Expression.__slots__ + (
343 '_name',
344 )
345
346 def __new__(cls, name):
347 if isinstance(name, Symbol):
348 name = name.name
349 elif not isinstance(name, str):
350 raise TypeError('name must be a string or a Symbol instance')
351 name = name.strip()
352 self = object().__new__(cls)
353 self._coefficients = OrderedDict([(name, 1)])
354 self._constant = 0
355 self._symbols = tuple(name)
356 self._name = name
357 self._dimension = 1
358 self._hash = hash(self._name)
359 return self
360
361 @property
362 def name(self):
363 return self._name
364
365 def issymbol(self):
366 return True
367
368 @classmethod
369 def _fromast(cls, node):
370 if isinstance(node, ast.Module) and len(node.body) == 1:
371 return cls._fromast(node.body[0])
372 elif isinstance(node, ast.Expr):
373 return cls._fromast(node.value)
374 elif isinstance(node, ast.Name):
375 return Symbol(node.id)
376 raise SyntaxError('invalid syntax')
377
378 def __repr__(self):
379 return '{}({!r})'.format(self.__class__.__name__, self._name)
380
381 @classmethod
382 def fromsympy(cls, expr):
383 import sympy
384 if isinstance(expr, sympy.Symbol):
385 return cls(expr.name)
386 else:
387 raise TypeError('expr must be a sympy.Symbol instance')
388
389
390 def symbols(names):
391 if isinstance(names, str):
392 names = names.replace(',', ' ').split()
393 return (Symbol(name) for name in names)
394
395
396 class Constant(Expression):
397
398 def __new__(cls, numerator=0, denominator=None):
399 self = object().__new__(cls)
400 if denominator is None and isinstance(numerator, Constant):
401 self._constant = numerator.constant
402 else:
403 self._constant = Fraction(numerator, denominator)
404 self._coefficients = OrderedDict()
405 self._symbols = ()
406 self._dimension = 0
407 self._hash = hash(self._constant)
408 return self
409
410 def isconstant(self):
411 return True
412
413 def __bool__(self):
414 return self.constant != 0
415
416 @classmethod
417 def fromstring(cls, string):
418 if isinstance(string, str):
419 return Constant(Fraction(string))
420 else:
421 raise TypeError('string must be a string instance')
422
423 def __repr__(self):
424 if self.constant.denominator == 1:
425 return '{}({!r})'.format(self.__class__.__name__,
426 self.constant.numerator)
427 else:
428 return '{}({!r}, {!r})'.format(self.__class__.__name__,
429 self.constant.numerator, self.constant.denominator)
430
431 @classmethod
432 def fromsympy(cls, expr):
433 import sympy
434 if isinstance(expr, sympy.Rational):
435 return cls(expr.p, expr.q)
436 elif isinstance(expr, numbers.Rational):
437 return cls(expr)
438 else:
439 raise TypeError('expr must be a sympy.Rational instance')