0db7edda4aa3fbff8ceac335c8f19cdd63394156
[linpy.git] / pypol / linexprs.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from fractions import Fraction, gcd
7
8
9 __all__ = [
10 'Expression',
11 'Symbol', 'symbols',
12 'Constant',
13 ]
14
15
16 def _polymorphic(func):
17 @functools.wraps(func)
18 def wrapper(left, right):
19 if isinstance(right, Expression):
20 return func(left, right)
21 elif isinstance(right, numbers.Rational):
22 right = Constant(right)
23 return func(left, right)
24 return NotImplemented
25 return wrapper
26
27
28 class Expression:
29 """
30 This class implements linear expressions.
31 """
32
33 __slots__ = (
34 '_coefficients',
35 '_constant',
36 '_symbols',
37 '_dimension',
38 )
39
40 def __new__(cls, coefficients=None, constant=0):
41 if isinstance(coefficients, str):
42 if constant:
43 raise TypeError('too many arguments')
44 return cls.fromstring(coefficients)
45 if isinstance(coefficients, dict):
46 coefficients = coefficients.items()
47 if coefficients is None:
48 return Constant(constant)
49 coefficients = [(symbol, coefficient)
50 for symbol, coefficient in coefficients if coefficient != 0]
51 if len(coefficients) == 0:
52 return Constant(constant)
53 elif len(coefficients) == 1 and constant == 0:
54 symbol, coefficient = coefficients[0]
55 if coefficient == 1:
56 return Symbol(symbol)
57 self = object().__new__(cls)
58 self._coefficients = {}
59 for symbol, coefficient in coefficients:
60 if isinstance(symbol, Symbol):
61 symbol = symbol.name
62 elif not isinstance(symbol, str):
63 raise TypeError('symbols must be strings or Symbol instances')
64 if isinstance(coefficient, Constant):
65 coefficient = coefficient.constant
66 if not isinstance(coefficient, numbers.Rational):
67 raise TypeError('coefficients must be rational numbers '
68 'or Constant instances')
69 self._coefficients[symbol] = coefficient
70 if isinstance(constant, Constant):
71 constant = constant.constant
72 if not isinstance(constant, numbers.Rational):
73 raise TypeError('constant must be a rational number '
74 'or a Constant instance')
75 self._constant = constant
76 self._symbols = tuple(sorted(self._coefficients))
77 self._dimension = len(self._symbols)
78 return self
79
80 def coefficient(self, symbol):
81 if isinstance(symbol, Symbol):
82 symbol = str(symbol)
83 elif not isinstance(symbol, str):
84 raise TypeError('symbol must be a string or a Symbol instance')
85 try:
86 return self._coefficients[symbol]
87 except KeyError:
88 return 0
89
90 __getitem__ = coefficient
91
92 def coefficients(self):
93 for symbol in self.symbols:
94 yield symbol, self.coefficient(symbol)
95
96 @property
97 def constant(self):
98 return self._constant
99
100 @property
101 def symbols(self):
102 return self._symbols
103
104 @property
105 def dimension(self):
106 return self._dimension
107
108 def isconstant(self):
109 return False
110
111 def issymbol(self):
112 return False
113
114 def values(self):
115 for symbol in self.symbols:
116 yield self.coefficient(symbol)
117 yield self.constant
118
119 def __bool__(self):
120 return True
121
122 def __pos__(self):
123 return self
124
125 def __neg__(self):
126 return self * -1
127
128 @_polymorphic
129 def __add__(self, other):
130 coefficients = dict(self.coefficients())
131 for symbol, coefficient in other.coefficients():
132 if symbol in coefficients:
133 coefficients[symbol] += coefficient
134 else:
135 coefficients[symbol] = coefficient
136 constant = self.constant + other.constant
137 return Expression(coefficients, constant)
138
139 __radd__ = __add__
140
141 @_polymorphic
142 def __sub__(self, other):
143 coefficients = dict(self.coefficients())
144 for symbol, coefficient in other.coefficients():
145 if symbol in coefficients:
146 coefficients[symbol] -= coefficient
147 else:
148 coefficients[symbol] = -coefficient
149 constant = self.constant - other.constant
150 return Expression(coefficients, constant)
151
152 def __rsub__(self, other):
153 return -(self - other)
154
155 @_polymorphic
156 def __mul__(self, other):
157 if other.isconstant():
158 coefficients = dict(self.coefficients())
159 for symbol in coefficients:
160 coefficients[symbol] *= other.constant
161 constant = self.constant * other.constant
162 return Expression(coefficients, constant)
163 if isinstance(other, Expression) and not self.isconstant():
164 raise ValueError('non-linear expression: '
165 '{} * {}'.format(self._parenstr(), other._parenstr()))
166 return NotImplemented
167
168 __rmul__ = __mul__
169
170 @_polymorphic
171 def __truediv__(self, other):
172 if other.isconstant():
173 coefficients = dict(self.coefficients())
174 for symbol in coefficients:
175 coefficients[symbol] = \
176 Fraction(coefficients[symbol], other.constant)
177 constant = Fraction(self.constant, other.constant)
178 return Expression(coefficients, constant)
179 if isinstance(other, Expression):
180 raise ValueError('non-linear expression: '
181 '{} / {}'.format(self._parenstr(), other._parenstr()))
182 return NotImplemented
183
184 def __rtruediv__(self, other):
185 if isinstance(other, self):
186 if self.isconstant():
187 constant = Fraction(other, self.constant)
188 return Expression(constant=constant)
189 else:
190 raise ValueError('non-linear expression: '
191 '{} / {}'.format(other._parenstr(), self._parenstr()))
192 return NotImplemented
193
194 @_polymorphic
195 def __eq__(self, other):
196 # "normal" equality
197 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
198 return isinstance(other, Expression) and \
199 self._coefficients == other._coefficients and \
200 self.constant == other.constant
201
202 @_polymorphic
203 def __le__(self, other):
204 from .polyhedra import Le
205 return Le(self, other)
206
207 @_polymorphic
208 def __lt__(self, other):
209 from .polyhedra import Lt
210 return Lt(self, other)
211
212 @_polymorphic
213 def __ge__(self, other):
214 from .polyhedra import Ge
215 return Ge(self, other)
216
217 @_polymorphic
218 def __gt__(self, other):
219 from .polyhedra import Gt
220 return Gt(self, other)
221
222 def __hash__(self):
223 return hash((tuple(sorted(self._coefficients.items())), self._constant))
224
225 def _toint(self):
226 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
227 [value.denominator for value in self.values()])
228 return self * lcm
229
230 @classmethod
231 def _fromast(cls, node):
232 if isinstance(node, ast.Module) and len(node.body) == 1:
233 return cls._fromast(node.body[0])
234 elif isinstance(node, ast.Expr):
235 return cls._fromast(node.value)
236 elif isinstance(node, ast.Name):
237 return Symbol(node.id)
238 elif isinstance(node, ast.Num):
239 return Constant(node.n)
240 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
241 return -cls._fromast(node.operand)
242 elif isinstance(node, ast.BinOp):
243 left = cls._fromast(node.left)
244 right = cls._fromast(node.right)
245 if isinstance(node.op, ast.Add):
246 return left + right
247 elif isinstance(node.op, ast.Sub):
248 return left - right
249 elif isinstance(node.op, ast.Mult):
250 return left * right
251 elif isinstance(node.op, ast.Div):
252 return left / right
253 raise SyntaxError('invalid syntax')
254
255 @classmethod
256 def fromstring(cls, string):
257 string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string)
258 tree = ast.parse(string, 'eval')
259 return cls._fromast(tree)
260
261 def __str__(self):
262 string = ''
263 i = 0
264 for symbol in self.symbols:
265 coefficient = self.coefficient(symbol)
266 if coefficient == 1:
267 if i == 0:
268 string += symbol
269 else:
270 string += ' + {}'.format(symbol)
271 elif coefficient == -1:
272 if i == 0:
273 string += '-{}'.format(symbol)
274 else:
275 string += ' - {}'.format(symbol)
276 else:
277 if i == 0:
278 string += '{}*{}'.format(coefficient, symbol)
279 elif coefficient > 0:
280 string += ' + {}*{}'.format(coefficient, symbol)
281 else:
282 assert coefficient < 0
283 coefficient *= -1
284 string += ' - {}*{}'.format(coefficient, symbol)
285 i += 1
286 constant = self.constant
287 if constant != 0 and i == 0:
288 string += '{}'.format(constant)
289 elif constant > 0:
290 string += ' + {}'.format(constant)
291 elif constant < 0:
292 constant *= -1
293 string += ' - {}'.format(constant)
294 if string == '':
295 string = '0'
296 return string
297
298 def _parenstr(self, always=False):
299 string = str(self)
300 if not always and (self.isconstant() or self.issymbol()):
301 return string
302 else:
303 return '({})'.format(string)
304
305 def __repr__(self):
306 return '{}({!r})'.format(self.__class__.__name__, str(self))
307
308 @classmethod
309 def fromsympy(cls, expr):
310 import sympy
311 coefficients = {}
312 constant = 0
313 for symbol, coefficient in expr.as_coefficients_dict().items():
314 coefficient = Fraction(coefficient.p, coefficient.q)
315 if symbol == sympy.S.One:
316 constant = coefficient
317 elif isinstance(symbol, sympy.Symbol):
318 symbol = symbol.name
319 coefficients[symbol] = coefficient
320 else:
321 raise ValueError('non-linear expression: {!r}'.format(expr))
322 return cls(coefficients, constant)
323
324 def tosympy(self):
325 import sympy
326 expr = 0
327 for symbol, coefficient in self.coefficients():
328 term = coefficient * sympy.Symbol(symbol)
329 expr += term
330 expr += self.constant
331 return expr
332
333
334 class Symbol(Expression):
335
336 __slots__ = Expression.__slots__ + (
337 '_name',
338 )
339
340 def __new__(cls, name):
341 if isinstance(name, Symbol):
342 name = name.name
343 elif not isinstance(name, str):
344 raise TypeError('name must be a string or a Symbol instance')
345 name = name.strip()
346 self = object().__new__(cls)
347 self._coefficients = {name: 1}
348 self._constant = 0
349 self._symbols = tuple(name)
350 self._name = name
351 self._dimension = 1
352 return self
353
354 @property
355 def name(self):
356 return self._name
357
358 def issymbol(self):
359 return True
360
361 @classmethod
362 def _fromast(cls, node):
363 if isinstance(node, ast.Module) and len(node.body) == 1:
364 return cls._fromast(node.body[0])
365 elif isinstance(node, ast.Expr):
366 return cls._fromast(node.value)
367 elif isinstance(node, ast.Name):
368 return Symbol(node.id)
369 raise SyntaxError('invalid syntax')
370
371 def __repr__(self):
372 return '{}({!r})'.format(self.__class__.__name__, self._name)
373
374 @classmethod
375 def fromsympy(cls, expr):
376 import sympy
377 if isinstance(expr, sympy.Symbol):
378 return cls(expr.name)
379 else:
380 raise TypeError('expr must be a sympy.Symbol instance')
381
382
383 def symbols(names):
384 if isinstance(names, str):
385 names = names.replace(',', ' ').split()
386 return (Symbol(name) for name in names)
387
388
389 class Constant(Expression):
390
391 def __new__(cls, numerator=0, denominator=None):
392 self = object().__new__(cls)
393 if denominator is None and isinstance(numerator, Constant):
394 self._constant = numerator.constant
395 else:
396 self._constant = Fraction(numerator, denominator)
397 self._coefficients = {}
398 self._symbols = ()
399 self._dimension = 0
400 return self
401
402 def isconstant(self):
403 return True
404
405 def __bool__(self):
406 return self.constant != 0
407
408 @classmethod
409 def fromstring(cls, string):
410 if isinstance(string, str):
411 return Constant(Fraction(string))
412 else:
413 raise TypeError('string must be a string instance')
414
415 def __repr__(self):
416 if self.constant.denominator == 1:
417 return '{}({!r})'.format(self.__class__.__name__,
418 self.constant.numerator)
419 else:
420 return '{}({!r}, {!r})'.format(self.__class__.__name__,
421 self.constant.numerator, self.constant.denominator)
422
423 @classmethod
424 def fromsympy(cls, expr):
425 import sympy
426 if isinstance(expr, sympy.Rational):
427 return cls(expr.p, expr.q)
428 elif isinstance(expr, numbers.Rational):
429 return cls(expr)
430 else:
431 raise TypeError('expr must be a sympy.Rational instance')