b330045f8d6b86fabedb243311990b19ca03ef20
6 from collections
import OrderedDict
7 from fractions
import Fraction
, gcd
12 'Symbol', 'symbols', 'symbolname', 'symbolnames',
17 def _polymorphic(func
):
18 @functools.wraps(func
)
19 def wrapper(left
, right
):
20 if isinstance(right
, Expression
):
21 return func(left
, right
)
22 elif isinstance(right
, numbers
.Rational
):
23 right
= Constant(right
)
24 return func(left
, right
)
31 This class implements linear expressions.
42 def __new__(cls
, coefficients
=None, constant
=0):
43 if isinstance(coefficients
, str):
45 raise TypeError('too many arguments')
46 return cls
.fromstring(coefficients
)
47 if isinstance(coefficients
, dict):
48 coefficients
= coefficients
.items()
49 if coefficients
is None:
50 return Constant(constant
)
51 coefficients
= [(symbol
, coefficient
)
52 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
53 if len(coefficients
) == 0:
54 return Constant(constant
)
55 elif len(coefficients
) == 1 and constant
== 0:
56 symbol
, coefficient
= coefficients
[0]
59 self
= object().__new
__(cls
)
60 self
._coefficients
= {}
61 for symbol
, coefficient
in coefficients
:
62 symbol
= symbolname(symbol
)
63 if isinstance(coefficient
, Constant
):
64 coefficient
= coefficient
.constant
65 if not isinstance(coefficient
, numbers
.Rational
):
66 raise TypeError('coefficients must be rational numbers '
67 'or Constant instances')
68 self
._coefficients
[symbol
] = coefficient
69 self
._coefficients
= OrderedDict(sorted(self
._coefficients
.items()))
70 if isinstance(constant
, Constant
):
71 constant
= constant
.constant
72 if not isinstance(constant
, numbers
.Rational
):
73 raise TypeError('constant must be a rational number '
74 'or a Constant instance')
75 self
._constant
= constant
76 self
._symbols
= tuple(self
._coefficients
)
77 self
._dimension
= len(self
._symbols
)
78 self
._hash
= hash((tuple(self
._coefficients
.items()), self
._constant
))
81 def coefficient(self
, symbol
):
82 symbol
= symbolname(symbol
)
84 return self
._coefficients
[symbol
]
88 __getitem__
= coefficient
90 def coefficients(self
):
91 yield from self
._coefficients
.items()
103 return self
._dimension
108 def isconstant(self
):
115 for symbol
in self
.symbols
:
116 yield self
.coefficient(symbol
)
129 def __add__(self
, other
):
130 coefficients
= dict(self
.coefficients())
131 for symbol
, coefficient
in other
.coefficients():
132 if symbol
in coefficients
:
133 coefficients
[symbol
] += coefficient
135 coefficients
[symbol
] = coefficient
136 constant
= self
.constant
+ other
.constant
137 return Expression(coefficients
, constant
)
142 def __sub__(self
, other
):
143 coefficients
= dict(self
.coefficients())
144 for symbol
, coefficient
in other
.coefficients():
145 if symbol
in coefficients
:
146 coefficients
[symbol
] -= coefficient
148 coefficients
[symbol
] = -coefficient
149 constant
= self
.constant
- other
.constant
150 return Expression(coefficients
, constant
)
152 def __rsub__(self
, other
):
153 return -(self
- other
)
156 def __mul__(self
, other
):
157 if other
.isconstant():
158 coefficients
= dict(self
.coefficients())
159 for symbol
in coefficients
:
160 coefficients
[symbol
] *= other
.constant
161 constant
= self
.constant
* other
.constant
162 return Expression(coefficients
, constant
)
163 if isinstance(other
, Expression
) and not self
.isconstant():
164 raise ValueError('non-linear expression: '
165 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
166 return NotImplemented
171 def __truediv__(self
, other
):
172 if other
.isconstant():
173 coefficients
= dict(self
.coefficients())
174 for symbol
in coefficients
:
175 coefficients
[symbol
] = \
176 Fraction(coefficients
[symbol
], other
.constant
)
177 constant
= Fraction(self
.constant
, other
.constant
)
178 return Expression(coefficients
, constant
)
179 if isinstance(other
, Expression
):
180 raise ValueError('non-linear expression: '
181 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
182 return NotImplemented
184 def __rtruediv__(self
, other
):
185 if isinstance(other
, self
):
186 if self
.isconstant():
187 constant
= Fraction(other
, self
.constant
)
188 return Expression(constant
=constant
)
190 raise ValueError('non-linear expression: '
191 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
192 return NotImplemented
195 def __eq__(self
, other
):
197 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
198 return isinstance(other
, Expression
) and \
199 self
._coefficients
== other
._coefficients
and \
200 self
.constant
== other
.constant
203 def __le__(self
, other
):
204 from .polyhedra
import Le
205 return Le(self
, other
)
208 def __lt__(self
, other
):
209 from .polyhedra
import Lt
210 return Lt(self
, other
)
213 def __ge__(self
, other
):
214 from .polyhedra
import Ge
215 return Ge(self
, other
)
218 def __gt__(self
, other
):
219 from .polyhedra
import Gt
220 return Gt(self
, other
)
223 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
224 [value
.denominator
for value
in self
.values()])
228 def _fromast(cls
, node
):
229 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
230 return cls
._fromast
(node
.body
[0])
231 elif isinstance(node
, ast
.Expr
):
232 return cls
._fromast
(node
.value
)
233 elif isinstance(node
, ast
.Name
):
234 return Symbol(node
.id)
235 elif isinstance(node
, ast
.Num
):
236 return Constant(node
.n
)
237 elif isinstance(node
, ast
.UnaryOp
) and isinstance(node
.op
, ast
.USub
):
238 return -cls
._fromast
(node
.operand
)
239 elif isinstance(node
, ast
.BinOp
):
240 left
= cls
._fromast
(node
.left
)
241 right
= cls
._fromast
(node
.right
)
242 if isinstance(node
.op
, ast
.Add
):
244 elif isinstance(node
.op
, ast
.Sub
):
246 elif isinstance(node
.op
, ast
.Mult
):
248 elif isinstance(node
.op
, ast
.Div
):
250 raise SyntaxError('invalid syntax')
252 def subs(self
, symbol
, expression
=None):
253 if expression
is None:
254 if isinstance(symbol
, dict):
255 symbol
= symbol
.items()
256 substitutions
= symbol
258 substitutions
= [(symbol
, expression
)]
260 for symbol
, expression
in substitutions
:
261 symbol
= symbolname(symbol
)
262 result
= result
._subs
(symbol
, expression
)
265 def _subs(self
, symbol
, expression
):
266 coefficients
= {name
: coefficient
267 for name
, coefficient
in self
.coefficients()
269 constant
= self
.constant
270 coefficient
= self
.coefficient(symbol
)
271 result
= Expression(coefficients
, self
.constant
)
272 result
+= coefficient
* expression
275 _RE_NUM_VAR
= re
.compile(r
'(\d+|\))\s*([^\W\d_]\w*|\()')
278 def fromstring(cls
, string
):
279 # add implicit multiplication operators, e.g. '5x' -> '5*x'
280 string
= cls
._RE
_NUM
_VAR
.sub(r
'\1*\2', string
)
281 tree
= ast
.parse(string
, 'eval')
282 return cls
._fromast
(tree
)
287 for symbol
in self
.symbols
:
288 coefficient
= self
.coefficient(symbol
)
293 string
+= ' + {}'.format(symbol
)
294 elif coefficient
== -1:
296 string
+= '-{}'.format(symbol
)
298 string
+= ' - {}'.format(symbol
)
301 string
+= '{}*{}'.format(coefficient
, symbol
)
302 elif coefficient
> 0:
303 string
+= ' + {}*{}'.format(coefficient
, symbol
)
305 assert coefficient
< 0
307 string
+= ' - {}*{}'.format(coefficient
, symbol
)
309 constant
= self
.constant
310 if constant
!= 0 and i
== 0:
311 string
+= '{}'.format(constant
)
313 string
+= ' + {}'.format(constant
)
316 string
+= ' - {}'.format(constant
)
321 def _parenstr(self
, always
=False):
323 if not always
and (self
.isconstant() or self
.issymbol()):
326 return '({})'.format(string
)
329 return '{}({!r})'.format(self
.__class
__.__name
__, str(self
))
332 def fromsympy(cls
, expr
):
336 for symbol
, coefficient
in expr
.as_coefficients_dict().items():
337 coefficient
= Fraction(coefficient
.p
, coefficient
.q
)
338 if symbol
== sympy
.S
.One
:
339 constant
= coefficient
340 elif isinstance(symbol
, sympy
.Symbol
):
342 coefficients
[symbol
] = coefficient
344 raise ValueError('non-linear expression: {!r}'.format(expr
))
345 return cls(coefficients
, constant
)
350 for symbol
, coefficient
in self
.coefficients():
351 term
= coefficient
* sympy
.Symbol(symbol
)
353 expr
+= self
.constant
357 class Symbol(Expression
):
364 def __new__(cls
, name
):
365 name
= symbolname(name
)
366 self
= object().__new
__(cls
)
368 self
._hash
= hash(self
._name
)
378 def coefficient(self
, symbol
):
379 symbol
= symbolname(symbol
)
380 if symbol
== self
.name
:
385 def coefficients(self
):
403 def __eq__(self
, other
):
404 return isinstance(other
, Symbol
) and self
.name
== other
.name
407 def _fromast(cls
, node
):
408 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
409 return cls
._fromast
(node
.body
[0])
410 elif isinstance(node
, ast
.Expr
):
411 return cls
._fromast
(node
.value
)
412 elif isinstance(node
, ast
.Name
):
413 return Symbol(node
.id)
414 raise SyntaxError('invalid syntax')
417 return '{}({!r})'.format(self
.__class
__.__name
__, self
._name
)
420 def fromsympy(cls
, expr
):
422 if isinstance(expr
, sympy
.Symbol
):
423 return cls(expr
.name
)
425 raise TypeError('expr must be a sympy.Symbol instance')
429 if isinstance(names
, str):
430 names
= names
.replace(',', ' ').split()
431 return (Symbol(name
) for name
in names
)
433 def symbolname(symbol
):
434 if isinstance(symbol
, str):
435 return symbol
.strip()
436 elif isinstance(symbol
, Symbol
):
439 raise TypeError('symbol must be a string or a Symbol instance')
441 def symbolnames(symbols
):
442 if isinstance(symbols
, str):
443 return symbols
.replace(',', ' ').split()
444 return (symbolname(symbol
) for symbol
in symbols
)
447 class Constant(Expression
):
454 def __new__(cls
, numerator
=0, denominator
=None):
455 self
= object().__new
__(cls
)
456 if denominator
is None and isinstance(numerator
, Constant
):
457 self
._constant
= numerator
.constant
459 self
._constant
= Fraction(numerator
, denominator
)
460 self
._hash
= hash(self
._constant
)
466 def coefficient(self
, symbol
):
467 symbol
= symbolname(symbol
)
470 def coefficients(self
):
481 def isconstant(self
):
485 def __eq__(self
, other
):
486 return isinstance(other
, Constant
) and self
.constant
== other
.constant
489 return self
.constant
!= 0
492 def fromstring(cls
, string
):
493 if isinstance(string
, str):
494 return Constant(Fraction(string
))
496 raise TypeError('string must be a string instance')
499 if self
.constant
.denominator
== 1:
500 return '{}({!r})'.format(self
.__class
__.__name
__,
501 self
.constant
.numerator
)
503 return '{}({!r}, {!r})'.format(self
.__class
__.__name
__,
504 self
.constant
.numerator
, self
.constant
.denominator
)
507 def fromsympy(cls
, expr
):
509 if isinstance(expr
, sympy
.Rational
):
510 return cls(expr
.p
, expr
.q
)
511 elif isinstance(expr
, numbers
.Rational
):
514 raise TypeError('expr must be a sympy.Rational instance')