1 # Copyright 2014 MINES ParisTech
3 # This file is part of LinPy.
5 # LinPy is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
10 # LinPy is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with LinPy. If not, see <http://www.gnu.org/licenses/>.
23 from collections
import OrderedDict
, defaultdict
, Mapping
24 from fractions
import Fraction
, gcd
29 'Symbol', 'Dummy', 'symbols',
34 def _polymorphic(func
):
35 @functools.wraps(func
)
36 def wrapper(left
, right
):
37 if isinstance(right
, LinExpr
):
38 return func(left
, right
)
39 elif isinstance(right
, numbers
.Rational
):
40 right
= Rational(right
)
41 return func(left
, right
)
48 A linear expression consists of a list of coefficient-variable pairs
49 that capture the linear terms, plus a constant term. Linear expressions
50 are used to build constraints. They are temporary objects that typically
53 Linear expressions are generally built using overloaded operators. For
54 example, if x is a Symbol, then x + 1 is an instance of LinExpr.
56 LinExpr instances are hashable, and should be treated as immutable.
59 def __new__(cls
, coefficients
=None, constant
=0):
61 Return a linear expression from a dictionary or a sequence, that maps
62 symbols to their coefficients, and a constant term. The coefficients and
63 the constant term must be rational numbers.
65 For example, the linear expression x + 2y + 1 can be constructed using
66 one of the following instructions:
68 >>> x, y = symbols('x y')
69 >>> LinExpr({x: 1, y: 2}, 1)
70 >>> LinExpr([(x, 1), (y, 2)], 1)
72 However, it may be easier to use overloaded operators:
74 >>> x, y = symbols('x y')
77 Alternatively, linear expressions can be constructed from a string:
79 >>> LinExpr('x + 2*y + 1')
81 A linear expression with a single symbol of coefficient 1 and no
82 constant term is automatically subclassed as a Symbol instance. A linear
83 expression with no symbol, only a constant term, is automatically
84 subclassed as a Rational instance.
86 if isinstance(coefficients
, str):
88 raise TypeError('too many arguments')
89 return LinExpr
.fromstring(coefficients
)
90 if coefficients
is None:
91 return Rational(constant
)
92 if isinstance(coefficients
, Mapping
):
93 coefficients
= coefficients
.items()
94 coefficients
= list(coefficients
)
95 for symbol
, coefficient
in coefficients
:
96 if not isinstance(symbol
, Symbol
):
97 raise TypeError('symbols must be Symbol instances')
98 if not isinstance(coefficient
, numbers
.Rational
):
99 raise TypeError('coefficients must be rational numbers')
100 if not isinstance(constant
, numbers
.Rational
):
101 raise TypeError('constant must be a rational number')
102 if len(coefficients
) == 0:
103 return Rational(constant
)
104 if len(coefficients
) == 1 and constant
== 0:
105 symbol
, coefficient
= coefficients
[0]
108 coefficients
= [(symbol
, Fraction(coefficient
))
109 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
110 coefficients
.sort(key
=lambda item
: item
[0].sortkey())
111 self
= object().__new
__(cls
)
112 self
._coefficients
= OrderedDict(coefficients
)
113 self
._constant
= Fraction(constant
)
114 self
._symbols
= tuple(self
._coefficients
)
115 self
._dimension
= len(self
._symbols
)
118 def coefficient(self
, symbol
):
120 Return the coefficient value of the given symbol, or 0 if the symbol
121 does not appear in the expression.
123 if not isinstance(symbol
, Symbol
):
124 raise TypeError('symbol must be a Symbol instance')
125 return Rational(self
._coefficients
.get(symbol
, 0))
127 __getitem__
= coefficient
129 def coefficients(self
):
131 Iterate over the pairs (symbol, value) of linear terms in the
132 expression. The constant term is ignored.
134 for symbol
, coefficient
in self
._coefficients
.items():
135 yield symbol
, Rational(coefficient
)
140 The constant term of the expression.
142 return Rational(self
._constant
)
147 The tuple of symbols present in the expression, sorted according to
155 The dimension of the expression, i.e. the number of symbols present in
158 return self
._dimension
161 return hash((tuple(self
._coefficients
.items()), self
._constant
))
163 def isconstant(self
):
165 Return True if the expression only consists of a constant term. In this
166 case, it is a Rational instance.
172 Return True if an expression only consists of a symbol with coefficient
173 1. In this case, it is a Symbol instance.
179 Iterate over the coefficient values in the expression, and the constant
182 for coefficient
in self
._coefficients
.values():
183 yield Rational(coefficient
)
184 yield Rational(self
._constant
)
196 def __add__(self
, other
):
198 Return the sum of two linear expressions.
200 coefficients
= defaultdict(Fraction
, self
._coefficients
)
201 for symbol
, coefficient
in other
._coefficients
.items():
202 coefficients
[symbol
] += coefficient
203 constant
= self
._constant
+ other
._constant
204 return LinExpr(coefficients
, constant
)
209 def __sub__(self
, other
):
211 Return the difference between two linear expressions.
213 coefficients
= defaultdict(Fraction
, self
._coefficients
)
214 for symbol
, coefficient
in other
._coefficients
.items():
215 coefficients
[symbol
] -= coefficient
216 constant
= self
._constant
- other
._constant
217 return LinExpr(coefficients
, constant
)
220 def __rsub__(self
, other
):
223 def __mul__(self
, other
):
225 Return the product of the linear expression by a rational.
227 if isinstance(other
, numbers
.Rational
):
228 coefficients
= ((symbol
, coefficient
* other
)
229 for symbol
, coefficient
in self
._coefficients
.items())
230 constant
= self
._constant
* other
231 return LinExpr(coefficients
, constant
)
232 return NotImplemented
236 def __truediv__(self
, other
):
238 Return the quotient of the linear expression by a rational.
240 if isinstance(other
, numbers
.Rational
):
241 coefficients
= ((symbol
, coefficient
/ other
)
242 for symbol
, coefficient
in self
._coefficients
.items())
243 constant
= self
._constant
/ other
244 return LinExpr(coefficients
, constant
)
245 return NotImplemented
248 def __eq__(self
, other
):
250 Test whether two linear expressions are equal.
252 return isinstance(other
, LinExpr
) and \
253 self
._coefficients
== other
._coefficients
and \
254 self
._constant
== other
._constant
256 def __le__(self
, other
):
257 from .polyhedra
import Le
258 return Le(self
, other
)
260 def __lt__(self
, other
):
261 from .polyhedra
import Lt
262 return Lt(self
, other
)
264 def __ge__(self
, other
):
265 from .polyhedra
import Ge
266 return Ge(self
, other
)
268 def __gt__(self
, other
):
269 from .polyhedra
import Gt
270 return Gt(self
, other
)
274 Return the expression multiplied by its lowest common denominator to
275 make all values integer.
277 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
278 [value
.denominator
for value
in self
.values()])
281 def subs(self
, symbol
, expression
=None):
283 Substitute the given symbol by an expression and return the resulting
284 expression. Raise TypeError if the resulting expression is not linear.
286 >>> x, y = symbols('x y')
291 To perform multiple substitutions at once, pass a sequence or a
292 dictionary of (old, new) pairs to subs.
294 >>> e.subs({x: y, y: x})
297 if expression
is None:
298 if isinstance(symbol
, Mapping
):
299 symbol
= symbol
.items()
300 substitutions
= symbol
302 substitutions
= [(symbol
, expression
)]
304 for symbol
, expression
in substitutions
:
305 if not isinstance(symbol
, Symbol
):
306 raise TypeError('symbols must be Symbol instances')
307 coefficients
= [(othersymbol
, coefficient
)
308 for othersymbol
, coefficient
in result
._coefficients
.items()
309 if othersymbol
!= symbol
]
310 coefficient
= result
._coefficients
.get(symbol
, 0)
311 constant
= result
._constant
312 result
= LinExpr(coefficients
, constant
) + coefficient
*expression
316 def _fromast(cls
, node
):
317 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
318 return cls
._fromast
(node
.body
[0])
319 elif isinstance(node
, ast
.Expr
):
320 return cls
._fromast
(node
.value
)
321 elif isinstance(node
, ast
.Name
):
322 return Symbol(node
.id)
323 elif isinstance(node
, ast
.Num
):
324 return Rational(node
.n
)
325 elif isinstance(node
, ast
.UnaryOp
) and isinstance(node
.op
, ast
.USub
):
326 return -cls
._fromast
(node
.operand
)
327 elif isinstance(node
, ast
.BinOp
):
328 left
= cls
._fromast
(node
.left
)
329 right
= cls
._fromast
(node
.right
)
330 if isinstance(node
.op
, ast
.Add
):
332 elif isinstance(node
.op
, ast
.Sub
):
334 elif isinstance(node
.op
, ast
.Mult
):
336 elif isinstance(node
.op
, ast
.Div
):
338 raise SyntaxError('invalid syntax')
340 _RE_NUM_VAR
= re
.compile(r
'(\d+|\))\s*([^\W\d_]\w*|\()')
343 def fromstring(cls
, string
):
345 Create an expression from a string. Raise SyntaxError if the string is
346 not properly formatted.
348 # add implicit multiplication operators, e.g. '5x' -> '5*x'
349 string
= LinExpr
._RE
_NUM
_VAR
.sub(r
'\1*\2', string
)
350 tree
= ast
.parse(string
, 'eval')
351 expr
= cls
._fromast
(tree
)
352 if not isinstance(expr
, cls
):
353 raise SyntaxError('invalid syntax')
358 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
362 elif coefficient
== -1:
363 string
+= '-' if i
== 0 else ' - '
365 string
+= '{}*'.format(coefficient
)
366 elif coefficient
> 0:
367 string
+= ' + {}*'.format(coefficient
)
369 string
+= ' - {}*'.format(-coefficient
)
370 string
+= '{}'.format(symbol
)
371 constant
= self
.constant
373 string
+= '{}'.format(constant
)
375 string
+= ' + {}'.format(constant
)
377 string
+= ' - {}'.format(-constant
)
380 def _repr_latex_(self
):
382 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
386 elif coefficient
== -1:
387 string
+= '-' if i
== 0 else ' - '
389 string
+= '{}'.format(coefficient
._repr
_latex
_().strip('$'))
390 elif coefficient
> 0:
391 string
+= ' + {}'.format(coefficient
._repr
_latex
_().strip('$'))
392 elif coefficient
< 0:
393 string
+= ' - {}'.format((-coefficient
)._repr
_latex
_().strip('$'))
394 string
+= '{}'.format(symbol
._repr
_latex
_().strip('$'))
395 constant
= self
.constant
397 string
+= '{}'.format(constant
._repr
_latex
_().strip('$'))
399 string
+= ' + {}'.format(constant
._repr
_latex
_().strip('$'))
401 string
+= ' - {}'.format((-constant
)._repr
_latex
_().strip('$'))
402 return '$${}$$'.format(string
)
404 def _parenstr(self
, always
=False):
406 if not always
and (self
.isconstant() or self
.issymbol()):
409 return '({})'.format(string
)
412 def fromsympy(cls
, expr
):
414 Create a linear expression from a sympy expression. Raise ValueError is
415 the sympy expression is not linear.
420 for symbol
, coefficient
in expr
.as_coefficients_dict().items():
421 coefficient
= Fraction(coefficient
.p
, coefficient
.q
)
422 if symbol
== sympy
.S
.One
:
423 constant
= coefficient
424 elif isinstance(symbol
, sympy
.Symbol
):
425 symbol
= Symbol(symbol
.name
)
426 coefficients
.append((symbol
, coefficient
))
428 raise ValueError('non-linear expression: {!r}'.format(expr
))
429 return LinExpr(coefficients
, constant
)
433 Convert the linear expression to a sympy expression.
437 for symbol
, coefficient
in self
.coefficients():
438 term
= coefficient
* sympy
.Symbol(symbol
.name
)
440 expr
+= self
.constant
444 class Symbol(LinExpr
):
446 Symbols are the basic components to build expressions and constraints.
447 They correspond to mathematical variables. Symbols are instances of
448 class LinExpr and inherit its functionalities.
450 Two instances of Symbol are equal if they have the same name.
453 def __new__(cls
, name
):
455 Return a symbol with the name string given in argument.
457 if not isinstance(name
, str):
458 raise TypeError('name must be a string')
459 node
= ast
.parse(name
)
461 name
= node
.body
[0].value
.id
462 except (AttributeError, SyntaxError):
463 raise SyntaxError('invalid syntax')
464 self
= object().__new
__(cls
)
466 self
._coefficients
= {self
: Fraction(1)}
467 self
._constant
= Fraction(0)
468 self
._symbols
= (self
,)
475 The name of the symbol.
480 return hash(self
.sortkey())
484 Return a sorting key for the symbol. It is useful to sort a list of
485 symbols in a consistent order, as comparison functions are overridden
486 (see the documentation of class LinExpr).
488 >>> sort(symbols, key=Symbol.sortkey)
495 def __eq__(self
, other
):
496 return self
.sortkey() == other
.sortkey()
500 Return a new Dummy symbol instance with the same name.
502 return Dummy(self
.name
)
507 def _repr_latex_(self
):
508 return '$${}$$'.format(self
.name
)
511 def fromsympy(cls
, expr
):
513 if isinstance(expr
, sympy
.Dummy
):
514 return Dummy(expr
.name
)
515 elif isinstance(expr
, sympy
.Symbol
):
516 return Symbol(expr
.name
)
518 raise TypeError('expr must be a sympy.Symbol instance')
523 A variation of Symbol in which all symbols are unique and identified by
524 an internal count index. If a name is not supplied then a string value
525 of the count index will be used. This is useful when a unique, temporary
526 variable is needed and the name of the variable used in the expression
529 Unlike Symbol, Dummy instances with the same name are not equal:
532 >>> x1, x2 = Dummy('x'), Dummy('x')
543 def __new__(cls
, name
=None):
545 Return a fresh dummy symbol with the name string given in argument.
548 name
= 'Dummy_{}'.format(Dummy
._count
)
549 elif not isinstance(name
, str):
550 raise TypeError('name must be a string')
551 self
= object().__new
__(cls
)
552 self
._index
= Dummy
._count
553 self
._name
= name
.strip()
554 self
._coefficients
= {self
: Fraction(1)}
555 self
._constant
= Fraction(0)
556 self
._symbols
= (self
,)
562 return hash(self
.sortkey())
565 return self
._name
, self
._index
568 return '_{}'.format(self
.name
)
570 def _repr_latex_(self
):
571 return '$${}_{{{}}}$$'.format(self
.name
, self
._index
)
576 This function returns a tuple of symbols whose names are taken from a comma
577 or whitespace delimited string, or a sequence of strings. It is useful to
578 define several symbols at once.
580 >>> x, y = symbols('x y')
581 >>> x, y = symbols('x, y')
582 >>> x, y = symbols(['x', 'y'])
584 if isinstance(names
, str):
585 names
= names
.replace(',', ' ').split()
586 return tuple(Symbol(name
) for name
in names
)
589 class Rational(LinExpr
, Fraction
):
591 A particular case of linear expressions are rational values, i.e. linear
592 expressions consisting only of a constant term, with no symbol. They are
593 implemented by the Rational class, that inherits from both LinExpr and
594 fractions.Fraction classes.
597 def __new__(cls
, numerator
=0, denominator
=None):
598 self
= object().__new
__(cls
)
599 self
._coefficients
= {}
600 self
._constant
= Fraction(numerator
, denominator
)
603 self
._numerator
= self
._constant
.numerator
604 self
._denominator
= self
._constant
.denominator
608 return Fraction
.__hash
__(self
)
614 def isconstant(self
):
618 return Fraction
.__bool
__(self
)
621 if self
.denominator
== 1:
622 return '{!r}'.format(self
.numerator
)
624 return '{!r}/{!r}'.format(self
.numerator
, self
.denominator
)
626 def _repr_latex_(self
):
627 if self
.denominator
== 1:
628 return '$${}$$'.format(self
.numerator
)
629 elif self
.numerator
< 0:
630 return '$$-\\frac{{{}}}{{{}}}$$'.format(-self
.numerator
,
633 return '$$\\frac{{{}}}{{{}}}$$'.format(self
.numerator
,
637 def fromsympy(cls
, expr
):
639 if isinstance(expr
, sympy
.Rational
):
640 return Rational(expr
.p
, expr
.q
)
641 elif isinstance(expr
, numbers
.Rational
):
642 return Rational(expr
)
644 raise TypeError('expr must be a sympy.Rational instance')