749f7c365e08567d671c72626ea19a9e0fbce27a
1 import ctypes
, ctypes
.util
5 from fractions
import Fraction
, gcd
8 from .isl
import libisl
12 'Expression', 'Constant', 'Symbol', 'symbols',
13 'eq', 'le', 'lt', 'ge', 'gt',
19 def _polymorphic_method(func
):
20 @functools.wraps(func
)
22 if isinstance(b
, Expression
):
24 if isinstance(b
, numbers
.Rational
):
30 def _polymorphic_operator(func
):
31 # A polymorphic operator should call a polymorphic method, hence we just
32 # have to test the left operand.
33 @functools.wraps(func
)
35 if isinstance(a
, numbers
.Rational
):
38 elif isinstance(a
, Expression
):
40 raise TypeError('arguments must be linear expressions')
44 _main_ctx
= isl
.Context()
49 This class implements linear expressions.
52 def __new__(cls
, coefficients
=None, constant
=0):
53 if isinstance(coefficients
, str):
55 raise TypeError('too many arguments')
56 return cls
.fromstring(coefficients
)
57 if isinstance(coefficients
, dict):
58 coefficients
= coefficients
.items()
59 if coefficients
is None:
60 return Constant(constant
)
61 coefficients
= [(symbol
, coefficient
)
62 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
63 if len(coefficients
) == 0:
64 return Constant(constant
)
65 elif len(coefficients
) == 1 and constant
== 0:
66 symbol
, coefficient
= coefficients
[0]
69 self
= object().__new
__(cls
)
70 self
._coefficients
= {}
71 for symbol
, coefficient
in coefficients
:
72 if isinstance(symbol
, Symbol
):
74 elif not isinstance(symbol
, str):
75 raise TypeError('symbols must be strings or Symbol instances')
76 if isinstance(coefficient
, Constant
):
77 coefficient
= coefficient
.constant
78 if not isinstance(coefficient
, numbers
.Rational
):
79 raise TypeError('coefficients must be rational numbers or Constant instances')
80 self
._coefficients
[symbol
] = coefficient
81 if isinstance(constant
, Constant
):
82 constant
= constant
.constant
83 if not isinstance(constant
, numbers
.Rational
):
84 raise TypeError('constant must be a rational number or a Constant instance')
85 self
._constant
= constant
86 self
._symbols
= tuple(sorted(self
._coefficients
))
87 self
._dimension
= len(self
._symbols
)
91 def fromstring(cls
, string
):
92 raise NotImplementedError
100 return self
._dimension
102 def coefficient(self
, symbol
):
103 if isinstance(symbol
, Symbol
):
105 elif not isinstance(symbol
, str):
106 raise TypeError('symbol must be a string or a Symbol instance')
108 return self
._coefficients
[symbol
]
112 __getitem__
= coefficient
114 def coefficients(self
):
115 for symbol
in self
.symbols
:
116 yield symbol
, self
.coefficient(symbol
)
120 return self
._constant
122 def isconstant(self
):
126 for symbol
in self
.symbols
:
127 yield self
.coefficient(symbol
)
132 raise ValueError('not a symbol: {}'.format(self
))
147 def __add__(self
, other
):
148 coefficients
= dict(self
.coefficients())
149 for symbol
, coefficient
in other
.coefficients():
150 if symbol
in coefficients
:
151 coefficients
[symbol
] += coefficient
153 coefficients
[symbol
] = coefficient
154 constant
= self
.constant
+ other
.constant
155 return Expression(coefficients
, constant
)
160 def __sub__(self
, other
):
161 coefficients
= dict(self
.coefficients())
162 for symbol
, coefficient
in other
.coefficients():
163 if symbol
in coefficients
:
164 coefficients
[symbol
] -= coefficient
166 coefficients
[symbol
] = -coefficient
167 constant
= self
.constant
- other
.constant
168 return Expression(coefficients
, constant
)
170 def __rsub__(self
, other
):
171 return -(self
- other
)
174 def __mul__(self
, other
):
175 if other
.isconstant():
176 coefficients
= dict(self
.coefficients())
177 for symbol
in coefficients
:
178 coefficients
[symbol
] *= other
.constant
179 constant
= self
.constant
* other
.constant
180 return Expression(coefficients
, constant
)
181 if isinstance(other
, Expression
) and not self
.isconstant():
182 raise ValueError('non-linear expression: '
183 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
184 return NotImplemented
189 def __truediv__(self
, other
):
190 if other
.isconstant():
191 coefficients
= dict(self
.coefficients())
192 for symbol
in coefficients
:
193 coefficients
[symbol
] = \
194 Fraction(coefficients
[symbol
], other
.constant
)
195 constant
= Fraction(self
.constant
, other
.constant
)
196 return Expression(coefficients
, constant
)
197 if isinstance(other
, Expression
):
198 raise ValueError('non-linear expression: '
199 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
200 return NotImplemented
202 def __rtruediv__(self
, other
):
203 if isinstance(other
, self
):
204 if self
.isconstant():
205 constant
= Fraction(other
, self
.constant
)
206 return Expression(constant
=constant
)
208 raise ValueError('non-linear expression: '
209 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
210 return NotImplemented
215 for symbol
in self
.symbols
:
216 coefficient
= self
.coefficient(symbol
)
221 string
+= ' + {}'.format(symbol
)
222 elif coefficient
== -1:
224 string
+= '-{}'.format(symbol
)
226 string
+= ' - {}'.format(symbol
)
229 string
+= '{}*{}'.format(coefficient
, symbol
)
230 elif coefficient
> 0:
231 string
+= ' + {}*{}'.format(coefficient
, symbol
)
233 assert coefficient
< 0
235 string
+= ' - {}*{}'.format(coefficient
, symbol
)
237 constant
= self
.constant
238 if constant
!= 0 and i
== 0:
239 string
+= '{}'.format(constant
)
241 string
+= ' + {}'.format(constant
)
244 string
+= ' - {}'.format(constant
)
249 def _parenstr(self
, always
=False):
251 if not always
and (self
.isconstant() or self
.issymbol()):
254 return '({})'.format(string
)
257 string
= '{}({{'.format(self
.__class
__.__name
__)
258 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
261 string
+= '{!r}: {!r}'.format(symbol
, coefficient
)
262 string
+= '}}, {!r})'.format(self
.constant
)
266 def __eq__(self
, other
):
268 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
269 return isinstance(other
, Expression
) and \
270 self
._coefficients
== other
._coefficients
and \
271 self
.constant
== other
.constant
274 return hash((tuple(sorted(self
._coefficients
.items())), self
._constant
))
277 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
278 [value
.denominator
for value
in self
.values()])
282 def _eq(self
, other
):
283 return Polyhedron(equalities
=[(self
- other
)._toint
()])
286 def __le__(self
, other
):
287 return Polyhedron(inequalities
=[(other
- self
)._toint
()])
290 def __lt__(self
, other
):
291 return Polyhedron(inequalities
=[(other
- self
)._toint
() - 1])
294 def __ge__(self
, other
):
295 return Polyhedron(inequalities
=[(self
- other
)._toint
()])
298 def __gt__(self
, other
):
299 return Polyhedron(inequalities
=[(self
- other
)._toint
() - 1])
302 class Constant(Expression
):
304 def __new__(cls
, numerator
=0, denominator
=None):
305 self
= object().__new
__(cls
)
306 if denominator
is None:
307 if isinstance(numerator
, numbers
.Rational
):
308 self
._constant
= numerator
309 elif isinstance(numerator
, Constant
):
310 self
._constant
= numerator
.constant
312 raise TypeError('constant must be a rational number or a Constant instance')
314 self
._constant
= Fraction(numerator
, denominator
)
315 self
._coefficients
= {}
320 def isconstant(self
):
324 return bool(self
.constant
)
327 return '{}({!r})'.format(self
.__class
__.__name
__, self
._constant
)
330 class Symbol(Expression
):
332 def __new__(cls
, name
):
333 if isinstance(name
, Symbol
):
335 elif not isinstance(name
, str):
336 raise TypeError('name must be a string or a Symbol instance')
337 self
= object().__new
__(cls
)
338 self
._coefficients
= {name
: 1}
340 self
._symbols
= tuple(name
)
353 return '{}({!r})'.format(self
.__class
__.__name
__, self
._symbol
)
356 if isinstance(names
, str):
357 names
= names
.replace(',', ' ').split()
358 return (symbol(name
) for name
in names
)
361 @_polymorphic_operator
365 @_polymorphic_operator
369 @_polymorphic_operator
373 @_polymorphic_operator
377 @_polymorphic_operator
384 This class implements polyhedrons.
387 def __new__(cls
, equalities
=None, inequalities
=None):
388 if isinstance(equalities
, str):
389 if inequalities
is not None:
390 raise TypeError('too many arguments')
391 return cls
.fromstring(equalities
)
392 self
= super().__new
__(cls
)
393 self
._equalities
= []
394 if equalities
is not None:
395 for constraint
in equalities
:
396 for value
in constraint
.values():
397 if value
.denominator
!= 1:
398 raise TypeError('non-integer constraint: '
399 '{} == 0'.format(constraint
))
400 self
._equalities
.append(constraint
)
401 self
._equalities
= tuple(self
._equalities
)
402 self
._inequalities
= []
403 if inequalities
is not None:
404 for constraint
in inequalities
:
405 for value
in constraint
.values():
406 if value
.denominator
!= 1:
407 raise TypeError('non-integer constraint: '
408 '{} <= 0'.format(constraint
))
409 self
._inequalities
.append(constraint
)
410 self
._inequalities
= tuple(self
._inequalities
)
411 self
._constraints
= self
._equalities
+ self
._inequalities
412 self
._symbols
= set()
413 for constraint
in self
._constraints
:
414 self
.symbols
.update(constraint
.symbols
)
415 self
._symbols
= tuple(sorted(self
._symbols
))
419 def fromstring(cls
, string
):
420 raise NotImplementedError
423 def equalities(self
):
424 return self
._equalities
427 def inequalities(self
):
428 return self
._inequalities
431 def constraints(self
):
432 return self
._constraints
440 return len(self
.symbols
)
443 return not self
.is_empty()
445 def __contains__(self
, value
):
446 # is the value in the polyhedron?
447 raise NotImplementedError
449 def __eq__(self
, other
):
450 raise NotImplementedError
454 return bool(libisl
.isl_basic_set_is_empty(bset
))
456 def isuniverse(self
):
457 raise NotImplementedError
459 def isdisjoint(self
, other
):
460 # return true if the polyhedron has no elements in common with other
461 raise NotImplementedError
463 def issubset(self
, other
):
464 raise NotImplementedError
466 def __le__(self
, other
):
467 return self
.issubset(other
)
469 def __lt__(self
, other
):
470 raise NotImplementedError
472 def issuperset(self
, other
):
473 # test whether every element in other is in the polyhedron
474 raise NotImplementedError
476 def __ge__(self
, other
):
477 return self
.issuperset(other
)
479 def __gt__(self
, other
):
480 raise NotImplementedError
482 def union(self
, *others
):
483 # return a new polyhedron with elements from the polyhedron and all
484 # others (convex union)
485 raise NotImplementedError
487 def __or__(self
, other
):
488 return self
.union(other
)
490 def intersection(self
, *others
):
491 # return a new polyhedron with elements common to the polyhedron and all
493 # a poor man's implementation could be:
494 # equalities = list(self.equalities)
495 # inequalities = list(self.inequalities)
496 # for other in others:
497 # equalities.extend(other.equalities)
498 # inequalities.extend(other.inequalities)
499 # return self.__class__(equalities, inequalities)
500 raise NotImplementedError
502 def __and__(self
, other
):
503 return self
.intersection(other
)
505 def difference(self
, *others
):
506 # return a new polyhedron with elements in the polyhedron that are not
508 raise NotImplementedError
510 def __sub__(self
, other
):
511 return self
.difference(other
)
515 for constraint
in self
.equalities
:
516 constraints
.append('{} == 0'.format(constraint
))
517 for constraint
in self
.inequalities
:
518 constraints
.append('{} >= 0'.format(constraint
))
519 return '{{{}}}'.format(', '.join(constraints
))
522 equalities
= list(self
.equalities
)
523 inequalities
= list(self
.inequalities
)
524 return '{}(equalities={!r}, inequalities={!r})' \
525 ''.format(self
.__class
__.__name
__, equalities
, inequalities
)
527 def _symbolunion(self
, *others
):
528 symbols
= set(self
.symbols
)
530 symbols
.update(other
.symbols
)
531 return sorted(symbols
)
533 def _toisl(self
, symbols
=None):
535 symbols
= self
.symbols
536 num_coefficients
= len(symbols
)
537 space
= libisl
.isl_space_set_alloc(_main_ctx
, 0, num_coefficients
)
538 bset
= libisl
.isl_basic_set_universe(libisl
.isl_space_copy(space
))
539 ls
= libisl
.isl_local_space_from_space(space
)
540 ceq
= libisl
.isl_equality_alloc(libisl
.isl_local_space_copy(ls
))
541 cin
= libisl
.isl_inequality_alloc(libisl
.isl_local_space_copy(ls
))
542 '''if there are equalities/inequalities, take each constant and coefficient and add as a constraint to the basic set'''
543 if list(self
.equalities
): #check if any equalities exist
544 for eq
in self
.equalities
:
545 coeff_eq
= dict(eq
.coefficients())
548 ceq
= libisl
.isl_constraint_set_constant_si(ceq
, value
)
550 num
= coeff_eq
.get(eq
)
551 iden
= symbols
.index(eq
)
552 ceq
= libisl
.isl_constraint_set_coefficient_si(ceq
, libisl
.isl_dim_set
, iden
, num
) #use 3 for type isl_dim_set
553 bset
= libisl
.isl_basic_set_add_constraint(bset
, ceq
)
554 if list(self
.inequalities
): #check if any inequalities exist
555 for ineq
in self
.inequalities
:
556 coeff_in
= dict(ineq
.coefficients())
558 value
= ineq
.constant
559 cin
= libisl
.isl_constraint_set_constant_si(cin
, value
)
560 for ineq
in coeff_in
:
561 num
= coeff_in
.get(ineq
)
562 iden
= symbols
.index(ineq
)
563 cin
= libisl
.isl_constraint_set_coefficient_si(cin
, libisl
.isl_dim_set
, iden
, num
) #use 3 for type isl_dim_set
564 bset
= libisl
.isl_basic_set_add_constraint(bset
, cin
)
565 bset
= isl
.BasicSet(bset
)
569 def _fromisl(cls
, bset
):
570 raise NotImplementedError
573 return cls(equalities
, inequalities
)
574 '''takes basic set in isl form and puts back into python version of polyhedron
575 isl example code gives isl form as:
576 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
577 our printer is giving form as:
578 b'{ [i0] : 1 = 0 }' '''
580 # if self._equalities:
581 # constraints = libisl.isl_basic_set_equalities_matrix(bset, 3)
582 # elif self._inequalities:
583 # constraints = libisl.isl_basic_set_inequalities_matrix(bset, 3)
587 empty
= None #eq(0,1)
588 universe
= None #Polyhedron()
591 if __name__
== '__main__':
592 ex1
= Expression(coefficients
={'a': 1, 'x': 2}, constant
=2)
593 ex2
= Expression(coefficients
={'a': 3 , 'b': 2}, constant
=3)
594 p
= Polyhedron(inequalities
=[ex1
, ex2
])
597 print('empty ?', p
.isempty())
598 print('empty ?', eq(0, 1).isempty())