3aef3373fa0263a55fedb8da2416a4c0691edb94
6 from collections
import OrderedDict
7 from fractions
import Fraction
, gcd
17 def _polymorphic(func
):
18 @functools.wraps(func
)
19 def wrapper(left
, right
):
20 if isinstance(right
, Expression
):
21 return func(left
, right
)
22 elif isinstance(right
, numbers
.Rational
):
23 right
= Constant(right
)
24 return func(left
, right
)
31 This class implements linear expressions.
42 def __new__(cls
, coefficients
=None, constant
=0):
43 if isinstance(coefficients
, str):
45 raise TypeError('too many arguments')
46 return cls
.fromstring(coefficients
)
47 if isinstance(coefficients
, dict):
48 coefficients
= coefficients
.items()
49 if coefficients
is None:
50 return Constant(constant
)
51 coefficients
= [(symbol
, coefficient
)
52 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
53 if len(coefficients
) == 0:
54 return Constant(constant
)
55 elif len(coefficients
) == 1 and constant
== 0:
56 symbol
, coefficient
= coefficients
[0]
59 self
= object().__new
__(cls
)
60 self
._coefficients
= {}
61 for symbol
, coefficient
in coefficients
:
62 if isinstance(symbol
, Symbol
):
64 elif not isinstance(symbol
, str):
65 raise TypeError('symbols must be strings or Symbol instances')
66 if isinstance(coefficient
, Constant
):
67 coefficient
= coefficient
.constant
68 if not isinstance(coefficient
, numbers
.Rational
):
69 raise TypeError('coefficients must be rational numbers '
70 'or Constant instances')
71 self
._coefficients
[symbol
] = coefficient
72 self
._coefficients
= OrderedDict(sorted(self
._coefficients
.items()))
73 if isinstance(constant
, Constant
):
74 constant
= constant
.constant
75 if not isinstance(constant
, numbers
.Rational
):
76 raise TypeError('constant must be a rational number '
77 'or a Constant instance')
78 self
._constant
= constant
79 self
._symbols
= tuple(self
._coefficients
)
80 self
._dimension
= len(self
._symbols
)
81 self
._hash
= hash((tuple(self
._coefficients
.items()), self
._constant
))
84 def coefficient(self
, symbol
):
85 if isinstance(symbol
, Symbol
):
87 elif not isinstance(symbol
, str):
88 raise TypeError('symbol must be a string or a Symbol instance')
90 return self
._coefficients
[symbol
]
94 __getitem__
= coefficient
96 def coefficients(self
):
97 yield from self
._coefficients
.items()
101 return self
._constant
109 return self
._dimension
114 def isconstant(self
):
121 for symbol
in self
.symbols
:
122 yield self
.coefficient(symbol
)
135 def __add__(self
, other
):
136 coefficients
= dict(self
.coefficients())
137 for symbol
, coefficient
in other
.coefficients():
138 if symbol
in coefficients
:
139 coefficients
[symbol
] += coefficient
141 coefficients
[symbol
] = coefficient
142 constant
= self
.constant
+ other
.constant
143 return Expression(coefficients
, constant
)
148 def __sub__(self
, other
):
149 coefficients
= dict(self
.coefficients())
150 for symbol
, coefficient
in other
.coefficients():
151 if symbol
in coefficients
:
152 coefficients
[symbol
] -= coefficient
154 coefficients
[symbol
] = -coefficient
155 constant
= self
.constant
- other
.constant
156 return Expression(coefficients
, constant
)
158 def __rsub__(self
, other
):
159 return -(self
- other
)
162 def __mul__(self
, other
):
163 if other
.isconstant():
164 coefficients
= dict(self
.coefficients())
165 for symbol
in coefficients
:
166 coefficients
[symbol
] *= other
.constant
167 constant
= self
.constant
* other
.constant
168 return Expression(coefficients
, constant
)
169 if isinstance(other
, Expression
) and not self
.isconstant():
170 raise ValueError('non-linear expression: '
171 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
172 return NotImplemented
177 def __truediv__(self
, other
):
178 if other
.isconstant():
179 coefficients
= dict(self
.coefficients())
180 for symbol
in coefficients
:
181 coefficients
[symbol
] = \
182 Fraction(coefficients
[symbol
], other
.constant
)
183 constant
= Fraction(self
.constant
, other
.constant
)
184 return Expression(coefficients
, constant
)
185 if isinstance(other
, Expression
):
186 raise ValueError('non-linear expression: '
187 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
188 return NotImplemented
190 def __rtruediv__(self
, other
):
191 if isinstance(other
, self
):
192 if self
.isconstant():
193 constant
= Fraction(other
, self
.constant
)
194 return Expression(constant
=constant
)
196 raise ValueError('non-linear expression: '
197 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
198 return NotImplemented
201 def __eq__(self
, other
):
203 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
204 return isinstance(other
, Expression
) and \
205 self
._coefficients
== other
._coefficients
and \
206 self
.constant
== other
.constant
209 def __le__(self
, other
):
210 from .polyhedra
import Le
211 return Le(self
, other
)
214 def __lt__(self
, other
):
215 from .polyhedra
import Lt
216 return Lt(self
, other
)
219 def __ge__(self
, other
):
220 from .polyhedra
import Ge
221 return Ge(self
, other
)
224 def __gt__(self
, other
):
225 from .polyhedra
import Gt
226 return Gt(self
, other
)
229 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
230 [value
.denominator
for value
in self
.values()])
234 def _fromast(cls
, node
):
235 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
236 return cls
._fromast
(node
.body
[0])
237 elif isinstance(node
, ast
.Expr
):
238 return cls
._fromast
(node
.value
)
239 elif isinstance(node
, ast
.Name
):
240 return Symbol(node
.id)
241 elif isinstance(node
, ast
.Num
):
242 return Constant(node
.n
)
243 elif isinstance(node
, ast
.UnaryOp
) and isinstance(node
.op
, ast
.USub
):
244 return -cls
._fromast
(node
.operand
)
245 elif isinstance(node
, ast
.BinOp
):
246 left
= cls
._fromast
(node
.left
)
247 right
= cls
._fromast
(node
.right
)
248 if isinstance(node
.op
, ast
.Add
):
250 elif isinstance(node
.op
, ast
.Sub
):
252 elif isinstance(node
.op
, ast
.Mult
):
254 elif isinstance(node
.op
, ast
.Div
):
256 raise SyntaxError('invalid syntax')
258 _RE_NUM_VAR
= re
.compile(r
'(\d+|\))\s*([^\W\d_]\w*|\()')
261 def fromstring(cls
, string
):
262 # add implicit multiplication operators, e.g. '5x' -> '5*x'
263 string
= cls
._RE
_NUM
_VAR
.sub(r
'\1*\2', string
)
264 tree
= ast
.parse(string
, 'eval')
265 return cls
._fromast
(tree
)
270 for symbol
in self
.symbols
:
271 coefficient
= self
.coefficient(symbol
)
276 string
+= ' + {}'.format(symbol
)
277 elif coefficient
== -1:
279 string
+= '-{}'.format(symbol
)
281 string
+= ' - {}'.format(symbol
)
284 string
+= '{}*{}'.format(coefficient
, symbol
)
285 elif coefficient
> 0:
286 string
+= ' + {}*{}'.format(coefficient
, symbol
)
288 assert coefficient
< 0
290 string
+= ' - {}*{}'.format(coefficient
, symbol
)
292 constant
= self
.constant
293 if constant
!= 0 and i
== 0:
294 string
+= '{}'.format(constant
)
296 string
+= ' + {}'.format(constant
)
299 string
+= ' - {}'.format(constant
)
304 def _parenstr(self
, always
=False):
306 if not always
and (self
.isconstant() or self
.issymbol()):
309 return '({})'.format(string
)
312 return '{}({!r})'.format(self
.__class
__.__name
__, str(self
))
315 def fromsympy(cls
, expr
):
319 for symbol
, coefficient
in expr
.as_coefficients_dict().items():
320 coefficient
= Fraction(coefficient
.p
, coefficient
.q
)
321 if symbol
== sympy
.S
.One
:
322 constant
= coefficient
323 elif isinstance(symbol
, sympy
.Symbol
):
325 coefficients
[symbol
] = coefficient
327 raise ValueError('non-linear expression: {!r}'.format(expr
))
328 return cls(coefficients
, constant
)
333 for symbol
, coefficient
in self
.coefficients():
334 term
= coefficient
* sympy
.Symbol(symbol
)
336 expr
+= self
.constant
340 class Symbol(Expression
):
342 __slots__
= Expression
.__slots
__ + (
346 def __new__(cls
, name
):
347 if isinstance(name
, Symbol
):
349 elif not isinstance(name
, str):
350 raise TypeError('name must be a string or a Symbol instance')
352 self
= object().__new
__(cls
)
353 self
._coefficients
= OrderedDict([(name
, 1)])
355 self
._symbols
= tuple(name
)
358 self
._hash
= hash(self
._name
)
369 def _fromast(cls
, node
):
370 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
371 return cls
._fromast
(node
.body
[0])
372 elif isinstance(node
, ast
.Expr
):
373 return cls
._fromast
(node
.value
)
374 elif isinstance(node
, ast
.Name
):
375 return Symbol(node
.id)
376 raise SyntaxError('invalid syntax')
379 return '{}({!r})'.format(self
.__class
__.__name
__, self
._name
)
382 def fromsympy(cls
, expr
):
384 if isinstance(expr
, sympy
.Symbol
):
385 return cls(expr
.name
)
387 raise TypeError('expr must be a sympy.Symbol instance')
391 if isinstance(names
, str):
392 names
= names
.replace(',', ' ').split()
393 return (Symbol(name
) for name
in names
)
396 class Constant(Expression
):
398 def __new__(cls
, numerator
=0, denominator
=None):
399 self
= object().__new
__(cls
)
400 if denominator
is None and isinstance(numerator
, Constant
):
401 self
._constant
= numerator
.constant
403 self
._constant
= Fraction(numerator
, denominator
)
404 self
._coefficients
= OrderedDict()
407 self
._hash
= hash(self
._constant
)
410 def isconstant(self
):
414 return self
.constant
!= 0
417 def fromstring(cls
, string
):
418 if isinstance(string
, str):
419 return Constant(Fraction(string
))
421 raise TypeError('string must be a string instance')
424 if self
.constant
.denominator
== 1:
425 return '{}({!r})'.format(self
.__class
__.__name
__,
426 self
.constant
.numerator
)
428 return '{}({!r}, {!r})'.format(self
.__class
__.__name
__,
429 self
.constant
.numerator
, self
.constant
.denominator
)
432 def fromsympy(cls
, expr
):
434 if isinstance(expr
, sympy
.Rational
):
435 return cls(expr
.p
, expr
.q
)
436 elif isinstance(expr
, numbers
.Rational
):
439 raise TypeError('expr must be a sympy.Rational instance')