Check for SymPy presence in unittary tests
[linpy.git] / pypol / linear.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from fractions import Fraction, gcd
7
8 from . import isl
9 from .isl import libisl
10
11
12 __all__ = [
13 'Expression', 'Constant', 'Symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
15 'Polyhedron',
16 'Empty', 'Universe'
17 ]
18
19
20 def _polymorphic_method(func):
21 @functools.wraps(func)
22 def wrapper(a, b):
23 if isinstance(b, Expression):
24 return func(a, b)
25 if isinstance(b, numbers.Rational):
26 b = Constant(b)
27 return func(a, b)
28 return NotImplemented
29 return wrapper
30
31 def _polymorphic_operator(func):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func)
35 def wrapper(a, b):
36 if isinstance(a, numbers.Rational):
37 a = Constant(a)
38 return func(a, b)
39 elif isinstance(a, Expression):
40 return func(a, b)
41 raise TypeError('arguments must be linear expressions')
42 return wrapper
43
44
45 _main_ctx = isl.Context()
46
47
48 class Expression:
49 """
50 This class implements linear expressions.
51 """
52
53 __slots__ = (
54 '_coefficients',
55 '_constant',
56 '_symbols',
57 '_dimension',
58 )
59
60 def __new__(cls, coefficients=None, constant=0):
61 if isinstance(coefficients, str):
62 if constant:
63 raise TypeError('too many arguments')
64 return cls.fromstring(coefficients)
65 if isinstance(coefficients, dict):
66 coefficients = coefficients.items()
67 if coefficients is None:
68 return Constant(constant)
69 coefficients = [(symbol, coefficient)
70 for symbol, coefficient in coefficients if coefficient != 0]
71 if len(coefficients) == 0:
72 return Constant(constant)
73 elif len(coefficients) == 1 and constant == 0:
74 symbol, coefficient = coefficients[0]
75 if coefficient == 1:
76 return Symbol(symbol)
77 self = object().__new__(cls)
78 self._coefficients = {}
79 for symbol, coefficient in coefficients:
80 if isinstance(symbol, Symbol):
81 symbol = symbol.name
82 elif not isinstance(symbol, str):
83 raise TypeError('symbols must be strings or Symbol instances')
84 if isinstance(coefficient, Constant):
85 coefficient = coefficient.constant
86 if not isinstance(coefficient, numbers.Rational):
87 raise TypeError('coefficients must be rational numbers or Constant instances')
88 self._coefficients[symbol] = coefficient
89 if isinstance(constant, Constant):
90 constant = constant.constant
91 if not isinstance(constant, numbers.Rational):
92 raise TypeError('constant must be a rational number or a Constant instance')
93 self._constant = constant
94 self._symbols = tuple(sorted(self._coefficients))
95 self._dimension = len(self._symbols)
96 return self
97
98 @classmethod
99 def _fromast(cls, node):
100 if isinstance(node, ast.Module) and len(node.body) == 1:
101 return cls._fromast(node.body[0])
102 elif isinstance(node, ast.Expr):
103 return cls._fromast(node.value)
104 elif isinstance(node, ast.Name):
105 return Symbol(node.id)
106 elif isinstance(node, ast.Num):
107 return Constant(node.n)
108 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
109 return -cls._fromast(node.operand)
110 elif isinstance(node, ast.BinOp):
111 left = cls._fromast(node.left)
112 right = cls._fromast(node.right)
113 if isinstance(node.op, ast.Add):
114 return left + right
115 elif isinstance(node.op, ast.Sub):
116 return left - right
117 elif isinstance(node.op, ast.Mult):
118 return left * right
119 elif isinstance(node.op, ast.Div):
120 return left / right
121 raise SyntaxError('invalid syntax')
122
123 @classmethod
124 def fromstring(cls, string):
125 string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string)
126 tree = ast.parse(string, 'eval')
127 return cls._fromast(tree)
128
129 @property
130 def symbols(self):
131 return self._symbols
132
133 @property
134 def dimension(self):
135 return self._dimension
136
137 def coefficient(self, symbol):
138 if isinstance(symbol, Symbol):
139 symbol = str(symbol)
140 elif not isinstance(symbol, str):
141 raise TypeError('symbol must be a string or a Symbol instance')
142 try:
143 return self._coefficients[symbol]
144 except KeyError:
145 return 0
146
147 __getitem__ = coefficient
148
149 def coefficients(self):
150 for symbol in self.symbols:
151 yield symbol, self.coefficient(symbol)
152
153 @property
154 def constant(self):
155 return self._constant
156
157 def isconstant(self):
158 return False
159
160 def values(self):
161 for symbol in self.symbols:
162 yield self.coefficient(symbol)
163 yield self.constant
164
165 def issymbol(self):
166 return False
167
168 def __bool__(self):
169 return True
170
171 def __pos__(self):
172 return self
173
174 def __neg__(self):
175 return self * -1
176
177 @_polymorphic_method
178 def __add__(self, other):
179 coefficients = dict(self.coefficients())
180 for symbol, coefficient in other.coefficients():
181 if symbol in coefficients:
182 coefficients[symbol] += coefficient
183 else:
184 coefficients[symbol] = coefficient
185 constant = self.constant + other.constant
186 return Expression(coefficients, constant)
187
188 __radd__ = __add__
189
190 @_polymorphic_method
191 def __sub__(self, other):
192 coefficients = dict(self.coefficients())
193 for symbol, coefficient in other.coefficients():
194 if symbol in coefficients:
195 coefficients[symbol] -= coefficient
196 else:
197 coefficients[symbol] = -coefficient
198 constant = self.constant - other.constant
199 return Expression(coefficients, constant)
200
201 def __rsub__(self, other):
202 return -(self - other)
203
204 @_polymorphic_method
205 def __mul__(self, other):
206 if other.isconstant():
207 coefficients = dict(self.coefficients())
208 for symbol in coefficients:
209 coefficients[symbol] *= other.constant
210 constant = self.constant * other.constant
211 return Expression(coefficients, constant)
212 if isinstance(other, Expression) and not self.isconstant():
213 raise ValueError('non-linear expression: '
214 '{} * {}'.format(self._parenstr(), other._parenstr()))
215 return NotImplemented
216
217 __rmul__ = __mul__
218
219 @_polymorphic_method
220 def __truediv__(self, other):
221 if other.isconstant():
222 coefficients = dict(self.coefficients())
223 for symbol in coefficients:
224 coefficients[symbol] = \
225 Fraction(coefficients[symbol], other.constant)
226 constant = Fraction(self.constant, other.constant)
227 return Expression(coefficients, constant)
228 if isinstance(other, Expression):
229 raise ValueError('non-linear expression: '
230 '{} / {}'.format(self._parenstr(), other._parenstr()))
231 return NotImplemented
232
233 def __rtruediv__(self, other):
234 if isinstance(other, self):
235 if self.isconstant():
236 constant = Fraction(other, self.constant)
237 return Expression(constant=constant)
238 else:
239 raise ValueError('non-linear expression: '
240 '{} / {}'.format(other._parenstr(), self._parenstr()))
241 return NotImplemented
242
243 def __str__(self):
244 string = ''
245 i = 0
246 for symbol in self.symbols:
247 coefficient = self.coefficient(symbol)
248 if coefficient == 1:
249 if i == 0:
250 string += symbol
251 else:
252 string += ' + {}'.format(symbol)
253 elif coefficient == -1:
254 if i == 0:
255 string += '-{}'.format(symbol)
256 else:
257 string += ' - {}'.format(symbol)
258 else:
259 if i == 0:
260 string += '{}*{}'.format(coefficient, symbol)
261 elif coefficient > 0:
262 string += ' + {}*{}'.format(coefficient, symbol)
263 else:
264 assert coefficient < 0
265 coefficient *= -1
266 string += ' - {}*{}'.format(coefficient, symbol)
267 i += 1
268 constant = self.constant
269 if constant != 0 and i == 0:
270 string += '{}'.format(constant)
271 elif constant > 0:
272 string += ' + {}'.format(constant)
273 elif constant < 0:
274 constant *= -1
275 string += ' - {}'.format(constant)
276 if string == '':
277 string = '0'
278 return string
279
280 def _parenstr(self, always=False):
281 string = str(self)
282 if not always and (self.isconstant() or self.issymbol()):
283 return string
284 else:
285 return '({})'.format(string)
286
287 def __repr__(self):
288 string = '{}({{'.format(self.__class__.__name__)
289 for i, (symbol, coefficient) in enumerate(self.coefficients()):
290 if i != 0:
291 string += ', '
292 string += '{!r}: {!r}'.format(symbol, coefficient)
293 string += '}}, {!r})'.format(self.constant)
294 return string
295
296 @_polymorphic_method
297 def __eq__(self, other):
298 # "normal" equality
299 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
300 return isinstance(other, Expression) and \
301 self._coefficients == other._coefficients and \
302 self.constant == other.constant
303
304 def __hash__(self):
305 return hash((tuple(sorted(self._coefficients.items())), self._constant))
306
307 def _toint(self):
308 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
309 [value.denominator for value in self.values()])
310 return self * lcm
311
312 @_polymorphic_method
313 def _eq(self, other):
314 return Polyhedron(equalities=[(self - other)._toint()])
315
316 @_polymorphic_method
317 def __le__(self, other):
318 return Polyhedron(inequalities=[(other - self)._toint()])
319
320 @_polymorphic_method
321 def __lt__(self, other):
322 return Polyhedron(inequalities=[(other - self)._toint() - 1])
323
324 @_polymorphic_method
325 def __ge__(self, other):
326 return Polyhedron(inequalities=[(self - other)._toint()])
327
328 @_polymorphic_method
329 def __gt__(self, other):
330 return Polyhedron(inequalities=[(self - other)._toint() - 1])
331
332 @classmethod
333 def fromsympy(cls, expr):
334 import sympy
335 coefficients = {}
336 constant = 0
337 for symbol, coefficient in expr.as_coefficients_dict().items():
338 coefficient = Fraction(coefficient.p, coefficient.q)
339 if symbol == sympy.S.One:
340 constant = coefficient
341 elif isinstance(symbol, sympy.Symbol):
342 symbol = symbol.name
343 coefficients[symbol] = coefficient
344 else:
345 raise ValueError('non-linear expression: {!r}'.format(expr))
346 return cls(coefficients, constant)
347
348 def tosympy(self):
349 import sympy
350 expr = 0
351 for symbol, coefficient in self.coefficients():
352 term = coefficient * sympy.Symbol(symbol)
353 expr += term
354 expr += self.constant
355 return expr
356
357
358 class Constant(Expression):
359
360 def __new__(cls, numerator=0, denominator=None):
361 self = object().__new__(cls)
362 if denominator is None:
363 if isinstance(numerator, numbers.Rational):
364 self._constant = numerator
365 elif isinstance(numerator, Constant):
366 self._constant = numerator.constant
367 else:
368 raise TypeError('constant must be a rational number or a Constant instance')
369 else:
370 self._constant = Fraction(numerator, denominator)
371 self._coefficients = {}
372 self._symbols = ()
373 self._dimension = 0
374 return self
375
376 def isconstant(self):
377 return True
378
379 def __bool__(self):
380 return bool(self.constant)
381
382 def __repr__(self):
383 if self.constant.denominator == 1:
384 return '{}({!r})'.format(self.__class__.__name__, self.constant)
385 else:
386 return '{}({!r}, {!r})'.format(self.__class__.__name__,
387 self.constant.numerator, self.constant.denominator)
388
389 @classmethod
390 def fromsympy(cls, expr):
391 import sympy
392 if isinstance(expr, sympy.Rational):
393 return cls(expr.p, expr.q)
394 elif isinstance(expr, numbers.Rational):
395 return cls(expr)
396 else:
397 raise TypeError('expr must be a sympy.Rational instance')
398
399
400 class Symbol(Expression):
401
402 __slots__ = Expression.__slots__ + (
403 '_name',
404 )
405
406 def __new__(cls, name):
407 if isinstance(name, Symbol):
408 name = name.name
409 elif not isinstance(name, str):
410 raise TypeError('name must be a string or a Symbol instance')
411 self = object().__new__(cls)
412 self._coefficients = {name: 1}
413 self._constant = 0
414 self._symbols = tuple(name)
415 self._name = name
416 self._dimension = 1
417 return self
418
419 @property
420 def name(self):
421 return self._name
422
423 def issymbol(self):
424 return True
425
426 def __repr__(self):
427 return '{}({!r})'.format(self.__class__.__name__, self._name)
428
429 @classmethod
430 def fromsympy(cls, expr):
431 import sympy
432 if isinstance(expr, sympy.Symbol):
433 return cls(expr.name)
434 else:
435 raise TypeError('expr must be a sympy.Symbol instance')
436
437
438 def symbols(names):
439 if isinstance(names, str):
440 names = names.replace(',', ' ').split()
441 return (Symbol(name) for name in names)
442
443
444 @_polymorphic_operator
445 def eq(a, b):
446 return a.__eq__(b)
447
448 @_polymorphic_operator
449 def le(a, b):
450 return a.__le__(b)
451
452 @_polymorphic_operator
453 def lt(a, b):
454 return a.__lt__(b)
455
456 @_polymorphic_operator
457 def ge(a, b):
458 return a.__ge__(b)
459
460 @_polymorphic_operator
461 def gt(a, b):
462 return a.__gt__(b)
463
464
465 class Polyhedron:
466 """
467 This class implements polyhedrons.
468 """
469
470 __slots__ = (
471 '_equalities',
472 '_inequalities',
473 '_constraints',
474 '_symbols',
475 )
476
477 def __new__(cls, equalities=None, inequalities=None):
478 if isinstance(equalities, str):
479 if inequalities is not None:
480 raise TypeError('too many arguments')
481 return cls.fromstring(equalities)
482 self = super().__new__(cls)
483 self._equalities = []
484 if equalities is not None:
485 for constraint in equalities:
486 for value in constraint.values():
487 if value.denominator != 1:
488 raise TypeError('non-integer constraint: '
489 '{} == 0'.format(constraint))
490 self._equalities.append(constraint)
491 self._equalities = tuple(self._equalities)
492 self._inequalities = []
493 if inequalities is not None:
494 for constraint in inequalities:
495 for value in constraint.values():
496 if value.denominator != 1:
497 raise TypeError('non-integer constraint: '
498 '{} <= 0'.format(constraint))
499 self._inequalities.append(constraint)
500 self._inequalities = tuple(self._inequalities)
501 self._constraints = self._equalities + self._inequalities
502 self._symbols = set()
503 for constraint in self._constraints:
504 self.symbols.update(constraint.symbols)
505 self._symbols = tuple(sorted(self._symbols))
506 return self
507
508 @classmethod
509 def _fromast(cls, node):
510 if isinstance(node, ast.Module) and len(node.body) == 1:
511 return cls._fromast(node.body[0])
512 elif isinstance(node, ast.Expr):
513 return cls._fromast(node.value)
514 elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitAnd):
515 equalities1, inequalities1 = cls._fromast(node.left)
516 equalities2, inequalities2 = cls._fromast(node.right)
517 equalities = equalities1 + equalities2
518 inequalities = inequalities1 + inequalities2
519 return equalities, inequalities
520 elif isinstance(node, ast.Compare):
521 equalities = []
522 inequalities = []
523 left = Expression._fromast(node.left)
524 for i in range(len(node.ops)):
525 op = node.ops[i]
526 right = Expression._fromast(node.comparators[i])
527 if isinstance(op, ast.Lt):
528 inequalities.append(right - left - 1)
529 elif isinstance(op, ast.LtE):
530 inequalities.append(right - left)
531 elif isinstance(op, ast.Eq):
532 equalities.append(left - right)
533 elif isinstance(op, ast.GtE):
534 inequalities.append(left - right)
535 elif isinstance(op, ast.Gt):
536 inequalities.append(left - right - 1)
537 else:
538 break
539 left = right
540 else:
541 return equalities, inequalities
542 raise SyntaxError('invalid syntax')
543
544 @classmethod
545 def fromstring(cls, string):
546 string = string.strip()
547 string = re.sub(r'^\{\s*|\s*\}$', '', string)
548 string = re.sub(r'([^<=>])=([^<=>])', r'\1==\2', string)
549 string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string)
550 tokens = re.split(r',|;|and|&&|/\\|∧', string, flags=re.I)
551 tokens = ['({})'.format(token) for token in tokens]
552 string = ' & '.join(tokens)
553 tree = ast.parse(string, 'eval')
554 equalities, inequalities = cls._fromast(tree)
555 return cls(equalities, inequalities)
556
557 @property
558 def equalities(self):
559 return self._equalities
560
561 @property
562 def inequalities(self):
563 return self._inequalities
564
565 @property
566 def constraints(self):
567 return self._constraints
568
569 @property
570 def symbols(self):
571 return self._symbols
572
573 @property
574 def dimension(self):
575 return len(self.symbols)
576
577 def __bool__(self):
578 return not self.is_empty()
579
580 def __contains__(self, value):
581 # is the value in the polyhedron?
582 raise NotImplementedError
583
584 def __eq__(self, other):
585 # works correctly when symbols is not passed
586 # should be equal if values are the same even if symbols are different
587 bset = self._toisl()
588 other = other._toisl()
589 return bool(libisl.isl_basic_set_plain_is_equal(bset, other))
590
591 def isempty(self):
592 bset = self._toisl()
593 return bool(libisl.isl_basic_set_is_empty(bset))
594
595 def isuniverse(self):
596 bset = self._toisl()
597 return bool(libisl.isl_basic_set_is_universe(bset))
598
599 def isdisjoint(self, other):
600 # return true if the polyhedron has no elements in common with other
601 #symbols = self._symbolunion(other)
602 bset = self._toisl()
603 other = other._toisl()
604 return bool(libisl.isl_set_is_disjoint(bset, other))
605
606 def issubset(self, other):
607 # check if self(bset) is a subset of other
608 symbols = self._symbolunion(other)
609 bset = self._toisl(symbols)
610 other = other._toisl(symbols)
611 return bool(libisl.isl_set_is_strict_subset(other, bset))
612
613 def __le__(self, other):
614 return self.issubset(other)
615
616 def __lt__(self, other):
617 symbols = self._symbolunion(other)
618 bset = self._toisl(symbols)
619 other = other._toisl(symbols)
620 return bool(libisl.isl_set_is_strict_subset(other, bset))
621
622 def issuperset(self, other):
623 # test whether every element in other is in the polyhedron
624 raise NotImplementedError
625
626 def __ge__(self, other):
627 return self.issuperset(other)
628
629 def __gt__(self, other):
630 symbols = self._symbolunion(other)
631 bset = self._toisl(symbols)
632 other = other._toisl(symbols)
633 bool(libisl.isl_set_is_strict_subset(other, bset))
634 raise NotImplementedError
635
636 def union(self, *others):
637 # return a new polyhedron with elements from the polyhedron and all
638 # others (convex union)
639 raise NotImplementedError
640
641 def __or__(self, other):
642 return self.union(other)
643
644 def intersection(self, *others):
645 # return a new polyhedron with elements common to the polyhedron and all
646 # others
647 # a poor man's implementation could be:
648 # equalities = list(self.equalities)
649 # inequalities = list(self.inequalities)
650 # for other in others:
651 # equalities.extend(other.equalities)
652 # inequalities.extend(other.inequalities)
653 # return self.__class__(equalities, inequalities)
654 raise NotImplementedError
655
656 def __and__(self, other):
657 return self.intersection(other)
658
659 def difference(self, other):
660 # return a new polyhedron with elements in the polyhedron that are not in the other
661 symbols = self._symbolunion(other)
662 bset = self._toisl(symbols)
663 other = other._toisl(symbols)
664 difference = libisl.isl_set_subtract(bset, other)
665 return difference
666
667 def __sub__(self, other):
668 return self.difference(other)
669
670 def __str__(self):
671 constraints = []
672 for constraint in self.equalities:
673 constraints.append('{} == 0'.format(constraint))
674 for constraint in self.inequalities:
675 constraints.append('{} >= 0'.format(constraint))
676 return '{{{}}}'.format(', '.join(constraints))
677
678 def __repr__(self):
679 if self.isempty():
680 return 'Empty'
681 elif self.isuniverse():
682 return 'Universe'
683 else:
684 equalities = list(self.equalities)
685 inequalities = list(self.inequalities)
686 return '{}(equalities={!r}, inequalities={!r})' \
687 ''.format(self.__class__.__name__, equalities, inequalities)
688
689 def _symbolunion(self, *others):
690 symbols = set(self.symbols)
691 for other in others:
692 symbols.update(other.symbols)
693 return sorted(symbols)
694
695 def _toisl(self, symbols=None):
696 if symbols is None:
697 symbols = self.symbols
698 dimension = len(symbols)
699 space = libisl.isl_space_set_alloc(_main_ctx, 0, dimension)
700 bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
701 ls = libisl.isl_local_space_from_space(space)
702 for equality in self.equalities:
703 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
704 for symbol, coefficient in equality.coefficients():
705 val = str(coefficient).encode()
706 val = libisl.isl_val_read_from_str(_main_ctx, val)
707 dim = symbols.index(symbol)
708 ceq = libisl.isl_constraint_set_coefficient_val(ceq, libisl.isl_dim_set, dim, val)
709 if equality.constant != 0:
710 val = str(equality.constant).encode()
711 val = libisl.isl_val_read_from_str(_main_ctx, val)
712 ceq = libisl.isl_constraint_set_constant_val(ceq, val)
713 bset = libisl.isl_basic_set_add_constraint(bset, ceq)
714 for inequality in self.inequalities:
715 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
716 for symbol, coefficient in inequality.coefficients():
717 val = str(coefficient).encode()
718 val = libisl.isl_val_read_from_str(_main_ctx, val)
719 dim = symbols.index(symbol)
720 cin = libisl.isl_constraint_set_coefficient_val(cin, libisl.isl_dim_set, dim, val)
721 if inequality.constant != 0:
722 val = str(inequality.constant).encode()
723 val = libisl.isl_val_read_from_str(_main_ctx, val)
724 cin = libisl.isl_constraint_set_constant_val(cin, val)
725 bset = libisl.isl_basic_set_add_constraint(bset, cin)
726 bset = isl.BasicSet(bset)
727 return bset
728
729 @classmethod
730 def _fromisl(cls, bset, symbols):
731 raise NotImplementedError
732 equalities = ...
733 inequalities = ...
734 return cls(equalities, inequalities)
735 '''takes basic set in isl form and puts back into python version of polyhedron
736 isl example code gives isl form as:
737 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
738 our printer is giving form as:
739 { [i0, i1] : 2i1 >= -2 - i0 } '''
740
741 Empty = eq(0,1)
742
743 Universe = Polyhedron()
744
745
746 if __name__ == '__main__':
747 #p = Polyhedron('2a + 2b + 1 == 0') # empty
748 p = Polyhedron('3x + 2y + 3 == 0, y == 0') # not empty
749 ip = p._toisl()
750 print(ip)
751 print(ip.constraints())