44c38e7f9e92fc7ade10e43b14cadefd170292fc
[linpy.git] / pypol / domains.py
1 import ast
2 import functools
3 import re
4
5 from fractions import Fraction
6
7 from . import islhelper
8 from .islhelper import mainctx, libisl, isl_set_basic_sets
9 from .linexprs import Expression, Symbol
10
11
12 __all__ = [
13 'Domain',
14 'And', 'Or', 'Not',
15 ]
16
17
18 @functools.total_ordering
19 class Domain:
20
21 __slots__ = (
22 '_polyhedra',
23 '_symbols',
24 '_dimension',
25 )
26
27 def __new__(cls, *polyhedra):
28 from .polyhedra import Polyhedron
29 if len(polyhedra) == 1:
30 polyhedron = polyhedra[0]
31 if isinstance(polyhedron, str):
32 return cls.fromstring(polyhedron)
33 elif isinstance(polyhedron, Polyhedron):
34 return polyhedron
35 else:
36 raise TypeError('argument must be a string '
37 'or a Polyhedron instance')
38 else:
39 for polyhedron in polyhedra:
40 if not isinstance(polyhedron, Polyhedron):
41 raise TypeError('arguments must be Polyhedron instances')
42 symbols = cls._xsymbols(polyhedra)
43 islset = cls._toislset(polyhedra, symbols)
44 return cls._fromislset(islset, symbols)
45
46 @classmethod
47 def _xsymbols(cls, iterator):
48 """
49 Return the ordered tuple of symbols present in iterator.
50 """
51 symbols = set()
52 for item in iterator:
53 symbols.update(item.symbols)
54 return tuple(sorted(symbols, key=Symbol.sortkey))
55
56 @property
57 def polyhedra(self):
58 return self._polyhedra
59
60 @property
61 def symbols(self):
62 return self._symbols
63
64 @property
65 def dimension(self):
66 return self._dimension
67
68 def disjoint(self):
69 islset = self._toislset(self.polyhedra, self.symbols)
70 islset = libisl.isl_set_make_disjoint(mainctx, islset)
71 return self._fromislset(islset, self.symbols)
72
73 def isempty(self):
74 islset = self._toislset(self.polyhedra, self.symbols)
75 empty = bool(libisl.isl_set_is_empty(islset))
76 libisl.isl_set_free(islset)
77 return empty
78
79 def __bool__(self):
80 return not self.isempty()
81
82 def isuniverse(self):
83 islset = self._toislset(self.polyhedra, self.symbols)
84 universe = bool(libisl.isl_set_plain_is_universe(islset))
85 libisl.isl_set_free(islset)
86 return universe
87
88 def isbounded(self):
89 islset = self._toislset(self.polyhedra, self.symbols)
90 bounded = bool(libisl.isl_set_is_bounded(islset))
91 libisl.isl_set_free(islset)
92 return bounded
93
94 def __eq__(self, other):
95 symbols = self._xsymbols([self, other])
96 islset1 = self._toislset(self.polyhedra, symbols)
97 islset2 = other._toislset(other.polyhedra, symbols)
98 equal = bool(libisl.isl_set_is_equal(islset1, islset2))
99 libisl.isl_set_free(islset1)
100 libisl.isl_set_free(islset2)
101 return equal
102
103 def isdisjoint(self, other):
104 symbols = self._xsymbols([self, other])
105 islset1 = self._toislset(self.polyhedra, symbols)
106 islset2 = self._toislset(other.polyhedra, symbols)
107 equal = bool(libisl.isl_set_is_disjoint(islset1, islset2))
108 libisl.isl_set_free(islset1)
109 libisl.isl_set_free(islset2)
110 return equal
111
112 def issubset(self, other):
113 symbols = self._xsymbols([self, other])
114 islset1 = self._toislset(self.polyhedra, symbols)
115 islset2 = self._toislset(other.polyhedra, symbols)
116 equal = bool(libisl.isl_set_is_subset(islset1, islset2))
117 libisl.isl_set_free(islset1)
118 libisl.isl_set_free(islset2)
119 return equal
120
121 def __le__(self, other):
122 return self.issubset(other)
123
124 def __lt__(self, other):
125 symbols = self._xsymbols([self, other])
126 islset1 = self._toislset(self.polyhedra, symbols)
127 islset2 = self._toislset(other.polyhedra, symbols)
128 equal = bool(libisl.isl_set_is_strict_subset(islset1, islset2))
129 libisl.isl_set_free(islset1)
130 libisl.isl_set_free(islset2)
131 return equal
132
133 def complement(self):
134 islset = self._toislset(self.polyhedra, self.symbols)
135 islset = libisl.isl_set_complement(islset)
136 return self._fromislset(islset, self.symbols)
137
138 def __invert__(self):
139 return self.complement()
140
141 def simplify(self):
142 #does not change anything in any of the examples
143 #isl seems to do this naturally
144 islset = self._toislset(self.polyhedra, self.symbols)
145 islset = libisl.isl_set_remove_redundancies(islset)
146 return self._fromislset(islset, self.symbols)
147
148 def aspolyhedron(self):
149 # several types of hull are available
150 # polyhedral seems to be the more appropriate, to be checked
151 from .polyhedra import Polyhedron
152 islset = self._toislset(self.polyhedra, self.symbols)
153 islbset = libisl.isl_set_polyhedral_hull(islset)
154 return Polyhedron._fromislbasicset(islbset, self.symbols)
155
156 def project(self, dims):
157 # use to remove certain variables
158 islset = self._toislset(self.polyhedra, self.symbols)
159 n = 0
160 for index, symbol in reversed(list(enumerate(self.symbols))):
161 if symbol in dims:
162 n += 1
163 elif n > 0:
164 islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, index + 1, n)
165 n = 0
166 if n > 0:
167 islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, 0, n)
168 dims = [symbol for symbol in self.symbols if symbol not in dims]
169 return Domain._fromislset(islset, dims)
170
171 def sample(self):
172 islset = self._toislset(self.polyhedra, self.symbols)
173 islpoint = libisl.isl_set_sample_point(islset)
174 if bool(libisl.isl_point_is_void(islpoint)):
175 libisl.isl_point_free(islpoint)
176 raise ValueError('domain must be non-empty')
177 point = {}
178 for index, symbol in enumerate(self.symbols):
179 coordinate = libisl.isl_point_get_coordinate_val(islpoint,
180 libisl.isl_dim_set, index)
181 coordinate = islhelper.isl_val_to_int(coordinate)
182 point[symbol] = coordinate
183 libisl.isl_point_free(islpoint)
184 return point
185
186 def intersection(self, *others):
187 if len(others) == 0:
188 return self
189 symbols = self._xsymbols((self,) + others)
190 islset1 = self._toislset(self.polyhedra, symbols)
191 for other in others:
192 islset2 = other._toislset(other.polyhedra, symbols)
193 islset1 = libisl.isl_set_intersect(islset1, islset2)
194 return self._fromislset(islset1, symbols)
195
196 def __and__(self, other):
197 return self.intersection(other)
198
199 def union(self, *others):
200 if len(others) == 0:
201 return self
202 symbols = self._xsymbols((self,) + others)
203 islset1 = self._toislset(self.polyhedra, symbols)
204 for other in others:
205 islset2 = other._toislset(other.polyhedra, symbols)
206 islset1 = libisl.isl_set_union(islset1, islset2)
207 return self._fromislset(islset1, symbols)
208
209 def __or__(self, other):
210 return self.union(other)
211
212 def __add__(self, other):
213 return self.union(other)
214
215 def difference(self, other):
216 symbols = self._xsymbols([self, other])
217 islset1 = self._toislset(self.polyhedra, symbols)
218 islset2 = other._toislset(other.polyhedra, symbols)
219 islset = libisl.isl_set_subtract(islset1, islset2)
220 return self._fromislset(islset, symbols)
221
222 def __sub__(self, other):
223 return self.difference(other)
224
225 def lexmin(self):
226 islset = self._toislset(self.polyhedra, self.symbols)
227 islset = libisl.isl_set_lexmin(islset)
228 return self._fromislset(islset, self.symbols)
229
230 def lexmax(self):
231 islset = self._toislset(self.polyhedra, self.symbols)
232 islset = libisl.isl_set_lexmax(islset)
233 return self._fromislset(islset, self.symbols)
234
235 def num_parameters(self):
236 #could be useful with large, complicated polyhedrons
237 islbset = self._toislbasicset(self.equalities, self.inequalities, self.symbols)
238 num = libisl.isl_basic_set_dim(islbset, libisl.isl_dim_set)
239 return num
240
241 def involves_dims(self, dims):
242 #could be useful with large, complicated polyhedrons
243 islset = self._toislset(self.polyhedra, self.symbols)
244 dims = sorted(dims)
245 symbols = sorted(list(self.symbols))
246 n = 0
247 if len(dims)>0:
248 for dim in dims:
249 if dim in symbols:
250 first = symbols.index(dims[0])
251 n +=1
252 else:
253 first = 0
254 else:
255 return False
256 value = bool(libisl.isl_set_involves_dims(islset, libisl.isl_dim_set, first, n))
257 libisl.isl_set_free(islset)
258 return value
259
260 _RE_COORDINATE = re.compile(r'\((?P<num>\-?\d+)\)(/(?P<den>\d+))?')
261
262 def vertices(self):
263 #returning list of verticies
264 from .polyhedra import Polyhedron
265 islbset = self._toislbasicset(self.equalities, self.inequalities, self.symbols)
266 vertices = libisl.isl_basic_set_compute_vertices(islbset);
267 vertices = islhelper.isl_vertices_vertices(vertices)
268 points = {}
269 num = 0
270 vertices_points = []
271 symbols = list(self.symbols)
272 for vertex in vertices:
273 expr = libisl.isl_vertex_get_expr(vertex); #make vertices a bset
274 if islhelper.isl_version < '0.13':
275 constraints = islhelper.isl_basic_set_constraints(expr) #get bset constraints
276 for dim in symbols:
277 index = symbols.index(dim)
278 for c in constraints: #for each constraint
279 constant = libisl.isl_constraint_get_constant_val(c) #get contant value
280 constant = islhelper.isl_val_to_int(constant)
281 coefficient = libisl.isl_constraint_get_coefficient_val(c,libisl.isl_dim_set, index)
282 coefficient = islhelper.isl_val_to_int(coefficient) #get coefficient
283 if coefficient != 0:
284 num = Fraction(constant, coefficient)
285 points.update({dim:num})
286 vertices_points.append(points)
287 print(points)
288
289 else:
290 points = []
291 string = islhelper.isl_multi_aff_to_str(expr)
292 matches = self._RE_COORDINATE.finditer(string)
293 point = {}
294 for symbol, match in zip(self.symbols, matches):
295 numerator = int(match.group('num'))
296 denominator = match.group('den')
297 denominator = 1 if denominator is None else int(denominator)
298 coordinate = Fraction(numerator, denominator)
299 point[symbol] = coordinate
300 points.append(point)
301 return vertices_points
302
303 def points(self):
304 if not self.isbounded():
305 raise ValueError('domain must be bounded')
306 from .polyhedra import Universe, Eq
307 islset = self._toislset(self.polyhedra, self.symbols)
308 islpoints = islhelper.isl_set_points(islset)
309 points = []
310 for islpoint in islpoints:
311 point = {}
312 for index, symbol in enumerate(self.symbols):
313 coordinate = libisl.isl_point_get_coordinate_val(islpoint,
314 libisl.isl_dim_set, index)
315 coordinate = islhelper.isl_val_to_int(coordinate)
316 point[symbol] = coordinate
317 points.append(point)
318 return points
319
320 def subs(self, symbol, expression=None):
321 polyhedra = [polyhedron.subs(symbol, expression)
322 for polyhedron in self.polyhedra]
323 return Domain(*polyhedra)
324
325 @classmethod
326 def _fromislset(cls, islset, symbols):
327 from .polyhedra import Polyhedron
328 islset = libisl.isl_set_remove_divs(islset)
329 islbsets = isl_set_basic_sets(islset)
330 libisl.isl_set_free(islset)
331 polyhedra = []
332 for islbset in islbsets:
333 polyhedron = Polyhedron._fromislbasicset(islbset, symbols)
334 polyhedra.append(polyhedron)
335 if len(polyhedra) == 0:
336 from .polyhedra import Empty
337 return Empty
338 elif len(polyhedra) == 1:
339 return polyhedra[0]
340 else:
341 self = object().__new__(Domain)
342 self._polyhedra = tuple(polyhedra)
343 self._symbols = cls._xsymbols(polyhedra)
344 self._dimension = len(self._symbols)
345 return self
346
347 @classmethod
348 def _toislset(cls, polyhedra, symbols):
349 polyhedron = polyhedra[0]
350 islbset = polyhedron._toislbasicset(polyhedron.equalities,
351 polyhedron.inequalities, symbols)
352 islset1 = libisl.isl_set_from_basic_set(islbset)
353 for polyhedron in polyhedra[1:]:
354 islbset = polyhedron._toislbasicset(polyhedron.equalities,
355 polyhedron.inequalities, symbols)
356 islset2 = libisl.isl_set_from_basic_set(islbset)
357 islset1 = libisl.isl_set_union(islset1, islset2)
358 return islset1
359
360 @classmethod
361 def _fromast(cls, node):
362 from .polyhedra import Polyhedron
363 if isinstance(node, ast.Module) and len(node.body) == 1:
364 return cls._fromast(node.body[0])
365 elif isinstance(node, ast.Expr):
366 return cls._fromast(node.value)
367 elif isinstance(node, ast.UnaryOp):
368 domain = cls._fromast(node.operand)
369 if isinstance(node.operand, ast.invert):
370 return Not(domain)
371 elif isinstance(node, ast.BinOp):
372 domain1 = cls._fromast(node.left)
373 domain2 = cls._fromast(node.right)
374 if isinstance(node.op, ast.BitAnd):
375 return And(domain1, domain2)
376 elif isinstance(node.op, ast.BitOr):
377 return Or(domain1, domain2)
378 elif isinstance(node, ast.Compare):
379 equalities = []
380 inequalities = []
381 left = Expression._fromast(node.left)
382 for i in range(len(node.ops)):
383 op = node.ops[i]
384 right = Expression._fromast(node.comparators[i])
385 if isinstance(op, ast.Lt):
386 inequalities.append(right - left - 1)
387 elif isinstance(op, ast.LtE):
388 inequalities.append(right - left)
389 elif isinstance(op, ast.Eq):
390 equalities.append(left - right)
391 elif isinstance(op, ast.GtE):
392 inequalities.append(left - right)
393 elif isinstance(op, ast.Gt):
394 inequalities.append(left - right - 1)
395 else:
396 break
397 left = right
398 else:
399 return Polyhedron(equalities, inequalities)
400 raise SyntaxError('invalid syntax')
401
402 _RE_BRACES = re.compile(r'^\{\s*|\s*\}$')
403 _RE_EQ = re.compile(r'([^<=>])=([^<=>])')
404 _RE_AND = re.compile(r'\band\b|,|&&|/\\|∧|∩')
405 _RE_OR = re.compile(r'\bor\b|;|\|\||\\/|∨|∪')
406 _RE_NOT = re.compile(r'\bnot\b|!|¬')
407 _RE_NUM_VAR = Expression._RE_NUM_VAR
408 _RE_OPERATORS = re.compile(r'(&|\||~)')
409
410 @classmethod
411 def fromstring(cls, string):
412 # remove curly brackets
413 string = cls._RE_BRACES.sub(r'', string)
414 # replace '=' by '=='
415 string = cls._RE_EQ.sub(r'\1==\2', string)
416 # replace 'and', 'or', 'not'
417 string = cls._RE_AND.sub(r' & ', string)
418 string = cls._RE_OR.sub(r' | ', string)
419 string = cls._RE_NOT.sub(r' ~', string)
420 # add implicit multiplication operators, e.g. '5x' -> '5*x'
421 string = cls._RE_NUM_VAR.sub(r'\1*\2', string)
422 # add parentheses to force precedence
423 tokens = cls._RE_OPERATORS.split(string)
424 for i, token in enumerate(tokens):
425 if i % 2 == 0:
426 token = '({})'.format(token)
427 tokens[i] = token
428 string = ''.join(tokens)
429 tree = ast.parse(string, 'eval')
430 return cls._fromast(tree)
431
432 def __repr__(self):
433 assert len(self.polyhedra) >= 2
434 strings = [repr(polyhedron) for polyhedron in self.polyhedra]
435 return 'Or({})'.format(', '.join(strings))
436
437 @classmethod
438 def fromsympy(cls, expr):
439 import sympy
440 from .polyhedra import Lt, Le, Eq, Ne, Ge, Gt
441 funcmap = {
442 sympy.And: And, sympy.Or: Or, sympy.Not: Not,
443 sympy.Lt: Lt, sympy.Le: Le,
444 sympy.Eq: Eq, sympy.Ne: Ne,
445 sympy.Ge: Ge, sympy.Gt: Gt,
446 }
447 if expr.func in funcmap:
448 args = [Domain.fromsympy(arg) for arg in expr.args]
449 return funcmap[expr.func](*args)
450 elif isinstance(expr, sympy.Expr):
451 return Expression.fromsympy(expr)
452 raise ValueError('non-domain expression: {!r}'.format(expr))
453
454 def tosympy(self):
455 import sympy
456 polyhedra = [polyhedron.tosympy() for polyhedron in polyhedra]
457 return sympy.Or(*polyhedra)
458
459
460 def And(*domains):
461 if len(domains) == 0:
462 from .polyhedra import Universe
463 return Universe
464 else:
465 return domains[0].intersection(*domains[1:])
466
467 def Or(*domains):
468 if len(domains) == 0:
469 from .polyhedra import Empty
470 return Empty
471 else:
472 return domains[0].union(*domains[1:])
473
474 def Not(domain):
475 return ~domain