Helper functions symbolname and symbolnames
[linpy.git] / pypol / linexprs.py
1 import ast
2 import functools
3 import numbers
4 import re
5
6 from collections import OrderedDict
7 from fractions import Fraction, gcd
8
9
10 __all__ = [
11 'Expression',
12 'Symbol', 'symbols', 'symbolname', 'symbolnames',
13 'Constant',
14 ]
15
16
17 def _polymorphic(func):
18 @functools.wraps(func)
19 def wrapper(left, right):
20 if isinstance(right, Expression):
21 return func(left, right)
22 elif isinstance(right, numbers.Rational):
23 right = Constant(right)
24 return func(left, right)
25 return NotImplemented
26 return wrapper
27
28
29 class Expression:
30 """
31 This class implements linear expressions.
32 """
33
34 __slots__ = (
35 '_coefficients',
36 '_constant',
37 '_symbols',
38 '_dimension',
39 '_hash',
40 )
41
42 def __new__(cls, coefficients=None, constant=0):
43 if isinstance(coefficients, str):
44 if constant:
45 raise TypeError('too many arguments')
46 return cls.fromstring(coefficients)
47 if isinstance(coefficients, dict):
48 coefficients = coefficients.items()
49 if coefficients is None:
50 return Constant(constant)
51 coefficients = [(symbol, coefficient)
52 for symbol, coefficient in coefficients if coefficient != 0]
53 if len(coefficients) == 0:
54 return Constant(constant)
55 elif len(coefficients) == 1 and constant == 0:
56 symbol, coefficient = coefficients[0]
57 if coefficient == 1:
58 return Symbol(symbol)
59 self = object().__new__(cls)
60 self._coefficients = {}
61 for symbol, coefficient in coefficients:
62 symbol = symbolname(symbol)
63 if isinstance(coefficient, Constant):
64 coefficient = coefficient.constant
65 if not isinstance(coefficient, numbers.Rational):
66 raise TypeError('coefficients must be rational numbers '
67 'or Constant instances')
68 self._coefficients[symbol] = coefficient
69 self._coefficients = OrderedDict(sorted(self._coefficients.items()))
70 if isinstance(constant, Constant):
71 constant = constant.constant
72 if not isinstance(constant, numbers.Rational):
73 raise TypeError('constant must be a rational number '
74 'or a Constant instance')
75 self._constant = constant
76 self._symbols = tuple(self._coefficients)
77 self._dimension = len(self._symbols)
78 self._hash = hash((tuple(self._coefficients.items()), self._constant))
79 return self
80
81 def coefficient(self, symbol):
82 symbol = symbolname(symbol)
83 try:
84 return self._coefficients[symbol]
85 except KeyError:
86 return 0
87
88 __getitem__ = coefficient
89
90 def coefficients(self):
91 yield from self._coefficients.items()
92
93 @property
94 def constant(self):
95 return self._constant
96
97 @property
98 def symbols(self):
99 return self._symbols
100
101 @property
102 def dimension(self):
103 return self._dimension
104
105 def __hash__(self):
106 return self._hash
107
108 def isconstant(self):
109 return False
110
111 def issymbol(self):
112 return False
113
114 def values(self):
115 for symbol in self.symbols:
116 yield self.coefficient(symbol)
117 yield self.constant
118
119 def __bool__(self):
120 return True
121
122 def __pos__(self):
123 return self
124
125 def __neg__(self):
126 return self * -1
127
128 @_polymorphic
129 def __add__(self, other):
130 coefficients = dict(self.coefficients())
131 for symbol, coefficient in other.coefficients():
132 if symbol in coefficients:
133 coefficients[symbol] += coefficient
134 else:
135 coefficients[symbol] = coefficient
136 constant = self.constant + other.constant
137 return Expression(coefficients, constant)
138
139 __radd__ = __add__
140
141 @_polymorphic
142 def __sub__(self, other):
143 coefficients = dict(self.coefficients())
144 for symbol, coefficient in other.coefficients():
145 if symbol in coefficients:
146 coefficients[symbol] -= coefficient
147 else:
148 coefficients[symbol] = -coefficient
149 constant = self.constant - other.constant
150 return Expression(coefficients, constant)
151
152 def __rsub__(self, other):
153 return -(self - other)
154
155 @_polymorphic
156 def __mul__(self, other):
157 if other.isconstant():
158 coefficients = dict(self.coefficients())
159 for symbol in coefficients:
160 coefficients[symbol] *= other.constant
161 constant = self.constant * other.constant
162 return Expression(coefficients, constant)
163 if isinstance(other, Expression) and not self.isconstant():
164 raise ValueError('non-linear expression: '
165 '{} * {}'.format(self._parenstr(), other._parenstr()))
166 return NotImplemented
167
168 __rmul__ = __mul__
169
170 @_polymorphic
171 def __truediv__(self, other):
172 if other.isconstant():
173 coefficients = dict(self.coefficients())
174 for symbol in coefficients:
175 coefficients[symbol] = \
176 Fraction(coefficients[symbol], other.constant)
177 constant = Fraction(self.constant, other.constant)
178 return Expression(coefficients, constant)
179 if isinstance(other, Expression):
180 raise ValueError('non-linear expression: '
181 '{} / {}'.format(self._parenstr(), other._parenstr()))
182 return NotImplemented
183
184 def __rtruediv__(self, other):
185 if isinstance(other, self):
186 if self.isconstant():
187 constant = Fraction(other, self.constant)
188 return Expression(constant=constant)
189 else:
190 raise ValueError('non-linear expression: '
191 '{} / {}'.format(other._parenstr(), self._parenstr()))
192 return NotImplemented
193
194 @_polymorphic
195 def __eq__(self, other):
196 # "normal" equality
197 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
198 return isinstance(other, Expression) and \
199 self._coefficients == other._coefficients and \
200 self.constant == other.constant
201
202 @_polymorphic
203 def __le__(self, other):
204 from .polyhedra import Le
205 return Le(self, other)
206
207 @_polymorphic
208 def __lt__(self, other):
209 from .polyhedra import Lt
210 return Lt(self, other)
211
212 @_polymorphic
213 def __ge__(self, other):
214 from .polyhedra import Ge
215 return Ge(self, other)
216
217 @_polymorphic
218 def __gt__(self, other):
219 from .polyhedra import Gt
220 return Gt(self, other)
221
222 def _toint(self):
223 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
224 [value.denominator for value in self.values()])
225 return self * lcm
226
227 @classmethod
228 def _fromast(cls, node):
229 if isinstance(node, ast.Module) and len(node.body) == 1:
230 return cls._fromast(node.body[0])
231 elif isinstance(node, ast.Expr):
232 return cls._fromast(node.value)
233 elif isinstance(node, ast.Name):
234 return Symbol(node.id)
235 elif isinstance(node, ast.Num):
236 return Constant(node.n)
237 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
238 return -cls._fromast(node.operand)
239 elif isinstance(node, ast.BinOp):
240 left = cls._fromast(node.left)
241 right = cls._fromast(node.right)
242 if isinstance(node.op, ast.Add):
243 return left + right
244 elif isinstance(node.op, ast.Sub):
245 return left - right
246 elif isinstance(node.op, ast.Mult):
247 return left * right
248 elif isinstance(node.op, ast.Div):
249 return left / right
250 raise SyntaxError('invalid syntax')
251
252 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
253
254 @classmethod
255 def fromstring(cls, string):
256 # add implicit multiplication operators, e.g. '5x' -> '5*x'
257 string = cls._RE_NUM_VAR.sub(r'\1*\2', string)
258 tree = ast.parse(string, 'eval')
259 return cls._fromast(tree)
260
261 def __str__(self):
262 string = ''
263 i = 0
264 for symbol in self.symbols:
265 coefficient = self.coefficient(symbol)
266 if coefficient == 1:
267 if i == 0:
268 string += symbol
269 else:
270 string += ' + {}'.format(symbol)
271 elif coefficient == -1:
272 if i == 0:
273 string += '-{}'.format(symbol)
274 else:
275 string += ' - {}'.format(symbol)
276 else:
277 if i == 0:
278 string += '{}*{}'.format(coefficient, symbol)
279 elif coefficient > 0:
280 string += ' + {}*{}'.format(coefficient, symbol)
281 else:
282 assert coefficient < 0
283 coefficient *= -1
284 string += ' - {}*{}'.format(coefficient, symbol)
285 i += 1
286 constant = self.constant
287 if constant != 0 and i == 0:
288 string += '{}'.format(constant)
289 elif constant > 0:
290 string += ' + {}'.format(constant)
291 elif constant < 0:
292 constant *= -1
293 string += ' - {}'.format(constant)
294 if string == '':
295 string = '0'
296 return string
297
298 def _parenstr(self, always=False):
299 string = str(self)
300 if not always and (self.isconstant() or self.issymbol()):
301 return string
302 else:
303 return '({})'.format(string)
304
305 def __repr__(self):
306 return '{}({!r})'.format(self.__class__.__name__, str(self))
307
308 @classmethod
309 def fromsympy(cls, expr):
310 import sympy
311 coefficients = {}
312 constant = 0
313 for symbol, coefficient in expr.as_coefficients_dict().items():
314 coefficient = Fraction(coefficient.p, coefficient.q)
315 if symbol == sympy.S.One:
316 constant = coefficient
317 elif isinstance(symbol, sympy.Symbol):
318 symbol = symbol.name
319 coefficients[symbol] = coefficient
320 else:
321 raise ValueError('non-linear expression: {!r}'.format(expr))
322 return cls(coefficients, constant)
323
324 def tosympy(self):
325 import sympy
326 expr = 0
327 for symbol, coefficient in self.coefficients():
328 term = coefficient * sympy.Symbol(symbol)
329 expr += term
330 expr += self.constant
331 return expr
332
333
334 class Symbol(Expression):
335
336 __slots__ = (
337 '_name',
338 '_hash',
339 )
340
341 def __new__(cls, name):
342 name = symbolname(name)
343 self = object().__new__(cls)
344 self._name = name
345 self._hash = hash(self._name)
346 return self
347
348 @property
349 def name(self):
350 return self._name
351
352 def __hash__(self):
353 return self._hash
354
355 def coefficient(self, symbol):
356 symbol = symbolname(symbol)
357 if symbol == self.name:
358 return 1
359 else:
360 return 0
361
362 def coefficients(self):
363 yield self.name, 1
364
365 @property
366 def constant(self):
367 return 0
368
369 @property
370 def symbols(self):
371 return self.name,
372
373 @property
374 def dimension(self):
375 return 1
376
377 def issymbol(self):
378 return True
379
380 def __eq__(self, other):
381 return isinstance(other, Symbol) and self.name == other.name
382
383 @classmethod
384 def _fromast(cls, node):
385 if isinstance(node, ast.Module) and len(node.body) == 1:
386 return cls._fromast(node.body[0])
387 elif isinstance(node, ast.Expr):
388 return cls._fromast(node.value)
389 elif isinstance(node, ast.Name):
390 return Symbol(node.id)
391 raise SyntaxError('invalid syntax')
392
393 def __repr__(self):
394 return '{}({!r})'.format(self.__class__.__name__, self._name)
395
396 @classmethod
397 def fromsympy(cls, expr):
398 import sympy
399 if isinstance(expr, sympy.Symbol):
400 return cls(expr.name)
401 else:
402 raise TypeError('expr must be a sympy.Symbol instance')
403
404
405 def symbols(names):
406 if isinstance(names, str):
407 names = names.replace(',', ' ').split()
408 return (Symbol(name) for name in names)
409
410 def symbolname(symbol):
411 if isinstance(symbol, str):
412 return symbol.strip()
413 elif isinstance(symbol, Symbol):
414 return symbol.name
415 else:
416 raise TypeError('symbol must be a string or a Symbol instance')
417
418 def symbolnames(symbols):
419 if isinstance(symbols, str):
420 return symbols.replace(',', ' ').split()
421 return (symbolname(symbol) for symbol in symbols)
422
423
424 class Constant(Expression):
425
426 __slots__ = (
427 '_constant',
428 '_hash',
429 )
430
431 def __new__(cls, numerator=0, denominator=None):
432 self = object().__new__(cls)
433 if denominator is None and isinstance(numerator, Constant):
434 self._constant = numerator.constant
435 else:
436 self._constant = Fraction(numerator, denominator)
437 self._hash = hash(self._constant)
438 return self
439
440 def __hash__(self):
441 return self._hash
442
443 def coefficient(self, symbol):
444 symbol = symbolname(symbol)
445 return 0
446
447 def coefficients(self):
448 yield from []
449
450 @property
451 def symbols(self):
452 return ()
453
454 @property
455 def dimension(self):
456 return 0
457
458 def isconstant(self):
459 return True
460
461 @_polymorphic
462 def __eq__(self, other):
463 return isinstance(other, Constant) and self.constant == other.constant
464
465 def __bool__(self):
466 return self.constant != 0
467
468 @classmethod
469 def fromstring(cls, string):
470 if isinstance(string, str):
471 return Constant(Fraction(string))
472 else:
473 raise TypeError('string must be a string instance')
474
475 def __repr__(self):
476 if self.constant.denominator == 1:
477 return '{}({!r})'.format(self.__class__.__name__,
478 self.constant.numerator)
479 else:
480 return '{}({!r}, {!r})'.format(self.__class__.__name__,
481 self.constant.numerator, self.constant.denominator)
482
483 @classmethod
484 def fromsympy(cls, expr):
485 import sympy
486 if isinstance(expr, sympy.Rational):
487 return cls(expr.p, expr.q)
488 elif isinstance(expr, numbers.Rational):
489 return cls(expr)
490 else:
491 raise TypeError('expr must be a sympy.Rational instance')