Cleaner and faster linear expressions
[linpy.git] / pypol / linexprs.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from collections import OrderedDict, defaultdict
7 from fractions import Fraction, gcd
8
9
10 __all__ = [
11 'Expression',
12 'Symbol', 'symbols',
13 'Constant',
14 ]
15
16
17 def _polymorphic(func):
18 @functools.wraps(func)
19 def wrapper(left, right):
20 if isinstance(right, Expression):
21 return func(left, right)
22 elif isinstance(right, numbers.Rational):
23 right = Constant(right)
24 return func(left, right)
25 return NotImplemented
26 return wrapper
27
28
29 class Expression:
30 """
31 This class implements linear expressions.
32 """
33
34 __slots__ = (
35 '_coefficients',
36 '_constant',
37 '_symbols',
38 '_dimension',
39 )
40
41 def __new__(cls, coefficients=None, constant=0):
42 if isinstance(coefficients, str):
43 if constant:
44 raise TypeError('too many arguments')
45 return Expression.fromstring(coefficients)
46 if coefficients is None:
47 return Constant(constant)
48 if isinstance(coefficients, dict):
49 coefficients = coefficients.items()
50 for symbol, coefficient in coefficients:
51 if not isinstance(symbol, Symbol):
52 raise TypeError('symbols must be Symbol instances')
53 coefficients = [(symbol, coefficient)
54 for symbol, coefficient in coefficients if coefficient != 0]
55 if len(coefficients) == 0:
56 return Constant(constant)
57 if len(coefficients) == 1 and constant == 0:
58 symbol, coefficient = coefficients[0]
59 if coefficient == 1:
60 return symbol
61 self = object().__new__(cls)
62 self._coefficients = OrderedDict()
63 for symbol, coefficient in sorted(coefficients,
64 key=lambda item: item[0].name):
65 if isinstance(coefficient, Constant):
66 coefficient = coefficient.constant
67 if not isinstance(coefficient, numbers.Rational):
68 raise TypeError('coefficients must be rational numbers '
69 'or Constant instances')
70 self._coefficients[symbol] = coefficient
71 if isinstance(constant, Constant):
72 constant = constant.constant
73 if not isinstance(constant, numbers.Rational):
74 raise TypeError('constant must be a rational number '
75 'or a Constant instance')
76 self._constant = constant
77 self._symbols = tuple(self._coefficients)
78 self._dimension = len(self._symbols)
79 return self
80
81 def coefficient(self, symbol):
82 if not isinstance(symbol, Symbol):
83 raise TypeError('symbol must be a Symbol instance')
84 try:
85 return self._coefficients[symbol]
86 except KeyError:
87 return 0
88
89 __getitem__ = coefficient
90
91 def coefficients(self):
92 yield from self._coefficients.items()
93
94 @property
95 def constant(self):
96 return self._constant
97
98 @property
99 def symbols(self):
100 return self._symbols
101
102 @property
103 def dimension(self):
104 return self._dimension
105
106 def __hash__(self):
107 return hash((tuple(self._coefficients.items()), self._constant))
108
109 def isconstant(self):
110 return False
111
112 def issymbol(self):
113 return False
114
115 def values(self):
116 yield from self._coefficients.values()
117 yield self.constant
118
119 def __bool__(self):
120 return True
121
122 def __pos__(self):
123 return self
124
125 def __neg__(self):
126 return self * -1
127
128 @_polymorphic
129 def __add__(self, other):
130 coefficients = defaultdict(Constant, self.coefficients())
131 for symbol, coefficient in other.coefficients():
132 coefficients[symbol] += coefficient
133 constant = self.constant + other.constant
134 return Expression(coefficients, constant)
135
136 __radd__ = __add__
137
138 @_polymorphic
139 def __sub__(self, other):
140 coefficients = defaultdict(Constant, self.coefficients())
141 for symbol, coefficient in other.coefficients():
142 coefficients[symbol] -= coefficient
143 constant = self.constant - other.constant
144 return Expression(coefficients, constant)
145
146 def __rsub__(self, other):
147 return -(self - other)
148
149 @_polymorphic
150 def __mul__(self, other):
151 if other.isconstant():
152 coefficients = dict(self.coefficients())
153 for symbol in coefficients:
154 coefficients[symbol] *= other.constant
155 constant = self.constant * other.constant
156 return Expression(coefficients, constant)
157 if isinstance(other, Expression) and not self.isconstant():
158 raise ValueError('non-linear expression: '
159 '{} * {}'.format(self._parenstr(), other._parenstr()))
160 return NotImplemented
161
162 __rmul__ = __mul__
163
164 @_polymorphic
165 def __truediv__(self, other):
166 if other.isconstant():
167 coefficients = dict(self.coefficients())
168 for symbol in coefficients:
169 coefficients[symbol] = Constant(coefficients[symbol], other.constant)
170 constant = Constant(self.constant, other.constant)
171 return Expression(coefficients, constant)
172 if isinstance(other, Expression):
173 raise ValueError('non-linear expression: '
174 '{} / {}'.format(self._parenstr(), other._parenstr()))
175 return NotImplemented
176
177 def __rtruediv__(self, other):
178 if isinstance(other, self):
179 if self.isconstant():
180 return Constant(other, self.constant)
181 else:
182 raise ValueError('non-linear expression: '
183 '{} / {}'.format(other._parenstr(), self._parenstr()))
184 return NotImplemented
185
186 @_polymorphic
187 def __eq__(self, other):
188 # "normal" equality
189 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
190 return isinstance(other, Expression) and \
191 self._coefficients == other._coefficients and \
192 self.constant == other.constant
193
194 @_polymorphic
195 def __le__(self, other):
196 from .polyhedra import Le
197 return Le(self, other)
198
199 @_polymorphic
200 def __lt__(self, other):
201 from .polyhedra import Lt
202 return Lt(self, other)
203
204 @_polymorphic
205 def __ge__(self, other):
206 from .polyhedra import Ge
207 return Ge(self, other)
208
209 @_polymorphic
210 def __gt__(self, other):
211 from .polyhedra import Gt
212 return Gt(self, other)
213
214 def scaleint(self):
215 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
216 [value.denominator for value in self.values()])
217 return self * lcm
218
219 def subs(self, symbol, expression=None):
220 if expression is None:
221 if isinstance(symbol, dict):
222 symbol = symbol.items()
223 substitutions = symbol
224 else:
225 substitutions = [(symbol, expression)]
226 result = self
227 for symbol, expression in substitutions:
228 coefficients = [(othersymbol, coefficient)
229 for othersymbol, coefficient in result.coefficients()
230 if othersymbol != symbol]
231 coefficient = result.coefficient(symbol)
232 constant = result.constant
233 result = Expression(coefficients, constant) + coefficient*expression
234 return result
235
236 @classmethod
237 def _fromast(cls, node):
238 if isinstance(node, ast.Module) and len(node.body) == 1:
239 return cls._fromast(node.body[0])
240 elif isinstance(node, ast.Expr):
241 return cls._fromast(node.value)
242 elif isinstance(node, ast.Name):
243 return Symbol(node.id)
244 elif isinstance(node, ast.Num):
245 return Constant(node.n)
246 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
247 return -cls._fromast(node.operand)
248 elif isinstance(node, ast.BinOp):
249 left = cls._fromast(node.left)
250 right = cls._fromast(node.right)
251 if isinstance(node.op, ast.Add):
252 return left + right
253 elif isinstance(node.op, ast.Sub):
254 return left - right
255 elif isinstance(node.op, ast.Mult):
256 return left * right
257 elif isinstance(node.op, ast.Div):
258 return left / right
259 raise SyntaxError('invalid syntax')
260
261 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
262
263 @classmethod
264 def fromstring(cls, string):
265 # add implicit multiplication operators, e.g. '5x' -> '5*x'
266 string = Expression._RE_NUM_VAR.sub(r'\1*\2', string)
267 tree = ast.parse(string, 'eval')
268 return cls._fromast(tree)
269
270 def __repr__(self):
271 string = ''
272 i = 0
273 for symbol in self.symbols:
274 coefficient = self.coefficient(symbol)
275 if coefficient == 1:
276 if i == 0:
277 string += symbol.name
278 else:
279 string += ' + {}'.format(symbol)
280 elif coefficient == -1:
281 if i == 0:
282 string += '-{}'.format(symbol)
283 else:
284 string += ' - {}'.format(symbol)
285 else:
286 if i == 0:
287 string += '{}*{}'.format(coefficient, symbol)
288 elif coefficient > 0:
289 string += ' + {}*{}'.format(coefficient, symbol)
290 else:
291 assert coefficient < 0
292 coefficient *= -1
293 string += ' - {}*{}'.format(coefficient, symbol)
294 i += 1
295 constant = self.constant
296 if constant != 0 and i == 0:
297 string += '{}'.format(constant)
298 elif constant > 0:
299 string += ' + {}'.format(constant)
300 elif constant < 0:
301 constant *= -1
302 string += ' - {}'.format(constant)
303 if string == '':
304 string = '0'
305 return string
306
307 def _parenstr(self, always=False):
308 string = str(self)
309 if not always and (self.isconstant() or self.issymbol()):
310 return string
311 else:
312 return '({})'.format(string)
313
314 @classmethod
315 def fromsympy(cls, expr):
316 import sympy
317 coefficients = []
318 constant = 0
319 for symbol, coefficient in expr.as_coefficients_dict().items():
320 coefficient = Fraction(coefficient.p, coefficient.q)
321 if symbol == sympy.S.One:
322 constant = coefficient
323 elif isinstance(symbol, sympy.Symbol):
324 symbol = Symbol(symbol.name)
325 coefficients.append((symbol, coefficient))
326 else:
327 raise ValueError('non-linear expression: {!r}'.format(expr))
328 return Expression(coefficients, constant)
329
330 def tosympy(self):
331 import sympy
332 expr = 0
333 for symbol, coefficient in self.coefficients():
334 term = coefficient * sympy.Symbol(symbol.name)
335 expr += term
336 expr += self.constant
337 return expr
338
339
340 class Symbol(Expression):
341
342 __slots__ = (
343 '_name',
344 )
345
346 def __new__(cls, name):
347 if not isinstance(name, str):
348 raise TypeError('name must be a string')
349 self = object().__new__(cls)
350 self._name = name.strip()
351 return self
352
353 @property
354 def name(self):
355 return self._name
356
357 def __hash__(self):
358 return hash(self._name)
359
360 def coefficient(self, symbol):
361 if not isinstance(symbol, Symbol):
362 raise TypeError('symbol must be a Symbol instance')
363 if symbol == self:
364 return 1
365 else:
366 return 0
367
368 def coefficients(self):
369 yield self, 1
370
371 @property
372 def constant(self):
373 return 0
374
375 @property
376 def symbols(self):
377 return self,
378
379 @property
380 def dimension(self):
381 return 1
382
383 def issymbol(self):
384 return True
385
386 def values(self):
387 yield 1
388
389 def __eq__(self, other):
390 return isinstance(other, Symbol) and self.name == other.name
391
392 @classmethod
393 def _fromast(cls, node):
394 if isinstance(node, ast.Module) and len(node.body) == 1:
395 return cls._fromast(node.body[0])
396 elif isinstance(node, ast.Expr):
397 return cls._fromast(node.value)
398 elif isinstance(node, ast.Name):
399 return Symbol(node.id)
400 raise SyntaxError('invalid syntax')
401
402 @classmethod
403 def fromsympy(cls, expr):
404 import sympy
405 if isinstance(expr, sympy.Symbol):
406 return Symbol(expr.name)
407 else:
408 raise TypeError('expr must be a sympy.Symbol instance')
409
410
411 def symbols(names):
412 if isinstance(names, str):
413 names = names.replace(',', ' ').split()
414 return tuple(Symbol(name) for name in names)
415
416
417 class Constant(Expression):
418
419 __slots__ = (
420 '_constant',
421 )
422
423 def __new__(cls, numerator=0, denominator=None):
424 self = object().__new__(cls)
425 if denominator is None and isinstance(numerator, Constant):
426 self._constant = numerator.constant
427 else:
428 self._constant = Fraction(numerator, denominator)
429 return self
430
431 def __hash__(self):
432 return hash(self.constant)
433
434 def coefficient(self, symbol):
435 if not isinstance(symbol, Symbol):
436 raise TypeError('symbol must be a Symbol instance')
437 return 0
438
439 def coefficients(self):
440 yield from ()
441
442 @property
443 def symbols(self):
444 return ()
445
446 @property
447 def dimension(self):
448 return 0
449
450 def isconstant(self):
451 return True
452
453 def values(self):
454 yield self._constant
455
456 @_polymorphic
457 def __eq__(self, other):
458 return isinstance(other, Constant) and self.constant == other.constant
459
460 def __bool__(self):
461 return self.constant != 0
462
463 @classmethod
464 def fromstring(cls, string):
465 if not isinstance(string, str):
466 raise TypeError('string must be a string instance')
467 return Constant(Fraction(string))
468
469 @classmethod
470 def fromsympy(cls, expr):
471 import sympy
472 if isinstance(expr, sympy.Rational):
473 return Constant(expr.p, expr.q)
474 elif isinstance(expr, numbers.Rational):
475 return Constant(expr)
476 else:
477 raise TypeError('expr must be a sympy.Rational instance')