X-Git-Url: https://scm.cri.mines-paristech.fr/git/linpy.git/blobdiff_plain/1d494bb187b70135df721c13306d7f26fdf33f50..10808766a204fcc854ae30fe471ada80bab1f60f:/pypol/polyhedra.py diff --git a/pypol/polyhedra.py b/pypol/polyhedra.py index 787e965..ac67cf8 100644 --- a/pypol/polyhedra.py +++ b/pypol/polyhedra.py @@ -1,7 +1,5 @@ -import ast import functools import numbers -import re from . import islhelper @@ -93,12 +91,11 @@ class Polyhedron(Domain): equalities = [] inequalities = [] for islconstraint in islconstraints: - islpr = libisl.isl_printer_to_str(mainctx) constant = libisl.isl_constraint_get_constant_val(islconstraint) constant = islhelper.isl_val_to_int(constant) coefficients = {} - for dim, symbol in enumerate(symbols): - coefficient = libisl.isl_constraint_get_coefficient_val(islconstraint, libisl.isl_dim_set, dim) + for index, symbol in enumerate(symbols): + coefficient = libisl.isl_constraint_get_coefficient_val(islconstraint, libisl.isl_dim_set, index) coefficient = islhelper.isl_val_to_int(coefficient) if coefficient != 0: coefficients[symbol] = coefficient @@ -119,85 +116,44 @@ class Polyhedron(Domain): @classmethod def _toislbasicset(cls, equalities, inequalities, symbols): dimension = len(symbols) + indices = {symbol: index for index, symbol in enumerate(symbols)} islsp = libisl.isl_space_set_alloc(mainctx, 0, dimension) islbset = libisl.isl_basic_set_universe(libisl.isl_space_copy(islsp)) islls = libisl.isl_local_space_from_space(islsp) for equality in equalities: isleq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(islls)) for symbol, coefficient in equality.coefficients(): - val = str(coefficient).encode() - val = libisl.isl_val_read_from_str(mainctx, val) - sid = symbols.index(symbol) + islval = str(coefficient).encode() + islval = libisl.isl_val_read_from_str(mainctx, islval) + index = indices[symbol] isleq = libisl.isl_constraint_set_coefficient_val(isleq, - libisl.isl_dim_set, sid, val) + libisl.isl_dim_set, index, islval) if equality.constant != 0: - val = str(equality.constant).encode() - val = libisl.isl_val_read_from_str(mainctx, val) - isleq = libisl.isl_constraint_set_constant_val(isleq, val) + islval = str(equality.constant).encode() + islval = libisl.isl_val_read_from_str(mainctx, islval) + isleq = libisl.isl_constraint_set_constant_val(isleq, islval) islbset = libisl.isl_basic_set_add_constraint(islbset, isleq) for inequality in inequalities: islin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(islls)) for symbol, coefficient in inequality.coefficients(): - val = str(coefficient).encode() - val = libisl.isl_val_read_from_str(mainctx, val) - sid = symbols.index(symbol) + islval = str(coefficient).encode() + islval = libisl.isl_val_read_from_str(mainctx, islval) + index = indices[symbol] islin = libisl.isl_constraint_set_coefficient_val(islin, - libisl.isl_dim_set, sid, val) + libisl.isl_dim_set, index, islval) if inequality.constant != 0: - val = str(inequality.constant).encode() - val = libisl.isl_val_read_from_str(mainctx, val) - islin = libisl.isl_constraint_set_constant_val(islin, val) + islval = str(inequality.constant).encode() + islval = libisl.isl_val_read_from_str(mainctx, islval) + islin = libisl.isl_constraint_set_constant_val(islin, islval) islbset = libisl.isl_basic_set_add_constraint(islbset, islin) return islbset - @classmethod - def _fromast(cls, node): - if isinstance(node, ast.Module) and len(node.body) == 1: - return cls._fromast(node.body[0]) - elif isinstance(node, ast.Expr): - return cls._fromast(node.value) - elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitAnd): - equalities1, inequalities1 = cls._fromast(node.left) - equalities2, inequalities2 = cls._fromast(node.right) - equalities = equalities1 + equalities2 - inequalities = inequalities1 + inequalities2 - return equalities, inequalities - elif isinstance(node, ast.Compare): - equalities = [] - inequalities = [] - left = Expression._fromast(node.left) - for i in range(len(node.ops)): - op = node.ops[i] - right = Expression._fromast(node.comparators[i]) - if isinstance(op, ast.Lt): - inequalities.append(right - left - 1) - elif isinstance(op, ast.LtE): - inequalities.append(right - left) - elif isinstance(op, ast.Eq): - equalities.append(left - right) - elif isinstance(op, ast.GtE): - inequalities.append(left - right) - elif isinstance(op, ast.Gt): - inequalities.append(left - right - 1) - else: - break - left = right - else: - return equalities, inequalities - raise SyntaxError('invalid syntax') - @classmethod def fromstring(cls, string): - string = string.strip() - string = re.sub(r'^\{\s*|\s*\}$', '', string) - string = re.sub(r'([^<=>])=([^<=>])', r'\1==\2', string) - string = re.sub(r'(\d+|\))\s*([^\W\d_]\w*|\()', r'\1*\2', string) - tokens = re.split(r',|;|and|&&|/\\|∧', string, flags=re.I) - tokens = ['({})'.format(token) for token in tokens] - string = ' & '.join(tokens) - tree = ast.parse(string, 'eval') - equalities, inequalities = cls._fromast(tree) - return cls(equalities, inequalities) + domain = Domain.fromstring(string) + if not isinstance(domain, Polyhedron): + raise ValueError('non-polyhedral expression: {!r}'.format(string)) + return domain def __repr__(self): if self.isempty():