Better implementation of _polymorphic_operator
[linpy.git] / pypol / linear.py
1
2 import functools
3 import numbers
4
5 from fractions import Fraction, gcd
6
7
8 __all__ = [
9 'Expression',
10 'constant', 'symbol', 'symbols',
11 'eq', 'le', 'lt', 'ge', 'gt',
12 'Polyhedron',
13 'empty', 'universe'
14 ]
15
16
17 def _polymorphic_method(func):
18 @functools.wraps(func)
19 def wrapper(a, b):
20 if isinstance(b, Expression):
21 return func(a, b)
22 if isinstance(b, numbers.Rational):
23 b = constant(b)
24 return func(a, b)
25 return NotImplemented
26 return wrapper
27
28 def _polymorphic_operator(func):
29 # A polymorphic operator should call a polymorphic method, hence we just
30 # have to test the left operand.
31 @functools.wraps(func)
32 def wrapper(a, b):
33 if isinstance(a, numbers.Rational):
34 a = constant(a)
35 return func(a, b)
36 elif isinstance(a, Expression):
37 return func(a, b)
38 raise TypeError('arguments must be linear expressions')
39 return wrapper
40
41
42 class Expression:
43 """
44 This class implements linear expressions.
45 """
46
47 def __new__(cls, coefficients=None, constant=0):
48 if isinstance(coefficients, str):
49 if constant:
50 raise TypeError('too many arguments')
51 return cls.fromstring(coefficients)
52 self = super().__new__(cls)
53 self._coefficients = {}
54 if isinstance(coefficients, dict):
55 coefficients = coefficients.items()
56 if coefficients is not None:
57 for symbol, coefficient in coefficients:
58 if isinstance(symbol, Expression) and symbol.issymbol():
59 symbol = str(symbol)
60 elif not isinstance(symbol, str):
61 raise TypeError('symbols must be strings')
62 if not isinstance(coefficient, numbers.Rational):
63 raise TypeError('coefficients must be rational numbers')
64 if coefficient != 0:
65 self._coefficients[symbol] = coefficient
66 if not isinstance(constant, numbers.Rational):
67 raise TypeError('constant must be a rational number')
68 self._constant = constant
69 return self
70
71 def symbols(self):
72 yield from sorted(self._coefficients)
73
74 @property
75 def dimension(self):
76 return len(list(self.symbols()))
77
78 def coefficient(self, symbol):
79 if isinstance(symbol, Expression) and symbol.issymbol():
80 symbol = str(symbol)
81 elif not isinstance(symbol, str):
82 raise TypeError('symbol must be a string')
83 try:
84 return self._coefficients[symbol]
85 except KeyError:
86 return 0
87
88 __getitem__ = coefficient
89
90 def coefficients(self):
91 for symbol in self.symbols():
92 yield symbol, self.coefficient(symbol)
93
94 @property
95 def constant(self):
96 return self._constant
97
98 def isconstant(self):
99 return len(self._coefficients) == 0
100
101 def values(self):
102 for symbol in self.symbols():
103 yield self.coefficient(symbol)
104 yield self.constant
105
106 def symbol(self):
107 if not self.issymbol():
108 raise ValueError('not a symbol: {}'.format(self))
109 for symbol in self.symbols():
110 return symbol
111
112 def issymbol(self):
113 return len(self._coefficients) == 1 and self._constant == 0
114
115 def __bool__(self):
116 return (not self.isconstant()) or bool(self.constant)
117
118 def __pos__(self):
119 return self
120
121 def __neg__(self):
122 return self * -1
123
124 @_polymorphic_method
125 def __add__(self, other):
126 coefficients = dict(self.coefficients())
127 for symbol, coefficient in other.coefficients():
128 if symbol in coefficients:
129 coefficients[symbol] += coefficient
130 else:
131 coefficients[symbol] = coefficient
132 constant = self.constant + other.constant
133 return Expression(coefficients, constant)
134
135 __radd__ = __add__
136
137 @_polymorphic_method
138 def __sub__(self, other):
139 coefficients = dict(self.coefficients())
140 for symbol, coefficient in other.coefficients():
141 if symbol in coefficients:
142 coefficients[symbol] -= coefficient
143 else:
144 coefficients[symbol] = -coefficient
145 constant = self.constant - other.constant
146 return Expression(coefficients, constant)
147
148 __rsub__ = __sub__
149
150 @_polymorphic_method
151 def __mul__(self, other):
152 if other.isconstant():
153 coefficients = dict(self.coefficients())
154 for symbol in coefficients:
155 coefficients[symbol] *= other.constant
156 constant = self.constant * other.constant
157 return Expression(coefficients, constant)
158 if isinstance(other, Expression) and not self.isconstant():
159 raise ValueError('non-linear expression: '
160 '{} * {}'.format(self._parenstr(), other._parenstr()))
161 return NotImplemented
162
163 __rmul__ = __mul__
164
165 @_polymorphic_method
166 def __truediv__(self, other):
167 if other.isconstant():
168 coefficients = dict(self.coefficients())
169 for symbol in coefficients:
170 coefficients[symbol] = \
171 Fraction(coefficients[symbol], other.constant)
172 constant = Fraction(self.constant, other.constant)
173 return Expression(coefficients, constant)
174 if isinstance(other, Expression):
175 raise ValueError('non-linear expression: '
176 '{} / {}'.format(self._parenstr(), other._parenstr()))
177 return NotImplemented
178
179 def __rtruediv__(self, other):
180 if isinstance(other, Rational):
181 if self.isconstant():
182 constant = Fraction(other, self.constant)
183 return Expression(constant=constant)
184 else:
185 raise ValueError('non-linear expression: '
186 '{} / {}'.format(other._parenstr(), self._parenstr()))
187 return NotImplemented
188
189 def __str__(self):
190 string = ''
191 symbols = sorted(self.symbols())
192 i = 0
193 for symbol in symbols:
194 coefficient = self[symbol]
195 if coefficient == 1:
196 if i == 0:
197 string += symbol
198 else:
199 string += ' + {}'.format(symbol)
200 elif coefficient == -1:
201 if i == 0:
202 string += '-{}'.format(symbol)
203 else:
204 string += ' - {}'.format(symbol)
205 else:
206 if i == 0:
207 string += '{}*{}'.format(coefficient, symbol)
208 elif coefficient > 0:
209 string += ' + {}*{}'.format(coefficient, symbol)
210 else:
211 assert coefficient < 0
212 coefficient *= -1
213 string += ' - {}*{}'.format(coefficient, symbol)
214 i += 1
215 constant = self.constant
216 if constant != 0 and i == 0:
217 string += '{}'.format(constant)
218 elif constant > 0:
219 string += ' + {}'.format(constant)
220 elif constant < 0:
221 constant *= -1
222 string += ' - {}'.format(constant)
223 if string == '':
224 string = '0'
225 return string
226
227 def _parenstr(self, always=False):
228 string = str(self)
229 if not always and (self.isconstant() or self.issymbol()):
230 return string
231 else:
232 return '({})'.format(string)
233
234 def __repr__(self):
235 string = '{}({{'.format(self.__class__.__name__)
236 for i, (symbol, coefficient) in enumerate(self.coefficients()):
237 if i != 0:
238 string += ', '
239 string += '{!r}: {!r}'.format(symbol, coefficient)
240 string += '}}, {!r})'.format(self.constant)
241 return string
242
243 @classmethod
244 def fromstring(cls, string):
245 raise NotImplementedError
246
247 @_polymorphic_method
248 def __eq__(self, other):
249 # "normal" equality
250 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
251 return isinstance(other, Expression) and \
252 self._coefficients == other._coefficients and \
253 self.constant == other.constant
254
255 def __hash__(self):
256 return hash((self._coefficients, self._constant))
257
258 def _canonify(self):
259 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
260 [value.denominator for value in self.values()])
261 return self * lcm
262
263 @_polymorphic_method
264 def _eq(self, other):
265 return Polyhedron(equalities=[(self - other)._canonify()])
266
267 @_polymorphic_method
268 def __le__(self, other):
269 return Polyhedron(inequalities=[(self - other)._canonify()])
270
271 @_polymorphic_method
272 def __lt__(self, other):
273 return Polyhedron(inequalities=[(self - other)._canonify() + 1])
274
275 @_polymorphic_method
276 def __ge__(self, other):
277 return Polyhedron(inequalities=[(other - self)._canonify()])
278
279 @_polymorphic_method
280 def __gt__(self, other):
281 return Polyhedron(inequalities=[(other - self)._canonify() + 1])
282
283
284 def constant(numerator=0, denominator=None):
285 if denominator is None and isinstance(numerator, numbers.Rational):
286 return Expression(constant=numerator)
287 else:
288 return Expression(constant=Fraction(numerator, denominator))
289
290 def symbol(name):
291 if not isinstance(name, str):
292 raise TypeError('name must be a string')
293 return Expression(coefficients={name: 1})
294
295 def symbols(names):
296 if isinstance(names, str):
297 names = names.replace(',', ' ').split()
298 return (symbol(name) for name in names)
299
300
301 @_polymorphic_operator
302 def eq(a, b):
303 return a._eq(b)
304
305 @_polymorphic_operator
306 def le(a, b):
307 return a <= b
308
309 @_polymorphic_operator
310 def lt(a, b):
311 return a < b
312
313 @_polymorphic_operator
314 def ge(a, b):
315 return a >= b
316
317 @_polymorphic_operator
318 def gt(a, b):
319 return a > b
320
321
322 class Polyhedron:
323 """
324 This class implements polyhedrons.
325 """
326
327 def __new__(cls, equalities=None, inequalities=None):
328 if isinstance(equalities, str):
329 if inequalities is not None:
330 raise TypeError('too many arguments')
331 return cls.fromstring(equalities)
332 self = super().__new__(cls)
333 self._equalities = []
334 if equalities is not None:
335 for constraint in equalities:
336 for value in constraint.values():
337 if value.denominator != 1:
338 raise TypeError('non-integer constraint: '
339 '{} == 0'.format(constraint))
340 self._equalities.append(constraint)
341 self._inequalities = []
342 if inequalities is not None:
343 for constraint in inequalities:
344 for value in constraint.values():
345 if value.denominator != 1:
346 raise TypeError('non-integer constraint: '
347 '{} <= 0'.format(constraint))
348 self._inequalities.append(constraint)
349 return self
350
351 @property
352 def equalities(self):
353 yield from self._equalities
354
355 @property
356 def inequalities(self):
357 yield from self._inequalities
358
359 def constraints(self):
360 yield from self.equalities
361 yield from self.inequalities
362
363 def symbols(self):
364 s = set()
365 for constraint in self.constraints():
366 s.update(constraint.symbols)
367 yield from sorted(s)
368
369 @property
370 def dimension(self):
371 return len(self.symbols())
372
373 def __bool__(self):
374 # return false if the polyhedron is empty, true otherwise
375 raise NotImplementedError
376
377 def __contains__(self, value):
378 # is the value in the polyhedron?
379 raise NotImplementedError
380
381 def __eq__(self, other):
382 raise NotImplementedError
383
384 def isempty(self):
385 return self == empty
386
387 def isuniverse(self):
388 return self == universe
389
390 def isdisjoint(self, other):
391 # return true if the polyhedron has no elements in common with other
392 raise NotImplementedError
393
394 def issubset(self, other):
395 raise NotImplementedError
396
397 def __le__(self, other):
398 return self.issubset(other)
399
400 def __lt__(self, other):
401 raise NotImplementedError
402
403 def issuperset(self, other):
404 # test whether every element in other is in the polyhedron
405 raise NotImplementedError
406
407 def __ge__(self, other):
408 return self.issuperset(other)
409
410 def __gt__(self, other):
411 raise NotImplementedError
412
413 def union(self, *others):
414 # return a new polyhedron with elements from the polyhedron and all
415 # others (convex union)
416 raise NotImplementedError
417
418 def __or__(self, other):
419 return self.union(other)
420
421 def intersection(self, *others):
422 # return a new polyhedron with elements common to the polyhedron and all
423 # others
424 # a poor man's implementation could be:
425 # equalities = list(self.equalities)
426 # inequalities = list(self.inequalities)
427 # for other in others:
428 # equalities.extend(other.equalities)
429 # inequalities.extend(other.inequalities)
430 # return self.__class__(equalities, inequalities)
431 raise NotImplementedError
432
433 def __and__(self, other):
434 return self.intersection(other)
435
436 def difference(self, *others):
437 # return a new polyhedron with elements in the polyhedron that are not
438 # in the others
439 raise NotImplementedError
440
441 def __sub__(self, other):
442 return self.difference(other)
443
444 def __str__(self):
445 constraints = []
446 for constraint in self.equalities:
447 constraints.append('{} == 0'.format(constraint))
448 for constraint in self.inequalities:
449 constraints.append('{} <= 0'.format(constraint))
450 return '{{{}}}'.format(', '.join(constraints))
451
452 def __repr__(self):
453 equalities = list(self.equalities)
454 inequalities = list(self.inequalities)
455 return '{}(equalities={!r}, inequalities={!r})' \
456 ''.format(self.__class__.__name__, equalities, inequalities)
457
458 @classmethod
459 def fromstring(cls, string):
460 raise NotImplementedError
461
462
463 empty = le(1, 0)
464
465 universe = Polyhedron()