Improve Polyhedron.fromstring
[linpy.git] / pypol / linear.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from fractions import Fraction, gcd
7
8 from pypol import isl
9 from pypol.isl import libisl
10
11
12 __all__ = [
13 'Expression', 'Constant', 'Symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
15 'Polyhedron',
16 'Empty', 'Universe'
17 ]
18
19
20 def _polymorphic_method(func):
21 @functools.wraps(func)
22 def wrapper(a, b):
23 if isinstance(b, Expression):
24 return func(a, b)
25 if isinstance(b, numbers.Rational):
26 b = Constant(b)
27 return func(a, b)
28 return NotImplemented
29 return wrapper
30
31 def _polymorphic_operator(func):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func)
35 def wrapper(a, b):
36 if isinstance(a, numbers.Rational):
37 a = Constant(a)
38 return func(a, b)
39 elif isinstance(a, Expression):
40 return func(a, b)
41 raise TypeError('arguments must be linear expressions')
42 return wrapper
43
44
45 _main_ctx = isl.Context()
46
47
48 class Expression:
49 """
50 This class implements linear expressions.
51 """
52
53 __slots__ = (
54 '_coefficients',
55 '_constant',
56 '_symbols',
57 '_dimension',
58 )
59
60 def __new__(cls, coefficients=None, constant=0):
61 if isinstance(coefficients, str):
62 if constant:
63 raise TypeError('too many arguments')
64 return cls.fromstring(coefficients)
65 if isinstance(coefficients, dict):
66 coefficients = coefficients.items()
67 if coefficients is None:
68 return Constant(constant)
69 coefficients = [(symbol, coefficient)
70 for symbol, coefficient in coefficients if coefficient != 0]
71 if len(coefficients) == 0:
72 return Constant(constant)
73 elif len(coefficients) == 1 and constant == 0:
74 symbol, coefficient = coefficients[0]
75 if coefficient == 1:
76 return Symbol(symbol)
77 self = object().__new__(cls)
78 self._coefficients = {}
79 for symbol, coefficient in coefficients:
80 if isinstance(symbol, Symbol):
81 symbol = symbol.name
82 elif not isinstance(symbol, str):
83 raise TypeError('symbols must be strings or Symbol instances')
84 if isinstance(coefficient, Constant):
85 coefficient = coefficient.constant
86 if not isinstance(coefficient, numbers.Rational):
87 raise TypeError('coefficients must be rational numbers or Constant instances')
88 self._coefficients[symbol] = coefficient
89 if isinstance(constant, Constant):
90 constant = constant.constant
91 if not isinstance(constant, numbers.Rational):
92 raise TypeError('constant must be a rational number or a Constant instance')
93 self._constant = constant
94 self._symbols = tuple(sorted(self._coefficients))
95 self._dimension = len(self._symbols)
96 return self
97
98 @classmethod
99 def _fromast(cls, node):
100 if isinstance(node, ast.Module) and len(node.body) == 1:
101 return cls._fromast(node.body[0])
102 elif isinstance(node, ast.Expr):
103 return cls._fromast(node.value)
104 elif isinstance(node, ast.Name):
105 return Symbol(node.id)
106 elif isinstance(node, ast.Num):
107 return Constant(node.n)
108 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
109 return -cls._fromast(node.operand)
110 elif isinstance(node, ast.BinOp):
111 left = cls._fromast(node.left)
112 right = cls._fromast(node.right)
113 if isinstance(node.op, ast.Add):
114 return left + right
115 elif isinstance(node.op, ast.Sub):
116 return left - right
117 elif isinstance(node.op, ast.Mult):
118 return left * right
119 elif isinstance(node.op, ast.Div):
120 return left / right
121 raise SyntaxError('invalid syntax')
122
123 @classmethod
124 def fromstring(cls, string):
125 string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string)
126 tree = ast.parse(string, 'eval')
127 return cls._fromast(tree)
128
129 @property
130 def symbols(self):
131 return self._symbols
132
133 @property
134 def dimension(self):
135 return self._dimension
136
137 def coefficient(self, symbol):
138 if isinstance(symbol, Symbol):
139 symbol = str(symbol)
140 elif not isinstance(symbol, str):
141 raise TypeError('symbol must be a string or a Symbol instance')
142 try:
143 return self._coefficients[symbol]
144 except KeyError:
145 return 0
146
147 __getitem__ = coefficient
148
149 def coefficients(self):
150 for symbol in self.symbols:
151 yield symbol, self.coefficient(symbol)
152
153 @property
154 def constant(self):
155 return self._constant
156
157 def isconstant(self):
158 return False
159
160 def values(self):
161 for symbol in self.symbols:
162 yield self.coefficient(symbol)
163 yield self.constant
164
165 def issymbol(self):
166 return False
167
168 def __bool__(self):
169 return True
170
171 def __pos__(self):
172 return self
173
174 def __neg__(self):
175 return self * -1
176
177 @_polymorphic_method
178 def __add__(self, other):
179 coefficients = dict(self.coefficients())
180 for symbol, coefficient in other.coefficients():
181 if symbol in coefficients:
182 coefficients[symbol] += coefficient
183 else:
184 coefficients[symbol] = coefficient
185 constant = self.constant + other.constant
186 return Expression(coefficients, constant)
187
188 __radd__ = __add__
189
190 @_polymorphic_method
191 def __sub__(self, other):
192 coefficients = dict(self.coefficients())
193 for symbol, coefficient in other.coefficients():
194 if symbol in coefficients:
195 coefficients[symbol] -= coefficient
196 else:
197 coefficients[symbol] = -coefficient
198 constant = self.constant - other.constant
199 return Expression(coefficients, constant)
200
201 def __rsub__(self, other):
202 return -(self - other)
203
204 @_polymorphic_method
205 def __mul__(self, other):
206 if other.isconstant():
207 coefficients = dict(self.coefficients())
208 for symbol in coefficients:
209 coefficients[symbol] *= other.constant
210 constant = self.constant * other.constant
211 return Expression(coefficients, constant)
212 if isinstance(other, Expression) and not self.isconstant():
213 raise ValueError('non-linear expression: '
214 '{} * {}'.format(self._parenstr(), other._parenstr()))
215 return NotImplemented
216
217 __rmul__ = __mul__
218
219 @_polymorphic_method
220 def __truediv__(self, other):
221 if other.isconstant():
222 coefficients = dict(self.coefficients())
223 for symbol in coefficients:
224 coefficients[symbol] = \
225 Fraction(coefficients[symbol], other.constant)
226 constant = Fraction(self.constant, other.constant)
227 return Expression(coefficients, constant)
228 if isinstance(other, Expression):
229 raise ValueError('non-linear expression: '
230 '{} / {}'.format(self._parenstr(), other._parenstr()))
231 return NotImplemented
232
233 def __rtruediv__(self, other):
234 if isinstance(other, self):
235 if self.isconstant():
236 constant = Fraction(other, self.constant)
237 return Expression(constant=constant)
238 else:
239 raise ValueError('non-linear expression: '
240 '{} / {}'.format(other._parenstr(), self._parenstr()))
241 return NotImplemented
242
243 def __str__(self):
244 string = ''
245 i = 0
246 for symbol in self.symbols:
247 coefficient = self.coefficient(symbol)
248 if coefficient == 1:
249 if i == 0:
250 string += symbol
251 else:
252 string += ' + {}'.format(symbol)
253 elif coefficient == -1:
254 if i == 0:
255 string += '-{}'.format(symbol)
256 else:
257 string += ' - {}'.format(symbol)
258 else:
259 if i == 0:
260 string += '{}*{}'.format(coefficient, symbol)
261 elif coefficient > 0:
262 string += ' + {}*{}'.format(coefficient, symbol)
263 else:
264 assert coefficient < 0
265 coefficient *= -1
266 string += ' - {}*{}'.format(coefficient, symbol)
267 i += 1
268 constant = self.constant
269 if constant != 0 and i == 0:
270 string += '{}'.format(constant)
271 elif constant > 0:
272 string += ' + {}'.format(constant)
273 elif constant < 0:
274 constant *= -1
275 string += ' - {}'.format(constant)
276 if string == '':
277 string = '0'
278 return string
279
280 def _parenstr(self, always=False):
281 string = str(self)
282 if not always and (self.isconstant() or self.issymbol()):
283 return string
284 else:
285 return '({})'.format(string)
286
287 def __repr__(self):
288 string = '{}({{'.format(self.__class__.__name__)
289 for i, (symbol, coefficient) in enumerate(self.coefficients()):
290 if i != 0:
291 string += ', '
292 string += '{!r}: {!r}'.format(symbol, coefficient)
293 string += '}}, {!r})'.format(self.constant)
294 return string
295
296 @_polymorphic_method
297 def __eq__(self, other):
298 # "normal" equality
299 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
300 return isinstance(other, Expression) and \
301 self._coefficients == other._coefficients and \
302 self.constant == other.constant
303
304 def __hash__(self):
305 return hash((tuple(sorted(self._coefficients.items())), self._constant))
306
307 def _toint(self):
308 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
309 [value.denominator for value in self.values()])
310 return self * lcm
311
312 @_polymorphic_method
313 def _eq(self, other):
314 return Polyhedron(equalities=[(self - other)._toint()])
315
316 @_polymorphic_method
317 def __le__(self, other):
318 return Polyhedron(inequalities=[(other - self)._toint()])
319
320 @_polymorphic_method
321 def __lt__(self, other):
322 return Polyhedron(inequalities=[(other - self)._toint() - 1])
323
324 @_polymorphic_method
325 def __ge__(self, other):
326 return Polyhedron(inequalities=[(self - other)._toint()])
327
328 @_polymorphic_method
329 def __gt__(self, other):
330 return Polyhedron(inequalities=[(self - other)._toint() - 1])
331
332
333 class Constant(Expression):
334
335 def __new__(cls, numerator=0, denominator=None):
336 self = object().__new__(cls)
337 if denominator is None:
338 if isinstance(numerator, numbers.Rational):
339 self._constant = numerator
340 elif isinstance(numerator, Constant):
341 self._constant = numerator.constant
342 else:
343 raise TypeError('constant must be a rational number or a Constant instance')
344 else:
345 self._constant = Fraction(numerator, denominator)
346 self._coefficients = {}
347 self._symbols = ()
348 self._dimension = 0
349 return self
350
351 def isconstant(self):
352 return True
353
354 def __bool__(self):
355 return bool(self.constant)
356
357 def __repr__(self):
358 return '{}({!r})'.format(self.__class__.__name__, self._constant)
359
360
361 class Symbol(Expression):
362
363 __slots__ = Expression.__slots__ + (
364 '_name',
365 )
366
367 def __new__(cls, name):
368 if isinstance(name, Symbol):
369 name = name.name
370 elif not isinstance(name, str):
371 raise TypeError('name must be a string or a Symbol instance')
372 self = object().__new__(cls)
373 self._coefficients = {name: 1}
374 self._constant = 0
375 self._symbols = tuple(name)
376 self._name = name
377 self._dimension = 1
378 return self
379
380 @property
381 def name(self):
382 return self._name
383
384 def issymbol(self):
385 return True
386
387 def __repr__(self):
388 return '{}({!r})'.format(self.__class__.__name__, self._name)
389
390 def symbols(names):
391 if isinstance(names, str):
392 names = names.replace(',', ' ').split()
393 return (Symbol(name) for name in names)
394
395
396 @_polymorphic_operator
397 def eq(a, b):
398 return a.__eq__(b)
399
400 @_polymorphic_operator
401 def le(a, b):
402 return a.__le__(b)
403
404 @_polymorphic_operator
405 def lt(a, b):
406 return a.__lt__(b)
407
408 @_polymorphic_operator
409 def ge(a, b):
410 return a.__ge__(b)
411
412 @_polymorphic_operator
413 def gt(a, b):
414 return a.__gt__(b)
415
416
417 class Polyhedron:
418 """
419 This class implements polyhedrons.
420 """
421
422 __slots__ = (
423 '_equalities',
424 '_inequalities',
425 '_constraints',
426 '_symbols',
427 )
428
429 def __new__(cls, equalities=None, inequalities=None):
430 if isinstance(equalities, str):
431 if inequalities is not None:
432 raise TypeError('too many arguments')
433 return cls.fromstring(equalities)
434 self = super().__new__(cls)
435 self._equalities = []
436 if equalities is not None:
437 for constraint in equalities:
438 for value in constraint.values():
439 if value.denominator != 1:
440 raise TypeError('non-integer constraint: '
441 '{} == 0'.format(constraint))
442 self._equalities.append(constraint)
443 self._equalities = tuple(self._equalities)
444 self._inequalities = []
445 if inequalities is not None:
446 for constraint in inequalities:
447 for value in constraint.values():
448 if value.denominator != 1:
449 raise TypeError('non-integer constraint: '
450 '{} <= 0'.format(constraint))
451 self._inequalities.append(constraint)
452 self._inequalities = tuple(self._inequalities)
453 self._constraints = self._equalities + self._inequalities
454 self._symbols = set()
455 for constraint in self._constraints:
456 self.symbols.update(constraint.symbols)
457 self._symbols = tuple(sorted(self._symbols))
458 return self
459
460 @classmethod
461 def _fromast(cls, node):
462 if isinstance(node, ast.Module) and len(node.body) == 1:
463 return cls._fromast(node.body[0])
464 elif isinstance(node, ast.Expr):
465 return cls._fromast(node.value)
466 elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitAnd):
467 equalities1, inequalities1 = cls._fromast(node.left)
468 equalities2, inequalities2 = cls._fromast(node.right)
469 equalities = equalities1 + equalities2
470 inequalities = inequalities1 + inequalities2
471 return equalities, inequalities
472 elif isinstance(node, ast.Compare):
473 equalities = []
474 inequalities = []
475 left = Expression._fromast(node.left)
476 for i in range(len(node.ops)):
477 op = node.ops[i]
478 right = Expression._fromast(node.comparators[i])
479 if isinstance(op, ast.Lt):
480 inequalities.append(right - left - 1)
481 elif isinstance(op, ast.LtE):
482 inequalities.append(right - left)
483 elif isinstance(op, ast.Eq):
484 equalities.append(left - right)
485 elif isinstance(op, ast.GtE):
486 inequalities.append(left - right)
487 elif isinstance(op, ast.Gt):
488 inequalities.append(left - right - 1)
489 else:
490 break
491 left = right
492 else:
493 return equalities, inequalities
494 raise SyntaxError('invalid syntax')
495
496 @classmethod
497 def fromstring(cls, string):
498 string = string.strip()
499 string = re.sub(r'^\{\s*|\s*\}$', '', string)
500 string = re.sub(r'([^<=>])=([^<=>])', r'\1==\2', string)
501 string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string)
502 tokens = re.split(r',|;|and|&&|/\\|∧', string, flags=re.I)
503 tokens = ['({})'.format(token) for token in tokens]
504 string = ' & '.join(tokens)
505 tree = ast.parse(string, 'eval')
506 equalities, inequalities = cls._fromast(tree)
507 return cls(equalities, inequalities)
508
509 @property
510 def equalities(self):
511 return self._equalities
512
513 @property
514 def inequalities(self):
515 return self._inequalities
516
517 @property
518 def constraints(self):
519 return self._constraints
520
521 @property
522 def symbols(self):
523 return self._symbols
524
525 @property
526 def dimension(self):
527 return len(self.symbols)
528
529 def __bool__(self):
530 return not self.is_empty()
531
532 def __contains__(self, value):
533 # is the value in the polyhedron?
534 raise NotImplementedError
535
536 def __eq__(self, other):
537 # works correctly when symbols is not passed
538 # should be equal if values are the same even if symbols are different
539 bset = self._toisl()
540 other = other._toisl()
541 return bool(libisl.isl_basic_set_plain_is_equal(bset, other))
542
543 def isempty(self):
544 bset = self._toisl()
545 return bool(libisl.isl_basic_set_is_empty(bset))
546
547 def isuniverse(self):
548 bset = self._toisl()
549 return bool(libisl.isl_basic_set_is_universe(bset))
550
551 def isdisjoint(self, other):
552 # return true if the polyhedron has no elements in common with other
553 #symbols = self._symbolunion(other)
554 bset = self._toisl()
555 other = other._toisl()
556 return bool(libisl.isl_set_is_disjoint(bset, other))
557
558 def issubset(self, other):
559 # check if self(bset) is a subset of other
560 symbols = self._symbolunion(other)
561 bset = self._toisl(symbols)
562 other = other._toisl(symbols)
563 return bool(libisl.isl_set_is_strict_subset(other, bset))
564
565 def __le__(self, other):
566 return self.issubset(other)
567
568 def __lt__(self, other):
569 symbols = self._symbolunion(other)
570 bset = self._toisl(symbols)
571 other = other._toisl(symbols)
572 return bool(libisl.isl_set_is_strict_subset(other, bset))
573
574 def issuperset(self, other):
575 # test whether every element in other is in the polyhedron
576 raise NotImplementedError
577
578 def __ge__(self, other):
579 return self.issuperset(other)
580
581 def __gt__(self, other):
582 symbols = self._symbolunion(other)
583 bset = self._toisl(symbols)
584 other = other._toisl(symbols)
585 bool(libisl.isl_set_is_strict_subset(other, bset))
586 raise NotImplementedError
587
588 def union(self, *others):
589 # return a new polyhedron with elements from the polyhedron and all
590 # others (convex union)
591 raise NotImplementedError
592
593 def __or__(self, other):
594 return self.union(other)
595
596 def intersection(self, *others):
597 # return a new polyhedron with elements common to the polyhedron and all
598 # others
599 # a poor man's implementation could be:
600 # equalities = list(self.equalities)
601 # inequalities = list(self.inequalities)
602 # for other in others:
603 # equalities.extend(other.equalities)
604 # inequalities.extend(other.inequalities)
605 # return self.__class__(equalities, inequalities)
606 raise NotImplementedError
607
608 def __and__(self, other):
609 return self.intersection(other)
610
611 def difference(self, other):
612 # return a new polyhedron with elements in the polyhedron that are not in the other
613 symbols = self._symbolunion(other)
614 bset = self._toisl(symbols)
615 other = other._toisl(symbols)
616 difference = libisl.isl_set_subtract(bset, other)
617 return difference
618
619 def __sub__(self, other):
620 return self.difference(other)
621
622 def __str__(self):
623 constraints = []
624 for constraint in self.equalities:
625 constraints.append('{} == 0'.format(constraint))
626 for constraint in self.inequalities:
627 constraints.append('{} >= 0'.format(constraint))
628 return '{{{}}}'.format(', '.join(constraints))
629
630 def __repr__(self):
631 if self.isempty():
632 return 'Empty'
633 elif self.isuniverse():
634 return 'Universe'
635 else:
636 equalities = list(self.equalities)
637 inequalities = list(self.inequalities)
638 return '{}(equalities={!r}, inequalities={!r})' \
639 ''.format(self.__class__.__name__, equalities, inequalities)
640
641 def _symbolunion(self, *others):
642 symbols = set(self.symbols)
643 for other in others:
644 symbols.update(other.symbols)
645 return sorted(symbols)
646
647 def _toisl(self, symbols=None):
648 if symbols is None:
649 symbols = self.symbols
650 dimension = len(symbols)
651 space = libisl.isl_space_set_alloc(_main_ctx, 0, dimension)
652 bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
653 ls = libisl.isl_local_space_from_space(space)
654 for equality in self.equalities:
655 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
656 for symbol, coefficient in equality.coefficients():
657 val = str(coefficient).encode()
658 val = libisl.isl_val_read_from_str(_main_ctx, val)
659 dim = symbols.index(symbol)
660 ceq = libisl.isl_constraint_set_coefficient_val(ceq, libisl.isl_dim_set, dim, val)
661 if equality.constant != 0:
662 val = str(equality.constant).encode()
663 val = libisl.isl_val_read_from_str(_main_ctx, val)
664 ceq = libisl.isl_constraint_set_constant_val(ceq, val)
665 bset = libisl.isl_basic_set_add_constraint(bset, ceq)
666 for inequality in self.inequalities:
667 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
668 for symbol, coefficient in inequality.coefficients():
669 val = str(coefficient).encode()
670 val = libisl.isl_val_read_from_str(_main_ctx, val)
671 dim = symbols.index(symbol)
672 cin = libisl.isl_constraint_set_coefficient_val(cin, libisl.isl_dim_set, dim, val)
673 if inequality.constant != 0:
674 val = str(inequality.constant).encode()
675 val = libisl.isl_val_read_from_str(_main_ctx, val)
676 cin = libisl.isl_constraint_set_constant_val(cin, val)
677 bset = libisl.isl_basic_set_add_constraint(bset, cin)
678 bset = isl.BasicSet(bset)
679 return bset
680
681 @classmethod
682 def _fromisl(cls, bset, symbols):
683 raise NotImplementedError
684 equalities = ...
685 inequalities = ...
686 return cls(equalities, inequalities)
687 '''takes basic set in isl form and puts back into python version of polyhedron
688 isl example code gives isl form as:
689 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
690 our printer is giving form as:
691 { [i0, i1] : 2i1 >= -2 - i0 } '''
692
693 Empty = eq(0,1)
694
695 Universe = Polyhedron()
696
697
698 if __name__ == '__main__':
699 p1 = Polyhedron('2a + 2b + 1 == 0') # empty
700 print(p1._toisl())
701 p2 = Polyhedron('3x + 2y + 3 == 0') # not empty
702 print(p2._toisl())