Make Domain.sample() consistent with Domain.points()
[linpy.git] / pypol / domains.py
1 import ast
2 import functools
3 import re
4
5 from . import islhelper
6
7 from .islhelper import mainctx, libisl, isl_set_basic_sets
8 from .linexprs import Expression, Symbol
9
10
11 __all__ = [
12 'Domain',
13 'And', 'Or', 'Not',
14 ]
15
16
17 @functools.total_ordering
18 class Domain:
19
20 __slots__ = (
21 '_polyhedra',
22 '_symbols',
23 '_dimension',
24 )
25
26 def __new__(cls, *polyhedra):
27 from .polyhedra import Polyhedron
28 if len(polyhedra) == 1:
29 polyhedron = polyhedra[0]
30 if isinstance(polyhedron, str):
31 return cls.fromstring(polyhedron)
32 elif isinstance(polyhedron, Polyhedron):
33 return polyhedron
34 else:
35 raise TypeError('argument must be a string '
36 'or a Polyhedron instance')
37 else:
38 for polyhedron in polyhedra:
39 if not isinstance(polyhedron, Polyhedron):
40 raise TypeError('arguments must be Polyhedron instances')
41 symbols = cls._xsymbols(polyhedra)
42 islset = cls._toislset(polyhedra, symbols)
43 return cls._fromislset(islset, symbols)
44
45 @classmethod
46 def _xsymbols(cls, iterator):
47 """
48 Return the ordered tuple of symbols present in iterator.
49 """
50 symbols = set()
51 for item in iterator:
52 symbols.update(item.symbols)
53 return tuple(sorted(symbols, key=Symbol.sortkey))
54
55 @property
56 def polyhedra(self):
57 return self._polyhedra
58
59 @property
60 def symbols(self):
61 return self._symbols
62
63 @property
64 def dimension(self):
65 return self._dimension
66
67 def disjoint(self):
68 islset = self._toislset(self.polyhedra, self.symbols)
69 islset = libisl.isl_set_make_disjoint(mainctx, islset)
70 return self._fromislset(islset, self.symbols)
71
72 def isempty(self):
73 islset = self._toislset(self.polyhedra, self.symbols)
74 empty = bool(libisl.isl_set_is_empty(islset))
75 libisl.isl_set_free(islset)
76 return empty
77
78 def __bool__(self):
79 return not self.isempty()
80
81 def isuniverse(self):
82 islset = self._toislset(self.polyhedra, self.symbols)
83 universe = bool(libisl.isl_set_plain_is_universe(islset))
84 libisl.isl_set_free(islset)
85 return universe
86
87 def isbounded(self):
88 islset = self._toislset(self.polyhedra, self.symbols)
89 bounded = bool(libisl.isl_set_is_bounded(islset))
90 libisl.isl_set_free(islset)
91 return bounded
92
93 def __eq__(self, other):
94 symbols = self._xsymbols([self, other])
95 islset1 = self._toislset(self.polyhedra, symbols)
96 islset2 = other._toislset(other.polyhedra, symbols)
97 equal = bool(libisl.isl_set_is_equal(islset1, islset2))
98 libisl.isl_set_free(islset1)
99 libisl.isl_set_free(islset2)
100 return equal
101
102 def isdisjoint(self, other):
103 symbols = self._xsymbols([self, other])
104 islset1 = self._toislset(self.polyhedra, symbols)
105 islset2 = self._toislset(other.polyhedra, symbols)
106 equal = bool(libisl.isl_set_is_disjoint(islset1, islset2))
107 libisl.isl_set_free(islset1)
108 libisl.isl_set_free(islset2)
109 return equal
110
111 def issubset(self, other):
112 symbols = self._xsymbols([self, other])
113 islset1 = self._toislset(self.polyhedra, symbols)
114 islset2 = self._toislset(other.polyhedra, symbols)
115 equal = bool(libisl.isl_set_is_subset(islset1, islset2))
116 libisl.isl_set_free(islset1)
117 libisl.isl_set_free(islset2)
118 return equal
119
120 def __le__(self, other):
121 return self.issubset(other)
122
123 def __lt__(self, other):
124 symbols = self._xsymbols([self, other])
125 islset1 = self._toislset(self.polyhedra, symbols)
126 islset2 = self._toislset(other.polyhedra, symbols)
127 equal = bool(libisl.isl_set_is_strict_subset(islset1, islset2))
128 libisl.isl_set_free(islset1)
129 libisl.isl_set_free(islset2)
130 return equal
131
132 def complement(self):
133 islset = self._toislset(self.polyhedra, self.symbols)
134 islset = libisl.isl_set_complement(islset)
135 return self._fromislset(islset, self.symbols)
136
137 def __invert__(self):
138 return self.complement()
139
140 def simplify(self):
141 #does not change anything in any of the examples
142 #isl seems to do this naturally
143 islset = self._toislset(self.polyhedra, self.symbols)
144 islset = libisl.isl_set_remove_redundancies(islset)
145 return self._fromislset(islset, self.symbols)
146
147 def aspolyhedron(self):
148 # several types of hull are available
149 # polyhedral seems to be the more appropriate, to be checked
150 from .polyhedra import Polyhedron
151 islset = self._toislset(self.polyhedra, self.symbols)
152 islbset = libisl.isl_set_polyhedral_hull(islset)
153 return Polyhedron._fromislbasicset(islbset, self.symbols)
154
155 def project(self, dims):
156 # use to remove certain variables
157 islset = self._toislset(self.polyhedra, self.symbols)
158 n = 0
159 for index, symbol in reversed(list(enumerate(self.symbols))):
160 if symbol in dims:
161 n += 1
162 elif n > 0:
163 islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, index + 1, n)
164 n = 0
165 if n > 0:
166 islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, 0, n)
167 dims = [symbol for symbol in self.symbols if symbol not in dims]
168 return Domain._fromislset(islset, dims)
169
170 def sample(self):
171 from .polyhedra import Polyhedron
172 islset = self._toislset(self.polyhedra, self.symbols)
173 islpoint = libisl.isl_set_sample_point(islset)
174 point = {}
175 for index, symbol in enumerate(self.symbols):
176 coordinate = libisl.isl_point_get_coordinate_val(islpoint,
177 libisl.isl_dim_set, index)
178 coordinate = islhelper.isl_val_to_int(coordinate)
179 point[symbol] = coordinate
180 if bool(libisl.isl_point_is_void(islpoint)):
181 point = None
182 libisl.isl_point_free(islpoint)
183 return point
184
185 def intersection(self, *others):
186 if len(others) == 0:
187 return self
188 symbols = self._xsymbols((self,) + others)
189 islset1 = self._toislset(self.polyhedra, symbols)
190 for other in others:
191 islset2 = other._toislset(other.polyhedra, symbols)
192 islset1 = libisl.isl_set_intersect(islset1, islset2)
193 return self._fromislset(islset1, symbols)
194
195 def __and__(self, other):
196 return self.intersection(other)
197
198 def union(self, *others):
199 if len(others) == 0:
200 return self
201 symbols = self._xsymbols((self,) + others)
202 islset1 = self._toislset(self.polyhedra, symbols)
203 for other in others:
204 islset2 = other._toislset(other.polyhedra, symbols)
205 islset1 = libisl.isl_set_union(islset1, islset2)
206 return self._fromislset(islset1, symbols)
207
208 def __or__(self, other):
209 return self.union(other)
210
211 def __add__(self, other):
212 return self.union(other)
213
214 def difference(self, other):
215 symbols = self._xsymbols([self, other])
216 islset1 = self._toislset(self.polyhedra, symbols)
217 islset2 = other._toislset(other.polyhedra, symbols)
218 islset = libisl.isl_set_subtract(islset1, islset2)
219 return self._fromislset(islset, symbols)
220
221 def __sub__(self, other):
222 return self.difference(other)
223
224 def lexmin(self):
225 islset = self._toislset(self.polyhedra, self.symbols)
226 islset = libisl.isl_set_lexmin(islset)
227 return self._fromislset(islset, self.symbols)
228
229 def lexmax(self):
230 islset = self._toislset(self.polyhedra, self.symbols)
231 islset = libisl.isl_set_lexmax(islset)
232 return self._fromislset(islset, self.symbols)
233
234 def num_parameters(self):
235 #could be useful with large, complicated polyhedrons
236 islbset = self._toislbasicset(self.equalities, self.inequalities, self.symbols)
237 num = libisl.isl_basic_set_dim(islbset, libisl.isl_dim_set)
238 return num
239
240 def involves_dims(self, dims):
241 #could be useful with large, complicated polyhedrons
242 islset = self._toislset(self.polyhedra, self.symbols)
243 dims = sorted(dims)
244 symbols = sorted(list(self.symbols))
245 n = 0
246 if len(dims)>0:
247 for dim in dims:
248 if dim in symbols:
249 first = symbols.index(dims[0])
250 n +=1
251 else:
252 first = 0
253 else:
254 return False
255 value = bool(libisl.isl_set_involves_dims(islset, libisl.isl_dim_set, first, n))
256 libisl.isl_set_free(islset)
257 return value
258
259 def vertices(self):
260 islbset = self._toislbasicset(self.equalities, self.inequalities, self.symbols)
261 vertices = libisl.isl_basic_set_compute_vertices(islbset);
262 vertices = islhelper.isl_vertices_vertices(vertices)
263 for vertex in vertices:
264 expr = libisl.isl_vertex_get_expr(vertex);
265 if islhelper.isl_version < '0.13':
266 string = islhelper.isl_set_to_str(expr)
267 else:
268 string = islhelper.isl_multi_aff_to_str(expr)
269 print(string)
270
271 def points(self):
272 if not self.isbounded():
273 raise ValueError('domain must be unbounded')
274 from .polyhedra import Universe, Eq
275 islset = self._toislset(self.polyhedra, self.symbols)
276 islpoints = islhelper.isl_set_points(islset)
277 points = []
278 for islpoint in islpoints:
279 point = {}
280 for index, symbol in enumerate(self.symbols):
281 coordinate = libisl.isl_point_get_coordinate_val(islpoint,
282 libisl.isl_dim_set, index)
283 coordinate = islhelper.isl_val_to_int(coordinate)
284 point[symbol] = coordinate
285 points.append(point)
286 return points
287
288 @classmethod
289 def _fromislset(cls, islset, symbols):
290 from .polyhedra import Polyhedron
291 islset = libisl.isl_set_remove_divs(islset)
292 islbsets = isl_set_basic_sets(islset)
293 libisl.isl_set_free(islset)
294 polyhedra = []
295 for islbset in islbsets:
296 polyhedron = Polyhedron._fromislbasicset(islbset, symbols)
297 polyhedra.append(polyhedron)
298 if len(polyhedra) == 0:
299 from .polyhedra import Empty
300 return Empty
301 elif len(polyhedra) == 1:
302 return polyhedra[0]
303 else:
304 self = object().__new__(Domain)
305 self._polyhedra = tuple(polyhedra)
306 self._symbols = cls._xsymbols(polyhedra)
307 self._dimension = len(self._symbols)
308 return self
309
310 def _toislset(cls, polyhedra, symbols):
311 polyhedron = polyhedra[0]
312 islbset = polyhedron._toislbasicset(polyhedron.equalities,
313 polyhedron.inequalities, symbols)
314 islset1 = libisl.isl_set_from_basic_set(islbset)
315 for polyhedron in polyhedra[1:]:
316 islbset = polyhedron._toislbasicset(polyhedron.equalities,
317 polyhedron.inequalities, symbols)
318 islset2 = libisl.isl_set_from_basic_set(islbset)
319 islset1 = libisl.isl_set_union(islset1, islset2)
320 return islset1
321
322 @classmethod
323 def _fromast(cls, node):
324 from .polyhedra import Polyhedron
325 if isinstance(node, ast.Module) and len(node.body) == 1:
326 return cls._fromast(node.body[0])
327 elif isinstance(node, ast.Expr):
328 return cls._fromast(node.value)
329 elif isinstance(node, ast.UnaryOp):
330 domain = cls._fromast(node.operand)
331 if isinstance(node.operand, ast.invert):
332 return Not(domain)
333 elif isinstance(node, ast.BinOp):
334 domain1 = cls._fromast(node.left)
335 domain2 = cls._fromast(node.right)
336 if isinstance(node.op, ast.BitAnd):
337 return And(domain1, domain2)
338 elif isinstance(node.op, ast.BitOr):
339 return Or(domain1, domain2)
340 elif isinstance(node, ast.Compare):
341 equalities = []
342 inequalities = []
343 left = Expression._fromast(node.left)
344 for i in range(len(node.ops)):
345 op = node.ops[i]
346 right = Expression._fromast(node.comparators[i])
347 if isinstance(op, ast.Lt):
348 inequalities.append(right - left - 1)
349 elif isinstance(op, ast.LtE):
350 inequalities.append(right - left)
351 elif isinstance(op, ast.Eq):
352 equalities.append(left - right)
353 elif isinstance(op, ast.GtE):
354 inequalities.append(left - right)
355 elif isinstance(op, ast.Gt):
356 inequalities.append(left - right - 1)
357 else:
358 break
359 left = right
360 else:
361 return Polyhedron(equalities, inequalities)
362 raise SyntaxError('invalid syntax')
363
364 _RE_BRACES = re.compile(r'^\{\s*|\s*\}$')
365 _RE_EQ = re.compile(r'([^<=>])=([^<=>])')
366 _RE_AND = re.compile(r'\band\b|,|&&|/\\|∧|∩')
367 _RE_OR = re.compile(r'\bor\b|;|\|\||\\/|∨|∪')
368 _RE_NOT = re.compile(r'\bnot\b|!|¬')
369 _RE_NUM_VAR = Expression._RE_NUM_VAR
370 _RE_OPERATORS = re.compile(r'(&|\||~)')
371
372 @classmethod
373 def fromstring(cls, string):
374 # remove curly brackets
375 string = cls._RE_BRACES.sub(r'', string)
376 # replace '=' by '=='
377 string = cls._RE_EQ.sub(r'\1==\2', string)
378 # replace 'and', 'or', 'not'
379 string = cls._RE_AND.sub(r' & ', string)
380 string = cls._RE_OR.sub(r' | ', string)
381 string = cls._RE_NOT.sub(r' ~', string)
382 # add implicit multiplication operators, e.g. '5x' -> '5*x'
383 string = cls._RE_NUM_VAR.sub(r'\1*\2', string)
384 # add parentheses to force precedence
385 tokens = cls._RE_OPERATORS.split(string)
386 for i, token in enumerate(tokens):
387 if i % 2 == 0:
388 token = '({})'.format(token)
389 tokens[i] = token
390 string = ''.join(tokens)
391 tree = ast.parse(string, 'eval')
392 return cls._fromast(tree)
393
394 def __repr__(self):
395 assert len(self.polyhedra) >= 2
396 strings = [repr(polyhedron) for polyhedron in self.polyhedra]
397 return 'Or({})'.format(', '.join(strings))
398
399 @classmethod
400 def fromsympy(cls, expr):
401 import sympy
402 from .polyhedra import Lt, Le, Eq, Ne, Ge, Gt
403 funcmap = {
404 sympy.And: And, sympy.Or: Or, sympy.Not: Not,
405 sympy.Lt: Lt, sympy.Le: Le,
406 sympy.Eq: Eq, sympy.Ne: Ne,
407 sympy.Ge: Ge, sympy.Gt: Gt,
408 }
409 if expr.func in funcmap:
410 args = [Domain.fromsympy(arg) for arg in expr.args]
411 return funcmap[expr.func](*args)
412 elif isinstance(expr, sympy.Expr):
413 return Expression.fromsympy(expr)
414 raise ValueError('non-domain expression: {!r}'.format(expr))
415
416 def tosympy(self):
417 import sympy
418 polyhedra = [polyhedron.tosympy() for polyhedron in polyhedra]
419 return sympy.Or(*polyhedra)
420
421
422 def And(*domains):
423 if len(domains) == 0:
424 from .polyhedra import Universe
425 return Universe
426 else:
427 return domains[0].intersection(*domains[1:])
428
429 def Or(*domains):
430 if len(domains) == 0:
431 from .polyhedra import Empty
432 return Empty
433 else:
434 return domains[0].union(*domains[1:])
435
436 def Not(domain):
437 return ~domain