3cd76335e443cba2cda1c711bb8b1c5f6ac50cbd
4 from fractions
import Fraction
, gcd
7 from pypol
.isl
import libisl
11 'Expression', 'Constant', 'Symbol', 'symbols',
12 'eq', 'le', 'lt', 'ge', 'gt',
18 def _polymorphic_method(func
):
19 @functools.wraps(func
)
21 if isinstance(b
, Expression
):
23 if isinstance(b
, numbers
.Rational
):
29 def _polymorphic_operator(func
):
30 # A polymorphic operator should call a polymorphic method, hence we just
31 # have to test the left operand.
32 @functools.wraps(func
)
34 if isinstance(a
, numbers
.Rational
):
37 elif isinstance(a
, Expression
):
39 raise TypeError('arguments must be linear expressions')
43 _main_ctx
= isl
.Context()
48 This class implements linear expressions.
51 def __new__(cls
, coefficients
=None, constant
=0):
52 if isinstance(coefficients
, str):
54 raise TypeError('too many arguments')
55 return cls
.fromstring(coefficients
)
56 if isinstance(coefficients
, dict):
57 coefficients
= coefficients
.items()
58 if coefficients
is None:
59 return Constant(constant
)
60 coefficients
= [(symbol
, coefficient
)
61 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
62 if len(coefficients
) == 0:
63 return Constant(constant
)
64 elif len(coefficients
) == 1 and constant
== 0:
65 symbol
, coefficient
= coefficients
[0]
68 self
= object().__new
__(cls
)
69 self
._coefficients
= {}
70 for symbol
, coefficient
in coefficients
:
71 if isinstance(symbol
, Symbol
):
73 elif not isinstance(symbol
, str):
74 raise TypeError('symbols must be strings or Symbol instances')
75 if isinstance(coefficient
, Constant
):
76 coefficient
= coefficient
.constant
77 if not isinstance(coefficient
, numbers
.Rational
):
78 raise TypeError('coefficients must be rational numbers or Constant instances')
79 self
._coefficients
[symbol
] = coefficient
80 if isinstance(constant
, Constant
):
81 constant
= constant
.constant
82 if not isinstance(constant
, numbers
.Rational
):
83 raise TypeError('constant must be a rational number or a Constant instance')
84 self
._constant
= constant
85 self
._symbols
= tuple(sorted(self
._coefficients
))
86 self
._dimension
= len(self
._symbols
)
90 def fromstring(cls
, string
):
91 raise NotImplementedError
99 return self
._dimension
101 def coefficient(self
, symbol
):
102 if isinstance(symbol
, Symbol
):
104 elif not isinstance(symbol
, str):
105 raise TypeError('symbol must be a string or a Symbol instance')
107 return self
._coefficients
[symbol
]
111 __getitem__
= coefficient
113 def coefficients(self
):
114 for symbol
in self
.symbols
:
115 yield symbol
, self
.coefficient(symbol
)
119 return self
._constant
121 def isconstant(self
):
125 for symbol
in self
.symbols
:
126 yield self
.coefficient(symbol
)
131 raise ValueError('not a symbol: {}'.format(self
))
146 def __add__(self
, other
):
147 coefficients
= dict(self
.coefficients())
148 for symbol
, coefficient
in other
.coefficients():
149 if symbol
in coefficients
:
150 coefficients
[symbol
] += coefficient
152 coefficients
[symbol
] = coefficient
153 constant
= self
.constant
+ other
.constant
154 return Expression(coefficients
, constant
)
159 def __sub__(self
, other
):
160 coefficients
= dict(self
.coefficients())
161 for symbol
, coefficient
in other
.coefficients():
162 if symbol
in coefficients
:
163 coefficients
[symbol
] -= coefficient
165 coefficients
[symbol
] = -coefficient
166 constant
= self
.constant
- other
.constant
167 return Expression(coefficients
, constant
)
169 def __rsub__(self
, other
):
170 return -(self
- other
)
173 def __mul__(self
, other
):
174 if other
.isconstant():
175 coefficients
= dict(self
.coefficients())
176 for symbol
in coefficients
:
177 coefficients
[symbol
] *= other
.constant
178 constant
= self
.constant
* other
.constant
179 return Expression(coefficients
, constant
)
180 if isinstance(other
, Expression
) and not self
.isconstant():
181 raise ValueError('non-linear expression: '
182 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
183 return NotImplemented
188 def __truediv__(self
, other
):
189 if other
.isconstant():
190 coefficients
= dict(self
.coefficients())
191 for symbol
in coefficients
:
192 coefficients
[symbol
] = \
193 Fraction(coefficients
[symbol
], other
.constant
)
194 constant
= Fraction(self
.constant
, other
.constant
)
195 return Expression(coefficients
, constant
)
196 if isinstance(other
, Expression
):
197 raise ValueError('non-linear expression: '
198 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
199 return NotImplemented
201 def __rtruediv__(self
, other
):
202 if isinstance(other
, self
):
203 if self
.isconstant():
204 constant
= Fraction(other
, self
.constant
)
205 return Expression(constant
=constant
)
207 raise ValueError('non-linear expression: '
208 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
209 return NotImplemented
214 for symbol
in self
.symbols
:
215 coefficient
= self
.coefficient(symbol
)
220 string
+= ' + {}'.format(symbol
)
221 elif coefficient
== -1:
223 string
+= '-{}'.format(symbol
)
225 string
+= ' - {}'.format(symbol
)
228 string
+= '{}*{}'.format(coefficient
, symbol
)
229 elif coefficient
> 0:
230 string
+= ' + {}*{}'.format(coefficient
, symbol
)
232 assert coefficient
< 0
234 string
+= ' - {}*{}'.format(coefficient
, symbol
)
236 constant
= self
.constant
237 if constant
!= 0 and i
== 0:
238 string
+= '{}'.format(constant
)
240 string
+= ' + {}'.format(constant
)
243 string
+= ' - {}'.format(constant
)
248 def _parenstr(self
, always
=False):
250 if not always
and (self
.isconstant() or self
.issymbol()):
253 return '({})'.format(string
)
256 string
= '{}({{'.format(self
.__class
__.__name
__)
257 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
260 string
+= '{!r}: {!r}'.format(symbol
, coefficient
)
261 string
+= '}}, {!r})'.format(self
.constant
)
265 def __eq__(self
, other
):
267 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
268 return isinstance(other
, Expression
) and \
269 self
._coefficients
== other
._coefficients
and \
270 self
.constant
== other
.constant
273 return hash((tuple(sorted(self
._coefficients
.items())), self
._constant
))
276 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
277 [value
.denominator
for value
in self
.values()])
281 def _eq(self
, other
):
282 return Polyhedron(equalities
=[(self
- other
)._toint
()])
285 def __le__(self
, other
):
286 return Polyhedron(inequalities
=[(other
- self
)._toint
()])
289 def __lt__(self
, other
):
290 return Polyhedron(inequalities
=[(other
- self
)._toint
() - 1])
293 def __ge__(self
, other
):
294 return Polyhedron(inequalities
=[(self
- other
)._toint
()])
297 def __gt__(self
, other
):
298 return Polyhedron(inequalities
=[(self
- other
)._toint
() - 1])
301 class Constant(Expression
):
303 def __new__(cls
, numerator
=0, denominator
=None):
304 self
= object().__new
__(cls
)
305 if denominator
is None:
306 if isinstance(numerator
, numbers
.Rational
):
307 self
._constant
= numerator
308 elif isinstance(numerator
, Constant
):
309 self
._constant
= numerator
.constant
311 raise TypeError('constant must be a rational number or a Constant instance')
313 self
._constant
= Fraction(numerator
, denominator
)
314 self
._coefficients
= {}
319 def isconstant(self
):
323 return bool(self
.constant
)
326 return '{}({!r})'.format(self
.__class
__.__name
__, self
._constant
)
329 class Symbol(Expression
):
331 def __new__(cls
, name
):
332 if isinstance(name
, Symbol
):
334 elif not isinstance(name
, str):
335 raise TypeError('name must be a string or a Symbol instance')
336 self
= object().__new
__(cls
)
337 self
._coefficients
= {name
: 1}
339 self
._symbols
= tuple(name
)
352 return '{}({!r})'.format(self
.__class
__.__name
__, self
._symbol
)
355 if isinstance(names
, str):
356 names
= names
.replace(',', ' ').split()
357 return (Symbol(name
) for name
in names
)
360 @_polymorphic_operator
364 @_polymorphic_operator
368 @_polymorphic_operator
372 @_polymorphic_operator
376 @_polymorphic_operator
383 This class implements polyhedrons.
386 def __new__(cls
, equalities
=None, inequalities
=None):
387 if isinstance(equalities
, str):
388 if inequalities
is not None:
389 raise TypeError('too many arguments')
390 return cls
.fromstring(equalities
)
391 self
= super().__new
__(cls
)
392 self
._equalities
= []
393 if equalities
is not None:
394 for constraint
in equalities
:
395 for value
in constraint
.values():
396 if value
.denominator
!= 1:
397 raise TypeError('non-integer constraint: '
398 '{} == 0'.format(constraint
))
399 self
._equalities
.append(constraint
)
400 self
._equalities
= tuple(self
._equalities
)
401 self
._inequalities
= []
402 if inequalities
is not None:
403 for constraint
in inequalities
:
404 for value
in constraint
.values():
405 if value
.denominator
!= 1:
406 raise TypeError('non-integer constraint: '
407 '{} <= 0'.format(constraint
))
408 self
._inequalities
.append(constraint
)
409 self
._inequalities
= tuple(self
._inequalities
)
410 self
._constraints
= self
._equalities
+ self
._inequalities
411 self
._symbols
= set()
412 for constraint
in self
._constraints
:
413 self
.symbols
.update(constraint
.symbols
)
414 self
._symbols
= tuple(sorted(self
._symbols
))
418 def fromstring(cls
, string
):
419 raise NotImplementedError
422 def equalities(self
):
423 return self
._equalities
426 def inequalities(self
):
427 return self
._inequalities
430 def constraints(self
):
431 return self
._constraints
439 return len(self
.symbols
)
442 return not self
.is_empty()
444 def __contains__(self
, value
):
445 # is the value in the polyhedron?
446 raise NotImplementedError
448 def __eq__(self
, other
):
449 # works correctly when symbols is not passed
450 # should be equal if values are the same even if symbols are different
452 other
= other
._toisl
()
453 return bool(libisl
.isl_basic_set_plain_is_equal(bset
, other
))
457 return bool(libisl
.isl_basic_set_is_empty(bset
))
459 def isuniverse(self
):
461 return bool(libisl
.isl_basic_set_is_universe(bset
))
463 def isdisjoint(self
, other
):
464 # return true if the polyhedron has no elements in common with other
465 #symbols = self._symbolunion(other)
467 other
= other
._toisl
()
468 return bool(libisl
.isl_set_is_disjoint(bset
, other
))
470 def issubset(self
, other
):
471 # check if self(bset) is a subset of other
472 symbols
= self
._symbolunion
(other
)
473 bset
= self
._toisl
(symbols
)
474 other
= other
._toisl
(symbols
)
475 return bool(libisl
.isl_set_is_strict_subset(other
, bset
))
477 def __le__(self
, other
):
478 return self
.issubset(other
)
480 def __lt__(self
, other
):
481 symbols
= self
._symbolunion
(other
)
482 bset
= self
._toisl
(symbols
)
483 other
= other
._toisl
(symbols
)
484 return bool(libisl
.isl_set_is_strict_subset(other
, bset
))
486 def issuperset(self
, other
):
487 # test whether every element in other is in the polyhedron
488 raise NotImplementedError
490 def __ge__(self
, other
):
491 return self
.issuperset(other
)
493 def __gt__(self
, other
):
494 symbols
= self
._symbolunion
(other
)
495 bset
= self
._toisl
(symbols
)
496 other
= other
._toisl
(symbols
)
497 bool(libisl
.isl_set_is_strict_subset(other
, bset
))
498 raise NotImplementedError
500 def union(self
, *others
):
501 # return a new polyhedron with elements from the polyhedron and all
502 # others (convex union)
503 raise NotImplementedError
505 def __or__(self
, other
):
506 return self
.union(other
)
508 def intersection(self
, *others
):
509 # return a new polyhedron with elements common to the polyhedron and all
511 # a poor man's implementation could be:
512 # equalities = list(self.equalities)
513 # inequalities = list(self.inequalities)
514 # for other in others:
515 # equalities.extend(other.equalities)
516 # inequalities.extend(other.inequalities)
517 # return self.__class__(equalities, inequalities)
518 raise NotImplementedError
520 def __and__(self
, other
):
521 return self
.intersection(other
)
523 def difference(self
, other
):
524 # return a new polyhedron with elements in the polyhedron that are not in the other
525 symbols
= self
._symbolunion
(other
)
526 bset
= self
._toisl
(symbols
)
527 other
= other
._toisl
(symbols
)
528 difference
= libisl
.isl_set_subtract(bset
, other
)
532 def __sub__(self
, other
):
533 return self
.difference(other
)
537 for constraint
in self
.equalities
:
538 constraints
.append('{} == 0'.format(constraint
))
539 for constraint
in self
.inequalities
:
540 constraints
.append('{} >= 0'.format(constraint
))
541 return '{{{}}}'.format(', '.join(constraints
))
546 elif self
.isuniverse():
549 equalities
= list(self
.equalities
)
550 inequalities
= list(self
.inequalities
)
551 return '{}(equalities={!r}, inequalities={!r})' \
552 ''.format(self
.__class
__.__name
__, equalities
, inequalities
)
554 def _symbolunion(self
, *others
):
555 symbols
= set(self
.symbols
)
557 symbols
.update(other
.symbols
)
558 return sorted(symbols
)
560 def _toisl(self
, symbols
=None):
562 symbols
= self
.symbols
563 num_coefficients
= len(symbols
)
564 space
= libisl
.isl_space_set_alloc(_main_ctx
, 0, num_coefficients
)
565 bset
= libisl
.isl_basic_set_universe(libisl
.isl_space_copy(space
))
566 ls
= libisl
.isl_local_space_from_space(space
)
567 #if there are equalities/inequalities, take each constant and coefficient and add as a constraint to the basic set
568 for eq
in self
.equalities
:
569 ceq
= libisl
.isl_equality_alloc(libisl
.isl_local_space_copy(ls
))
570 coeff_eq
= dict(eq
.coefficients())
572 value
= str(eq
.constant
).encode()
573 val
= libisl
.isl_val_read_from_str(_main_ctx
, value
)
574 ceq
= libisl
.isl_constraint_set_constant_val(ceq
, val
)
576 number
= str(coeff_eq
.get(eq
)).encode()
577 num
= libisl
.isl_val_read_from_str(_main_ctx
, number
)
578 iden
= symbols
.index(eq
)
579 ceq
= libisl
.isl_constraint_set_coefficient_val(ceq
, libisl
.isl_dim_set
, iden
, num
) #use 3 for type isl_dim_set
580 bset
= libisl
.isl_basic_set_add_constraint(bset
, ceq
)
581 for ineq
in self
.inequalities
:
582 cin
= libisl
.isl_inequality_alloc(libisl
.isl_local_space_copy(ls
))
583 coeff_in
= dict(ineq
.coefficients())
585 value
= str(ineq
.constant
).encode()
586 val
= libisl
.isl_val_read_from_str(_main_ctx
, value
)
587 cin
= libisl
.isl_constraint_set_constant_val(cin
, val
)
588 for ineq
in coeff_in
:
589 number
= str(coeff_in
.get(ineq
)).encode()
590 num
= libisl
.isl_val_read_from_str(_main_ctx
, number
)
591 iden
= symbols
.index(ineq
)
592 cin
= libisl
.isl_constraint_set_coefficient_val(cin
, libisl
.isl_dim_set
, iden
, num
) #use 3 for type isl_dim_set
593 bset
= libisl
.isl_basic_set_add_constraint(bset
, cin
)
594 bset
= isl
.BasicSet(bset
)
598 def _fromisl(cls
, bset
):
599 raise NotImplementedError
602 return cls(equalities
, inequalities
)
603 '''takes basic set in isl form and puts back into python version of polyhedron
604 isl example code gives isl form as:
605 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
606 our printer is giving form as:
607 { [i0, i1] : 2i1 >= -2 - i0 } '''
610 Universe
= Polyhedron()
612 if __name__
== '__main__':
613 ex1
= Expression(coefficients
={'a': 6, 'b': 6}, constant
= 3) #this is the expression that does not work (even without adding values)
614 ex2
= Expression(coefficients
={'x': 4, 'y': 2}, constant
= 3)
615 p
= Polyhedron(equalities
=[ex2
])
616 p2
= Polyhedron(equalities
=[ex2
])
617 print(p
._toisl
()) # checking is values works for toisl