From 2ffea1a47578a1b1d09906d57511062d68e6abea Mon Sep 17 00:00:00 2001 From: Vivien Maisonneuve Date: Wed, 2 Jul 2014 09:00:15 +0200 Subject: [PATCH] New method Expression.subs --- pypol/_islhelper.c | 2 +- pypol/linexprs.py | 23 +++++++++++++++++++++++ pypol/tests/test_linexprs.py | 12 ++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/pypol/_islhelper.c b/pypol/_islhelper.c index bc62968..eaacc67 100644 --- a/pypol/_islhelper.c +++ b/pypol/_islhelper.c @@ -36,7 +36,7 @@ static PyObject * isl_basic_set_constraints(PyObject *self, PyObject* args) { return NULL; } bset = (isl_basic_set *) ptr; - bset = isl_basic_set_finalize(bset); + bset = isl_basic_set_finalize(bset); // this instruction should not be required n = isl_basic_set_n_constraint(bset); if (n == -1) { PyErr_SetString(PyExc_RuntimeError, diff --git a/pypol/linexprs.py b/pypol/linexprs.py index ed68493..b330045 100644 --- a/pypol/linexprs.py +++ b/pypol/linexprs.py @@ -249,6 +249,29 @@ class Expression: return left / right raise SyntaxError('invalid syntax') + def subs(self, symbol, expression=None): + if expression is None: + if isinstance(symbol, dict): + symbol = symbol.items() + substitutions = symbol + else: + substitutions = [(symbol, expression)] + result = self + for symbol, expression in substitutions: + symbol = symbolname(symbol) + result = result._subs(symbol, expression) + return result + + def _subs(self, symbol, expression): + coefficients = {name: coefficient + for name, coefficient in self.coefficients() + if name != symbol} + constant = self.constant + coefficient = self.coefficient(symbol) + result = Expression(coefficients, self.constant) + result += coefficient * expression + return result + _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()') @classmethod diff --git a/pypol/tests/test_linexprs.py b/pypol/tests/test_linexprs.py index 1606ea0..5862351 100644 --- a/pypol/tests/test_linexprs.py +++ b/pypol/tests/test_linexprs.py @@ -145,6 +145,18 @@ class TestExpression(unittest.TestCase): self.assertEqual((self.x + self.y/2 + self.z/3)._toint(), 6*self.x + 3*self.y + 2*self.z) + def test_subs(self): + self.assertEqual(self.x.subs('x', 3), 3) + self.assertEqual(self.x.subs('x', self.x), self.x) + self.assertEqual(self.x.subs('x', self.y), self.y) + self.assertEqual(self.x.subs('x', self.x + self.y), self.x + self.y) + self.assertEqual(self.x.subs('y', 3), self.x) + self.assertEqual(self.pi.subs('x', 3), self.pi) + self.assertEqual(self.expr.subs('x', -3), -2 * self.y) + self.assertEqual(self.expr.subs([('x', self.y), ('y', self.x)]), 3 - self.x) + self.assertEqual(self.expr.subs({'x': self.y, 'y': self.x}), 3 - self.x) + self.assertEqual(self.expr.subs({self.x: self.y, self.y: self.x}), 3 - self.x) + def test_fromstring(self): self.assertEqual(Expression.fromstring('x'), self.x) self.assertEqual(Expression.fromstring('-x'), -self.x) -- 2.20.1