de6f4ca1d71b047fb76981c045b3fc7c6d8fe83e
[linpy.git] / pypol / linear.py
1 import functools
2 import numbers
3 import ctypes, ctypes.util
4 from pypol import isl
5
6 from fractions import Fraction, gcd
7
8 libisl = ctypes.CDLL(ctypes.util.find_library('isl'))
9
10 libisl.isl_printer_get_str.restype = ctypes.c_char_p
11
12 __all__ = [
13 'Expression',
14 'constant', 'symbol', 'symbols',
15 'eq', 'le', 'lt', 'ge', 'gt',
16 'Polyhedron',
17 'empty', 'universe'
18 ]
19
20
21 _CONTEXT = isl.Context()
22
23 def _polymorphic_method(func):
24 @functools.wraps(func)
25 def wrapper(a, b):
26 if isinstance(b, Expression):
27 return func(a, b)
28 if isinstance(b, numbers.Rational):
29 b = constant(b)
30 return func(a, b)
31 return NotImplemented
32 return wrapper
33
34 def _polymorphic_operator(func):
35 # A polymorphic operator should call a polymorphic method, hence we just
36 # have to test the left operand.
37 @functools.wraps(func)
38 def wrapper(a, b):
39 if isinstance(a, numbers.Rational):
40 a = constant(a)
41 return func(a, b)
42 elif isinstance(a, Expression):
43 return func(a, b)
44 raise TypeError('arguments must be linear expressions')
45 return wrapper
46
47
48 class Expression:
49 """
50 This class implements linear expressions.
51 """
52
53 def __new__(cls, coefficients=None, constant=0):
54 if isinstance(coefficients, str):
55 if constant:
56 raise TypeError('too many arguments')
57 return cls.fromstring(coefficients)
58 self = super().__new__(cls)
59 self._coefficients = {}
60 if isinstance(coefficients, dict):
61 coefficients = coefficients.items()
62 if coefficients is not None:
63 for symbol, coefficient in coefficients:
64 if isinstance(symbol, Expression) and symbol.issymbol():
65 symbol = str(symbol)
66 elif not isinstance(symbol, str):
67 raise TypeError('symbols must be strings')
68 if not isinstance(coefficient, numbers.Rational):
69 raise TypeError('coefficients must be rational numbers')
70 if coefficient != 0:
71 self._coefficients[symbol] = coefficient
72 if not isinstance(constant, numbers.Rational):
73 raise TypeError('constant must be a rational number')
74 self._constant = constant
75 return self
76
77
78 def symbols(self):
79 yield from sorted(self._coefficients)
80
81 @property
82 def dimension(self):
83 return len(list(self.symbols()))
84
85 def coefficient(self, symbol):
86 if isinstance(symbol, Expression) and symbol.issymbol():
87 symbol = str(symbol)
88 elif not isinstance(symbol, str):
89 raise TypeError('symbol must be a string')
90 try:
91 return self._coefficients[symbol]
92 except KeyError:
93 return 0
94
95 __getitem__ = coefficient
96
97 def coefficients(self):
98 for symbol in self.symbols():
99 yield symbol, self.coefficient(symbol)
100
101 @property
102 def constant(self):
103 return self._constant
104
105 def isconstant(self):
106 return len(self._coefficients) == 0
107
108 def values(self):
109 for symbol in self.symbols():
110 yield self.coefficient(symbol)
111 yield self.constant
112
113 def values_int(self):
114 for symbol in self.symbols():
115 return self.coefficient(symbol)
116 return int(self.constant)
117
118
119 def symbol(self):
120 if not self.issymbol():
121 raise ValueError('not a symbol: {}'.format(self))
122 for symbol in self.symbols():
123 return symbol
124
125 def issymbol(self):
126 return len(self._coefficients) == 1 and self._constant == 0
127
128 def __bool__(self):
129 return (not self.isconstant()) or bool(self.constant)
130
131 def __pos__(self):
132 return self
133
134 def __neg__(self):
135 return self * -1
136
137 @_polymorphic_method
138 def __add__(self, other):
139 coefficients = dict(self.coefficients())
140 for symbol, coefficient in other.coefficients():
141 if symbol in coefficients:
142 coefficients[symbol] += coefficient
143 else:
144 coefficients[symbol] = coefficient
145 constant = self.constant + other.constant
146 return Expression(coefficients, constant)
147
148 __radd__ = __add__
149
150 @_polymorphic_method
151 def __sub__(self, other):
152 coefficients = dict(self.coefficients())
153 for symbol, coefficient in other.coefficients():
154 if symbol in coefficients:
155 coefficients[symbol] -= coefficient
156 else:
157 coefficients[symbol] = -coefficient
158 constant = self.constant - other.constant
159 return Expression(coefficients, constant)
160
161 __rsub__ = __sub__
162
163 @_polymorphic_method
164 def __mul__(self, other):
165 if other.isconstant():
166 coefficients = dict(self.coefficients())
167 for symbol in coefficients:
168 coefficients[symbol] *= other.constant
169 constant = self.constant * other.constant
170 return Expression(coefficients, constant)
171 if isinstance(other, Expression) and not self.isconstant():
172 raise ValueError('non-linear expression: '
173 '{} * {}'.format(self._parenstr(), other._parenstr()))
174 return NotImplemented
175
176 __rmul__ = __mul__
177
178 @_polymorphic_method
179 def __truediv__(self, other):
180 if other.isconstant():
181 coefficients = dict(self.coefficients())
182 for symbol in coefficients:
183 coefficients[symbol] = \
184 Fraction(coefficients[symbol], other.constant)
185 constant = Fraction(self.constant, other.constant)
186 return Expression(coefficients, constant)
187 if isinstance(other, Expression):
188 raise ValueError('non-linear expression: '
189 '{} / {}'.format(self._parenstr(), other._parenstr()))
190 return NotImplemented
191
192 def __rtruediv__(self, other):
193 if isinstance(other, self):
194 if self.isconstant():
195 constant = Fraction(other, self.constant)
196 return Expression(constant=constant)
197 else:
198 raise ValueError('non-linear expression: '
199 '{} / {}'.format(other._parenstr(), self._parenstr()))
200 return NotImplemented
201
202 def __str__(self):
203 string = ''
204 symbols = sorted(self.symbols())
205 i = 0
206 for symbol in symbols:
207 coefficient = self[symbol]
208 if coefficient == 1:
209 if i == 0:
210 string += symbol
211 else:
212 string += ' + {}'.format(symbol)
213 elif coefficient == -1:
214 if i == 0:
215 string += '-{}'.format(symbol)
216 else:
217 string += ' - {}'.format(symbol)
218 else:
219 if i == 0:
220 string += '{}*{}'.format(coefficient, symbol)
221 elif coefficient > 0:
222 string += ' + {}*{}'.format(coefficient, symbol)
223 else:
224 assert coefficient < 0
225 coefficient *= -1
226 string += ' - {}*{}'.format(coefficient, symbol)
227 i += 1
228 constant = self.constant
229 if constant != 0 and i == 0:
230 string += '{}'.format(constant)
231 elif constant > 0:
232 string += ' + {}'.format(constant)
233 elif constant < 0:
234 constant *= -1
235 string += ' - {}'.format(constant)
236 if string == '':
237 string = '0'
238 return string
239
240 def _parenstr(self, always=False):
241 string = str(self)
242 if not always and (self.isconstant() or self.issymbol()):
243 return string
244 else:
245 return '({})'.format(string)
246
247 def __repr__(self):
248 string = '{}({{'.format(self.__class__.__name__)
249 for i, (symbol, coefficient) in enumerate(self.coefficients()):
250 if i != 0:
251 string += ', '
252 string += '{!r}: {!r}'.format(symbol, coefficient)
253 string += '}}, {!r})'.format(self.constant)
254 return string
255
256 @classmethod
257 def fromstring(cls, string):
258 raise NotImplementedError
259
260 @_polymorphic_method
261 def __eq__(self, other):
262 # "normal" equality
263 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
264 return isinstance(other, Expression) and \
265 self._coefficients == other._coefficients and \
266 self.constant == other.constant
267
268 def __hash__(self):
269 return hash((self._coefficients, self._constant))
270
271 def _canonify(self):
272 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
273 [value.denominator for value in self.values()])
274 return self * lcm
275
276 @_polymorphic_method
277 def _eq(self, other):
278 return Polyhedron(equalities=[(self - other)._canonify()])
279
280 @_polymorphic_method
281 def __le__(self, other):
282 return Polyhedron(inequalities=[(self - other)._canonify()])
283
284 @_polymorphic_method
285 def __lt__(self, other):
286 return Polyhedron(inequalities=[(self - other)._canonify() + 1])
287
288 @_polymorphic_method
289 def __ge__(self, other):
290 return Polyhedron(inequalities=[(other - self)._canonify()])
291
292 @_polymorphic_method
293 def __gt__(self, other):
294 return Polyhedron(inequalities=[(other - self)._canonify() + 1])
295
296
297 def constant(numerator=0, denominator=None):
298 if denominator is None and isinstance(numerator, numbers.Rational):
299 return Expression(constant=numerator)
300 else:
301 return Expression(constant=Fraction(numerator, denominator))
302
303 def symbol(name):
304 if not isinstance(name, str):
305 raise TypeError('name must be a string')
306 return Expression(coefficients={name: 1})
307
308 def symbols(names):
309 if isinstance(names, str):
310 names = names.replace(',', ' ').split()
311 return (symbol(name) for name in names)
312
313
314 @_polymorphic_operator
315 def eq(a, b):
316 return a._eq(b)
317
318 @_polymorphic_operator
319 def le(a, b):
320 return a <= b
321
322 @_polymorphic_operator
323 def lt(a, b):
324 return a < b
325
326 @_polymorphic_operator
327 def ge(a, b):
328 return a >= b
329
330 @_polymorphic_operator
331 def gt(a, b):
332 return a > b
333
334
335 class Polyhedron:
336 """
337 This class implements polyhedrons.
338 """
339
340 def __new__(cls, equalities=None, inequalities=None):
341 if isinstance(equalities, str):
342 if inequalities is not None:
343 raise TypeError('too many arguments')
344 return cls.fromstring(equalities)
345 self = super().__new__(cls)
346 self._equalities = []
347 if equalities is not None:
348 for constraint in equalities:
349 for value in constraint.values():
350 if value.denominator != 1:
351 raise TypeError('non-integer constraint: '
352 '{} == 0'.format(constraint))
353 self._equalities.append(constraint)
354 self._inequalities = []
355 if inequalities is not None:
356 for constraint in inequalities:
357 for value in constraint.values():
358 if value.denominator != 1:
359 raise TypeError('non-integer constraint: '
360 '{} <= 0'.format(constraint))
361 self._inequalities.append(constraint)
362 self._bset = self.to_isl()
363 return self._bset
364
365
366 @property
367 def equalities(self):
368 yield from self._equalities
369
370 @property
371 def inequalities(self):
372 yield from self._inequalities
373
374 @property
375 def constant(self):
376 return self._constant
377
378 def isconstant(self):
379 return len(self._coefficients) == 0
380
381
382 def isempty(self):
383 return bool(libisl.isl_basic_set_is_empty(self._bset))
384
385 def constraints(self):
386 yield from self.equalities
387 yield from self.inequalities
388
389
390 def symbols(self):
391 s = set()
392 for constraint in self.constraints():
393 s.update(constraint.symbols)
394 yield from sorted(s)
395
396 def symbol_count(self):
397 s = []
398 for constraint in self.constraints():
399 s.append(constraint.symbols)
400 return s
401
402 @property
403 def dimension(self):
404 return len(self.symbols())
405
406 def __bool__(self):
407 # return false if the polyhedron is empty, true otherwise
408 if self._equalities or self._inequalities:
409 return False
410 else:
411 return True
412
413
414 def __contains__(self, value):
415 # is the value in the polyhedron?
416 raise NotImplementedError
417
418 def __eq__(self, other):
419 raise NotImplementedError
420
421 def is_empty(self):
422 return
423
424 def isuniverse(self):
425 return self == universe
426
427 def isdisjoint(self, other):
428 # return true if the polyhedron has no elements in common with other
429 raise NotImplementedError
430
431 def issubset(self, other):
432 raise NotImplementedError
433
434 def __le__(self, other):
435 return self.issubset(other)
436
437 def __lt__(self, other):
438 raise NotImplementedError
439
440 def issuperset(self, other):
441 # test whether every element in other is in the polyhedron
442 for value in other:
443 if value == self.constraints():
444 return True
445 else:
446 return False
447 raise NotImplementedError
448
449 def __ge__(self, other):
450 return self.issuperset(other)
451
452 def __gt__(self, other):
453 raise NotImplementedError
454
455 def union(self, *others):
456 # return a new polyhedron with elements from the polyhedron and all
457 # others (convex union)
458 raise NotImplementedError
459
460 def __or__(self, other):
461 return self.union(other)
462
463 def intersection(self, *others):
464 # return a new polyhedron with elements common to the polyhedron and all
465 # others
466 # a poor man's implementation could be:
467 # equalities = list(self.equalities)
468 # inequalities = list(self.inequalities)
469 # for other in others:
470 # equalities.extend(other.equalities)
471 # inequalities.extend(other.inequalities)
472 # return self.__class__(equalities, inequalities)
473 raise NotImplementedError
474
475 def __and__(self, other):
476 return self.intersection(other)
477
478 def difference(self, *others):
479 # return a new polyhedron with elements in the polyhedron that are not
480 # in the others
481 raise NotImplementedError
482
483 def __sub__(self, other):
484 return self.difference(other)
485
486 def __str__(self):
487 constraints = []
488 for constraint in self.equalities:
489 constraints.append('{} == 0'.format(constraint))
490 for constraint in self.inequalities:
491 constraints.append('{} <= 0'.format(constraint))
492 return '{{{}}}'.format(', '.join(constraints))
493
494 def __repr__(self):
495 equalities = list(self.equalities)
496 inequalities = list(self.inequalities)
497 return '{}(equalities={!r}, inequalities={!r})' \
498 ''.format(self.__class__.__name__, equalities, inequalities)
499
500 @classmethod
501 def fromstring(cls, string):
502 raise NotImplementedError
503
504 def printer(self):
505
506 ip = libisl.isl_printer_to_str(_CONTEXT)
507 ip = libisl.isl_printer_print_val(ip, self) #self should be value
508 string = libisl.isl_printer_get_str(ip).decode()
509 print(string)
510 return string
511
512
513 def to_isl(self):
514 space = libisl.isl_space_set_alloc(_CONTEXT, 0, len(self.symbol_count()))
515 bset = libisl.isl_basic_set_empty(libisl.isl_space_copy(space))
516 ls = libisl.isl_local_space_from_space(libisl.isl_space_copy(space))
517 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
518 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
519 dict_ex = Expression().__dict__
520 print(dict_ex)
521 '''
522 if there are equalities/inequalities, take each constant and coefficient and add as a constraint to the basic set
523 need to change the symbols method to a lookup table for the integer value for each letter that could be a symbol
524 '''
525 if self.equalities:
526 for _constant in dict_ex:
527 value = dict_ex.get('_constant')
528 ceq = libisl.isl_constraint_set_constant_val(ceq, value)
529 for _coefficients in dict_ex:
530 value_co = dict_ex.get('_coefficients')
531 if value_co:
532 ceq = libisl.isl_constraint_set_coefficient_si(ceq, libisl.isl_set_dim, self.symbols(), value_co)
533 bset = libisl.isl_set_add_constraint(bset, ceq)
534 bset = libisl.isl_basic_set_project_out(bset, libisl.isl_set_dim, 1, 1);
535 elif self.inequalities:
536 for _constant in dict_ex:
537 value = dict_ex.get('_constant')
538 cin = libisl.isl_constraint_set_constant_val(cin, value)
539 for _coefficients in dict_ex:
540 value_co = dict_ex.get('_coefficients')
541 if value_co:
542 cin = libisl.isl_constraint_set_coefficient_si(cin, libisl.isl_set_dim, self.symbols(), value_co)
543 bset = libisl.isl_set_add_contraint(bset, cin)
544
545 string = libisl.isl_printer_print_basic_set(bset)
546 print('here')
547 print(bset)
548 print(self)
549 #print(string)
550 return bset
551
552 empty = eq(1, 1)
553
554
555 universe = Polyhedron()