Symplify TypeError messages in Expression.__new__
[linpy.git] / pypol / domains.py
1 import ast
2 import functools
3 import re
4
5 from fractions import Fraction
6
7 from . import islhelper
8 from .islhelper import mainctx, libisl
9 from .geometry import GeometricObject, Point
10 from .linexprs import Expression, Symbol
11
12
13 __all__ = [
14 'Domain',
15 'And', 'Or', 'Not',
16 ]
17
18
19 @functools.total_ordering
20 class Domain(GeometricObject):
21
22 __slots__ = (
23 '_polyhedra',
24 '_symbols',
25 '_dimension',
26 )
27
28 def __new__(cls, *polyhedra):
29 from .polyhedra import Polyhedron
30 if len(polyhedra) == 1:
31 argument = polyhedra[0]
32 if isinstance(argument, str):
33 return cls.fromstring(argument)
34 elif isinstance(argument, GeometricObject):
35 return argument.aspolyhedron()
36 else:
37 raise TypeError('argument must be a string '
38 'or a GeometricObject instance')
39 else:
40 for polyhedron in polyhedra:
41 if not isinstance(polyhedron, Polyhedron):
42 raise TypeError('arguments must be Polyhedron instances')
43 symbols = cls._xsymbols(polyhedra)
44 islset = cls._toislset(polyhedra, symbols)
45 return cls._fromislset(islset, symbols)
46
47 @classmethod
48 def _xsymbols(cls, iterator):
49 """
50 Return the ordered tuple of symbols present in iterator.
51 """
52 symbols = set()
53 for item in iterator:
54 symbols.update(item.symbols)
55 return tuple(sorted(symbols, key=Symbol.sortkey))
56
57 @property
58 def polyhedra(self):
59 return self._polyhedra
60
61 @property
62 def symbols(self):
63 return self._symbols
64
65 @property
66 def dimension(self):
67 return self._dimension
68
69 def disjoint(self):
70 islset = self._toislset(self.polyhedra, self.symbols)
71 islset = libisl.isl_set_make_disjoint(mainctx, islset)
72 return self._fromislset(islset, self.symbols)
73
74 def isempty(self):
75 islset = self._toislset(self.polyhedra, self.symbols)
76 empty = bool(libisl.isl_set_is_empty(islset))
77 libisl.isl_set_free(islset)
78 return empty
79
80 def __bool__(self):
81 return not self.isempty()
82
83 def isuniverse(self):
84 islset = self._toislset(self.polyhedra, self.symbols)
85 universe = bool(libisl.isl_set_plain_is_universe(islset))
86 libisl.isl_set_free(islset)
87 return universe
88
89 def isbounded(self):
90 islset = self._toislset(self.polyhedra, self.symbols)
91 bounded = bool(libisl.isl_set_is_bounded(islset))
92 libisl.isl_set_free(islset)
93 return bounded
94
95 def __eq__(self, other):
96 symbols = self._xsymbols([self, other])
97 islset1 = self._toislset(self.polyhedra, symbols)
98 islset2 = other._toislset(other.polyhedra, symbols)
99 equal = bool(libisl.isl_set_is_equal(islset1, islset2))
100 libisl.isl_set_free(islset1)
101 libisl.isl_set_free(islset2)
102 return equal
103
104 def isdisjoint(self, other):
105 symbols = self._xsymbols([self, other])
106 islset1 = self._toislset(self.polyhedra, symbols)
107 islset2 = self._toislset(other.polyhedra, symbols)
108 equal = bool(libisl.isl_set_is_disjoint(islset1, islset2))
109 libisl.isl_set_free(islset1)
110 libisl.isl_set_free(islset2)
111 return equal
112
113 def issubset(self, other):
114 symbols = self._xsymbols([self, other])
115 islset1 = self._toislset(self.polyhedra, symbols)
116 islset2 = self._toislset(other.polyhedra, symbols)
117 equal = bool(libisl.isl_set_is_subset(islset1, islset2))
118 libisl.isl_set_free(islset1)
119 libisl.isl_set_free(islset2)
120 return equal
121
122 def __le__(self, other):
123 return self.issubset(other)
124
125 def __lt__(self, other):
126 symbols = self._xsymbols([self, other])
127 islset1 = self._toislset(self.polyhedra, symbols)
128 islset2 = self._toislset(other.polyhedra, symbols)
129 equal = bool(libisl.isl_set_is_strict_subset(islset1, islset2))
130 libisl.isl_set_free(islset1)
131 libisl.isl_set_free(islset2)
132 return equal
133
134 def complement(self):
135 islset = self._toislset(self.polyhedra, self.symbols)
136 islset = libisl.isl_set_complement(islset)
137 return self._fromislset(islset, self.symbols)
138
139 def __invert__(self):
140 return self.complement()
141
142 def simplify(self):
143 #does not change anything in any of the examples
144 #isl seems to do this naturally
145 islset = self._toislset(self.polyhedra, self.symbols)
146 islset = libisl.isl_set_remove_redundancies(islset)
147 return self._fromislset(islset, self.symbols)
148
149 def aspolyhedron(self):
150 # several types of hull are available
151 # polyhedral seems to be the more appropriate, to be checked
152 from .polyhedra import Polyhedron
153 islset = self._toislset(self.polyhedra, self.symbols)
154 islbset = libisl.isl_set_polyhedral_hull(islset)
155 return Polyhedron._fromislbasicset(islbset, self.symbols)
156
157 def asdomain(self):
158 return self
159
160 def project(self, dims):
161 # use to remove certain variables
162 islset = self._toislset(self.polyhedra, self.symbols)
163 n = 0
164 for index, symbol in reversed(list(enumerate(self.symbols))):
165 if symbol in dims:
166 n += 1
167 elif n > 0:
168 islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, index + 1, n)
169 n = 0
170 if n > 0:
171 islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, 0, n)
172 dims = [symbol for symbol in self.symbols if symbol not in dims]
173 return Domain._fromislset(islset, dims)
174
175 def sample(self):
176 islset = self._toislset(self.polyhedra, self.symbols)
177 islpoint = libisl.isl_set_sample_point(islset)
178 if bool(libisl.isl_point_is_void(islpoint)):
179 libisl.isl_point_free(islpoint)
180 raise ValueError('domain must be non-empty')
181 point = {}
182 for index, symbol in enumerate(self.symbols):
183 coordinate = libisl.isl_point_get_coordinate_val(islpoint,
184 libisl.isl_dim_set, index)
185 coordinate = islhelper.isl_val_to_int(coordinate)
186 point[symbol] = coordinate
187 libisl.isl_point_free(islpoint)
188 return point
189
190 def intersection(self, *others):
191 if len(others) == 0:
192 return self
193 symbols = self._xsymbols((self,) + others)
194 islset1 = self._toislset(self.polyhedra, symbols)
195 for other in others:
196 islset2 = other._toislset(other.polyhedra, symbols)
197 islset1 = libisl.isl_set_intersect(islset1, islset2)
198 return self._fromislset(islset1, symbols)
199
200 def __and__(self, other):
201 return self.intersection(other)
202
203 def union(self, *others):
204 if len(others) == 0:
205 return self
206 symbols = self._xsymbols((self,) + others)
207 islset1 = self._toislset(self.polyhedra, symbols)
208 for other in others:
209 islset2 = other._toislset(other.polyhedra, symbols)
210 islset1 = libisl.isl_set_union(islset1, islset2)
211 return self._fromislset(islset1, symbols)
212
213 def __or__(self, other):
214 return self.union(other)
215
216 def __add__(self, other):
217 return self.union(other)
218
219 def difference(self, other):
220 symbols = self._xsymbols([self, other])
221 islset1 = self._toislset(self.polyhedra, symbols)
222 islset2 = other._toislset(other.polyhedra, symbols)
223 islset = libisl.isl_set_subtract(islset1, islset2)
224 return self._fromislset(islset, symbols)
225
226 def __sub__(self, other):
227 return self.difference(other)
228
229 def lexmin(self):
230 islset = self._toislset(self.polyhedra, self.symbols)
231 islset = libisl.isl_set_lexmin(islset)
232 return self._fromislset(islset, self.symbols)
233
234 def lexmax(self):
235 islset = self._toislset(self.polyhedra, self.symbols)
236 islset = libisl.isl_set_lexmax(islset)
237 return self._fromislset(islset, self.symbols)
238
239 def num_parameters(self):
240 #could be useful with large, complicated polyhedrons
241 islbset = self._toislbasicset(self.equalities, self.inequalities, self.symbols)
242 num = libisl.isl_basic_set_dim(islbset, libisl.isl_dim_set)
243 return num
244
245 def involves_dims(self, dims):
246 #could be useful with large, complicated polyhedrons
247 islset = self._toislset(self.polyhedra, self.symbols)
248 dims = sorted(dims)
249 symbols = sorted(list(self.symbols))
250 n = 0
251 if len(dims)>0:
252 for dim in dims:
253 if dim in symbols:
254 first = symbols.index(dims[0])
255 n +=1
256 else:
257 first = 0
258 else:
259 return False
260 value = bool(libisl.isl_set_involves_dims(islset, libisl.isl_dim_set, first, n))
261 libisl.isl_set_free(islset)
262 return value
263
264 _RE_COORDINATE = re.compile(r'\((?P<num>\-?\d+)\)(/(?P<den>\d+))?')
265
266 def vertices(self):
267 #returning list of verticies
268 from .polyhedra import Polyhedron
269 islbset = self._toislbasicset(self.equalities, self.inequalities, self.symbols)
270 vertices = libisl.isl_basic_set_compute_vertices(islbset);
271 vertices = islhelper.isl_vertices_vertices(vertices)
272 points = []
273 for vertex in vertices:
274 expr = libisl.isl_vertex_get_expr(vertex)
275 coordinates = []
276 if islhelper.isl_version < '0.13':
277 constraints = islhelper.isl_basic_set_constraints(expr)
278 for constraint in constraints:
279 constant = libisl.isl_constraint_get_constant_val(constraint)
280 constant = islhelper.isl_val_to_int(constant)
281 for index, symbol in enumerate(self.symbols):
282 coefficient = libisl.isl_constraint_get_coefficient_val(constraint,
283 libisl.isl_dim_set, index)
284 coefficient = islhelper.isl_val_to_int(coefficient)
285 if coefficient != 0:
286 coordinate = -Fraction(constant, coefficient)
287 coordinates.append((symbol, coordinate))
288 else:
289 # horrible hack, find a cleaner solution
290 string = islhelper.isl_multi_aff_to_str(expr)
291 matches = self._RE_COORDINATE.finditer(string)
292 for symbol, match in zip(self.symbols, matches):
293 numerator = int(match.group('num'))
294 denominator = match.group('den')
295 denominator = 1 if denominator is None else int(denominator)
296 coordinate = Fraction(numerator, denominator)
297 coordinates.append((symbol, coordinate))
298 points.append(Point(coordinates))
299 return points
300
301 def points(self):
302 if not self.isbounded():
303 raise ValueError('domain must be bounded')
304 from .polyhedra import Universe, Eq
305 islset = self._toislset(self.polyhedra, self.symbols)
306 islpoints = islhelper.isl_set_points(islset)
307 points = []
308 for islpoint in islpoints:
309 coordinates = {}
310 for index, symbol in enumerate(self.symbols):
311 coordinate = libisl.isl_point_get_coordinate_val(islpoint,
312 libisl.isl_dim_set, index)
313 coordinate = islhelper.isl_val_to_int(coordinate)
314 coordinates[symbol] = coordinate
315 points.append(Point(coordinates))
316 return points
317
318 def __contains__(self, point):
319 for polyhedron in self.polyhedra:
320 if point in polyhedron:
321 return True
322 return False
323
324 def subs(self, symbol, expression=None):
325 polyhedra = [polyhedron.subs(symbol, expression)
326 for polyhedron in self.polyhedra]
327 return Domain(*polyhedra)
328
329 @classmethod
330 def _fromislset(cls, islset, symbols):
331 from .polyhedra import Polyhedron
332 islset = libisl.isl_set_remove_divs(islset)
333 islbsets = islhelper.isl_set_basic_sets(islset)
334 libisl.isl_set_free(islset)
335 polyhedra = []
336 for islbset in islbsets:
337 polyhedron = Polyhedron._fromislbasicset(islbset, symbols)
338 polyhedra.append(polyhedron)
339 if len(polyhedra) == 0:
340 from .polyhedra import Empty
341 return Empty
342 elif len(polyhedra) == 1:
343 return polyhedra[0]
344 else:
345 self = object().__new__(Domain)
346 self._polyhedra = tuple(polyhedra)
347 self._symbols = cls._xsymbols(polyhedra)
348 self._dimension = len(self._symbols)
349 return self
350
351 @classmethod
352 def _toislset(cls, polyhedra, symbols):
353 polyhedron = polyhedra[0]
354 islbset = polyhedron._toislbasicset(polyhedron.equalities,
355 polyhedron.inequalities, symbols)
356 islset1 = libisl.isl_set_from_basic_set(islbset)
357 for polyhedron in polyhedra[1:]:
358 islbset = polyhedron._toislbasicset(polyhedron.equalities,
359 polyhedron.inequalities, symbols)
360 islset2 = libisl.isl_set_from_basic_set(islbset)
361 islset1 = libisl.isl_set_union(islset1, islset2)
362 return islset1
363
364 @classmethod
365 def _fromast(cls, node):
366 from .polyhedra import Polyhedron
367 if isinstance(node, ast.Module) and len(node.body) == 1:
368 return cls._fromast(node.body[0])
369 elif isinstance(node, ast.Expr):
370 return cls._fromast(node.value)
371 elif isinstance(node, ast.UnaryOp):
372 domain = cls._fromast(node.operand)
373 if isinstance(node.operand, ast.invert):
374 return Not(domain)
375 elif isinstance(node, ast.BinOp):
376 domain1 = cls._fromast(node.left)
377 domain2 = cls._fromast(node.right)
378 if isinstance(node.op, ast.BitAnd):
379 return And(domain1, domain2)
380 elif isinstance(node.op, ast.BitOr):
381 return Or(domain1, domain2)
382 elif isinstance(node, ast.Compare):
383 equalities = []
384 inequalities = []
385 left = Expression._fromast(node.left)
386 for i in range(len(node.ops)):
387 op = node.ops[i]
388 right = Expression._fromast(node.comparators[i])
389 if isinstance(op, ast.Lt):
390 inequalities.append(right - left - 1)
391 elif isinstance(op, ast.LtE):
392 inequalities.append(right - left)
393 elif isinstance(op, ast.Eq):
394 equalities.append(left - right)
395 elif isinstance(op, ast.GtE):
396 inequalities.append(left - right)
397 elif isinstance(op, ast.Gt):
398 inequalities.append(left - right - 1)
399 else:
400 break
401 left = right
402 else:
403 return Polyhedron(equalities, inequalities)
404 raise SyntaxError('invalid syntax')
405
406 _RE_BRACES = re.compile(r'^\{\s*|\s*\}$')
407 _RE_EQ = re.compile(r'([^<=>])=([^<=>])')
408 _RE_AND = re.compile(r'\band\b|,|&&|/\\|∧|∩')
409 _RE_OR = re.compile(r'\bor\b|;|\|\||\\/|∨|∪')
410 _RE_NOT = re.compile(r'\bnot\b|!|¬')
411 _RE_NUM_VAR = Expression._RE_NUM_VAR
412 _RE_OPERATORS = re.compile(r'(&|\||~)')
413
414 @classmethod
415 def fromstring(cls, string):
416 # remove curly brackets
417 string = cls._RE_BRACES.sub(r'', string)
418 # replace '=' by '=='
419 string = cls._RE_EQ.sub(r'\1==\2', string)
420 # replace 'and', 'or', 'not'
421 string = cls._RE_AND.sub(r' & ', string)
422 string = cls._RE_OR.sub(r' | ', string)
423 string = cls._RE_NOT.sub(r' ~', string)
424 # add implicit multiplication operators, e.g. '5x' -> '5*x'
425 string = cls._RE_NUM_VAR.sub(r'\1*\2', string)
426 # add parentheses to force precedence
427 tokens = cls._RE_OPERATORS.split(string)
428 for i, token in enumerate(tokens):
429 if i % 2 == 0:
430 token = '({})'.format(token)
431 tokens[i] = token
432 string = ''.join(tokens)
433 tree = ast.parse(string, 'eval')
434 return cls._fromast(tree)
435
436 def __repr__(self):
437 assert len(self.polyhedra) >= 2
438 strings = [repr(polyhedron) for polyhedron in self.polyhedra]
439 return 'Or({})'.format(', '.join(strings))
440
441 @classmethod
442 def fromsympy(cls, expr):
443 import sympy
444 from .polyhedra import Lt, Le, Eq, Ne, Ge, Gt
445 funcmap = {
446 sympy.And: And, sympy.Or: Or, sympy.Not: Not,
447 sympy.Lt: Lt, sympy.Le: Le,
448 sympy.Eq: Eq, sympy.Ne: Ne,
449 sympy.Ge: Ge, sympy.Gt: Gt,
450 }
451 if expr.func in funcmap:
452 args = [Domain.fromsympy(arg) for arg in expr.args]
453 return funcmap[expr.func](*args)
454 elif isinstance(expr, sympy.Expr):
455 return Expression.fromsympy(expr)
456 raise ValueError('non-domain expression: {!r}'.format(expr))
457
458 def tosympy(self):
459 import sympy
460 polyhedra = [polyhedron.tosympy() for polyhedron in polyhedra]
461 return sympy.Or(*polyhedra)
462
463
464 def And(*domains):
465 if len(domains) == 0:
466 from .polyhedra import Universe
467 return Universe
468 else:
469 return domains[0].intersection(*domains[1:])
470
471 def Or(*domains):
472 if len(domains) == 0:
473 from .polyhedra import Empty
474 return Empty
475 else:
476 return domains[0].union(*domains[1:])
477
478 def Not(domain):
479 return ~domain