Code cleanup
[linpy.git] / pypol / linear.py
1 import ctypes, ctypes.util
2 import functools
3 import numbers
4
5 from fractions import Fraction, gcd
6
7 from . import isl, islhelper
8 from .isl import libisl, Context
9
10
11 __all__ = [
12 'Expression',
13 'constant', 'symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
15 'Polyhedron',
16 'empty', 'universe'
17 ]
18
19
20 def _polymorphic_method(func):
21 @functools.wraps(func)
22 def wrapper(a, b):
23 if isinstance(b, Expression):
24 return func(a, b)
25 if isinstance(b, numbers.Rational):
26 b = constant(b)
27 return func(a, b)
28 return NotImplemented
29 return wrapper
30
31 def _polymorphic_operator(func):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func)
35 def wrapper(a, b):
36 if isinstance(a, numbers.Rational):
37 a = constant(a)
38 return func(a, b)
39 elif isinstance(a, Expression):
40 return func(a, b)
41 raise TypeError('arguments must be linear expressions')
42 return wrapper
43
44
45 class Expression:
46 """
47 This class implements linear expressions.
48 """
49
50 def __new__(cls, coefficients=None, constant=0):
51 if isinstance(coefficients, str):
52 if constant:
53 raise TypeError('too many arguments')
54 return cls.fromstring(coefficients)
55 self = super().__new__(cls)
56 self._coefficients = {}
57 if isinstance(coefficients, dict):
58 coefficients = coefficients.items()
59 if coefficients is not None:
60 for symbol, coefficient in coefficients:
61 if isinstance(symbol, Expression) and symbol.issymbol():
62 symbol = str(symbol)
63 elif not isinstance(symbol, str):
64 raise TypeError('symbols must be strings')
65 if not isinstance(coefficient, numbers.Rational):
66 raise TypeError('coefficients must be rational numbers')
67 if coefficient != 0:
68 self._coefficients[symbol] = coefficient
69 if not isinstance(constant, numbers.Rational):
70 raise TypeError('constant must be a rational number')
71 self._constant = constant
72 return self
73
74
75 def symbols(self):
76 yield from sorted(self._coefficients)
77
78 @property
79 def dimension(self):
80 return len(list(self.symbols()))
81
82 def coefficient(self, symbol):
83 if isinstance(symbol, Expression) and symbol.issymbol():
84 symbol = str(symbol)
85 elif not isinstance(symbol, str):
86 raise TypeError('symbol must be a string')
87 try:
88 return self._coefficients[symbol]
89 except KeyError:
90 return 0
91
92 __getitem__ = coefficient
93
94 @property
95 def coefficients(self):
96 for symbol in self.symbols():
97 yield symbol, self.coefficient(symbol)
98
99 @property
100 def constant(self):
101 return self._constant
102
103 def isconstant(self):
104 return len(self._coefficients) == 0
105
106 def values(self):
107 for symbol in self.symbols():
108 yield self.coefficient(symbol)
109 yield self.constant
110
111 def values_int(self):
112 for symbol in self.symbols():
113 return self.coefficient(symbol)
114 return int(self.constant)
115
116 def symbol(self):
117 if not self.issymbol():
118 raise ValueError('not a symbol: {}'.format(self))
119 for symbol in self.symbols():
120 return symbol
121
122 def issymbol(self):
123 return len(self._coefficients) == 1 and self._constant == 0
124
125 def __bool__(self):
126 return (not self.isconstant()) or bool(self.constant)
127
128 def __pos__(self):
129 return self
130
131 def __neg__(self):
132 return self * -1
133
134 @_polymorphic_method
135 def __add__(self, other):
136 coefficients = dict(self.coefficients)
137 for symbol, coefficient in other.coefficients:
138 if symbol in coefficients:
139 coefficients[symbol] += coefficient
140 else:
141 coefficients[symbol] = coefficient
142 constant = self.constant + other.constant
143 return Expression(coefficients, constant)
144
145 __radd__ = __add__
146
147 @_polymorphic_method
148 def __sub__(self, other):
149 coefficients = dict(self.coefficients)
150 for symbol, coefficient in other.coefficients:
151 if symbol in coefficients:
152 coefficients[symbol] -= coefficient
153 else:
154 coefficients[symbol] = -coefficient
155 constant = self.constant - other.constant
156 return Expression(coefficients, constant)
157
158 def __rsub__(self, other):
159 return -(self - other)
160
161 @_polymorphic_method
162 def __mul__(self, other):
163 if other.isconstant():
164 coefficients = dict(self.coefficients)
165 for symbol in coefficients:
166 coefficients[symbol] *= other.constant
167 constant = self.constant * other.constant
168 return Expression(coefficients, constant)
169 if isinstance(other, Expression) and not self.isconstant():
170 raise ValueError('non-linear expression: '
171 '{} * {}'.format(self._parenstr(), other._parenstr()))
172 return NotImplemented
173
174 __rmul__ = __mul__
175
176 @_polymorphic_method
177 def __truediv__(self, other):
178 if other.isconstant():
179 coefficients = dict(self.coefficients())
180 for symbol in coefficients:
181 coefficients[symbol] = \
182 Fraction(coefficients[symbol], other.constant)
183 constant = Fraction(self.constant, other.constant)
184 return Expression(coefficients, constant)
185 if isinstance(other, Expression):
186 raise ValueError('non-linear expression: '
187 '{} / {}'.format(self._parenstr(), other._parenstr()))
188 return NotImplemented
189
190 def __rtruediv__(self, other):
191 if isinstance(other, self):
192 if self.isconstant():
193 constant = Fraction(other, self.constant)
194 return Expression(constant=constant)
195 else:
196 raise ValueError('non-linear expression: '
197 '{} / {}'.format(other._parenstr(), self._parenstr()))
198 return NotImplemented
199
200 def __str__(self):
201 string = ''
202 symbols = sorted(self.symbols())
203 i = 0
204 for symbol in symbols:
205 coefficient = self[symbol]
206 if coefficient == 1:
207 if i == 0:
208 string += symbol
209 else:
210 string += ' + {}'.format(symbol)
211 elif coefficient == -1:
212 if i == 0:
213 string += '-{}'.format(symbol)
214 else:
215 string += ' - {}'.format(symbol)
216 else:
217 if i == 0:
218 string += '{}*{}'.format(coefficient, symbol)
219 elif coefficient > 0:
220 string += ' + {}*{}'.format(coefficient, symbol)
221 else:
222 assert coefficient < 0
223 coefficient *= -1
224 string += ' - {}*{}'.format(coefficient, symbol)
225 i += 1
226 constant = self.constant
227 if constant != 0 and i == 0:
228 string += '{}'.format(constant)
229 elif constant > 0:
230 string += ' + {}'.format(constant)
231 elif constant < 0:
232 constant *= -1
233 string += ' - {}'.format(constant)
234 if string == '':
235 string = '0'
236 return string
237
238 def _parenstr(self, always=False):
239 string = str(self)
240 if not always and (self.isconstant() or self.issymbol()):
241 return string
242 else:
243 return '({})'.format(string)
244
245 def __repr__(self):
246 string = '{}({{'.format(self.__class__.__name__)
247 for i, (symbol, coefficient) in enumerate(self.coefficients):
248 if i != 0:
249 string += ', '
250 string += '{!r}: {!r}'.format(symbol, coefficient)
251 string += '}}, {!r})'.format(self.constant)
252 return string
253
254 @classmethod
255 def fromstring(cls, string):
256 raise NotImplementedError
257
258 @_polymorphic_method
259 def __eq__(self, other):
260 # "normal" equality
261 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
262 return isinstance(other, Expression) and \
263 self._coefficients == other._coefficients and \
264 self.constant == other.constant
265
266 def __hash__(self):
267 return hash((self._coefficients, self._constant))
268
269 def _canonify(self):
270 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
271 [value.denominator for value in self.values()])
272 return self * lcm
273
274 @_polymorphic_method
275 def _eq(self, other):
276 return Polyhedron(equalities=[(self - other)._canonify()])
277
278 @_polymorphic_method
279 def __le__(self, other):
280 return Polyhedron(inequalities=[(other - self)._canonify()])
281
282 @_polymorphic_method
283 def __lt__(self, other):
284 return Polyhedron(inequalities=[(other - self)._canonify() - 1])
285
286 @_polymorphic_method
287 def __ge__(self, other):
288 return Polyhedron(inequalities=[(self - other)._canonify()])
289
290 @_polymorphic_method
291 def __gt__(self, other):
292 return Polyhedron(inequalities=[(self - other)._canonify() - 1])
293
294
295 def constant(numerator=0, denominator=None):
296 if denominator is None and isinstance(numerator, numbers.Rational):
297 return Expression(constant=numerator)
298 else:
299 return Expression(constant=Fraction(numerator, denominator))
300
301 def symbol(name):
302 if not isinstance(name, str):
303 raise TypeError('name must be a string')
304 return Expression(coefficients={name: 1})
305
306 def symbols(names):
307 if isinstance(names, str):
308 names = names.replace(',', ' ').split()
309 return (symbol(name) for name in names)
310
311
312 @_polymorphic_operator
313 def eq(a, b):
314 return a._eq(b)
315
316 @_polymorphic_operator
317 def le(a, b):
318 return a <= b
319
320 @_polymorphic_operator
321 def lt(a, b):
322 return a < b
323
324 @_polymorphic_operator
325 def ge(a, b):
326 return a >= b
327
328 @_polymorphic_operator
329 def gt(a, b):
330 return a > b
331
332
333 class Polyhedron:
334 """
335 This class implements polyhedrons.
336 """
337
338 def __new__(cls, equalities=None, inequalities=None):
339 if isinstance(equalities, str):
340 if inequalities is not None:
341 raise TypeError('too many arguments')
342 return cls.fromstring(equalities)
343 self = super().__new__(cls)
344 self._equalities = []
345 if equalities is not None:
346 for constraint in equalities:
347 for value in constraint.values():
348 if value.denominator != 1:
349 raise TypeError('non-integer constraint: '
350 '{} == 0'.format(constraint))
351 self._equalities.append(constraint)
352 self._inequalities = []
353 if inequalities is not None:
354 for constraint in inequalities:
355 for value in constraint.values():
356 if value.denominator != 1:
357 raise TypeError('non-integer constraint: '
358 '{} <= 0'.format(constraint))
359 self._inequalities.append(constraint)
360 self._bset = self._to_isl()
361 #print(self._bset)
362 #put this here just to test from isl method
363 #from_isl = self.from_isl(self._bset)
364 #print(from_isl)
365 #rint(self)
366 return self
367
368 @property
369 def equalities(self):
370 yield from self._equalities
371
372 @property
373 def inequalities(self):
374 yield from self._inequalities
375
376 @property
377 def constant(self):
378 return self._constant
379
380 def isconstant(self):
381 return len(self._coefficients) == 0
382
383 def isempty(self):
384 return bool(libisl.isl_basic_set_is_empty(self._bset))
385
386 def constraints(self):
387 yield from self.equalities
388 yield from self.inequalities
389
390 def symbols(self):
391 s = set()
392 for constraint in self.constraints():
393 s.update(constraint.symbols())
394 return sorted(s)
395
396 @property
397 def dimension(self):
398 return len(self.symbols())
399
400 def __bool__(self):
401 # return false if the polyhedron is empty, true otherwise
402 if self._equalities or self._inequalities:
403 return False
404 else:
405 return True
406
407 def __contains__(self, value):
408 # is the value in the polyhedron?
409 raise NotImplementedError
410
411 def __eq__(self, other):
412 raise NotImplementedError
413
414 def is_empty(self):
415 return
416
417 def isuniverse(self):
418 return self == universe
419
420 def isdisjoint(self, other):
421 # return true if the polyhedron has no elements in common with other
422 raise NotImplementedError
423
424 def issubset(self, other):
425 raise NotImplementedError
426
427 def __le__(self, other):
428 return self.issubset(other)
429
430 def __lt__(self, other):
431 raise NotImplementedError
432
433 def issuperset(self, other):
434 # test whether every element in other is in the polyhedron
435 for value in other:
436 if value == self.constraints():
437 return True
438 else:
439 return False
440 raise NotImplementedError
441
442 def __ge__(self, other):
443 return self.issuperset(other)
444
445 def __gt__(self, other):
446 raise NotImplementedError
447
448 def union(self, *others):
449 # return a new polyhedron with elements from the polyhedron and all
450 # others (convex union)
451 raise NotImplementedError
452
453 def __or__(self, other):
454 return self.union(other)
455
456 def intersection(self, *others):
457 # return a new polyhedron with elements common to the polyhedron and all
458 # others
459 # a poor man's implementation could be:
460 # equalities = list(self.equalities)
461 # inequalities = list(self.inequalities)
462 # for other in others:
463 # equalities.extend(other.equalities)
464 # inequalities.extend(other.inequalities)
465 # return self.__class__(equalities, inequalities)
466 raise NotImplementedError
467
468 def __and__(self, other):
469 return self.intersection(other)
470
471 def difference(self, *others):
472 # return a new polyhedron with elements in the polyhedron that are not
473 # in the others
474 raise NotImplementedError
475
476 def __sub__(self, other):
477 return self.difference(other)
478
479 def __str__(self):
480 constraints = []
481 for constraint in self.equalities:
482 constraints.append('{} == 0'.format(constraint))
483 for constraint in self.inequalities:
484 constraints.append('{} >= 0'.format(constraint))
485 return '{{{}}}'.format(', '.join(constraints))
486
487 def __repr__(self):
488 equalities = list(self.equalities)
489 inequalities = list(self.inequalities)
490 return '{}(equalities={!r}, inequalities={!r})' \
491 ''.format(self.__class__.__name__, equalities, inequalities)
492
493 @classmethod
494 def fromstring(cls, string):
495 raise NotImplementedError
496
497 def _symbolunion(self, *others):
498 symbols = set(self.symbols())
499 for other in others:
500 symbols.update(other.symbols())
501 return sorted(symbols)
502
503 def _to_isl(self, symbols=None):
504 if symbols is None:
505 symbols = self.symbols()
506 num_coefficients = len(symbols)
507 ctx = Context()
508 space = libisl.isl_space_set_alloc(ctx, 0, num_coefficients)
509 bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
510 ls = libisl.isl_local_space_from_space(space)
511 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
512 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
513 '''if there are equalities/inequalities, take each constant and coefficient and add as a constraint to the basic set'''
514 if list(self.equalities): #check if any equalities exist
515 for eq in self.equalities:
516 coeff_eq = dict(eq.coefficients)
517 if eq.constant:
518 value = eq.constant
519 ceq = libisl.isl_constraint_set_constant_si(ceq, value)
520 for eq in coeff_eq:
521 num = coeff_eq.get(eq)
522 iden = symbols.index(eq)
523 ceq = libisl.isl_constraint_set_coefficient_si(ceq, islhelper.isl_dim_set, iden, num) #use 3 for type isl_dim_set
524 bset = libisl.isl_basic_set_add_constraint(bset, ceq)
525 if list(self.inequalities): #check if any inequalities exist
526 for ineq in self.inequalities:
527 coeff_in = dict(ineq.coefficients)
528 if ineq.constant:
529 value = ineq.constant
530 cin = libisl.isl_constraint_set_constant_si(cin, value)
531 for ineq in coeff_in:
532 num = coeff_in.get(ineq)
533 iden = symbols.index(ineq)
534 cin = libisl.isl_constraint_set_coefficient_si(cin, islhelper.isl_dim_set, iden, num) #use 3 for type isl_dim_set
535 bset = libisl.isl_basic_set_add_constraint(bset, cin)
536 ip = libisl.isl_printer_to_str(ctx) #create string printer
537 ip = libisl.isl_printer_print_basic_set(ip, bset) #print basic set to printer
538 string = libisl.isl_printer_get_str(ip) #get string from printer
539 string = str(string.decode())
540 print(string)
541 return bset
542
543 def from_isl(self, bset):
544 '''takes basic set in isl form and puts back into python version of polyhedron
545 isl example code gives isl form as:
546 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
547 our printer is giving form as:
548 b'{ [i0] : 1 = 0 }' '''
549 #bset = self
550 if self._equalities:
551 constraints = libisl.isl_basic_set_equalities_matrix(bset, 3)
552 elif self._inequalities:
553 constraints = libisl.isl_basic_set_inequalities_matrix(bset, 3)
554 print(constraints)
555 return constraints
556
557 empty = None #eq(0,1)
558 universe = None #Polyhedron()
559
560 if __name__ == '__main__':
561 ex1 = Expression(coefficients={'a': 1, 'x': 2}, constant=2)
562 ex2 = Expression(coefficients={'a': 3 , 'b': 2}, constant=3)
563 p = Polyhedron(inequalities=[ex1, ex2])
564 #p = eq(ex2, 0)# 2a+4 = 0, in fact 6a+3 = 0
565 #p.to_isl()
566
567 #universe = Polyhedron()