b73814eefce01b272ac880ed484c9c850a1ffad7
[linpy.git] / polyp.py
1
2 import functools
3 import numbers
4 import re
5 import subprocess
6
7 from fractions import Fraction, gcd
8
9
10 __all__ = [
11 'Expression',
12 'Constant', 'Symbol', 'symbols',
13 'Eq', 'Le', 'Lt', 'Ge', 'Gt',
14 'Polyhedron',
15 'empty', 'universe'
16 ]
17
18
19 _iscc_debug = False
20
21 def _iscc(input):
22 if not input.endswith(';'):
23 input += ';'
24 proc = subprocess.Popen(['iscc'],
25 stdin=subprocess.PIPE, stdout=subprocess.PIPE,
26 universal_newlines=True)
27 output, error = proc.communicate(input=input)
28 output = output.strip()
29 if _iscc_debug:
30 print('ISCC({!r}) = {!r}'.format(input, output))
31 return output
32
33
34 def _polymorphic_method(func):
35 @functools.wraps(func)
36 def wrapper(a, b):
37 if isinstance(b, Expression):
38 return func(a, b)
39 if isinstance(b, numbers.Rational):
40 b = Constant(b)
41 return func(a, b)
42 return NotImplemented
43 return wrapper
44
45 def _polymorphic_operator(func):
46 # A polymorphic operator should call a polymorphic method, hence we just
47 # have to test the left operand.
48 @functools.wraps(func)
49 def wrapper(a, b):
50 if isinstance(a, numbers.Rational):
51 a = Constant(a)
52 return func(a, b)
53 elif isinstance(a, Expression):
54 return func(a, b)
55 raise TypeError('arguments must be linear expressions')
56 return wrapper
57
58
59 class Expression:
60 """
61 This class implements linear expressions.
62 """
63
64 def __new__(cls, coefficients=None, constant=0):
65 self = super().__new__(cls)
66 self._coefficients = {}
67 if isinstance(coefficients, dict):
68 coefficients = coefficients.items()
69 if coefficients is not None:
70 for symbol, coefficient in coefficients:
71 if isinstance(symbol, Expression) and symbol.issymbol():
72 symbol = str(symbol)
73 elif not isinstance(symbol, str):
74 raise TypeError('symbols must be strings')
75 if not isinstance(coefficient, numbers.Rational):
76 raise TypeError('coefficients must be rational numbers')
77 if coefficient != 0:
78 self._coefficients[symbol] = coefficient
79 if not isinstance(constant, numbers.Rational):
80 raise TypeError('constant must be a rational number')
81 self._constant = constant
82 return self
83
84 @classmethod
85 def _fromiscc(cls, symbols, string):
86 string = re.sub(r'(\d+)\s*([a-zA-Z_]\w*)',
87 lambda m: '{}*{}'.format(m.group(1), m.group(2)),
88 string)
89 context = {}
90 for symbol in symbols:
91 context[symbol] = Symbol(symbol)
92 return eval(string, context)
93
94 def symbols(self):
95 yield from sorted(self._coefficients)
96
97 @property
98 def dimension(self):
99 return len(list(self.symbols()))
100
101 def coefficient(self, symbol):
102 if isinstance(symbol, Expression) and symbol.issymbol():
103 symbol = str(symbol)
104 elif not isinstance(symbol, str):
105 raise TypeError('symbol must be a string')
106 try:
107 return self._coefficients[symbol]
108 except KeyError:
109 return 0
110
111 __getitem__ = coefficient
112
113 def coefficients(self):
114 for symbol in self.symbols():
115 yield symbol, self.coefficient(symbol)
116
117 @property
118 def constant(self):
119 return self._constant
120
121 def isconstant(self):
122 return len(self._coefficients) == 0
123
124 def values(self):
125 for symbol in self.symbols():
126 yield self.coefficient(symbol)
127 yield self.constant
128
129 def symbol(self):
130 if not self.issymbol():
131 raise ValueError('not a symbol: {}'.format(self))
132 for symbol in self.symbols():
133 return symbol
134
135 def issymbol(self):
136 return len(self._coefficients) == 1 and self._constant == 0
137
138 def __bool__(self):
139 return (not self.isconstant()) or bool(self.constant)
140
141 def __pos__(self):
142 return self
143
144 def __neg__(self):
145 return self * -1
146
147 @_polymorphic_method
148 def __add__(self, other):
149 coefficients = dict(self.coefficients())
150 for symbol, coefficient in other.coefficients():
151 if symbol in coefficients:
152 coefficients[symbol] += coefficient
153 else:
154 coefficients[symbol] = coefficient
155 constant = self.constant + other.constant
156 return Expression(coefficients, constant)
157
158 __radd__ = __add__
159
160 @_polymorphic_method
161 def __sub__(self, other):
162 coefficients = dict(self.coefficients())
163 for symbol, coefficient in other.coefficients():
164 if symbol in coefficients:
165 coefficients[symbol] -= coefficient
166 else:
167 coefficients[symbol] = -coefficient
168 constant = self.constant - other.constant
169 return Expression(coefficients, constant)
170
171 def __rsub__(self, other):
172 return -(self - other)
173
174 @_polymorphic_method
175 def __mul__(self, other):
176 if other.isconstant():
177 coefficients = dict(self.coefficients())
178 for symbol in coefficients:
179 coefficients[symbol] *= other.constant
180 constant = self.constant * other.constant
181 return Expression(coefficients, constant)
182 if isinstance(other, Expression) and not self.isconstant():
183 raise ValueError('non-linear expression: '
184 '{} * {}'.format(self._prepr(), other._prepr()))
185 return NotImplemented
186
187 __rmul__ = __mul__
188
189 def __repr__(self):
190 string = ''
191 symbols = sorted(self.symbols())
192 i = 0
193 for symbol in symbols:
194 coefficient = self[symbol]
195 if coefficient == 1:
196 if i == 0:
197 string += symbol
198 else:
199 string += ' + {}'.format(symbol)
200 elif coefficient == -1:
201 if i == 0:
202 string += '-{}'.format(symbol)
203 else:
204 string += ' - {}'.format(symbol)
205 else:
206 if i == 0:
207 string += '{}*{}'.format(coefficient, symbol)
208 elif coefficient > 0:
209 string += ' + {}*{}'.format(coefficient, symbol)
210 else:
211 assert coefficient < 0
212 coefficient *= -1
213 string += ' - {}*{}'.format(coefficient, symbol)
214 i += 1
215 constant = self.constant
216 if constant != 0 and i == 0:
217 string += '{}'.format(constant)
218 elif constant > 0:
219 string += ' + {}'.format(constant)
220 elif constant < 0:
221 constant *= -1
222 string += ' - {}'.format(constant)
223 if string == '':
224 string = '0'
225 return string
226
227 def _prepr(self, always=False):
228 string = str(self)
229 if not always and (self.isconstant() or self.issymbol()):
230 return string
231 else:
232 return '({})'.format(string)
233
234 @_polymorphic_method
235 def __eq__(self, other):
236 # "normal" equality
237 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
238 return isinstance(other, Expression) and \
239 self._coefficients == other._coefficients and \
240 self.constant == other.constant
241
242 def __hash__(self):
243 return hash((tuple(self._coefficients), self._constant))
244
245 @_polymorphic_method
246 def _eq(self, other):
247 return Polyhedron(equalities=[self - other])
248
249 @_polymorphic_method
250 def __le__(self, other):
251 return Polyhedron(inequalities=[self - other])
252
253 @_polymorphic_method
254 def __lt__(self, other):
255 return Polyhedron(inequalities=[self - other + 1])
256
257 @_polymorphic_method
258 def __ge__(self, other):
259 return Polyhedron(inequalities=[other - self])
260
261 @_polymorphic_method
262 def __gt__(self, other):
263 return Polyhedron(inequalities=[other - self + 1])
264
265
266 def Constant(numerator=0, denominator=None):
267 if denominator is None and isinstance(numerator, numbers.Rational):
268 return Expression(constant=numerator)
269 else:
270 return Expression(constant=Fraction(numerator, denominator))
271
272 def Symbol(name):
273 if not isinstance(name, str):
274 raise TypeError('name must be a string')
275 return Expression(coefficients={name: 1})
276
277 def symbols(names):
278 if isinstance(names, str):
279 names = names.replace(',', ' ').split()
280 return (Symbol(name) for name in names)
281
282
283 @_polymorphic_operator
284 def Eq(a, b):
285 return a._eq(b)
286
287 @_polymorphic_operator
288 def Le(a, b):
289 return a <= b
290
291 @_polymorphic_operator
292 def Lt(a, b):
293 return a < b
294
295 @_polymorphic_operator
296 def Ge(a, b):
297 return a >= b
298
299 @_polymorphic_operator
300 def Gt(a, b):
301 return a > b
302
303
304 class Polyhedron:
305 """
306 This class implements polyhedrons.
307 """
308
309 def __new__(cls, equalities=None, inequalities=None):
310 if equalities is None:
311 equalities = []
312 if inequalities is None:
313 inequalities = []
314 symbols = set()
315 for equality in equalities:
316 symbols.update(equality.symbols())
317 for inequality in inequalities:
318 symbols.update(inequality.symbols())
319 symbols = sorted(symbols)
320 string = cls._isccpoly(symbols, equalities, inequalities)
321 string = _iscc(string)
322 return cls._fromiscc(symbols, string)
323
324 @classmethod
325 def _isccpoly(cls, symbols, equalities, inequalities):
326 strings = []
327 for equality in equalities:
328 strings.append('{} = 0'.format(equality))
329 for inequality in inequalities:
330 strings.append('{} <= 0'.format(inequality))
331 string = '{{ [{}] : {} }}'.format(', '.join(symbols), ' and '.join(strings))
332 return string
333
334 def _toiscc(self, symbols):
335 return self._isccpoly(symbols, self.equalities, self.inequalities)
336
337 @classmethod
338 def _fromiscc(cls, symbols, string):
339 if re.match(r'^\s*\{\s*\}\s*$', string):
340 return empty
341 self = super().__new__(cls)
342 self._symbols = symbols
343 self._equalities = []
344 self._inequalities = []
345 string = re.sub(r'^\s*\{\s*(.*?)\s*\}\s*$', lambda m: m.group(1), string)
346 if ':' not in string:
347 string = string + ':'
348 vstr, cstr = re.split(r'\s*:\s*', string)
349 assert vstr != ''
350 vstr = re.sub(r'^\s*\[\s*(.*?)\s*\]\s*$', lambda m: m.group(1), vstr)
351 toks = list(filter(None, re.split(r'\s*,\s*', vstr)))
352 assert len(toks) == len(symbols)
353 for i in range(len(symbols)):
354 symbol = symbols[i]
355 if toks[i] != symbol:
356 expr = Expression._fromiscc(symbols, toks[i])
357 self._equalities.append(Symbol(symbol) - expr)
358 if cstr != '':
359 cstrs = re.split(r'\s*and\s*', cstr)
360 for cstr in cstrs:
361 lhs, op, rhs = re.split(r'\s*(<=|>=|<|>|=)\s*', cstr)
362 lhs = Expression._fromiscc(symbols, lhs)
363 rhs = Expression._fromiscc(symbols, rhs)
364 if op == '=':
365 self._equalities.append(lhs - rhs)
366 elif op == '<=':
367 self._inequalities.append(lhs - rhs)
368 elif op == '>=':
369 self._inequalities.append(rhs - lhs)
370 elif op == '<':
371 self._inequalities.append(lhs - rhs + 1)
372 elif op == '>':
373 self._inequalities.append(rhs - lhs + 1)
374 return self
375
376 @property
377 def equalities(self):
378 yield from self._equalities
379
380 @property
381 def inequalities(self):
382 yield from self._inequalities
383
384 def constraints(self):
385 yield from self.equalities
386 yield from self.inequalities
387
388 def symbols(self):
389 yield from self._symbols
390
391 @property
392 def dimension(self):
393 return len(self.symbols())
394
395 def __bool__(self):
396 # return false if the polyhedron is empty, true otherwise
397 raise not self.isempty()
398
399 def __eq__(self, other):
400 symbols = set(self.symbols()) | set(other.symbols())
401 string = '{} = {}'.format(self._toiscc(symbols), other._toiscc(symbols))
402 string = _iscc(string)
403 return string == 'True'
404
405 def isempty(self):
406 return self == empty
407
408 def isuniverse(self):
409 return self == universe
410
411 def isdisjoint(self, other):
412 # return true if the polyhedron has no elements in common with other
413 return (self & other).isempty()
414
415 def issubset(self, other):
416 symbols = set(self.symbols()) | set(other.symbols())
417 string = '{} <= {}'.format(self._toiscc(symbols), other._toiscc(symbols))
418 string = _iscc(string)
419 return string == 'True'
420
421 def __le__(self, other):
422 return self.issubset(other)
423
424 def __lt__(self, other):
425 symbols = set(self.symbols()) | set(other.symbols())
426 string = '{} < {}'.format(self._toiscc(symbols), other._toiscc(symbols))
427 string = _iscc(string)
428 return string == 'True'
429
430 def issuperset(self, other):
431 # test whether every element in other is in the polyhedron
432 symbols = set(self.symbols()) | set(other.symbols())
433 string = '{} >= {}'.format(self._toiscc(symbols), other._toiscc(symbols))
434 string = _iscc(string)
435 return string == 'True'
436
437 def __ge__(self, other):
438 return self.issuperset(other)
439
440 def __gt__(self, other):
441 symbols = set(self.symbols() + other.symbols())
442 string = '{} > {}'.format(self._toiscc(symbols), other._toiscc(symbols))
443 string = _iscc(string)
444 return string == 'True'
445
446 def union(self, *others):
447 # return a new polyhedron with elements from the polyhedron and all
448 # others (convex union)
449 symbols = set(self.symbols())
450 for other in others:
451 symbols.update(other.symbols())
452 symbols = sorted(symbols)
453 strings = [self._toiscc(symbols)]
454 for other in others:
455 strings.append(other._toiscc(symbols))
456 string = ' + '.join(strings)
457 string = _iscc('poly ({})'.format(string))
458 return Polyhedron._fromiscc(symbols, string)
459
460 def __or__(self, other):
461 return self.union(other)
462
463 def intersection(self, *others):
464 symbols = set(self.symbols())
465 for other in others:
466 symbols.update(other.symbols())
467 symbols = sorted(symbols)
468 strings = [self._toiscc(symbols)]
469 for other in others:
470 strings.append(other._toiscc(symbols))
471 string = ' * '.join(strings)
472 string = _iscc('poly ({})'.format(string))
473 return Polyhedron._fromiscc(symbols, string)
474
475 def __and__(self, other):
476 return self.intersection(other)
477
478 def difference(self, *others):
479 # return a new polyhedron with elements in the polyhedron that are not
480 # in the others
481 symbols = set(self.symbols())
482 for other in others:
483 symbols.update(other.symbols())
484 symbols = sorted(symbols)
485 strings = [self._toiscc(symbols)]
486 for other in others:
487 strings.append(other._toiscc(symbols))
488 string = ' - '.join(strings)
489 string = _iscc('poly ({})'.format(string))
490 return Polyhedron._fromiscc(symbols, string)
491
492 def __sub__(self, other):
493 return self.difference(other)
494
495 def __repr__(self):
496 constraints = []
497 for constraint in self.equalities:
498 constraints.append('{} == 0'.format(constraint))
499 for constraint in self.inequalities:
500 constraints.append('{} <= 0'.format(constraint))
501 if len(constraints) == 0:
502 return 'universe'
503 elif len(constraints) == 1:
504 string = constraints[0]
505 return 'empty' if string == '1 == 0' else string
506 else:
507 strings = ['({})'.format(constraint) for constraint in constraints]
508 return ' & '.join(strings)
509
510 empty = Eq(1, 0)
511
512 universe = Polyhedron()