Remove empty lines at top of files
[linpy.git] / pypol / linexprs.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from collections import OrderedDict, defaultdict, Mapping
7 from fractions import Fraction, gcd
8
9
10 __all__ = [
11 'Expression',
12 'Symbol', 'Dummy', 'symbols',
13 'Rational',
14 ]
15
16
17 def _polymorphic(func):
18 @functools.wraps(func)
19 def wrapper(left, right):
20 if isinstance(right, Expression):
21 return func(left, right)
22 elif isinstance(right, numbers.Rational):
23 right = Rational(right)
24 return func(left, right)
25 return NotImplemented
26 return wrapper
27
28
29 class Expression:
30 """
31 This class implements linear expressions.
32 """
33
34 __slots__ = (
35 '_coefficients',
36 '_constant',
37 '_symbols',
38 '_dimension',
39 )
40
41 def __new__(cls, coefficients=None, constant=0):
42 if isinstance(coefficients, str):
43 if constant:
44 raise TypeError('too many arguments')
45 return Expression.fromstring(coefficients)
46 if coefficients is None:
47 return Rational(constant)
48 if isinstance(coefficients, Mapping):
49 coefficients = coefficients.items()
50 for symbol, coefficient in coefficients:
51 if not isinstance(symbol, Symbol):
52 raise TypeError('symbols must be Symbol instances')
53 coefficients = [(symbol, coefficient)
54 for symbol, coefficient in coefficients if coefficient != 0]
55 if len(coefficients) == 0:
56 return Rational(constant)
57 if len(coefficients) == 1 and constant == 0:
58 symbol, coefficient = coefficients[0]
59 if coefficient == 1:
60 return symbol
61 self = object().__new__(cls)
62 self._coefficients = OrderedDict()
63 for symbol, coefficient in sorted(coefficients,
64 key=lambda item: item[0].sortkey()):
65 if isinstance(coefficient, Rational):
66 coefficient = coefficient.constant
67 if not isinstance(coefficient, numbers.Rational):
68 raise TypeError('coefficients must be rational numbers '
69 'or Rational instances')
70 self._coefficients[symbol] = coefficient
71 if isinstance(constant, Rational):
72 constant = constant.constant
73 if not isinstance(constant, numbers.Rational):
74 raise TypeError('constant must be a rational number '
75 'or a Rational instance')
76 self._constant = constant
77 self._symbols = tuple(self._coefficients)
78 self._dimension = len(self._symbols)
79 return self
80
81 def coefficient(self, symbol):
82 if not isinstance(symbol, Symbol):
83 raise TypeError('symbol must be a Symbol instance')
84 try:
85 return self._coefficients[symbol]
86 except KeyError:
87 return 0
88
89 __getitem__ = coefficient
90
91 def coefficients(self):
92 yield from self._coefficients.items()
93
94 @property
95 def constant(self):
96 return self._constant
97
98 @property
99 def symbols(self):
100 return self._symbols
101
102 @property
103 def dimension(self):
104 return self._dimension
105
106 def __hash__(self):
107 return hash((tuple(self._coefficients.items()), self._constant))
108
109 def isconstant(self):
110 return False
111
112 def issymbol(self):
113 return False
114
115 def values(self):
116 yield from self._coefficients.values()
117 yield self.constant
118
119 def __bool__(self):
120 return True
121
122 def __pos__(self):
123 return self
124
125 def __neg__(self):
126 return self * -1
127
128 @_polymorphic
129 def __add__(self, other):
130 coefficients = defaultdict(Rational, self.coefficients())
131 for symbol, coefficient in other.coefficients():
132 coefficients[symbol] += coefficient
133 constant = self.constant + other.constant
134 return Expression(coefficients, constant)
135
136 __radd__ = __add__
137
138 @_polymorphic
139 def __sub__(self, other):
140 coefficients = defaultdict(Rational, self.coefficients())
141 for symbol, coefficient in other.coefficients():
142 coefficients[symbol] -= coefficient
143 constant = self.constant - other.constant
144 return Expression(coefficients, constant)
145
146 def __rsub__(self, other):
147 return -(self - other)
148
149 @_polymorphic
150 def __mul__(self, other):
151 if other.isconstant():
152 coefficients = dict(self.coefficients())
153 for symbol in coefficients:
154 coefficients[symbol] *= other.constant
155 constant = self.constant * other.constant
156 return Expression(coefficients, constant)
157 if isinstance(other, Expression) and not self.isconstant():
158 raise ValueError('non-linear expression: '
159 '{} * {}'.format(self._parenstr(), other._parenstr()))
160 return NotImplemented
161
162 __rmul__ = __mul__
163
164 @_polymorphic
165 def __truediv__(self, other):
166 if other.isconstant():
167 coefficients = dict(self.coefficients())
168 for symbol in coefficients:
169 coefficients[symbol] = Rational(coefficients[symbol], other.constant)
170 constant = Rational(self.constant, other.constant)
171 return Expression(coefficients, constant)
172 if isinstance(other, Expression):
173 raise ValueError('non-linear expression: '
174 '{} / {}'.format(self._parenstr(), other._parenstr()))
175 return NotImplemented
176
177 def __rtruediv__(self, other):
178 if isinstance(other, self):
179 if self.isconstant():
180 return Rational(other, self.constant)
181 else:
182 raise ValueError('non-linear expression: '
183 '{} / {}'.format(other._parenstr(), self._parenstr()))
184 return NotImplemented
185
186 @_polymorphic
187 def __eq__(self, other):
188 # "normal" equality
189 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
190 return isinstance(other, Expression) and \
191 self._coefficients == other._coefficients and \
192 self.constant == other.constant
193
194 @_polymorphic
195 def __le__(self, other):
196 from .polyhedra import Le
197 return Le(self, other)
198
199 @_polymorphic
200 def __lt__(self, other):
201 from .polyhedra import Lt
202 return Lt(self, other)
203
204 @_polymorphic
205 def __ge__(self, other):
206 from .polyhedra import Ge
207 return Ge(self, other)
208
209 @_polymorphic
210 def __gt__(self, other):
211 from .polyhedra import Gt
212 return Gt(self, other)
213
214 def scaleint(self):
215 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
216 [value.denominator for value in self.values()])
217 return self * lcm
218
219 def subs(self, symbol, expression=None):
220 if expression is None:
221 if isinstance(symbol, Mapping):
222 symbol = symbol.items()
223 substitutions = symbol
224 else:
225 substitutions = [(symbol, expression)]
226 result = self
227 for symbol, expression in substitutions:
228 coefficients = [(othersymbol, coefficient)
229 for othersymbol, coefficient in result.coefficients()
230 if othersymbol != symbol]
231 coefficient = result.coefficient(symbol)
232 constant = result.constant
233 result = Expression(coefficients, constant) + coefficient*expression
234 return result
235
236 @classmethod
237 def _fromast(cls, node):
238 if isinstance(node, ast.Module) and len(node.body) == 1:
239 return cls._fromast(node.body[0])
240 elif isinstance(node, ast.Expr):
241 return cls._fromast(node.value)
242 elif isinstance(node, ast.Name):
243 return Symbol(node.id)
244 elif isinstance(node, ast.Num):
245 return Rational(node.n)
246 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
247 return -cls._fromast(node.operand)
248 elif isinstance(node, ast.BinOp):
249 left = cls._fromast(node.left)
250 right = cls._fromast(node.right)
251 if isinstance(node.op, ast.Add):
252 return left + right
253 elif isinstance(node.op, ast.Sub):
254 return left - right
255 elif isinstance(node.op, ast.Mult):
256 return left * right
257 elif isinstance(node.op, ast.Div):
258 return left / right
259 raise SyntaxError('invalid syntax')
260
261 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
262
263 @classmethod
264 def fromstring(cls, string):
265 # add implicit multiplication operators, e.g. '5x' -> '5*x'
266 string = Expression._RE_NUM_VAR.sub(r'\1*\2', string)
267 tree = ast.parse(string, 'eval')
268 return cls._fromast(tree)
269
270 def __repr__(self):
271 string = ''
272 for i, (symbol, coefficient) in enumerate(self.coefficients()):
273 if coefficient == 1:
274 string += '' if i == 0 else ' + '
275 string += '{!r}'.format(symbol)
276 elif coefficient == -1:
277 string += '-' if i == 0 else ' - '
278 string += '{!r}'.format(symbol)
279 else:
280 if i == 0:
281 string += '{}*{!r}'.format(coefficient, symbol)
282 elif coefficient > 0:
283 string += ' + {}*{!r}'.format(coefficient, symbol)
284 else:
285 string += ' - {}*{!r}'.format(-coefficient, symbol)
286 constant = self.constant
287 if len(string) == 0:
288 string += '{}'.format(constant)
289 elif constant > 0:
290 string += ' + {}'.format(constant)
291 elif constant < 0:
292 string += ' - {}'.format(-constant)
293 return string
294
295 def _parenstr(self, always=False):
296 string = str(self)
297 if not always and (self.isconstant() or self.issymbol()):
298 return string
299 else:
300 return '({})'.format(string)
301
302 @classmethod
303 def fromsympy(cls, expr):
304 import sympy
305 coefficients = []
306 constant = 0
307 for symbol, coefficient in expr.as_coefficients_dict().items():
308 coefficient = Fraction(coefficient.p, coefficient.q)
309 if symbol == sympy.S.One:
310 constant = coefficient
311 elif isinstance(symbol, sympy.Symbol):
312 symbol = Symbol(symbol.name)
313 coefficients.append((symbol, coefficient))
314 else:
315 raise ValueError('non-linear expression: {!r}'.format(expr))
316 return Expression(coefficients, constant)
317
318 def tosympy(self):
319 import sympy
320 expr = 0
321 for symbol, coefficient in self.coefficients():
322 term = coefficient * sympy.Symbol(symbol.name)
323 expr += term
324 expr += self.constant
325 return expr
326
327
328 class Symbol(Expression):
329
330 __slots__ = (
331 '_name',
332 )
333
334 def __new__(cls, name):
335 if not isinstance(name, str):
336 raise TypeError('name must be a string')
337 self = object().__new__(cls)
338 self._name = name.strip()
339 return self
340
341 @property
342 def name(self):
343 return self._name
344
345 def __hash__(self):
346 return hash(self.sortkey())
347
348 def coefficient(self, symbol):
349 if not isinstance(symbol, Symbol):
350 raise TypeError('symbol must be a Symbol instance')
351 if symbol == self:
352 return 1
353 else:
354 return 0
355
356 def coefficients(self):
357 yield self, 1
358
359 @property
360 def constant(self):
361 return 0
362
363 @property
364 def symbols(self):
365 return self,
366
367 @property
368 def dimension(self):
369 return 1
370
371 def sortkey(self):
372 return self.name,
373
374 def issymbol(self):
375 return True
376
377 def values(self):
378 yield 1
379
380 def __eq__(self, other):
381 return not isinstance(other, Dummy) and isinstance(other, Symbol) \
382 and self.name == other.name
383
384 def asdummy(self):
385 return Dummy(self.name)
386
387 @classmethod
388 def _fromast(cls, node):
389 if isinstance(node, ast.Module) and len(node.body) == 1:
390 return cls._fromast(node.body[0])
391 elif isinstance(node, ast.Expr):
392 return cls._fromast(node.value)
393 elif isinstance(node, ast.Name):
394 return Symbol(node.id)
395 raise SyntaxError('invalid syntax')
396
397 def __repr__(self):
398 return self.name
399
400 @classmethod
401 def fromsympy(cls, expr):
402 import sympy
403 if isinstance(expr, sympy.Symbol):
404 return cls(expr.name)
405 else:
406 raise TypeError('expr must be a sympy.Symbol instance')
407
408
409 class Dummy(Symbol):
410
411 __slots__ = (
412 '_name',
413 '_index',
414 )
415
416 _count = 0
417
418 def __new__(cls, name=None):
419 if name is None:
420 name = 'Dummy_{}'.format(Dummy._count)
421 self = object().__new__(cls)
422 self._name = name.strip()
423 self._index = Dummy._count
424 Dummy._count += 1
425 return self
426
427 def __hash__(self):
428 return hash(self.sortkey())
429
430 def sortkey(self):
431 return self._name, self._index
432
433 def __eq__(self, other):
434 return isinstance(other, Dummy) and self._index == other._index
435
436 def __repr__(self):
437 return '_{}'.format(self.name)
438
439
440 def symbols(names):
441 if isinstance(names, str):
442 names = names.replace(',', ' ').split()
443 return tuple(Symbol(name) for name in names)
444
445
446 class Rational(Expression):
447
448 __slots__ = (
449 '_constant',
450 )
451
452 def __new__(cls, numerator=0, denominator=None):
453 self = object().__new__(cls)
454 if denominator is None and isinstance(numerator, Rational):
455 self._constant = numerator.constant
456 else:
457 self._constant = Fraction(numerator, denominator)
458 return self
459
460 def __hash__(self):
461 return hash(self.constant)
462
463 def coefficient(self, symbol):
464 if not isinstance(symbol, Symbol):
465 raise TypeError('symbol must be a Symbol instance')
466 return 0
467
468 def coefficients(self):
469 yield from ()
470
471 @property
472 def symbols(self):
473 return ()
474
475 @property
476 def dimension(self):
477 return 0
478
479 def isconstant(self):
480 return True
481
482 def values(self):
483 yield self._constant
484
485 @_polymorphic
486 def __eq__(self, other):
487 return isinstance(other, Rational) and self.constant == other.constant
488
489 def __bool__(self):
490 return self.constant != 0
491
492 @classmethod
493 def fromstring(cls, string):
494 if not isinstance(string, str):
495 raise TypeError('string must be a string instance')
496 return Rational(Fraction(string))
497
498 @classmethod
499 def fromsympy(cls, expr):
500 import sympy
501 if isinstance(expr, sympy.Rational):
502 return Rational(expr.p, expr.q)
503 elif isinstance(expr, numbers.Rational):
504 return Rational(expr)
505 else:
506 raise TypeError('expr must be a sympy.Rational instance')