1c400bd22cc8a275d096f4145622cac0e3575c55
[linpy.git] / pypol / domains.py
1 import ast
2 import functools
3 import re
4 import math
5
6 from fractions import Fraction
7
8 from . import islhelper
9 from .islhelper import mainctx, libisl
10 from .linexprs import Expression, Symbol, Rational
11 from .geometry import GeometricObject, Point, Vector
12
13
14 __all__ = [
15 'Domain',
16 'And', 'Or', 'Not',
17 ]
18
19
20 @functools.total_ordering
21 class Domain(GeometricObject):
22
23 __slots__ = (
24 '_polyhedra',
25 '_symbols',
26 '_dimension',
27 )
28
29 def __new__(cls, *polyhedra):
30 from .polyhedra import Polyhedron
31 if len(polyhedra) == 1:
32 argument = polyhedra[0]
33 if isinstance(argument, str):
34 return cls.fromstring(argument)
35 elif isinstance(argument, GeometricObject):
36 return argument.aspolyhedron()
37 else:
38 raise TypeError('argument must be a string '
39 'or a GeometricObject instance')
40 else:
41 for polyhedron in polyhedra:
42 if not isinstance(polyhedron, Polyhedron):
43 raise TypeError('arguments must be Polyhedron instances')
44 symbols = cls._xsymbols(polyhedra)
45 islset = cls._toislset(polyhedra, symbols)
46 return cls._fromislset(islset, symbols)
47
48 @classmethod
49 def _xsymbols(cls, iterator):
50 """
51 Return the ordered tuple of symbols present in iterator.
52 """
53 symbols = set()
54 for item in iterator:
55 symbols.update(item.symbols)
56 return tuple(sorted(symbols, key=Symbol.sortkey))
57
58 @property
59 def polyhedra(self):
60 return self._polyhedra
61
62 @property
63 def symbols(self):
64 return self._symbols
65
66 @property
67 def dimension(self):
68 return self._dimension
69
70 def disjoint(self):
71 """
72 Returns this set as disjoint.
73 """
74 islset = self._toislset(self.polyhedra, self.symbols)
75 islset = libisl.isl_set_make_disjoint(mainctx, islset)
76 return self._fromislset(islset, self.symbols)
77
78 def isempty(self):
79 """
80 Returns true if this set is an Empty set.
81 """
82 islset = self._toislset(self.polyhedra, self.symbols)
83 empty = bool(libisl.isl_set_is_empty(islset))
84 libisl.isl_set_free(islset)
85 return empty
86
87 def __bool__(self):
88 return not self.isempty()
89
90 def isuniverse(self):
91 """
92 Returns true if this set is the Universe set.
93 """
94 islset = self._toislset(self.polyhedra, self.symbols)
95 universe = bool(libisl.isl_set_plain_is_universe(islset))
96 libisl.isl_set_free(islset)
97 return universe
98
99 def isbounded(self):
100 """
101 Returns true if this set is bounded.
102 """
103 islset = self._toislset(self.polyhedra, self.symbols)
104 bounded = bool(libisl.isl_set_is_bounded(islset))
105 libisl.isl_set_free(islset)
106 return bounded
107
108 def __eq__(self, other):
109 """
110 Returns true if two sets are equal.
111 """
112 symbols = self._xsymbols([self, other])
113 islset1 = self._toislset(self.polyhedra, symbols)
114 islset2 = other._toislset(other.polyhedra, symbols)
115 equal = bool(libisl.isl_set_is_equal(islset1, islset2))
116 libisl.isl_set_free(islset1)
117 libisl.isl_set_free(islset2)
118 return equal
119
120 def isdisjoint(self, other):
121 """
122 Return True if two sets have a null intersection.
123 """
124 symbols = self._xsymbols([self, other])
125 islset1 = self._toislset(self.polyhedra, symbols)
126 islset2 = self._toislset(other.polyhedra, symbols)
127 equal = bool(libisl.isl_set_is_disjoint(islset1, islset2))
128 libisl.isl_set_free(islset1)
129 libisl.isl_set_free(islset2)
130 return equal
131
132 def issubset(self, other):
133 """
134 Report whether another set contains this set.
135 """
136 symbols = self._xsymbols([self, other])
137 islset1 = self._toislset(self.polyhedra, symbols)
138 islset2 = self._toislset(other.polyhedra, symbols)
139 equal = bool(libisl.isl_set_is_subset(islset1, islset2))
140 libisl.isl_set_free(islset1)
141 libisl.isl_set_free(islset2)
142 return equal
143
144 def __le__(self, other):
145 """
146 Returns true if this set is less than or equal to another set.
147 """
148 return self.issubset(other)
149
150 def __lt__(self, other):
151 """
152 Returns true if this set is less than another set.
153 """
154 symbols = self._xsymbols([self, other])
155 islset1 = self._toislset(self.polyhedra, symbols)
156 islset2 = self._toislset(other.polyhedra, symbols)
157 equal = bool(libisl.isl_set_is_strict_subset(islset1, islset2))
158 libisl.isl_set_free(islset1)
159 libisl.isl_set_free(islset2)
160 return equal
161
162 def complement(self):
163 """
164 Returns the complement of this set.
165 """
166 islset = self._toislset(self.polyhedra, self.symbols)
167 islset = libisl.isl_set_complement(islset)
168 return self._fromislset(islset, self.symbols)
169
170 def __invert__(self):
171 """
172 Returns the complement of this set.
173 """
174 return self.complement()
175
176 def simplify(self):
177 """
178 Returns a set without redundant constraints.
179 """
180 islset = self._toislset(self.polyhedra, self.symbols)
181 islset = libisl.isl_set_remove_redundancies(islset)
182 return self._fromislset(islset, self.symbols)
183
184 def aspolyhedron(self):
185 """
186 Returns polyhedral hull of set.
187 """
188 from .polyhedra import Polyhedron
189 islset = self._toislset(self.polyhedra, self.symbols)
190 islbset = libisl.isl_set_polyhedral_hull(islset)
191 return Polyhedron._fromislbasicset(islbset, self.symbols)
192
193 def asdomain(self):
194 return self
195
196 def project(self, dims):
197 """
198 Return new set with given dimensions removed.
199 """
200 islset = self._toislset(self.polyhedra, self.symbols)
201 n = 0
202 for index, symbol in reversed(list(enumerate(self.symbols))):
203 if symbol in dims:
204 n += 1
205 elif n > 0:
206 islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, index + 1, n)
207 n = 0
208 if n > 0:
209 islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, 0, n)
210 dims = [symbol for symbol in self.symbols if symbol not in dims]
211 return Domain._fromislset(islset, dims)
212
213 def sample(self):
214 """
215 Returns a single subset of the input.
216 """
217 islset = self._toislset(self.polyhedra, self.symbols)
218 islpoint = libisl.isl_set_sample_point(islset)
219 if bool(libisl.isl_point_is_void(islpoint)):
220 libisl.isl_point_free(islpoint)
221 raise ValueError('domain must be non-empty')
222 point = {}
223 for index, symbol in enumerate(self.symbols):
224 coordinate = libisl.isl_point_get_coordinate_val(islpoint,
225 libisl.isl_dim_set, index)
226 coordinate = islhelper.isl_val_to_int(coordinate)
227 point[symbol] = coordinate
228 libisl.isl_point_free(islpoint)
229 return point
230
231 def intersection(self, *others):
232 """
233 Return the intersection of two sets as a new set.
234 """
235 if len(others) == 0:
236 return self
237 symbols = self._xsymbols((self,) + others)
238 islset1 = self._toislset(self.polyhedra, symbols)
239 for other in others:
240 islset2 = other._toislset(other.polyhedra, symbols)
241 islset1 = libisl.isl_set_intersect(islset1, islset2)
242 return self._fromislset(islset1, symbols)
243
244 def __and__(self, other):
245 """
246 Return the intersection of two sets as a new set.
247 """
248 return self.intersection(other)
249
250 def union(self, *others):
251 """
252 Return the union of sets as a new set.
253 """
254 if len(others) == 0:
255 return self
256 symbols = self._xsymbols((self,) + others)
257 islset1 = self._toislset(self.polyhedra, symbols)
258 for other in others:
259 islset2 = other._toislset(other.polyhedra, symbols)
260 islset1 = libisl.isl_set_union(islset1, islset2)
261 return self._fromislset(islset1, symbols)
262
263 def __or__(self, other):
264 """
265 Return a new set with elements from both sets.
266 """
267 return self.union(other)
268
269 def __add__(self, other):
270 """
271 Return new set containing all elements in both sets.
272 """
273 return self.union(other)
274
275 def difference(self, other):
276 """
277 Return the difference of two sets as a new set.
278 """
279 symbols = self._xsymbols([self, other])
280 islset1 = self._toislset(self.polyhedra, symbols)
281 islset2 = other._toislset(other.polyhedra, symbols)
282 islset = libisl.isl_set_subtract(islset1, islset2)
283 return self._fromislset(islset, symbols)
284
285 def __sub__(self, other):
286 """
287 Return the difference of two sets as a new set.
288 """
289 return self.difference(other)
290
291 def lexmin(self):
292 """
293 Return a new set containing the lexicographic minimum of the elements in the set.
294 """
295 islset = self._toislset(self.polyhedra, self.symbols)
296 islset = libisl.isl_set_lexmin(islset)
297 return self._fromislset(islset, self.symbols)
298
299 def lexmax(self):
300 """
301 Return a new set containing the lexicographic maximum of the elements in the set.
302 """
303 islset = self._toislset(self.polyhedra, self.symbols)
304 islset = libisl.isl_set_lexmax(islset)
305 return self._fromislset(islset, self.symbols)
306
307 def num_parameters(self):
308 """
309 Return the total number of parameters, input, output or set dimensions.
310 """
311 islbset = self._toislbasicset(self.equalities, self.inequalities, self.symbols)
312 num = libisl.isl_basic_set_dim(islbset, libisl.isl_dim_set)
313 return num
314
315 def involves_dims(self, dims):
316 """
317 Returns true if set depends on given dimensions.
318 """
319 islset = self._toislset(self.polyhedra, self.symbols)
320 dims = sorted(dims)
321 symbols = sorted(list(self.symbols))
322 n = 0
323 if len(dims)>0:
324 for dim in dims:
325 if dim in symbols:
326 first = symbols.index(dims[0])
327 n +=1
328 else:
329 first = 0
330 else:
331 return False
332 value = bool(libisl.isl_set_involves_dims(islset, libisl.isl_dim_set, first, n))
333 libisl.isl_set_free(islset)
334 return value
335
336 _RE_COORDINATE = re.compile(r'\((?P<num>\-?\d+)\)(/(?P<den>\d+))?')
337
338 def vertices(self):
339 """
340 Return a list of vertices for this Polygon.
341 """
342 from .polyhedra import Polyhedron
343 if not self.isbounded():
344 raise ValueError('domain must be bounded')
345 islbset = self._toislbasicset(self.equalities, self.inequalities, self.symbols)
346 vertices = libisl.isl_basic_set_compute_vertices(islbset);
347 vertices = islhelper.isl_vertices_vertices(vertices)
348 points = []
349 for vertex in vertices:
350 expr = libisl.isl_vertex_get_expr(vertex)
351 coordinates = []
352 if islhelper.isl_version < '0.13':
353 constraints = islhelper.isl_basic_set_constraints(expr)
354 for constraint in constraints:
355 constant = libisl.isl_constraint_get_constant_val(constraint)
356 constant = islhelper.isl_val_to_int(constant)
357 for index, symbol in enumerate(self.symbols):
358 coefficient = libisl.isl_constraint_get_coefficient_val(constraint,
359 libisl.isl_dim_set, index)
360 coefficient = islhelper.isl_val_to_int(coefficient)
361 if coefficient != 0:
362 coordinate = -Fraction(constant, coefficient)
363 coordinates.append((symbol, coordinate))
364 else:
365 string = islhelper.isl_multi_aff_to_str(expr)
366 matches = self._RE_COORDINATE.finditer(string)
367 for symbol, match in zip(self.symbols, matches):
368 numerator = int(match.group('num'))
369 denominator = match.group('den')
370 denominator = 1 if denominator is None else int(denominator)
371 coordinate = Fraction(numerator, denominator)
372 coordinates.append((symbol, coordinate))
373 points.append(Point(coordinates))
374 return points
375
376 def points(self):
377 """
378 Returns the points contained in the set.
379 """
380 if not self.isbounded():
381 raise ValueError('domain must be bounded')
382 from .polyhedra import Universe, Eq
383 islset = self._toislset(self.polyhedra, self.symbols)
384 islpoints = islhelper.isl_set_points(islset)
385 points = []
386 for islpoint in islpoints:
387 coordinates = {}
388 for index, symbol in enumerate(self.symbols):
389 coordinate = libisl.isl_point_get_coordinate_val(islpoint,
390 libisl.isl_dim_set, index)
391 coordinate = islhelper.isl_val_to_int(coordinate)
392 coordinates[symbol] = coordinate
393 points.append(Point(coordinates))
394 return points
395
396 @classmethod
397 def _polygon_inner_point(cls, points):
398 symbols = points[0].symbols
399 coordinates = {symbol: 0 for symbol in symbols}
400 for point in points:
401 for symbol, coordinate in point.coordinates():
402 coordinates[symbol] += coordinate
403 for symbol in symbols:
404 coordinates[symbol] /= len(points)
405 return Point(coordinates)
406
407 @classmethod
408 def _sort_polygon_2d(cls, points):
409 if len(points) <= 3:
410 return points
411 o = cls._polygon_inner_point(points)
412 angles = {}
413 for m in points:
414 om = Vector(o, m)
415 dx, dy = (coordinate for symbol, coordinate in om.coordinates())
416 angle = math.atan2(dy, dx)
417 angles[m] = angle
418 return sorted(points, key=angles.get)
419
420 @classmethod
421 def _sort_polygon_3d(cls, points):
422 if len(points) <= 3:
423 return points
424 o = cls._polygon_inner_point(points)
425 a = points[0]
426 oa = Vector(o, a)
427 norm_oa = oa.norm()
428 for b in points[1:]:
429 ob = Vector(o, b)
430 u = oa.cross(ob)
431 if not u.isnull():
432 u = u.asunit()
433 break
434 else:
435 raise ValueError('degenerate polygon')
436 angles = {a: 0.}
437 for m in points[1:]:
438 om = Vector(o, m)
439 normprod = norm_oa * om.norm()
440 cosinus = max(oa.dot(om) / normprod, -1.)
441 sinus = u.dot(oa.cross(om)) / normprod
442 angle = math.acos(cosinus)
443 angle = math.copysign(angle, sinus)
444 angles[m] = angle
445 return sorted(points, key=angles.get)
446
447 def faces(self):
448 """
449 Returns the vertices of the faces of a polyhedra.
450 """
451 faces = []
452 for polyhedron in self.polyhedra:
453 vertices = polyhedron.vertices()
454 for constraint in polyhedron.constraints:
455 face = []
456 for vertex in vertices:
457 if constraint.subs(vertex.coordinates()) == 0:
458 face.append(vertex)
459 if len(face) >= 3:
460 faces.append(face)
461 return faces
462
463 def _plot_2d(self, plot=None, **kwargs):
464 import matplotlib.pyplot as plt
465 from matplotlib.patches import Polygon
466 if plot is None:
467 fig = plt.figure()
468 plot = fig.add_subplot(1, 1, 1)
469 xmin, xmax = plot.get_xlim()
470 ymin, ymax = plot.get_ylim()
471 for polyhedron in self.polyhedra:
472 vertices = polyhedron._sort_polygon_2d(polyhedron.vertices())
473 xys = [tuple(vertex.values()) for vertex in vertices]
474 xs, ys = zip(*xys)
475 xmin, xmax = min(xmin, float(min(xs))), max(xmax, float(max(xs)))
476 ymin, ymax = min(ymin, float(min(ys))), max(ymax, float(max(ys)))
477 plot.add_patch(Polygon(xys, closed=True, **kwargs))
478 plot.set_xlim(xmin, xmax)
479 plot.set_ylim(ymin, ymax)
480 return plot
481
482 def _plot_3d(self, plot=None, **kwargs):
483 import matplotlib.pyplot as plt
484 from mpl_toolkits.mplot3d import Axes3D
485 from mpl_toolkits.mplot3d.art3d import Poly3DCollection
486 if plot is None:
487 fig = plt.figure()
488 axes = Axes3D(fig)
489 else:
490 axes = plot
491 xmin, xmax = axes.get_xlim()
492 ymin, ymax = axes.get_ylim()
493 zmin, zmax = axes.get_zlim()
494 poly_xyzs = []
495 for vertices in self.faces():
496 vertices = self._sort_polygon_3d(vertices)
497 vertices.append(vertices[0])
498 face_xyzs = [tuple(vertex.values()) for vertex in vertices]
499 xs, ys, zs = zip(*face_xyzs)
500 xmin, xmax = min(xmin, float(min(xs))), max(xmax, float(max(xs)))
501 ymin, ymax = min(ymin, float(min(ys))), max(ymax, float(max(ys)))
502 zmin, zmax = min(zmin, float(min(zs))), max(zmax, float(max(zs)))
503 poly_xyzs.append(face_xyzs)
504 collection = Poly3DCollection(poly_xyzs, **kwargs)
505 axes.add_collection3d(collection)
506 axes.set_xlim(xmin, xmax)
507 axes.set_ylim(ymin, ymax)
508 axes.set_zlim(zmin, zmax)
509 return axes
510
511
512 def plot(self, plot=None, **kwargs):
513 """
514 Display plot of this set.
515 """
516 if not self.isbounded():
517 raise ValueError('domain must be bounded')
518 elif self.dimension == 2:
519 return self._plot_2d(plot=plot, **kwargs)
520 elif self.dimension == 3:
521 return self._plot_3d(plot=plot, **kwargs)
522 else:
523 raise ValueError('polyhedron must be 2 or 3-dimensional')
524
525 def __contains__(self, point):
526 for polyhedron in self.polyhedra:
527 if point in polyhedron:
528 return True
529 return False
530
531 def subs(self, symbol, expression=None):
532 """
533 Subsitute the given value into an expression and return the resulting expression.
534 """
535 polyhedra = [polyhedron.subs(symbol, expression)
536 for polyhedron in self.polyhedra]
537 return Domain(*polyhedra)
538
539 @classmethod
540 def _fromislset(cls, islset, symbols):
541 from .polyhedra import Polyhedron
542 islset = libisl.isl_set_remove_divs(islset)
543 islbsets = islhelper.isl_set_basic_sets(islset)
544 libisl.isl_set_free(islset)
545 polyhedra = []
546 for islbset in islbsets:
547 polyhedron = Polyhedron._fromislbasicset(islbset, symbols)
548 polyhedra.append(polyhedron)
549 if len(polyhedra) == 0:
550 from .polyhedra import Empty
551 return Empty
552 elif len(polyhedra) == 1:
553 return polyhedra[0]
554 else:
555 self = object().__new__(Domain)
556 self._polyhedra = tuple(polyhedra)
557 self._symbols = cls._xsymbols(polyhedra)
558 self._dimension = len(self._symbols)
559 return self
560
561 @classmethod
562 def _toislset(cls, polyhedra, symbols):
563 polyhedron = polyhedra[0]
564 islbset = polyhedron._toislbasicset(polyhedron.equalities,
565 polyhedron.inequalities, symbols)
566 islset1 = libisl.isl_set_from_basic_set(islbset)
567 for polyhedron in polyhedra[1:]:
568 islbset = polyhedron._toislbasicset(polyhedron.equalities,
569 polyhedron.inequalities, symbols)
570 islset2 = libisl.isl_set_from_basic_set(islbset)
571 islset1 = libisl.isl_set_union(islset1, islset2)
572 return islset1
573
574 @classmethod
575 def _fromast(cls, node):
576 from .polyhedra import Polyhedron
577 if isinstance(node, ast.Module) and len(node.body) == 1:
578 return cls._fromast(node.body[0])
579 elif isinstance(node, ast.Expr):
580 return cls._fromast(node.value)
581 elif isinstance(node, ast.UnaryOp):
582 domain = cls._fromast(node.operand)
583 if isinstance(node.operand, ast.invert):
584 return Not(domain)
585 elif isinstance(node, ast.BinOp):
586 domain1 = cls._fromast(node.left)
587 domain2 = cls._fromast(node.right)
588 if isinstance(node.op, ast.BitAnd):
589 return And(domain1, domain2)
590 elif isinstance(node.op, ast.BitOr):
591 return Or(domain1, domain2)
592 elif isinstance(node, ast.Compare):
593 equalities = []
594 inequalities = []
595 left = Expression._fromast(node.left)
596 for i in range(len(node.ops)):
597 op = node.ops[i]
598 right = Expression._fromast(node.comparators[i])
599 if isinstance(op, ast.Lt):
600 inequalities.append(right - left - 1)
601 elif isinstance(op, ast.LtE):
602 inequalities.append(right - left)
603 elif isinstance(op, ast.Eq):
604 equalities.append(left - right)
605 elif isinstance(op, ast.GtE):
606 inequalities.append(left - right)
607 elif isinstance(op, ast.Gt):
608 inequalities.append(left - right - 1)
609 else:
610 break
611 left = right
612 else:
613 return Polyhedron(equalities, inequalities)
614 raise SyntaxError('invalid syntax')
615
616 _RE_BRACES = re.compile(r'^\{\s*|\s*\}$')
617 _RE_EQ = re.compile(r'([^<=>])=([^<=>])')
618 _RE_AND = re.compile(r'\band\b|,|&&|/\\|∧|∩')
619 _RE_OR = re.compile(r'\bor\b|;|\|\||\\/|∨|∪')
620 _RE_NOT = re.compile(r'\bnot\b|!|¬')
621 _RE_NUM_VAR = Expression._RE_NUM_VAR
622 _RE_OPERATORS = re.compile(r'(&|\||~)')
623
624 @classmethod
625 def fromstring(cls, string):
626 # remove curly brackets
627 string = cls._RE_BRACES.sub(r'', string)
628 # replace '=' by '=='
629 string = cls._RE_EQ.sub(r'\1==\2', string)
630 # replace 'and', 'or', 'not'
631 string = cls._RE_AND.sub(r' & ', string)
632 string = cls._RE_OR.sub(r' | ', string)
633 string = cls._RE_NOT.sub(r' ~', string)
634 # add implicit multiplication operators, e.g. '5x' -> '5*x'
635 string = cls._RE_NUM_VAR.sub(r'\1*\2', string)
636 # add parentheses to force precedence
637 tokens = cls._RE_OPERATORS.split(string)
638 for i, token in enumerate(tokens):
639 if i % 2 == 0:
640 token = '({})'.format(token)
641 tokens[i] = token
642 string = ''.join(tokens)
643 tree = ast.parse(string, 'eval')
644 return cls._fromast(tree)
645
646 def __repr__(self):
647 assert len(self.polyhedra) >= 2
648 strings = [repr(polyhedron) for polyhedron in self.polyhedra]
649 return 'Or({})'.format(', '.join(strings))
650
651 def _repr_latex_(self):
652 strings = []
653 for polyhedron in self.polyhedra:
654 strings.append('({})'.format(polyhedron._repr_latex_().strip('$')))
655 return '${}$'.format(' \\vee '.join(strings))
656
657 @classmethod
658 def fromsympy(cls, expr):
659 import sympy
660 from .polyhedra import Lt, Le, Eq, Ne, Ge, Gt
661 funcmap = {
662 sympy.And: And, sympy.Or: Or, sympy.Not: Not,
663 sympy.Lt: Lt, sympy.Le: Le,
664 sympy.Eq: Eq, sympy.Ne: Ne,
665 sympy.Ge: Ge, sympy.Gt: Gt,
666 }
667 if expr.func in funcmap:
668 args = [Domain.fromsympy(arg) for arg in expr.args]
669 return funcmap[expr.func](*args)
670 elif isinstance(expr, sympy.Expr):
671 return Expression.fromsympy(expr)
672 raise ValueError('non-domain expression: {!r}'.format(expr))
673
674 def tosympy(self):
675 import sympy
676 polyhedra = [polyhedron.tosympy() for polyhedron in polyhedra]
677 return sympy.Or(*polyhedra)
678
679
680 def And(*domains):
681 """
682 Return the intersection of two sets as a new set.
683 """
684 if len(domains) == 0:
685 from .polyhedra import Universe
686 return Universe
687 else:
688 return domains[0].intersection(*domains[1:])
689
690 def Or(*domains):
691 """
692 Return the union of sets as a new set.
693 """
694 if len(domains) == 0:
695 from .polyhedra import Empty
696 return Empty
697 else:
698 return domains[0].union(*domains[1:])
699
700 def Not(domain):
701 """
702 Returns the complement of this set.
703 """
704 return ~domain