4f3aeef724d112b84f593bdd76794581e879c34e
[linpy.git] / pypol / linear.py
1 import ctypes, ctypes.util
2 import functools
3 import numbers
4
5 from fractions import Fraction, gcd
6
7 from . import isl
8 from .isl import libisl
9
10
11 __all__ = [
12 'Expression',
13 'constant', 'symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
15 'Polyhedron',
16 'empty', 'universe'
17 ]
18
19
20 def _polymorphic_method(func):
21 @functools.wraps(func)
22 def wrapper(a, b):
23 if isinstance(b, Expression):
24 return func(a, b)
25 if isinstance(b, numbers.Rational):
26 b = constant(b)
27 return func(a, b)
28 return NotImplemented
29 return wrapper
30
31 def _polymorphic_operator(func):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func)
35 def wrapper(a, b):
36 if isinstance(a, numbers.Rational):
37 a = constant(a)
38 return func(a, b)
39 elif isinstance(a, Expression):
40 return func(a, b)
41 raise TypeError('arguments must be linear expressions')
42 return wrapper
43
44
45 _main_ctx = isl.Context()
46
47
48 class Expression:
49 """
50 This class implements linear expressions.
51 """
52
53 def __new__(cls, coefficients=None, constant=0):
54 if isinstance(coefficients, str):
55 if constant:
56 raise TypeError('too many arguments')
57 return cls.fromstring(coefficients)
58 self = super().__new__(cls)
59 self._coefficients = {}
60 if isinstance(coefficients, dict):
61 coefficients = coefficients.items()
62 if coefficients is not None:
63 for symbol, coefficient in coefficients:
64 if isinstance(symbol, Expression) and symbol.issymbol():
65 symbol = str(symbol)
66 elif not isinstance(symbol, str):
67 raise TypeError('symbols must be strings')
68 if not isinstance(coefficient, numbers.Rational):
69 raise TypeError('coefficients must be rational numbers')
70 if coefficient != 0:
71 self._coefficients[symbol] = coefficient
72 if not isinstance(constant, numbers.Rational):
73 raise TypeError('constant must be a rational number')
74 self._constant = constant
75 self._symbols = tuple(sorted(self._coefficients))
76 self._dimension = len(self._symbols)
77 return self
78
79 @property
80 def symbols(self):
81 return self._symbols
82
83 @property
84 def dimension(self):
85 return self._dimension
86
87 def coefficient(self, symbol):
88 if isinstance(symbol, Expression) and symbol.issymbol():
89 symbol = str(symbol)
90 elif not isinstance(symbol, str):
91 raise TypeError('symbol must be a string')
92 try:
93 return self._coefficients[symbol]
94 except KeyError:
95 return 0
96
97 __getitem__ = coefficient
98
99 def coefficients(self):
100 for symbol in self.symbols:
101 yield symbol, self.coefficient(symbol)
102
103 @property
104 def constant(self):
105 return self._constant
106
107 def isconstant(self):
108 return len(self._coefficients) == 0
109
110 def values(self):
111 for symbol in self.symbols:
112 yield self.coefficient(symbol)
113 yield self.constant
114
115 def values_int(self):
116 for symbol in self.symbols:
117 return self.coefficient(symbol)
118 return int(self.constant)
119
120 @property
121 def symbol(self):
122 if not self.issymbol():
123 raise ValueError('not a symbol: {}'.format(self))
124 for symbol in self.symbols:
125 return symbol
126
127 def issymbol(self):
128 return len(self._coefficients) == 1 and self._constant == 0
129
130 def __bool__(self):
131 return (not self.isconstant()) or bool(self.constant)
132
133 def __pos__(self):
134 return self
135
136 def __neg__(self):
137 return self * -1
138
139 @_polymorphic_method
140 def __add__(self, other):
141 coefficients = dict(self.coefficients())
142 for symbol, coefficient in other.coefficients:
143 if symbol in coefficients:
144 coefficients[symbol] += coefficient
145 else:
146 coefficients[symbol] = coefficient
147 constant = self.constant + other.constant
148 return Expression(coefficients, constant)
149
150 __radd__ = __add__
151
152 @_polymorphic_method
153 def __sub__(self, other):
154 coefficients = dict(self.coefficients())
155 for symbol, coefficient in other.coefficients:
156 if symbol in coefficients:
157 coefficients[symbol] -= coefficient
158 else:
159 coefficients[symbol] = -coefficient
160 constant = self.constant - other.constant
161 return Expression(coefficients, constant)
162
163 def __rsub__(self, other):
164 return -(self - other)
165
166 @_polymorphic_method
167 def __mul__(self, other):
168 if other.isconstant():
169 coefficients = dict(self.coefficients())
170 for symbol in coefficients:
171 coefficients[symbol] *= other.constant
172 constant = self.constant * other.constant
173 return Expression(coefficients, constant)
174 if isinstance(other, Expression) and not self.isconstant():
175 raise ValueError('non-linear expression: '
176 '{} * {}'.format(self._parenstr(), other._parenstr()))
177 return NotImplemented
178
179 __rmul__ = __mul__
180
181 @_polymorphic_method
182 def __truediv__(self, other):
183 if other.isconstant():
184 coefficients = dict(self.coefficients())
185 for symbol in coefficients:
186 coefficients[symbol] = \
187 Fraction(coefficients[symbol], other.constant)
188 constant = Fraction(self.constant, other.constant)
189 return Expression(coefficients, constant)
190 if isinstance(other, Expression):
191 raise ValueError('non-linear expression: '
192 '{} / {}'.format(self._parenstr(), other._parenstr()))
193 return NotImplemented
194
195 def __rtruediv__(self, other):
196 if isinstance(other, self):
197 if self.isconstant():
198 constant = Fraction(other, self.constant)
199 return Expression(constant=constant)
200 else:
201 raise ValueError('non-linear expression: '
202 '{} / {}'.format(other._parenstr(), self._parenstr()))
203 return NotImplemented
204
205 def __str__(self):
206 string = ''
207 i = 0
208 for symbol in symbols:
209 coefficient = self[symbol]
210 if coefficient == 1:
211 if i == 0:
212 string += symbol
213 else:
214 string += ' + {}'.format(symbol)
215 elif coefficient == -1:
216 if i == 0:
217 string += '-{}'.format(symbol)
218 else:
219 string += ' - {}'.format(symbol)
220 else:
221 if i == 0:
222 string += '{}*{}'.format(coefficient, symbol)
223 elif coefficient > 0:
224 string += ' + {}*{}'.format(coefficient, symbol)
225 else:
226 assert coefficient < 0
227 coefficient *= -1
228 string += ' - {}*{}'.format(coefficient, symbol)
229 i += 1
230 constant = self.constant
231 if constant != 0 and i == 0:
232 string += '{}'.format(constant)
233 elif constant > 0:
234 string += ' + {}'.format(constant)
235 elif constant < 0:
236 constant *= -1
237 string += ' - {}'.format(constant)
238 if string == '':
239 string = '0'
240 return string
241
242 def _parenstr(self, always=False):
243 string = str(self)
244 if not always and (self.isconstant() or self.issymbol()):
245 return string
246 else:
247 return '({})'.format(string)
248
249 def __repr__(self):
250 string = '{}({{'.format(self.__class__.__name__)
251 for i, (symbol, coefficient) in enumerate(self.coefficients()):
252 if i != 0:
253 string += ', '
254 string += '{!r}: {!r}'.format(symbol, coefficient)
255 string += '}}, {!r})'.format(self.constant)
256 return string
257
258 @classmethod
259 def fromstring(cls, string):
260 raise NotImplementedError
261
262 @_polymorphic_method
263 def __eq__(self, other):
264 # "normal" equality
265 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
266 return isinstance(other, Expression) and \
267 self._coefficients == other._coefficients and \
268 self.constant == other.constant
269
270 def __hash__(self):
271 return hash((self._coefficients, self._constant))
272
273 def _canonify(self):
274 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
275 [value.denominator for value in self.values()])
276 return self * lcm
277
278 @_polymorphic_method
279 def _eq(self, other):
280 return Polyhedron(equalities=[(self - other)._canonify()])
281
282 @_polymorphic_method
283 def __le__(self, other):
284 return Polyhedron(inequalities=[(other - self)._canonify()])
285
286 @_polymorphic_method
287 def __lt__(self, other):
288 return Polyhedron(inequalities=[(other - self)._canonify() - 1])
289
290 @_polymorphic_method
291 def __ge__(self, other):
292 return Polyhedron(inequalities=[(self - other)._canonify()])
293
294 @_polymorphic_method
295 def __gt__(self, other):
296 return Polyhedron(inequalities=[(self - other)._canonify() - 1])
297
298
299 def constant(numerator=0, denominator=None):
300 if denominator is None and isinstance(numerator, numbers.Rational):
301 return Expression(constant=numerator)
302 else:
303 return Expression(constant=Fraction(numerator, denominator))
304
305 def symbol(name):
306 if not isinstance(name, str):
307 raise TypeError('name must be a string')
308 return Expression(coefficients={name: 1})
309
310 def symbols(names):
311 if isinstance(names, str):
312 names = names.replace(',', ' ').split()
313 return (symbol(name) for name in names)
314
315
316 @_polymorphic_operator
317 def eq(a, b):
318 return a._eq(b)
319
320 @_polymorphic_operator
321 def le(a, b):
322 return a <= b
323
324 @_polymorphic_operator
325 def lt(a, b):
326 return a < b
327
328 @_polymorphic_operator
329 def ge(a, b):
330 return a >= b
331
332 @_polymorphic_operator
333 def gt(a, b):
334 return a > b
335
336
337 class Polyhedron:
338 """
339 This class implements polyhedrons.
340 """
341
342 def __new__(cls, equalities=None, inequalities=None):
343 if isinstance(equalities, str):
344 if inequalities is not None:
345 raise TypeError('too many arguments')
346 return cls.fromstring(equalities)
347 self = super().__new__(cls)
348 self._equalities = []
349 if equalities is not None:
350 for constraint in equalities:
351 for value in constraint.values():
352 if value.denominator != 1:
353 raise TypeError('non-integer constraint: '
354 '{} == 0'.format(constraint))
355 self._equalities.append(constraint)
356 self._equalities = tuple(self._equalities)
357 self._inequalities = []
358 if inequalities is not None:
359 for constraint in inequalities:
360 for value in constraint.values():
361 if value.denominator != 1:
362 raise TypeError('non-integer constraint: '
363 '{} <= 0'.format(constraint))
364 self._inequalities.append(constraint)
365 self._inequalities = tuple(self._inequalities)
366 self._constraints = self._equalities + self._inequalities
367 self._symbols = set()
368 for constraint in self._constraints:
369 self.symbols.update(constraint.symbols)
370 self._symbols = tuple(sorted(self._symbols))
371 return self
372
373 @property
374 def equalities(self):
375 return self._equalities
376
377 @property
378 def inequalities(self):
379 return self._inequalities
380
381 @property
382 def constant(self):
383 return self._constant
384
385 def isconstant(self):
386 return len(self._coefficients) == 0
387
388 def isempty(self):
389 return bool(libisl.isl_basic_set_is_empty(self._bset))
390
391 @property
392 def constraints(self):
393 return self._constraints
394
395 @property
396 def symbols(self):
397 return self._symbols
398
399 @property
400 def dimension(self):
401 return len(self.symbols)
402
403 def __bool__(self):
404 return not self.is_empty()
405
406 def __contains__(self, value):
407 # is the value in the polyhedron?
408 raise NotImplementedError
409
410 def __eq__(self, other):
411 raise NotImplementedError
412
413 def isempty(self):
414 return self == empty
415
416 def isuniverse(self):
417 return self == universe
418
419 def isdisjoint(self, other):
420 # return true if the polyhedron has no elements in common with other
421 raise NotImplementedError
422
423 def issubset(self, other):
424 raise NotImplementedError
425
426 def __le__(self, other):
427 return self.issubset(other)
428
429 def __lt__(self, other):
430 raise NotImplementedError
431
432 def issuperset(self, other):
433 # test whether every element in other is in the polyhedron
434 raise NotImplementedError
435
436 def __ge__(self, other):
437 return self.issuperset(other)
438
439 def __gt__(self, other):
440 raise NotImplementedError
441
442 def union(self, *others):
443 # return a new polyhedron with elements from the polyhedron and all
444 # others (convex union)
445 raise NotImplementedError
446
447 def __or__(self, other):
448 return self.union(other)
449
450 def intersection(self, *others):
451 # return a new polyhedron with elements common to the polyhedron and all
452 # others
453 # a poor man's implementation could be:
454 # equalities = list(self.equalities)
455 # inequalities = list(self.inequalities)
456 # for other in others:
457 # equalities.extend(other.equalities)
458 # inequalities.extend(other.inequalities)
459 # return self.__class__(equalities, inequalities)
460 raise NotImplementedError
461
462 def __and__(self, other):
463 return self.intersection(other)
464
465 def difference(self, *others):
466 # return a new polyhedron with elements in the polyhedron that are not
467 # in the others
468 raise NotImplementedError
469
470 def __sub__(self, other):
471 return self.difference(other)
472
473 def __str__(self):
474 constraints = []
475 for constraint in self.equalities:
476 constraints.append('{} == 0'.format(constraint))
477 for constraint in self.inequalities:
478 constraints.append('{} >= 0'.format(constraint))
479 return '{{{}}}'.format(', '.join(constraints))
480
481 def __repr__(self):
482 equalities = list(self.equalities)
483 inequalities = list(self.inequalities)
484 return '{}(equalities={!r}, inequalities={!r})' \
485 ''.format(self.__class__.__name__, equalities, inequalities)
486
487 @classmethod
488 def fromstring(cls, string):
489 raise NotImplementedError
490
491 def _symbolunion(self, *others):
492 symbols = set(self.symbols)
493 for other in others:
494 symbols.update(other.symbols)
495 return sorted(symbols)
496
497 def _to_isl(self, symbols=None):
498 if symbols is None:
499 symbols = self.symbols
500 num_coefficients = len(symbols)
501 space = libisl.isl_space_set_alloc(_main_ctx, 0, num_coefficients)
502 bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
503 ls = libisl.isl_local_space_from_space(space)
504 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
505 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
506 '''if there are equalities/inequalities, take each constant and coefficient and add as a constraint to the basic set'''
507 if list(self.equalities): #check if any equalities exist
508 for eq in self.equalities:
509 coeff_eq = dict(eq.coefficients())
510 if eq.constant:
511 value = eq.constant
512 ceq = libisl.isl_constraint_set_constant_si(ceq, value)
513 for eq in coeff_eq:
514 num = coeff_eq.get(eq)
515 iden = symbols.index(eq)
516 ceq = libisl.isl_constraint_set_coefficient_si(ceq, libisl.isl_dim_set, iden, num) #use 3 for type isl_dim_set
517 bset = libisl.isl_basic_set_add_constraint(bset, ceq)
518 if list(self.inequalities): #check if any inequalities exist
519 for ineq in self.inequalities:
520 coeff_in = dict(ineq.coefficients())
521 if ineq.constant:
522 value = ineq.constant
523 cin = libisl.isl_constraint_set_constant_si(cin, value)
524 for ineq in coeff_in:
525 num = coeff_in.get(ineq)
526 iden = symbols.index(ineq)
527 cin = libisl.isl_constraint_set_coefficient_si(cin, libisl.isl_dim_set, iden, num) #use 3 for type isl_dim_set
528 bset = libisl.isl_basic_set_add_constraint(bset, cin)
529 bset = isl.BasicSet(bset)
530 return bset
531
532 @classmethod
533 def from_isl(cls, bset):
534 '''takes basic set in isl form and puts back into python version of polyhedron
535 isl example code gives isl form as:
536 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
537 our printer is giving form as:
538 b'{ [i0] : 1 = 0 }' '''
539 raise NotImplementedError
540 equalities = ...
541 inequalities = ...
542 return cls(equalities, inequalities)
543 #bset = self
544 # if self._equalities:
545 # constraints = libisl.isl_basic_set_equalities_matrix(bset, 3)
546 # elif self._inequalities:
547 # constraints = libisl.isl_basic_set_inequalities_matrix(bset, 3)
548 # print(constraints)
549 # return constraints
550
551 empty = None #eq(0,1)
552 universe = None #Polyhedron()
553
554
555 if __name__ == '__main__':
556 ex1 = Expression(coefficients={'a': 1, 'x': 2}, constant=2)
557 ex2 = Expression(coefficients={'a': 3 , 'b': 2}, constant=3)
558 p = Polyhedron(inequalities=[ex1, ex2])
559 bs = p._to_isl()
560 print(bs)