9058c0890866b5bbf390d55a928c622b33d46f9e
[linpy.git] / polyp.py
1
2 import functools
3 import numbers
4 import re
5 import subprocess
6
7 from fractions import Fraction, gcd
8
9
10 __all__ = [
11 'Expression',
12 'Constant', 'Symbol', 'symbols',
13 'Eq', 'Le', 'Lt', 'Ge', 'Gt',
14 'Polyhedron',
15 'empty', 'universe'
16 ]
17
18
19 def _strsymbol(symbol):
20 if isinstance(symbol, str):
21 return symbol
22 elif isinstance(symbol, Expression) and symbol.issymbol():
23 return str(symbol)
24 else:
25 raise TypeError('symbol must be a string or a symbol')
26
27 def _strsymbols(symbols):
28 if isinstance(symbols, str):
29 return symbols.replace(',', ' ').split()
30 else:
31 return [_strsymbol(symbol) for symbol in symbols]
32
33
34 class _ISCC:
35
36 command =['iscc']
37 debug = False
38
39 @classmethod
40 def eval(cls, input):
41 if not input.endswith(';'):
42 input += ';'
43 proc = subprocess.Popen(cls.command,
44 stdin=subprocess.PIPE, stdout=subprocess.PIPE,
45 universal_newlines=True)
46 output, error = proc.communicate(input=input)
47 output = output.strip()
48 if cls.debug:
49 print('iscc: {!r} -> {!r}'.format(input, output))
50 return output
51
52 @classmethod
53 def symbols(cls, symbols):
54 symbols = _strsymbols(symbols)
55 return '[{}]'.format(', '.join(symbols))
56
57 @classmethod
58 def constraints(cls, equalities, inequalities):
59 strings = []
60 for equality in equalities:
61 strings.append('{} = 0'.format(equality))
62 for inequality in inequalities:
63 strings.append('{} <= 0'.format(inequality))
64 return ' and '.join(strings)
65
66 @classmethod
67 def set(cls, symbols, equalities, inequalities):
68 return '{{ {} : {} }}'.format(cls.symbols(symbols),
69 cls.constraints(equalities, inequalities))
70
71 @classmethod
72 def map(cls, lsymbols, rsymbols, equalities, inequalities):
73 return '{{ {} -> {} : {} }}'.format(
74 cls.symbols(lsymbols), cls.symbols(rsymbols),
75 cls.constraints(equalities, inequalities))
76
77 _iscc = _ISCC()
78
79
80 def _polymorphic_method(func):
81 @functools.wraps(func)
82 def wrapper(a, b):
83 if isinstance(b, Expression):
84 return func(a, b)
85 if isinstance(b, numbers.Rational):
86 b = Constant(b)
87 return func(a, b)
88 return NotImplemented
89 return wrapper
90
91 def _polymorphic_operator(func):
92 # A polymorphic operator should call a polymorphic method, hence we just
93 # have to test the left operand.
94 @functools.wraps(func)
95 def wrapper(a, b):
96 if isinstance(a, numbers.Rational):
97 a = Constant(a)
98 return func(a, b)
99 elif isinstance(a, Expression):
100 return func(a, b)
101 raise TypeError('arguments must be linear expressions')
102 return wrapper
103
104
105 class Expression:
106 """
107 This class implements linear expressions.
108 """
109
110 def __new__(cls, coefficients=None, constant=0):
111 self = super().__new__(cls)
112 self._coefficients = {}
113 if isinstance(coefficients, dict):
114 coefficients = coefficients.items()
115 if coefficients is not None:
116 for symbol, coefficient in coefficients:
117 symbol = _strsymbol(symbol)
118 if not isinstance(coefficient, numbers.Rational):
119 raise TypeError('coefficients must be rational numbers')
120 if coefficient != 0:
121 self._coefficients[symbol] = coefficient
122 if not isinstance(constant, numbers.Rational):
123 raise TypeError('constant must be a rational number')
124 self._constant = constant
125 return self
126
127 @classmethod
128 def _fromiscc(cls, symbols, string):
129 symbols = _strsymbols(symbols)
130 string = re.sub(r'(\d+)\s*([a-zA-Z_]\w*)',
131 lambda m: '{}*{}'.format(m.group(1), m.group(2)),
132 string)
133 context = {}
134 for symbol in symbols:
135 context[symbol] = Symbol(symbol)
136 return eval(string, context)
137
138 def symbols(self):
139 yield from sorted(self._coefficients)
140
141 @property
142 def dimension(self):
143 return len(list(self.symbols()))
144
145 def coefficient(self, symbol):
146 symbol = _strsymbol(symbol)
147 try:
148 return self._coefficients[symbol]
149 except KeyError:
150 return 0
151
152 __getitem__ = coefficient
153
154 def coefficients(self):
155 for symbol in self.symbols():
156 yield symbol, self.coefficient(symbol)
157
158 @property
159 def constant(self):
160 return self._constant
161
162 def isconstant(self):
163 return len(self._coefficients) == 0
164
165 def values(self):
166 for symbol in self.symbols():
167 yield self.coefficient(symbol)
168 yield self.constant
169
170 def symbol(self):
171 if not self.issymbol():
172 raise ValueError('not a symbol: {}'.format(self))
173 for symbol in self.symbols():
174 return symbol
175
176 def issymbol(self):
177 return len(self._coefficients) == 1 and self._constant == 0
178
179 def __bool__(self):
180 return (not self.isconstant()) or bool(self.constant)
181
182 def __pos__(self):
183 return self
184
185 def __neg__(self):
186 return self * -1
187
188 @_polymorphic_method
189 def __add__(self, other):
190 coefficients = dict(self.coefficients())
191 for symbol, coefficient in other.coefficients():
192 if symbol in coefficients:
193 coefficients[symbol] += coefficient
194 else:
195 coefficients[symbol] = coefficient
196 constant = self.constant + other.constant
197 return Expression(coefficients, constant)
198
199 __radd__ = __add__
200
201 @_polymorphic_method
202 def __sub__(self, other):
203 coefficients = dict(self.coefficients())
204 for symbol, coefficient in other.coefficients():
205 if symbol in coefficients:
206 coefficients[symbol] -= coefficient
207 else:
208 coefficients[symbol] = -coefficient
209 constant = self.constant - other.constant
210 return Expression(coefficients, constant)
211
212 def __rsub__(self, other):
213 return -(self - other)
214
215 @_polymorphic_method
216 def __mul__(self, other):
217 if other.isconstant():
218 coefficients = dict(self.coefficients())
219 for symbol in coefficients:
220 coefficients[symbol] *= other.constant
221 constant = self.constant * other.constant
222 return Expression(coefficients, constant)
223 if isinstance(other, Expression) and not self.isconstant():
224 raise ValueError('non-linear expression: '
225 '{} * {}'.format(self._prepr(), other._prepr()))
226 return NotImplemented
227
228 __rmul__ = __mul__
229
230 def __repr__(self):
231 string = ''
232 symbols = sorted(self.symbols())
233 i = 0
234 for symbol in symbols:
235 coefficient = self[symbol]
236 if coefficient == 1:
237 if i == 0:
238 string += symbol
239 else:
240 string += ' + {}'.format(symbol)
241 elif coefficient == -1:
242 if i == 0:
243 string += '-{}'.format(symbol)
244 else:
245 string += ' - {}'.format(symbol)
246 else:
247 if i == 0:
248 string += '{}*{}'.format(coefficient, symbol)
249 elif coefficient > 0:
250 string += ' + {}*{}'.format(coefficient, symbol)
251 else:
252 assert coefficient < 0
253 coefficient *= -1
254 string += ' - {}*{}'.format(coefficient, symbol)
255 i += 1
256 constant = self.constant
257 if constant != 0 and i == 0:
258 string += '{}'.format(constant)
259 elif constant > 0:
260 string += ' + {}'.format(constant)
261 elif constant < 0:
262 constant *= -1
263 string += ' - {}'.format(constant)
264 if string == '':
265 string = '0'
266 return string
267
268 def _prepr(self, always=False):
269 string = str(self)
270 if not always and (self.isconstant() or self.issymbol()):
271 return string
272 else:
273 return '({})'.format(string)
274
275 @_polymorphic_method
276 def __eq__(self, other):
277 # "normal" equality
278 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
279 return isinstance(other, Expression) and \
280 self._coefficients == other._coefficients and \
281 self.constant == other.constant
282
283 def __hash__(self):
284 return hash((tuple(self._coefficients), self._constant))
285
286 @_polymorphic_method
287 def _eq(self, other):
288 return Polyhedron(equalities=[self - other])
289
290 @_polymorphic_method
291 def __le__(self, other):
292 return Polyhedron(inequalities=[self - other])
293
294 @_polymorphic_method
295 def __lt__(self, other):
296 return Polyhedron(inequalities=[self - other + 1])
297
298 @_polymorphic_method
299 def __ge__(self, other):
300 return Polyhedron(inequalities=[other - self])
301
302 @_polymorphic_method
303 def __gt__(self, other):
304 return Polyhedron(inequalities=[other - self + 1])
305
306
307 def Constant(numerator=0, denominator=None):
308 if denominator is None and isinstance(numerator, numbers.Rational):
309 return Expression(constant=numerator)
310 else:
311 return Expression(constant=Fraction(numerator, denominator))
312
313 def Symbol(name):
314 name = _strsymbol(name)
315 return Expression(coefficients={name: 1})
316
317 def symbols(names):
318 names = _strsymbols(names)
319 return (Symbol(name) for name in names)
320
321
322 @_polymorphic_operator
323 def Eq(a, b):
324 return a._eq(b)
325
326 @_polymorphic_operator
327 def Le(a, b):
328 return a <= b
329
330 @_polymorphic_operator
331 def Lt(a, b):
332 return a < b
333
334 @_polymorphic_operator
335 def Ge(a, b):
336 return a >= b
337
338 @_polymorphic_operator
339 def Gt(a, b):
340 return a > b
341
342
343 class Polyhedron:
344 """
345 This class implements polyhedrons.
346 """
347
348 def __new__(cls, equalities=None, inequalities=None):
349 if equalities is None:
350 equalities = []
351 if inequalities is None:
352 inequalities = []
353 symbols = set()
354 for equality in equalities:
355 symbols.update(equality.symbols())
356 for inequality in inequalities:
357 symbols.update(inequality.symbols())
358 symbols = sorted(symbols)
359 string = _iscc.set(symbols, equalities, inequalities)
360 string = _iscc.eval(string)
361 return cls._fromiscc(symbols, string)
362
363 def _toiscc(self, symbols):
364 return _iscc.set(symbols, self.equalities, self.inequalities)
365
366 @classmethod
367 def _fromiscc(cls, symbols, string):
368 if re.match(r'^\s*\{\s*\}\s*$', string):
369 return empty
370 self = super().__new__(cls)
371 self._symbols = _strsymbols(symbols)
372 self._equalities = []
373 self._inequalities = []
374 string = re.sub(r'^\s*\{\s*(.*?)\s*\}\s*$', lambda m: m.group(1), string)
375 if ':' not in string:
376 string = string + ':'
377 vstr, cstr = re.split(r'\s*:\s*', string)
378 assert vstr != ''
379 vstr = re.sub(r'^\s*\[\s*(.*?)\s*\]\s*$', lambda m: m.group(1), vstr)
380 toks = list(filter(None, re.split(r'\s*,\s*', vstr)))
381 assert len(toks) == len(symbols)
382 for i in range(len(symbols)):
383 symbol = symbols[i]
384 if toks[i] != symbol:
385 expr = Expression._fromiscc(symbols, toks[i])
386 self._equalities.append(Symbol(symbol) - expr)
387 if cstr != '':
388 cstrs = re.split(r'\s*and\s*', cstr)
389 for cstr in cstrs:
390 lhs, op, rhs = re.split(r'\s*(<=|>=|<|>|=)\s*', cstr)
391 lhs = Expression._fromiscc(symbols, lhs)
392 rhs = Expression._fromiscc(symbols, rhs)
393 if op == '=':
394 self._equalities.append(lhs - rhs)
395 elif op == '<=':
396 self._inequalities.append(lhs - rhs)
397 elif op == '>=':
398 self._inequalities.append(rhs - lhs)
399 elif op == '<':
400 self._inequalities.append(lhs - rhs + 1)
401 elif op == '>':
402 self._inequalities.append(rhs - lhs + 1)
403 return self
404
405 @property
406 def equalities(self):
407 yield from self._equalities
408
409 @property
410 def inequalities(self):
411 yield from self._inequalities
412
413 def constraints(self):
414 yield from self.equalities
415 yield from self.inequalities
416
417 def symbols(self):
418 yield from self._symbols
419
420 @property
421 def dimension(self):
422 return len(self.symbols())
423
424 def __bool__(self):
425 # return false if the polyhedron is empty, true otherwise
426 raise not self.isempty()
427
428 def __eq__(self, other):
429 symbols = set(self.symbols()) | set(other.symbols())
430 string = '{} = {}'.format(self._toiscc(symbols), other._toiscc(symbols))
431 string = _iscc.eval(string)
432 return string == 'True'
433
434 def isempty(self):
435 return self == empty
436
437 def isuniverse(self):
438 return self == universe
439
440 def isdisjoint(self, other):
441 # return true if the polyhedron has no elements in common with other
442 return (self & other).isempty()
443
444 def issubset(self, other):
445 symbols = set(self.symbols()) | set(other.symbols())
446 string = '{} <= {}'.format(self._toiscc(symbols), other._toiscc(symbols))
447 string = _iscc.eval(string)
448 return string == 'True'
449
450 def __le__(self, other):
451 return self.issubset(other)
452
453 def __lt__(self, other):
454 symbols = set(self.symbols()) | set(other.symbols())
455 string = '{} < {}'.format(self._toiscc(symbols), other._toiscc(symbols))
456 string = _iscc.eval(string)
457 return string == 'True'
458
459 def issuperset(self, other):
460 # test whether every element in other is in the polyhedron
461 symbols = set(self.symbols()) | set(other.symbols())
462 string = '{} >= {}'.format(self._toiscc(symbols), other._toiscc(symbols))
463 string = _iscc.eval(string)
464 return string == 'True'
465
466 def __ge__(self, other):
467 return self.issuperset(other)
468
469 def __gt__(self, other):
470 symbols = set(self.symbols() + other.symbols())
471 string = '{} > {}'.format(self._toiscc(symbols), other._toiscc(symbols))
472 string = _iscc.eval(string)
473 return string == 'True'
474
475 def union(self, *others):
476 # return a new polyhedron with elements from the polyhedron and all
477 # others (convex union)
478 symbols = set(self.symbols())
479 for other in others:
480 symbols.update(other.symbols())
481 symbols = sorted(symbols)
482 strings = [self._toiscc(symbols)]
483 for other in others:
484 strings.append(other._toiscc(symbols))
485 string = ' + '.join(strings)
486 string = _iscc.eval('poly ({})'.format(string))
487 return Polyhedron._fromiscc(symbols, string)
488
489 def __or__(self, other):
490 return self.union(other)
491
492 def intersection(self, *others):
493 symbols = set(self.symbols())
494 for other in others:
495 symbols.update(other.symbols())
496 symbols = sorted(symbols)
497 strings = [self._toiscc(symbols)]
498 for other in others:
499 strings.append(other._toiscc(symbols))
500 string = ' * '.join(strings)
501 string = _iscc.eval('poly ({})'.format(string))
502 return Polyhedron._fromiscc(symbols, string)
503
504 def __and__(self, other):
505 return self.intersection(other)
506
507 def difference(self, *others):
508 # return a new polyhedron with elements in the polyhedron that are not
509 # in the others
510 symbols = set(self.symbols())
511 for other in others:
512 symbols.update(other.symbols())
513 symbols = sorted(symbols)
514 strings = [self._toiscc(symbols)]
515 for other in others:
516 strings.append(other._toiscc(symbols))
517 string = ' - '.join(strings)
518 string = _iscc.eval('poly ({})'.format(string))
519 return Polyhedron._fromiscc(symbols, string)
520
521 def __sub__(self, other):
522 return self.difference(other)
523
524 def projection(self, symbols):
525 symbols = _strsymbols(symbols)
526 string = _iscc.map(symbols, self.symbols(),
527 self.equalities, self.inequalities)
528 string = _iscc.eval('poly (dom ({}))'.format(string))
529 return Polyhedron._fromiscc(symbols, string)
530
531 def __getitem__(self, symbol):
532 return self.projection([symbol])
533
534 def __repr__(self):
535 constraints = []
536 for constraint in self.equalities:
537 constraints.append('{} == 0'.format(constraint))
538 for constraint in self.inequalities:
539 constraints.append('{} <= 0'.format(constraint))
540 if len(constraints) == 0:
541 return 'universe'
542 elif len(constraints) == 1:
543 string = constraints[0]
544 return 'empty' if string == '1 == 0' else string
545 else:
546 strings = ['({})'.format(constraint) for constraint in constraints]
547 return ' & '.join(strings)
548
549 empty = Eq(1, 0)
550
551 universe = Polyhedron()
552
553
554 if __name__ == '__main__':
555 x, y, z = symbols('x y z')
556 sq0 = (0 <= x) & (x <= 1) & (0 <= y) & (y <= 1)
557 sq1 = (1 <= x) & (x <= 2) & (1 <= y) & (y <= 2)
558 print('sq0 ∪ sq1 =', sq0 | sq1)
559 print('sq0 ∩ sq1 =', sq0 & sq1)
560 print('sq0[x] =', sq0[x])