5b5d8aa50ce4e1f6887b8cd7dd34a53b5bfd54f1
5 from fractions
import Fraction
, gcd
10 'constant', 'symbol', 'symbols',
11 'eq', 'le', 'lt', 'ge', 'gt',
19 This class implements linear expressions.
22 def __new__(cls
, coefficients
=None, constant
=0):
23 if isinstance(coefficients
, str):
25 raise TypeError('too many arguments')
26 return cls
.fromstring(coefficients
)
27 self
= super().__new
__(cls
)
28 self
._coefficients
= {}
29 if isinstance(coefficients
, dict):
30 coefficients
= coefficients
.items()
31 if coefficients
is not None:
32 for symbol
, coefficient
in coefficients
:
33 if isinstance(symbol
, Expression
) and symbol
.issymbol():
35 elif not isinstance(symbol
, str):
36 raise TypeError('symbols must be strings')
37 if not isinstance(coefficient
, numbers
.Rational
):
38 raise TypeError('coefficients must be rational numbers')
40 self
._coefficients
[symbol
] = coefficient
41 if not isinstance(constant
, numbers
.Rational
):
42 raise TypeError('constant must be a rational number')
43 self
._constant
= constant
47 yield from sorted(self
._coefficients
)
51 return len(list(self
.symbols()))
53 def coefficient(self
, symbol
):
54 if isinstance(symbol
, Expression
) and symbol
.issymbol():
56 elif not isinstance(symbol
, str):
57 raise TypeError('symbol must be a string')
59 return self
._coefficients
[symbol
]
63 __getitem__
= coefficient
65 def coefficients(self
):
66 for symbol
in self
.symbols():
67 yield symbol
, self
.coefficient(symbol
)
74 return len(self
._coefficients
) == 0
77 for symbol
in self
.symbols():
78 yield self
.coefficient(symbol
)
82 if not self
.issymbol():
83 raise ValueError('not a symbol: {}'.format(self
))
84 for symbol
in self
.symbols():
88 return len(self
._coefficients
) == 1 and self
._constant
== 0
91 return (not self
.isconstant()) or bool(self
.constant
)
99 def _polymorphic(func
):
100 @functools.wraps(func
)
101 def wrapper(self
, other
):
102 if isinstance(other
, Expression
):
103 return func(self
, other
)
104 if isinstance(other
, numbers
.Rational
):
105 other
= Expression(constant
=other
)
106 return func(self
, other
)
107 return NotImplemented
111 def __add__(self
, other
):
112 coefficients
= dict(self
.coefficients())
113 for symbol
, coefficient
in other
.coefficients():
114 if symbol
in coefficients
:
115 coefficients
[symbol
] += coefficient
117 coefficients
[symbol
] = coefficient
118 constant
= self
.constant
+ other
.constant
119 return Expression(coefficients
, constant
)
124 def __sub__(self
, other
):
125 coefficients
= dict(self
.coefficients())
126 for symbol
, coefficient
in other
.coefficients():
127 if symbol
in coefficients
:
128 coefficients
[symbol
] -= coefficient
130 coefficients
[symbol
] = -coefficient
131 constant
= self
.constant
- other
.constant
132 return Expression(coefficients
, constant
)
137 def __mul__(self
, other
):
138 if other
.isconstant():
139 coefficients
= dict(self
.coefficients())
140 for symbol
in coefficients
:
141 coefficients
[symbol
] *= other
.constant
142 constant
= self
.constant
* other
.constant
143 return Expression(coefficients
, constant
)
144 if isinstance(other
, Expression
) and not self
.isconstant():
145 raise ValueError('non-linear expression: '
146 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
147 return NotImplemented
152 def __truediv__(self
, other
):
153 if other
.isconstant():
154 coefficients
= dict(self
.coefficients())
155 for symbol
in coefficients
:
156 coefficients
[symbol
] = \
157 Fraction(coefficients
[symbol
], other
.constant
)
158 constant
= Fraction(self
.constant
, other
.constant
)
159 return Expression(coefficients
, constant
)
160 if isinstance(other
, Expression
):
161 raise ValueError('non-linear expression: '
162 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
163 return NotImplemented
165 def __rtruediv__(self
, other
):
166 if isinstance(other
, Rational
):
167 if self
.isconstant():
168 constant
= Fraction(other
, self
.constant
)
169 return Expression(constant
=constant
)
171 raise ValueError('non-linear expression: '
172 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
173 return NotImplemented
177 symbols
= sorted(self
.symbols())
179 for symbol
in symbols
:
180 coefficient
= self
[symbol
]
185 string
+= ' + {}'.format(symbol
)
186 elif coefficient
== -1:
188 string
+= '-{}'.format(symbol
)
190 string
+= ' - {}'.format(symbol
)
193 string
+= '{}*{}'.format(coefficient
, symbol
)
194 elif coefficient
> 0:
195 string
+= ' + {}*{}'.format(coefficient
, symbol
)
197 assert coefficient
< 0
199 string
+= ' - {}*{}'.format(coefficient
, symbol
)
201 constant
= self
.constant
202 if constant
!= 0 and i
== 0:
203 string
+= '{}'.format(constant
)
205 string
+= ' + {}'.format(constant
)
208 string
+= ' - {}'.format(constant
)
211 def _parenstr(self
, always
=False):
213 if not always
and (self
.isconstant() or self
.issymbol()):
216 return '({})'.format(string
)
219 string
= '{}({{'.format(self
.__class
__.__name
__)
220 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
223 string
+= '{!r}: {!r}'.format(symbol
, coefficient
)
224 string
+= '}}, {!r})'.format(self
.constant
)
228 def fromstring(cls
, string
):
229 raise NotImplementedError
232 def __eq__(self
, other
):
234 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
235 return isinstance(other
, Expression
) and \
236 self
._coefficients
== other
._coefficients
and \
237 self
.constant
== other
.constant
240 return hash((self
._coefficients
, self
._constant
))
243 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
244 [value
.denominator
for value
in self
.values()])
248 def _eq(self
, other
):
249 return Polyhedron(equalities
=[(self
- other
)._canonify
()])
252 def __le__(self
, other
):
253 return Polyhedron(inequalities
=[(self
- other
)._canonify
()])
256 def __lt__(self
, other
):
257 return Polyhedron(inequalities
=[(self
- other
)._canonify
() + 1])
260 def __ge__(self
, other
):
261 return Polyhedron(inequalities
=[(other
- self
)._canonify
()])
264 def __gt__(self
, other
):
265 return Polyhedron(inequalities
=[(other
- self
)._canonify
() + 1])
268 def constant(numerator
=0, denominator
=None):
269 return Expression(constant
=Fraction(numerator
, denominator
))
272 if not isinstance(name
, str):
273 raise TypeError('name must be a string')
274 return Expression(coefficients
={name
: 1})
277 if isinstance(names
, str):
278 names
= names
.replace(',', ' ').split()
279 return (symbol(name
) for name
in names
)
283 @functools.wraps(func
)
285 if isinstance(a
, numbers
.Rational
):
287 if isinstance(b
, numbers
.Rational
):
289 if isinstance(a
, Expression
) and isinstance(b
, Expression
):
291 raise TypeError('arguments must be linear expressions')
317 This class implements polyhedrons.
320 def __new__(cls
, equalities
=None, inequalities
=None):
321 if isinstance(equalities
, str):
322 if inequalities
is not None:
323 raise TypeError('too many arguments')
324 return cls
.fromstring(equalities
)
325 self
= super().__new
__(cls
)
326 self
._equalities
= []
327 if equalities
is not None:
328 for constraint
in equalities
:
329 for value
in constraint
.values():
330 if value
.denominator
!= 1:
331 raise TypeError('non-integer constraint: '
332 '{} == 0'.format(constraint
))
333 self
._equalities
.append(constraint
)
334 self
._inequalities
= []
335 if inequalities
is not None:
336 for constraint
in inequalities
:
337 for value
in constraint
.values():
338 if value
.denominator
!= 1:
339 raise TypeError('non-integer constraint: '
340 '{} <= 0'.format(constraint
))
341 self
._inequalities
.append(constraint
)
345 def equalities(self
):
346 yield from self
._equalities
349 def inequalities(self
):
350 yield from self
._inequalities
352 def constraints(self
):
353 yield from self
.equalities
354 yield from self
.inequalities
358 for constraint
in self
.constraints():
359 s
.update(constraint
.symbols
)
364 return len(self
.symbols())
367 # return false if the polyhedron is empty, true otherwise
368 raise NotImplementedError
370 def __contains__(self
, value
):
371 # is the value in the polyhedron?
372 raise NotImplementedError
374 def __eq__(self
, other
):
375 raise NotImplementedError
380 def isuniverse(self
):
381 return self
== universe
383 def isdisjoint(self
, other
):
384 # return true if the polyhedron has no elements in common with other
385 raise NotImplementedError
387 def issubset(self
, other
):
388 raise NotImplementedError
390 def __le__(self
, other
):
391 return self
.issubset(other
)
393 def __lt__(self
, other
):
394 raise NotImplementedError
396 def issuperset(self
, other
):
397 # test whether every element in other is in the polyhedron
398 raise NotImplementedError
400 def __ge__(self
, other
):
401 return self
.issuperset(other
)
403 def __gt__(self
, other
):
404 raise NotImplementedError
406 def union(self
, *others
):
407 # return a new polyhedron with elements from the polyhedron and all
408 # others (convex union)
409 raise NotImplementedError
411 def __or__(self
, other
):
412 return self
.union(other
)
414 def intersection(self
, *others
):
415 # return a new polyhedron with elements common to the polyhedron and all
417 # a poor man's implementation could be:
418 # equalities = list(self.equalities)
419 # inequalities = list(self.inequalities)
420 # for other in others:
421 # equalities.extend(other.equalities)
422 # inequalities.extend(other.inequalities)
423 # return self.__class__(equalities, inequalities)
424 raise NotImplementedError
426 def __and__(self
, other
):
427 return self
.intersection(other
)
429 def difference(self
, *others
):
430 # return a new polyhedron with elements in the polyhedron that are not
432 raise NotImplementedError
434 def __sub__(self
, other
):
435 return self
.difference(other
)
439 for constraint
in self
.equalities
:
440 constraints
.append('{} == 0'.format(constraint
))
441 for constraint
in self
.inequalities
:
442 constraints
.append('{} <= 0'.format(constraint
))
443 return '{{{}}}'.format(', '.join(constraints
))
446 equalities
= list(self
.equalities
)
447 inequalities
= list(self
.inequalities
)
448 return '{}(equalities={!r}, inequalities={!r})' \
449 ''.format(self
.__class
__.__name
__, equalities
, inequalities
)
452 def fromstring(cls
, string
):
453 raise NotImplementedError
458 universe
= Polyhedron()