fabf2a2d18df95413e1257a19d65eeea41881fdd
[linpy.git] / pypol / linear.py
1
2 import functools
3 import numbers
4
5 from fractions import Fraction, gcd
6
7
8 __all__ = [
9 'Expression',
10 'constant', 'symbol', 'symbols',
11 'eq', 'le', 'lt', 'ge', 'gt',
12 'Polyhedron',
13 'empty', 'universe'
14 ]
15
16
17 def _polymorphic_method(func):
18 @functools.wraps(func)
19 def wrapper(a, b):
20 if isinstance(b, Expression):
21 return func(a, b)
22 if isinstance(b, numbers.Rational):
23 b = constant(b)
24 return func(a, b)
25 return NotImplemented
26 return wrapper
27
28 def _polymorphic_operator(func):
29 @functools.wraps(func)
30 def wrapper(a, b):
31 if isinstance(a, numbers.Rational):
32 a = constant(a)
33 if isinstance(b, numbers.Rational):
34 b = constant(b)
35 if isinstance(a, Expression) and isinstance(b, Expression):
36 return func(a, b)
37 raise TypeError('arguments must be linear expressions')
38 return wrapper
39
40
41 class Expression:
42 """
43 This class implements linear expressions.
44 """
45
46 def __new__(cls, coefficients=None, constant=0):
47 if isinstance(coefficients, str):
48 if constant:
49 raise TypeError('too many arguments')
50 return cls.fromstring(coefficients)
51 self = super().__new__(cls)
52 self._coefficients = {}
53 if isinstance(coefficients, dict):
54 coefficients = coefficients.items()
55 if coefficients is not None:
56 for symbol, coefficient in coefficients:
57 if isinstance(symbol, Expression) and symbol.issymbol():
58 symbol = str(symbol)
59 elif not isinstance(symbol, str):
60 raise TypeError('symbols must be strings')
61 if not isinstance(coefficient, numbers.Rational):
62 raise TypeError('coefficients must be rational numbers')
63 if coefficient != 0:
64 self._coefficients[symbol] = coefficient
65 if not isinstance(constant, numbers.Rational):
66 raise TypeError('constant must be a rational number')
67 self._constant = constant
68 return self
69
70 def symbols(self):
71 yield from sorted(self._coefficients)
72
73 @property
74 def dimension(self):
75 return len(list(self.symbols()))
76
77 def coefficient(self, symbol):
78 if isinstance(symbol, Expression) and symbol.issymbol():
79 symbol = str(symbol)
80 elif not isinstance(symbol, str):
81 raise TypeError('symbol must be a string')
82 try:
83 return self._coefficients[symbol]
84 except KeyError:
85 return 0
86
87 __getitem__ = coefficient
88
89 def coefficients(self):
90 for symbol in self.symbols():
91 yield symbol, self.coefficient(symbol)
92
93 @property
94 def constant(self):
95 return self._constant
96
97 def isconstant(self):
98 return len(self._coefficients) == 0
99
100 def values(self):
101 for symbol in self.symbols():
102 yield self.coefficient(symbol)
103 yield self.constant
104
105 def symbol(self):
106 if not self.issymbol():
107 raise ValueError('not a symbol: {}'.format(self))
108 for symbol in self.symbols():
109 return symbol
110
111 def issymbol(self):
112 return len(self._coefficients) == 1 and self._constant == 0
113
114 def __bool__(self):
115 return (not self.isconstant()) or bool(self.constant)
116
117 def __pos__(self):
118 return self
119
120 def __neg__(self):
121 return self * -1
122
123 @_polymorphic_method
124 def __add__(self, other):
125 coefficients = dict(self.coefficients())
126 for symbol, coefficient in other.coefficients():
127 if symbol in coefficients:
128 coefficients[symbol] += coefficient
129 else:
130 coefficients[symbol] = coefficient
131 constant = self.constant + other.constant
132 return Expression(coefficients, constant)
133
134 __radd__ = __add__
135
136 @_polymorphic_method
137 def __sub__(self, other):
138 coefficients = dict(self.coefficients())
139 for symbol, coefficient in other.coefficients():
140 if symbol in coefficients:
141 coefficients[symbol] -= coefficient
142 else:
143 coefficients[symbol] = -coefficient
144 constant = self.constant - other.constant
145 return Expression(coefficients, constant)
146
147 __rsub__ = __sub__
148
149 @_polymorphic_method
150 def __mul__(self, other):
151 if other.isconstant():
152 coefficients = dict(self.coefficients())
153 for symbol in coefficients:
154 coefficients[symbol] *= other.constant
155 constant = self.constant * other.constant
156 return Expression(coefficients, constant)
157 if isinstance(other, Expression) and not self.isconstant():
158 raise ValueError('non-linear expression: '
159 '{} * {}'.format(self._parenstr(), other._parenstr()))
160 return NotImplemented
161
162 __rmul__ = __mul__
163
164 @_polymorphic_method
165 def __truediv__(self, other):
166 if other.isconstant():
167 coefficients = dict(self.coefficients())
168 for symbol in coefficients:
169 coefficients[symbol] = \
170 Fraction(coefficients[symbol], other.constant)
171 constant = Fraction(self.constant, other.constant)
172 return Expression(coefficients, constant)
173 if isinstance(other, Expression):
174 raise ValueError('non-linear expression: '
175 '{} / {}'.format(self._parenstr(), other._parenstr()))
176 return NotImplemented
177
178 def __rtruediv__(self, other):
179 if isinstance(other, Rational):
180 if self.isconstant():
181 constant = Fraction(other, self.constant)
182 return Expression(constant=constant)
183 else:
184 raise ValueError('non-linear expression: '
185 '{} / {}'.format(other._parenstr(), self._parenstr()))
186 return NotImplemented
187
188 def __str__(self):
189 string = ''
190 symbols = sorted(self.symbols())
191 i = 0
192 for symbol in symbols:
193 coefficient = self[symbol]
194 if coefficient == 1:
195 if i == 0:
196 string += symbol
197 else:
198 string += ' + {}'.format(symbol)
199 elif coefficient == -1:
200 if i == 0:
201 string += '-{}'.format(symbol)
202 else:
203 string += ' - {}'.format(symbol)
204 else:
205 if i == 0:
206 string += '{}*{}'.format(coefficient, symbol)
207 elif coefficient > 0:
208 string += ' + {}*{}'.format(coefficient, symbol)
209 else:
210 assert coefficient < 0
211 coefficient *= -1
212 string += ' - {}*{}'.format(coefficient, symbol)
213 i += 1
214 constant = self.constant
215 if constant != 0 and i == 0:
216 string += '{}'.format(constant)
217 elif constant > 0:
218 string += ' + {}'.format(constant)
219 elif constant < 0:
220 constant *= -1
221 string += ' - {}'.format(constant)
222 if string == '':
223 string = '0'
224 return string
225
226 def _parenstr(self, always=False):
227 string = str(self)
228 if not always and (self.isconstant() or self.issymbol()):
229 return string
230 else:
231 return '({})'.format(string)
232
233 def __repr__(self):
234 string = '{}({{'.format(self.__class__.__name__)
235 for i, (symbol, coefficient) in enumerate(self.coefficients()):
236 if i != 0:
237 string += ', '
238 string += '{!r}: {!r}'.format(symbol, coefficient)
239 string += '}}, {!r})'.format(self.constant)
240 return string
241
242 @classmethod
243 def fromstring(cls, string):
244 raise NotImplementedError
245
246 @_polymorphic_method
247 def __eq__(self, other):
248 # "normal" equality
249 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
250 return isinstance(other, Expression) and \
251 self._coefficients == other._coefficients and \
252 self.constant == other.constant
253
254 def __hash__(self):
255 return hash((self._coefficients, self._constant))
256
257 def _canonify(self):
258 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
259 [value.denominator for value in self.values()])
260 return self * lcm
261
262 @_polymorphic_method
263 def _eq(self, other):
264 return Polyhedron(equalities=[(self - other)._canonify()])
265
266 @_polymorphic_method
267 def __le__(self, other):
268 return Polyhedron(inequalities=[(self - other)._canonify()])
269
270 @_polymorphic_method
271 def __lt__(self, other):
272 return Polyhedron(inequalities=[(self - other)._canonify() + 1])
273
274 @_polymorphic_method
275 def __ge__(self, other):
276 return Polyhedron(inequalities=[(other - self)._canonify()])
277
278 @_polymorphic_method
279 def __gt__(self, other):
280 return Polyhedron(inequalities=[(other - self)._canonify() + 1])
281
282
283 def constant(numerator=0, denominator=None):
284 if denominator is None and isinstance(numerator, numbers.Rational):
285 return Expression(constant=numerator)
286 else:
287 return Expression(constant=Fraction(numerator, denominator))
288
289 def symbol(name):
290 if not isinstance(name, str):
291 raise TypeError('name must be a string')
292 return Expression(coefficients={name: 1})
293
294 def symbols(names):
295 if isinstance(names, str):
296 names = names.replace(',', ' ').split()
297 return (symbol(name) for name in names)
298
299
300 @_polymorphic_operator
301 def eq(a, b):
302 return a._eq(b)
303
304 @_polymorphic_operator
305 def le(a, b):
306 return a <= b
307
308 @_polymorphic_operator
309 def lt(a, b):
310 return a < b
311
312 @_polymorphic_operator
313 def ge(a, b):
314 return a >= b
315
316 @_polymorphic_operator
317 def gt(a, b):
318 return a > b
319
320
321 class Polyhedron:
322 """
323 This class implements polyhedrons.
324 """
325
326 def __new__(cls, equalities=None, inequalities=None):
327 if isinstance(equalities, str):
328 if inequalities is not None:
329 raise TypeError('too many arguments')
330 return cls.fromstring(equalities)
331 self = super().__new__(cls)
332 self._equalities = []
333 if equalities is not None:
334 for constraint in equalities:
335 for value in constraint.values():
336 if value.denominator != 1:
337 raise TypeError('non-integer constraint: '
338 '{} == 0'.format(constraint))
339 self._equalities.append(constraint)
340 self._inequalities = []
341 if inequalities is not None:
342 for constraint in inequalities:
343 for value in constraint.values():
344 if value.denominator != 1:
345 raise TypeError('non-integer constraint: '
346 '{} <= 0'.format(constraint))
347 self._inequalities.append(constraint)
348 return self
349
350 @property
351 def equalities(self):
352 yield from self._equalities
353
354 @property
355 def inequalities(self):
356 yield from self._inequalities
357
358 def constraints(self):
359 yield from self.equalities
360 yield from self.inequalities
361
362 def symbols(self):
363 s = set()
364 for constraint in self.constraints():
365 s.update(constraint.symbols)
366 yield from sorted(s)
367
368 @property
369 def dimension(self):
370 return len(self.symbols())
371
372 def __bool__(self):
373 # return false if the polyhedron is empty, true otherwise
374 raise NotImplementedError
375
376 def __contains__(self, value):
377 # is the value in the polyhedron?
378 raise NotImplementedError
379
380 def __eq__(self, other):
381 raise NotImplementedError
382
383 def isempty(self):
384 return self == empty
385
386 def isuniverse(self):
387 return self == universe
388
389 def isdisjoint(self, other):
390 # return true if the polyhedron has no elements in common with other
391 raise NotImplementedError
392
393 def issubset(self, other):
394 raise NotImplementedError
395
396 def __le__(self, other):
397 return self.issubset(other)
398
399 def __lt__(self, other):
400 raise NotImplementedError
401
402 def issuperset(self, other):
403 # test whether every element in other is in the polyhedron
404 raise NotImplementedError
405
406 def __ge__(self, other):
407 return self.issuperset(other)
408
409 def __gt__(self, other):
410 raise NotImplementedError
411
412 def union(self, *others):
413 # return a new polyhedron with elements from the polyhedron and all
414 # others (convex union)
415 raise NotImplementedError
416
417 def __or__(self, other):
418 return self.union(other)
419
420 def intersection(self, *others):
421 # return a new polyhedron with elements common to the polyhedron and all
422 # others
423 # a poor man's implementation could be:
424 # equalities = list(self.equalities)
425 # inequalities = list(self.inequalities)
426 # for other in others:
427 # equalities.extend(other.equalities)
428 # inequalities.extend(other.inequalities)
429 # return self.__class__(equalities, inequalities)
430 raise NotImplementedError
431
432 def __and__(self, other):
433 return self.intersection(other)
434
435 def difference(self, *others):
436 # return a new polyhedron with elements in the polyhedron that are not
437 # in the others
438 raise NotImplementedError
439
440 def __sub__(self, other):
441 return self.difference(other)
442
443 def __str__(self):
444 constraints = []
445 for constraint in self.equalities:
446 constraints.append('{} == 0'.format(constraint))
447 for constraint in self.inequalities:
448 constraints.append('{} <= 0'.format(constraint))
449 return '{{{}}}'.format(', '.join(constraints))
450
451 def __repr__(self):
452 equalities = list(self.equalities)
453 inequalities = list(self.inequalities)
454 return '{}(equalities={!r}, inequalities={!r})' \
455 ''.format(self.__class__.__name__, equalities, inequalities)
456
457 @classmethod
458 def fromstring(cls, string):
459 raise NotImplementedError
460
461
462 empty = le(1, 0)
463
464 universe = Polyhedron()