Implement methods Polyhedron.__contains__(), Domain.__contains__()
[linpy.git] / pypol / polyhedra.py
1
2 import functools
3 import math
4 import numbers
5
6 from . import islhelper
7
8 from .islhelper import mainctx, libisl
9 from .coordinates import Point
10 from .linexprs import Expression, Symbol, Rational
11 from .domains import Domain
12
13
14 __all__ = [
15 'Polyhedron',
16 'Lt', 'Le', 'Eq', 'Ne', 'Ge', 'Gt',
17 'Empty', 'Universe',
18 ]
19
20
21 class Polyhedron(Domain):
22
23 __slots__ = (
24 '_equalities',
25 '_inequalities',
26 '_constraints',
27 '_symbols',
28 '_dimension',
29 )
30
31 def __new__(cls, equalities=None, inequalities=None):
32 if isinstance(equalities, str):
33 if inequalities is not None:
34 raise TypeError('too many arguments')
35 return cls.fromstring(equalities)
36 elif isinstance(equalities, Polyhedron):
37 if inequalities is not None:
38 raise TypeError('too many arguments')
39 return equalities
40 elif isinstance(equalities, Domain):
41 if inequalities is not None:
42 raise TypeError('too many arguments')
43 return equalities.aspolyhedron()
44 if equalities is None:
45 equalities = []
46 else:
47 for i, equality in enumerate(equalities):
48 if not isinstance(equality, Expression):
49 raise TypeError('equalities must be linear expressions')
50 equalities[i] = equality.scaleint()
51 if inequalities is None:
52 inequalities = []
53 else:
54 for i, inequality in enumerate(inequalities):
55 if not isinstance(inequality, Expression):
56 raise TypeError('inequalities must be linear expressions')
57 inequalities[i] = inequality.scaleint()
58 symbols = cls._xsymbols(equalities + inequalities)
59 islbset = cls._toislbasicset(equalities, inequalities, symbols)
60 return cls._fromislbasicset(islbset, symbols)
61
62 @property
63 def equalities(self):
64 return self._equalities
65
66 @property
67 def inequalities(self):
68 return self._inequalities
69
70 @property
71 def constraints(self):
72 return self._constraints
73
74 @property
75 def polyhedra(self):
76 return self,
77
78 def disjoint(self):
79 return self
80
81 def isuniverse(self):
82 islbset = self._toislbasicset(self.equalities, self.inequalities,
83 self.symbols)
84 universe = bool(libisl.isl_basic_set_is_universe(islbset))
85 libisl.isl_basic_set_free(islbset)
86 return universe
87
88 def aspolyhedron(self):
89 return self
90
91 def __contains__(self, point):
92 if not isinstance(point, Point):
93 raise TypeError('point must be a Point instance')
94 if self.symbols != point.symbols:
95 raise ValueError('arguments must belong to the same space')
96 for equality in self.equalities:
97 if equality.subs(point.coordinates()) != 0:
98 return False
99 for inequality in self.inequalities:
100 if inequality.subs(point.coordinates()) < 0:
101 return False
102 return True
103
104 def subs(self, symbol, expression=None):
105 equalities = [equality.subs(symbol, expression)
106 for equality in self.equalities]
107 inequalities = [inequality.subs(symbol, expression)
108 for inequality in self.inequalities]
109 return Polyhedron(equalities, inequalities)
110
111 @classmethod
112 def _fromislbasicset(cls, islbset, symbols):
113 islconstraints = islhelper.isl_basic_set_constraints(islbset)
114 equalities = []
115 inequalities = []
116 for islconstraint in islconstraints:
117 constant = libisl.isl_constraint_get_constant_val(islconstraint)
118 constant = islhelper.isl_val_to_int(constant)
119 coefficients = {}
120 for index, symbol in enumerate(symbols):
121 coefficient = libisl.isl_constraint_get_coefficient_val(islconstraint,
122 libisl.isl_dim_set, index)
123 coefficient = islhelper.isl_val_to_int(coefficient)
124 if coefficient != 0:
125 coefficients[symbol] = coefficient
126 expression = Expression(coefficients, constant)
127 if libisl.isl_constraint_is_equality(islconstraint):
128 equalities.append(expression)
129 else:
130 inequalities.append(expression)
131 libisl.isl_basic_set_free(islbset)
132 self = object().__new__(Polyhedron)
133 self._equalities = tuple(equalities)
134 self._inequalities = tuple(inequalities)
135 self._constraints = tuple(equalities + inequalities)
136 self._symbols = cls._xsymbols(self._constraints)
137 self._dimension = len(self._symbols)
138 return self
139
140 @classmethod
141 def _toislbasicset(cls, equalities, inequalities, symbols):
142 dimension = len(symbols)
143 indices = {symbol: index for index, symbol in enumerate(symbols)}
144 islsp = libisl.isl_space_set_alloc(mainctx, 0, dimension)
145 islbset = libisl.isl_basic_set_universe(libisl.isl_space_copy(islsp))
146 islls = libisl.isl_local_space_from_space(islsp)
147 for equality in equalities:
148 isleq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(islls))
149 for symbol, coefficient in equality.coefficients():
150 islval = str(coefficient).encode()
151 islval = libisl.isl_val_read_from_str(mainctx, islval)
152 index = indices[symbol]
153 isleq = libisl.isl_constraint_set_coefficient_val(isleq,
154 libisl.isl_dim_set, index, islval)
155 if equality.constant != 0:
156 islval = str(equality.constant).encode()
157 islval = libisl.isl_val_read_from_str(mainctx, islval)
158 isleq = libisl.isl_constraint_set_constant_val(isleq, islval)
159 islbset = libisl.isl_basic_set_add_constraint(islbset, isleq)
160 for inequality in inequalities:
161 islin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(islls))
162 for symbol, coefficient in inequality.coefficients():
163 islval = str(coefficient).encode()
164 islval = libisl.isl_val_read_from_str(mainctx, islval)
165 index = indices[symbol]
166 islin = libisl.isl_constraint_set_coefficient_val(islin,
167 libisl.isl_dim_set, index, islval)
168 if inequality.constant != 0:
169 islval = str(inequality.constant).encode()
170 islval = libisl.isl_val_read_from_str(mainctx, islval)
171 islin = libisl.isl_constraint_set_constant_val(islin, islval)
172 islbset = libisl.isl_basic_set_add_constraint(islbset, islin)
173 return islbset
174
175 @classmethod
176 def fromstring(cls, string):
177 domain = Domain.fromstring(string)
178 if not isinstance(domain, Polyhedron):
179 raise ValueError('non-polyhedral expression: {!r}'.format(string))
180 return domain
181
182 def __repr__(self):
183 if self.isempty():
184 return 'Empty'
185 elif self.isuniverse():
186 return 'Universe'
187 else:
188 strings = []
189 for equality in self.equalities:
190 strings.append('0 == {}'.format(equality))
191 for inequality in self.inequalities:
192 strings.append('0 <= {}'.format(inequality))
193 if len(strings) == 1:
194 return strings[0]
195 else:
196 return 'And({})'.format(', '.join(strings))
197
198 @classmethod
199 def fromsympy(cls, expr):
200 domain = Domain.fromsympy(expr)
201 if not isinstance(domain, Polyhedron):
202 raise ValueError('non-polyhedral expression: {!r}'.format(expr))
203 return domain
204
205 def tosympy(self):
206 import sympy
207 constraints = []
208 for equality in self.equalities:
209 constraints.append(sympy.Eq(equality.tosympy(), 0))
210 for inequality in self.inequalities:
211 constraints.append(sympy.Ge(inequality.tosympy(), 0))
212 return sympy.And(*constraints)
213
214 @classmethod
215 def _sort_polygon_2d(cls, points):
216 if len(points) <= 3:
217 return points
218 o = sum((Vector(point) for point in points)) / len(points)
219 o = Point(o.coordinates())
220 angles = {}
221 for m in points:
222 om = Vector(o, m)
223 dx, dy = (coordinate for symbol, coordinates in om.coordinates())
224 angle = math.atan2(dy, dx)
225 angles[m] = angle
226 return sorted(points, key=angles.get)
227
228 @classmethod
229 def _sort_polygon_3d(cls, points):
230 if len(points) <= 3:
231 return points
232 o = sum((Vector(point) for point in points)) / len(points)
233 o = Point(o.coordinates())
234 a, b = points[:2]
235 oa = Vector(o, a)
236 ob = Vector(o, b)
237 norm_oa = oa.norm()
238 u = (oa.cross(ob)).asunit()
239 angles = {a: 0.}
240 for m in points[1:]:
241 om = Vector(o, m)
242 normprod = norm_oa * om.norm()
243 cosinus = oa.dot(om) / normprod
244 sinus = u.dot(oa.cross(om)) / normprod
245 angle = math.acos(cosinus)
246 angle = math.copysign(angle, sinus)
247 angles[m] = angle
248 return sorted(points, key=angles.get)
249
250 def plot(self):
251 import matplotlib.pyplot as plt
252 from matplotlib.path import Path
253 import matplotlib.patches as patches
254
255 if len(self.symbols)> 3:
256 raise TypeError
257
258 elif len(self.symbols) == 2:
259 verts = self.vertices()
260 points = []
261 codes = [Path.MOVETO]
262 for vert in verts:
263 pairs = ()
264 for sym in sorted(vert, key=Symbol.sortkey):
265 num = vert.get(sym)
266 pairs = pairs + (num,)
267 points.append(pairs)
268 points.append((0.0, 0.0))
269 num = len(points)
270 while num > 2:
271 codes.append(Path.LINETO)
272 num = num - 1
273 else:
274 codes.append(Path.CLOSEPOLY)
275 path = Path(points, codes)
276 fig = plt.figure()
277 ax = fig.add_subplot(111)
278 patch = patches.PathPatch(path, facecolor='blue', lw=2)
279 ax.add_patch(patch)
280 ax.set_xlim(-5,5)
281 ax.set_ylim(-5,5)
282 plt.show()
283
284 elif len(self.symbols)==3:
285 return 0
286
287 return points
288
289
290 def _polymorphic(func):
291 @functools.wraps(func)
292 def wrapper(left, right):
293 if isinstance(left, numbers.Rational):
294 left = Rational(left)
295 elif not isinstance(left, Expression):
296 raise TypeError('left must be a a rational number '
297 'or a linear expression')
298 if isinstance(right, numbers.Rational):
299 right = Rational(right)
300 elif not isinstance(right, Expression):
301 raise TypeError('right must be a a rational number '
302 'or a linear expression')
303 return func(left, right)
304 return wrapper
305
306 @_polymorphic
307 def Lt(left, right):
308 return Polyhedron([], [right - left - 1])
309
310 @_polymorphic
311 def Le(left, right):
312 return Polyhedron([], [right - left])
313
314 @_polymorphic
315 def Eq(left, right):
316 return Polyhedron([left - right], [])
317
318 @_polymorphic
319 def Ne(left, right):
320 return ~Eq(left, right)
321
322 @_polymorphic
323 def Gt(left, right):
324 return Polyhedron([], [left - right - 1])
325
326 @_polymorphic
327 def Ge(left, right):
328 return Polyhedron([], [left - right])
329
330
331 Empty = Eq(1, 0)
332
333 Universe = Polyhedron([])