74b54775c1b99176e7b0af1440df7119b406b69d
6 from fractions
import Fraction
, gcd
9 from .isl
import libisl
13 'Expression', 'Constant', 'Symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
20 def _polymorphic_method(func
):
21 @functools.wraps(func
)
23 if isinstance(b
, Expression
):
25 if isinstance(b
, numbers
.Rational
):
31 def _polymorphic_operator(func
):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func
)
36 if isinstance(a
, numbers
.Rational
):
39 elif isinstance(a
, Expression
):
41 raise TypeError('arguments must be linear expressions')
45 _main_ctx
= isl
.Context()
50 This class implements linear expressions.
60 def __new__(cls
, coefficients
=None, constant
=0):
61 if isinstance(coefficients
, str):
63 raise TypeError('too many arguments')
64 return cls
.fromstring(coefficients
)
65 if isinstance(coefficients
, dict):
66 coefficients
= coefficients
.items()
67 if coefficients
is None:
68 return Constant(constant
)
69 coefficients
= [(symbol
, coefficient
)
70 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
71 if len(coefficients
) == 0:
72 return Constant(constant
)
73 elif len(coefficients
) == 1 and constant
== 0:
74 symbol
, coefficient
= coefficients
[0]
77 self
= object().__new
__(cls
)
78 self
._coefficients
= {}
79 for symbol
, coefficient
in coefficients
:
80 if isinstance(symbol
, Symbol
):
82 elif not isinstance(symbol
, str):
83 raise TypeError('symbols must be strings or Symbol instances')
84 if isinstance(coefficient
, Constant
):
85 coefficient
= coefficient
.constant
86 if not isinstance(coefficient
, numbers
.Rational
):
87 raise TypeError('coefficients must be rational numbers or Constant instances')
88 self
._coefficients
[symbol
] = coefficient
89 if isinstance(constant
, Constant
):
90 constant
= constant
.constant
91 if not isinstance(constant
, numbers
.Rational
):
92 raise TypeError('constant must be a rational number or a Constant instance')
93 self
._constant
= constant
94 self
._symbols
= tuple(sorted(self
._coefficients
))
95 self
._dimension
= len(self
._symbols
)
99 def _fromast(cls
, node
):
100 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
101 return cls
._fromast
(node
.body
[0])
102 elif isinstance(node
, ast
.Expr
):
103 return cls
._fromast
(node
.value
)
104 elif isinstance(node
, ast
.Name
):
105 return Symbol(node
.id)
106 elif isinstance(node
, ast
.Num
):
107 return Constant(node
.n
)
108 elif isinstance(node
, ast
.UnaryOp
) and isinstance(node
.op
, ast
.USub
):
109 return -cls
._fromast
(node
.operand
)
110 elif isinstance(node
, ast
.BinOp
):
111 left
= cls
._fromast
(node
.left
)
112 right
= cls
._fromast
(node
.right
)
113 if isinstance(node
.op
, ast
.Add
):
115 elif isinstance(node
.op
, ast
.Sub
):
117 elif isinstance(node
.op
, ast
.Mult
):
119 elif isinstance(node
.op
, ast
.Div
):
121 raise SyntaxError('invalid syntax')
124 def fromstring(cls
, string
):
125 string
= re
.sub(r
'(\d+|\))\s*([^\W\d_]\w*|\()', r
'\1*\2', string
)
126 tree
= ast
.parse(string
, 'eval')
127 return cls
._fromast
(tree
)
135 return self
._dimension
137 def coefficient(self
, symbol
):
138 if isinstance(symbol
, Symbol
):
140 elif not isinstance(symbol
, str):
141 raise TypeError('symbol must be a string or a Symbol instance')
143 return self
._coefficients
[symbol
]
147 __getitem__
= coefficient
149 def coefficients(self
):
150 for symbol
in self
.symbols
:
151 yield symbol
, self
.coefficient(symbol
)
155 return self
._constant
157 def isconstant(self
):
161 for symbol
in self
.symbols
:
162 yield self
.coefficient(symbol
)
178 def __add__(self
, other
):
179 coefficients
= dict(self
.coefficients())
180 for symbol
, coefficient
in other
.coefficients():
181 if symbol
in coefficients
:
182 coefficients
[symbol
] += coefficient
184 coefficients
[symbol
] = coefficient
185 constant
= self
.constant
+ other
.constant
186 return Expression(coefficients
, constant
)
191 def __sub__(self
, other
):
192 coefficients
= dict(self
.coefficients())
193 for symbol
, coefficient
in other
.coefficients():
194 if symbol
in coefficients
:
195 coefficients
[symbol
] -= coefficient
197 coefficients
[symbol
] = -coefficient
198 constant
= self
.constant
- other
.constant
199 return Expression(coefficients
, constant
)
201 def __rsub__(self
, other
):
202 return -(self
- other
)
205 def __mul__(self
, other
):
206 if other
.isconstant():
207 coefficients
= dict(self
.coefficients())
208 for symbol
in coefficients
:
209 coefficients
[symbol
] *= other
.constant
210 constant
= self
.constant
* other
.constant
211 return Expression(coefficients
, constant
)
212 if isinstance(other
, Expression
) and not self
.isconstant():
213 raise ValueError('non-linear expression: '
214 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
215 return NotImplemented
220 def __truediv__(self
, other
):
221 if other
.isconstant():
222 coefficients
= dict(self
.coefficients())
223 for symbol
in coefficients
:
224 coefficients
[symbol
] = \
225 Fraction(coefficients
[symbol
], other
.constant
)
226 constant
= Fraction(self
.constant
, other
.constant
)
227 return Expression(coefficients
, constant
)
228 if isinstance(other
, Expression
):
229 raise ValueError('non-linear expression: '
230 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
231 return NotImplemented
233 def __rtruediv__(self
, other
):
234 if isinstance(other
, self
):
235 if self
.isconstant():
236 constant
= Fraction(other
, self
.constant
)
237 return Expression(constant
=constant
)
239 raise ValueError('non-linear expression: '
240 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
241 return NotImplemented
246 for symbol
in self
.symbols
:
247 coefficient
= self
.coefficient(symbol
)
252 string
+= ' + {}'.format(symbol
)
253 elif coefficient
== -1:
255 string
+= '-{}'.format(symbol
)
257 string
+= ' - {}'.format(symbol
)
260 string
+= '{}*{}'.format(coefficient
, symbol
)
261 elif coefficient
> 0:
262 string
+= ' + {}*{}'.format(coefficient
, symbol
)
264 assert coefficient
< 0
266 string
+= ' - {}*{}'.format(coefficient
, symbol
)
268 constant
= self
.constant
269 if constant
!= 0 and i
== 0:
270 string
+= '{}'.format(constant
)
272 string
+= ' + {}'.format(constant
)
275 string
+= ' - {}'.format(constant
)
280 def _parenstr(self
, always
=False):
282 if not always
and (self
.isconstant() or self
.issymbol()):
285 return '({})'.format(string
)
288 string
= '{}({{'.format(self
.__class
__.__name
__)
289 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
292 string
+= '{!r}: {!r}'.format(symbol
, coefficient
)
293 string
+= '}}, {!r})'.format(self
.constant
)
297 def __eq__(self
, other
):
299 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
300 return isinstance(other
, Expression
) and \
301 self
._coefficients
== other
._coefficients
and \
302 self
.constant
== other
.constant
305 return hash((tuple(sorted(self
._coefficients
.items())), self
._constant
))
308 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
309 [value
.denominator
for value
in self
.values()])
313 def _eq(self
, other
):
314 return Polyhedron(equalities
=[(self
- other
)._toint
()])
317 def __le__(self
, other
):
318 return Polyhedron(inequalities
=[(other
- self
)._toint
()])
321 def __lt__(self
, other
):
322 return Polyhedron(inequalities
=[(other
- self
)._toint
() - 1])
325 def __ge__(self
, other
):
326 return Polyhedron(inequalities
=[(self
- other
)._toint
()])
329 def __gt__(self
, other
):
330 return Polyhedron(inequalities
=[(self
- other
)._toint
() - 1])
333 def fromsympy(cls
, expr
):
337 for symbol
, coefficient
in expr
.as_coefficients_dict().items():
338 coefficient
= Fraction(coefficient
.p
, coefficient
.q
)
339 if symbol
== sympy
.S
.One
:
340 constant
= coefficient
341 elif isinstance(symbol
, sympy
.Symbol
):
343 coefficients
[symbol
] = coefficient
345 raise ValueError('non-linear expression: {!r}'.format(expr
))
346 return cls(coefficients
, constant
)
351 for symbol
, coefficient
in self
.coefficients():
352 term
= coefficient
* sympy
.Symbol(symbol
)
354 expr
+= self
.constant
358 class Constant(Expression
):
360 def __new__(cls
, numerator
=0, denominator
=None):
361 self
= object().__new
__(cls
)
362 if denominator
is None:
363 if isinstance(numerator
, numbers
.Rational
):
364 self
._constant
= numerator
365 elif isinstance(numerator
, Constant
):
366 self
._constant
= numerator
.constant
368 raise TypeError('constant must be a rational number or a Constant instance')
370 self
._constant
= Fraction(numerator
, denominator
)
371 self
._coefficients
= {}
376 def isconstant(self
):
380 return bool(self
.constant
)
383 if self
.constant
.denominator
== 1:
384 return '{}({!r})'.format(self
.__class
__.__name
__, self
.constant
)
386 return '{}({!r}, {!r})'.format(self
.__class
__.__name
__,
387 self
.constant
.numerator
, self
.constant
.denominator
)
390 def fromsympy(cls
, expr
):
392 if isinstance(expr
, sympy
.Rational
):
393 return cls(expr
.p
, expr
.q
)
394 elif isinstance(expr
, numbers
.Rational
):
397 raise TypeError('expr must be a sympy.Rational instance')
400 class Symbol(Expression
):
402 __slots__
= Expression
.__slots
__ + (
406 def __new__(cls
, name
):
407 if isinstance(name
, Symbol
):
409 elif not isinstance(name
, str):
410 raise TypeError('name must be a string or a Symbol instance')
411 self
= object().__new
__(cls
)
412 self
._coefficients
= {name
: 1}
414 self
._symbols
= tuple(name
)
427 return '{}({!r})'.format(self
.__class
__.__name
__, self
._name
)
430 def fromsympy(cls
, expr
):
432 if isinstance(expr
, sympy
.Symbol
):
433 return cls(expr
.name
)
435 raise TypeError('expr must be a sympy.Symbol instance')
439 if isinstance(names
, str):
440 names
= names
.replace(',', ' ').split()
441 return (Symbol(name
) for name
in names
)
444 @_polymorphic_operator
448 @_polymorphic_operator
452 @_polymorphic_operator
456 @_polymorphic_operator
460 @_polymorphic_operator
467 This class implements polyhedrons.
477 def __new__(cls
, equalities
=None, inequalities
=None):
478 if isinstance(equalities
, str):
479 if inequalities
is not None:
480 raise TypeError('too many arguments')
481 return cls
.fromstring(equalities
)
482 self
= super().__new
__(cls
)
483 self
._equalities
= []
484 if equalities
is not None:
485 for constraint
in equalities
:
486 for value
in constraint
.values():
487 if value
.denominator
!= 1:
488 raise TypeError('non-integer constraint: '
489 '{} == 0'.format(constraint
))
490 self
._equalities
.append(constraint
)
491 self
._equalities
= tuple(self
._equalities
)
492 self
._inequalities
= []
493 if inequalities
is not None:
494 for constraint
in inequalities
:
495 for value
in constraint
.values():
496 if value
.denominator
!= 1:
497 raise TypeError('non-integer constraint: '
498 '{} <= 0'.format(constraint
))
499 self
._inequalities
.append(constraint
)
500 self
._inequalities
= tuple(self
._inequalities
)
501 self
._constraints
= self
._equalities
+ self
._inequalities
502 self
._symbols
= set()
503 for constraint
in self
._constraints
:
504 self
.symbols
.update(constraint
.symbols
)
505 self
._symbols
= tuple(sorted(self
._symbols
))
509 def _fromast(cls
, node
):
510 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
511 return cls
._fromast
(node
.body
[0])
512 elif isinstance(node
, ast
.Expr
):
513 return cls
._fromast
(node
.value
)
514 elif isinstance(node
, ast
.BinOp
) and isinstance(node
.op
, ast
.BitAnd
):
515 equalities1
, inequalities1
= cls
._fromast
(node
.left
)
516 equalities2
, inequalities2
= cls
._fromast
(node
.right
)
517 equalities
= equalities1
+ equalities2
518 inequalities
= inequalities1
+ inequalities2
519 return equalities
, inequalities
520 elif isinstance(node
, ast
.Compare
):
523 left
= Expression
._fromast
(node
.left
)
524 for i
in range(len(node
.ops
)):
526 right
= Expression
._fromast
(node
.comparators
[i
])
527 if isinstance(op
, ast
.Lt
):
528 inequalities
.append(right
- left
- 1)
529 elif isinstance(op
, ast
.LtE
):
530 inequalities
.append(right
- left
)
531 elif isinstance(op
, ast
.Eq
):
532 equalities
.append(left
- right
)
533 elif isinstance(op
, ast
.GtE
):
534 inequalities
.append(left
- right
)
535 elif isinstance(op
, ast
.Gt
):
536 inequalities
.append(left
- right
- 1)
541 return equalities
, inequalities
542 raise SyntaxError('invalid syntax')
545 def fromstring(cls
, string
):
546 string
= string
.strip()
547 string
= re
.sub(r
'^\{\s*|\s*\}$', '', string
)
548 string
= re
.sub(r
'([^<=>])=([^<=>])', r
'\1==\2', string
)
549 string
= re
.sub(r
'(\d+|\))\s*([^\W\d_]\w*|\()', r
'\1*\2', string
)
550 tokens
= re
.split(r
',|;|and|&&|/\\|∧', string
, flags
=re
.I
)
551 tokens
= ['({})'.format(token
) for token
in tokens
]
552 string
= ' & '.join(tokens
)
553 tree
= ast
.parse(string
, 'eval')
554 equalities
, inequalities
= cls
._fromast
(tree
)
555 return cls(equalities
, inequalities
)
558 def equalities(self
):
559 return self
._equalities
562 def inequalities(self
):
563 return self
._inequalities
566 def constraints(self
):
567 return self
._constraints
575 return len(self
.symbols
)
578 return not self
.is_empty()
580 def __contains__(self
, value
):
581 # is the value in the polyhedron?
582 raise NotImplementedError
584 def __eq__(self
, other
):
585 # works correctly when symbols is not passed
586 # should be equal if values are the same even if symbols are different
588 other
= other
._toisl
()
589 return bool(libisl
.isl_basic_set_plain_is_equal(bset
, other
))
593 return bool(libisl
.isl_basic_set_is_empty(bset
))
595 def isuniverse(self
):
597 return bool(libisl
.isl_basic_set_is_universe(bset
))
599 def isdisjoint(self
, other
):
600 # return true if the polyhedron has no elements in common with other
601 #symbols = self._symbolunion(other)
603 other
= other
._toisl
()
604 return bool(libisl
.isl_set_is_disjoint(bset
, other
))
606 def issubset(self
, other
):
607 # check if self(bset) is a subset of other
608 symbols
= self
._symbolunion
(other
)
609 bset
= self
._toisl
(symbols
)
610 other
= other
._toisl
(symbols
)
611 return bool(libisl
.isl_set_is_strict_subset(other
, bset
))
613 def __le__(self
, other
):
614 return self
.issubset(other
)
616 def __lt__(self
, other
):
617 symbols
= self
._symbolunion
(other
)
618 bset
= self
._toisl
(symbols
)
619 other
= other
._toisl
(symbols
)
620 return bool(libisl
.isl_set_is_strict_subset(other
, bset
))
622 def issuperset(self
, other
):
623 # test whether every element in other is in the polyhedron
624 raise NotImplementedError
626 def __ge__(self
, other
):
627 return self
.issuperset(other
)
629 def __gt__(self
, other
):
630 symbols
= self
._symbolunion
(other
)
631 bset
= self
._toisl
(symbols
)
632 other
= other
._toisl
(symbols
)
633 bool(libisl
.isl_set_is_strict_subset(other
, bset
))
634 raise NotImplementedError
636 def union(self
, *others
):
637 # return a new polyhedron with elements from the polyhedron and all
638 # others (convex union)
639 raise NotImplementedError
641 def __or__(self
, other
):
642 return self
.union(other
)
644 def intersection(self
, *others
):
645 # return a new polyhedron with elements common to the polyhedron and all
647 # a poor man's implementation could be:
648 # equalities = list(self.equalities)
649 # inequalities = list(self.inequalities)
650 # for other in others:
651 # equalities.extend(other.equalities)
652 # inequalities.extend(other.inequalities)
653 # return self.__class__(equalities, inequalities)
654 raise NotImplementedError
656 def __and__(self
, other
):
657 return self
.intersection(other
)
659 def difference(self
, other
):
660 # return a new polyhedron with elements in the polyhedron that are not in the other
661 symbols
= self
._symbolunion
(other
)
662 bset
= self
._toisl
(symbols
)
663 other
= other
._toisl
(symbols
)
664 difference
= libisl
.isl_set_subtract(bset
, other
)
667 def __sub__(self
, other
):
668 return self
.difference(other
)
672 for constraint
in self
.equalities
:
673 constraints
.append('{} == 0'.format(constraint
))
674 for constraint
in self
.inequalities
:
675 constraints
.append('{} >= 0'.format(constraint
))
676 return '{{{}}}'.format(', '.join(constraints
))
681 elif self
.isuniverse():
684 equalities
= list(self
.equalities
)
685 inequalities
= list(self
.inequalities
)
686 return '{}(equalities={!r}, inequalities={!r})' \
687 ''.format(self
.__class
__.__name
__, equalities
, inequalities
)
690 def _fromsympy(cls
, expr
):
694 if expr
.func
== sympy
.And
:
695 for arg
in expr
.args
:
696 arg_eqs
, arg_ins
= cls
._fromsympy
(arg
)
697 equalities
.extend(arg_eqs
)
698 inequalities
.extend(arg_ins
)
699 elif expr
.func
== sympy
.Eq
:
700 expr
= Expression
.fromsympy(expr
.args
[0] - expr
.args
[1])
701 equalities
.append(expr
)
703 if expr
.func
== sympy
.Lt
:
704 expr
= Expression
.fromsympy(expr
.args
[1] - expr
.args
[0] - 1)
705 elif expr
.func
== sympy
.Le
:
706 expr
= Expression
.fromsympy(expr
.args
[1] - expr
.args
[0])
707 elif expr
.func
== sympy
.Ge
:
708 expr
= Expression
.fromsympy(expr
.args
[0] - expr
.args
[1])
709 elif expr
.func
== sympy
.Gt
:
710 expr
= Expression
.fromsympy(expr
.args
[0] - expr
.args
[1] - 1)
712 raise ValueError('non-polyhedral expression: {!r}'.format(expr
))
713 inequalities
.append(expr
)
714 return equalities
, inequalities
717 def fromsympy(cls
, expr
):
719 equalities
, inequalities
= cls
._fromsympy
(expr
)
720 return cls(equalities
, inequalities
)
725 for equality
in self
.equalities
:
726 constraints
.append(sympy
.Eq(equality
.tosympy(), 0))
727 for inequality
in self
.inequalities
:
728 constraints
.append(sympy
.Ge(inequality
.tosympy(), 0))
729 return sympy
.And(*constraints
)
731 def _symbolunion(self
, *others
):
732 symbols
= set(self
.symbols
)
734 symbols
.update(other
.symbols
)
735 return sorted(symbols
)
737 def _toisl(self
, symbols
=None):
739 symbols
= self
.symbols
740 dimension
= len(symbols
)
741 space
= libisl
.isl_space_set_alloc(_main_ctx
, 0, dimension
)
742 bset
= libisl
.isl_basic_set_universe(libisl
.isl_space_copy(space
))
743 ls
= libisl
.isl_local_space_from_space(space
)
744 for equality
in self
.equalities
:
745 ceq
= libisl
.isl_equality_alloc(libisl
.isl_local_space_copy(ls
))
746 for symbol
, coefficient
in equality
.coefficients():
747 val
= str(coefficient
).encode()
748 val
= libisl
.isl_val_read_from_str(_main_ctx
, val
)
749 dim
= symbols
.index(symbol
)
750 ceq
= libisl
.isl_constraint_set_coefficient_val(ceq
, libisl
.isl_dim_set
, dim
, val
)
751 if equality
.constant
!= 0:
752 val
= str(equality
.constant
).encode()
753 val
= libisl
.isl_val_read_from_str(_main_ctx
, val
)
754 ceq
= libisl
.isl_constraint_set_constant_val(ceq
, val
)
755 bset
= libisl
.isl_basic_set_add_constraint(bset
, ceq
)
756 for inequality
in self
.inequalities
:
757 cin
= libisl
.isl_inequality_alloc(libisl
.isl_local_space_copy(ls
))
758 for symbol
, coefficient
in inequality
.coefficients():
759 val
= str(coefficient
).encode()
760 val
= libisl
.isl_val_read_from_str(_main_ctx
, val
)
761 dim
= symbols
.index(symbol
)
762 cin
= libisl
.isl_constraint_set_coefficient_val(cin
, libisl
.isl_dim_set
, dim
, val
)
763 if inequality
.constant
!= 0:
764 val
= str(inequality
.constant
).encode()
765 val
= libisl
.isl_val_read_from_str(_main_ctx
, val
)
766 cin
= libisl
.isl_constraint_set_constant_val(cin
, val
)
767 bset
= libisl
.isl_basic_set_add_constraint(bset
, cin
)
768 bset
= isl
.BasicSet(bset
)
772 def _fromisl(cls
, bset
, symbols
):
773 raise NotImplementedError
776 return cls(equalities
, inequalities
)
777 '''takes basic set in isl form and puts back into python version of polyhedron
778 isl example code gives isl form as:
779 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
780 our printer is giving form as:
781 { [i0, i1] : 2i1 >= -2 - i0 } '''
785 Universe
= Polyhedron()
788 if __name__
== '__main__':
789 #p = Polyhedron('2a + 2b + 1 == 0') # empty
790 p
= Polyhedron('3x + 2y + 3 == 0, y == 0') # not empty
793 print(ip
.constraints())