From 2e558859456a109279713a2cbdd6c48a70a171c6 Mon Sep 17 00:00:00 2001 From: Vivien Maisonneuve Date: Mon, 18 Aug 2014 16:55:29 +0200 Subject: [PATCH] Check symbol names --- linpy/linexprs.py | 7 ++++++- linpy/tests/test_linexprs.py | 11 ++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/linpy/linexprs.py b/linpy/linexprs.py index cf2a980..834c3b4 100644 --- a/linpy/linexprs.py +++ b/linpy/linexprs.py @@ -456,8 +456,13 @@ class Symbol(LinExpr): """ if not isinstance(name, str): raise TypeError('name must be a string') + node = ast.parse(name) + try: + name = node.body[0].value.id + except (AttributeError, SyntaxError): + raise SyntaxError('invalid syntax') self = object().__new__(cls) - self._name = name.strip() + self._name = name self._coefficients = {self: Fraction(1)} self._constant = Fraction(0) self._symbols = (self,) diff --git a/linpy/tests/test_linexprs.py b/linpy/tests/test_linexprs.py index fb7e4a2..9599d06 100644 --- a/linpy/tests/test_linexprs.py +++ b/linpy/tests/test_linexprs.py @@ -214,11 +214,20 @@ class TestSymbol(unittest.TestCase): self.y = Symbol('y') def test_new(self): - self.assertEqual(Symbol(' x '), self.x) + self.assertEqual(Symbol('x'), self.x) with self.assertRaises(TypeError): Symbol(self.x) with self.assertRaises(TypeError): Symbol(1) + with self.assertRaises(SyntaxError): + Symbol('1') + with self.assertRaises(SyntaxError): + Symbol('x.1') + with self.assertRaises(SyntaxError): + Symbol('x 1') + Symbol('_') + Symbol('_x') + Symbol('x_1') def test_name(self): self.assertEqual(self.x.name, 'x') -- 2.20.1