import ast import functools import re import math from fractions import Fraction from . import islhelper from .islhelper import mainctx, libisl from .linexprs import Expression, Symbol, Rational from .geometry import GeometricObject, Point, Vector __all__ = [ 'Domain', 'And', 'Or', 'Not', ] @functools.total_ordering class Domain(GeometricObject): __slots__ = ( '_polyhedra', '_symbols', '_dimension', ) def __new__(cls, *polyhedra): from .polyhedra import Polyhedron if len(polyhedra) == 1: argument = polyhedra[0] if isinstance(argument, str): return cls.fromstring(argument) elif isinstance(argument, GeometricObject): return argument.aspolyhedron() else: raise TypeError('argument must be a string ' 'or a GeometricObject instance') else: for polyhedron in polyhedra: if not isinstance(polyhedron, Polyhedron): raise TypeError('arguments must be Polyhedron instances') symbols = cls._xsymbols(polyhedra) islset = cls._toislset(polyhedra, symbols) return cls._fromislset(islset, symbols) @classmethod def _xsymbols(cls, iterator): """ Return the ordered tuple of symbols present in iterator. """ symbols = set() for item in iterator: symbols.update(item.symbols) return tuple(sorted(symbols, key=Symbol.sortkey)) @property def polyhedra(self): return self._polyhedra @property def symbols(self): return self._symbols @property def dimension(self): return self._dimension def disjoint(self): """ Returns this set as disjoint. """ islset = self._toislset(self.polyhedra, self.symbols) islset = libisl.isl_set_make_disjoint(mainctx, islset) return self._fromislset(islset, self.symbols) def isempty(self): """ Returns true if this set is an Empty set. """ islset = self._toislset(self.polyhedra, self.symbols) empty = bool(libisl.isl_set_is_empty(islset)) libisl.isl_set_free(islset) return empty def __bool__(self): return not self.isempty() def isuniverse(self): """ Returns true if this set is the Universe set. """ islset = self._toislset(self.polyhedra, self.symbols) universe = bool(libisl.isl_set_plain_is_universe(islset)) libisl.isl_set_free(islset) return universe def isbounded(self): """ Returns true if this set is bounded. """ islset = self._toislset(self.polyhedra, self.symbols) bounded = bool(libisl.isl_set_is_bounded(islset)) libisl.isl_set_free(islset) return bounded def __eq__(self, other): """ Returns true if two sets are equal. """ symbols = self._xsymbols([self, other]) islset1 = self._toislset(self.polyhedra, symbols) islset2 = other._toislset(other.polyhedra, symbols) equal = bool(libisl.isl_set_is_equal(islset1, islset2)) libisl.isl_set_free(islset1) libisl.isl_set_free(islset2) return equal def isdisjoint(self, other): """ Return True if two sets have a null intersection. """ symbols = self._xsymbols([self, other]) islset1 = self._toislset(self.polyhedra, symbols) islset2 = self._toislset(other.polyhedra, symbols) equal = bool(libisl.isl_set_is_disjoint(islset1, islset2)) libisl.isl_set_free(islset1) libisl.isl_set_free(islset2) return equal def issubset(self, other): """ Report whether another set contains this set. """ symbols = self._xsymbols([self, other]) islset1 = self._toislset(self.polyhedra, symbols) islset2 = self._toislset(other.polyhedra, symbols) equal = bool(libisl.isl_set_is_subset(islset1, islset2)) libisl.isl_set_free(islset1) libisl.isl_set_free(islset2) return equal def __le__(self, other): """ Returns true if this set is less than or equal to another set. """ return self.issubset(other) def __lt__(self, other): """ Returns true if this set is less than another set. """ symbols = self._xsymbols([self, other]) islset1 = self._toislset(self.polyhedra, symbols) islset2 = self._toislset(other.polyhedra, symbols) equal = bool(libisl.isl_set_is_strict_subset(islset1, islset2)) libisl.isl_set_free(islset1) libisl.isl_set_free(islset2) return equal def complement(self): """ Returns the complement of this set. """ islset = self._toislset(self.polyhedra, self.symbols) islset = libisl.isl_set_complement(islset) return self._fromislset(islset, self.symbols) def __invert__(self): """ Returns the complement of this set. """ return self.complement() def simplify(self): """ Returns a set without redundant constraints. """ islset = self._toislset(self.polyhedra, self.symbols) islset = libisl.isl_set_remove_redundancies(islset) return self._fromislset(islset, self.symbols) def aspolyhedron(self): """ Returns polyhedral hull of set. """ from .polyhedra import Polyhedron islset = self._toislset(self.polyhedra, self.symbols) islbset = libisl.isl_set_polyhedral_hull(islset) return Polyhedron._fromislbasicset(islbset, self.symbols) def asdomain(self): return self def project(self, dims): """ Return new set with given dimensions removed. """ islset = self._toislset(self.polyhedra, self.symbols) n = 0 for index, symbol in reversed(list(enumerate(self.symbols))): if symbol in dims: n += 1 elif n > 0: islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, index + 1, n) n = 0 if n > 0: islset = libisl.isl_set_project_out(islset, libisl.isl_dim_set, 0, n) dims = [symbol for symbol in self.symbols if symbol not in dims] return Domain._fromislset(islset, dims) def sample(self): """ Returns a single subset of the input. """ islset = self._toislset(self.polyhedra, self.symbols) islpoint = libisl.isl_set_sample_point(islset) if bool(libisl.isl_point_is_void(islpoint)): libisl.isl_point_free(islpoint) raise ValueError('domain must be non-empty') point = {} for index, symbol in enumerate(self.symbols): coordinate = libisl.isl_point_get_coordinate_val(islpoint, libisl.isl_dim_set, index) coordinate = islhelper.isl_val_to_int(coordinate) point[symbol] = coordinate libisl.isl_point_free(islpoint) return point def intersection(self, *others): """ Return the intersection of two sets as a new set. """ if len(others) == 0: return self symbols = self._xsymbols((self,) + others) islset1 = self._toislset(self.polyhedra, symbols) for other in others: islset2 = other._toislset(other.polyhedra, symbols) islset1 = libisl.isl_set_intersect(islset1, islset2) return self._fromislset(islset1, symbols) def __and__(self, other): """ Return the intersection of two sets as a new set. """ return self.intersection(other) def union(self, *others): """ Return the union of sets as a new set. """ if len(others) == 0: return self symbols = self._xsymbols((self,) + others) islset1 = self._toislset(self.polyhedra, symbols) for other in others: islset2 = other._toislset(other.polyhedra, symbols) islset1 = libisl.isl_set_union(islset1, islset2) return self._fromislset(islset1, symbols) def __or__(self, other): """ Return a new set with elements from both sets. """ return self.union(other) def __add__(self, other): """ Return new set containing all elements in both sets. """ return self.union(other) def difference(self, other): """ Return the difference of two sets as a new set. """ symbols = self._xsymbols([self, other]) islset1 = self._toislset(self.polyhedra, symbols) islset2 = other._toislset(other.polyhedra, symbols) islset = libisl.isl_set_subtract(islset1, islset2) return self._fromislset(islset, symbols) def __sub__(self, other): """ Return the difference of two sets as a new set. """ return self.difference(other) def lexmin(self): """ Return a new set containing the lexicographic minimum of the elements in the set. """ islset = self._toislset(self.polyhedra, self.symbols) islset = libisl.isl_set_lexmin(islset) return self._fromislset(islset, self.symbols) def lexmax(self): """ Return a new set containing the lexicographic maximum of the elements in the set. """ islset = self._toislset(self.polyhedra, self.symbols) islset = libisl.isl_set_lexmax(islset) return self._fromislset(islset, self.symbols) def num_parameters(self): """ Return the total number of parameters, input, output or set dimensions. """ islbset = self._toislbasicset(self.equalities, self.inequalities, self.symbols) num = libisl.isl_basic_set_dim(islbset, libisl.isl_dim_set) return num def involves_dims(self, dims): """ Returns true if set depends on given dimensions. """ islset = self._toislset(self.polyhedra, self.symbols) dims = sorted(dims) symbols = sorted(list(self.symbols)) n = 0 if len(dims)>0: for dim in dims: if dim in symbols: first = symbols.index(dims[0]) n +=1 else: first = 0 else: return False value = bool(libisl.isl_set_involves_dims(islset, libisl.isl_dim_set, first, n)) libisl.isl_set_free(islset) return value _RE_COORDINATE = re.compile(r'\((?P\-?\d+)\)(/(?P\d+))?') def vertices(self): """ Return a list of vertices for this Polygon. """ from .polyhedra import Polyhedron islbset = self._toislbasicset(self.equalities, self.inequalities, self.symbols) vertices = libisl.isl_basic_set_compute_vertices(islbset); vertices = islhelper.isl_vertices_vertices(vertices) points = [] for vertex in vertices: expr = libisl.isl_vertex_get_expr(vertex) coordinates = [] if islhelper.isl_version < '0.13': constraints = islhelper.isl_basic_set_constraints(expr) for constraint in constraints: constant = libisl.isl_constraint_get_constant_val(constraint) constant = islhelper.isl_val_to_int(constant) for index, symbol in enumerate(self.symbols): coefficient = libisl.isl_constraint_get_coefficient_val(constraint, libisl.isl_dim_set, index) coefficient = islhelper.isl_val_to_int(coefficient) if coefficient != 0: coordinate = -Fraction(constant, coefficient) coordinates.append((symbol, coordinate)) else: string = islhelper.isl_multi_aff_to_str(expr) matches = self._RE_COORDINATE.finditer(string) for symbol, match in zip(self.symbols, matches): numerator = int(match.group('num')) denominator = match.group('den') denominator = 1 if denominator is None else int(denominator) coordinate = Fraction(numerator, denominator) coordinates.append((symbol, coordinate)) points.append(Point(coordinates)) return points def points(self): """ Returns the points contained in the set. """ if not self.isbounded(): raise ValueError('domain must be bounded') from .polyhedra import Universe, Eq islset = self._toislset(self.polyhedra, self.symbols) islpoints = islhelper.isl_set_points(islset) points = [] for islpoint in islpoints: coordinates = {} for index, symbol in enumerate(self.symbols): coordinate = libisl.isl_point_get_coordinate_val(islpoint, libisl.isl_dim_set, index) coordinate = islhelper.isl_val_to_int(coordinate) coordinates[symbol] = coordinate points.append(Point(coordinates)) return points @classmethod def _polygon_inner_point(cls, points): symbols = points[0].symbols coordinates = {symbol: 0 for symbol in symbols} for point in points: for symbol, coordinate in point.coordinates(): coordinates[symbol] += coordinate for symbol in symbols: coordinates[symbol] /= len(points) return Point(coordinates) @classmethod def _sort_polygon_2d(cls, points): if len(points) <= 3: return points o = cls._polygon_inner_point(points) angles = {} for m in points: om = Vector(o, m) dx, dy = (coordinate for symbol, coordinate in om.coordinates()) angle = math.atan2(dy, dx) angles[m] = angle return sorted(points, key=angles.get) @classmethod def _sort_polygon_3d(cls, points): if len(points) <= 3: return points o = cls._polygon_inner_point(points) a = points[0] oa = Vector(o, a) norm_oa = oa.norm() for b in points[1:]: ob = Vector(o, b) u = oa.cross(ob) if not u.isnull(): u = u.asunit() break else: raise ValueError('degenerate polygon') angles = {a: 0.} for m in points[1:]: om = Vector(o, m) normprod = norm_oa * om.norm() cosinus = max(oa.dot(om) / normprod, -1.) sinus = u.dot(oa.cross(om)) / normprod angle = math.acos(cosinus) angle = math.copysign(angle, sinus) angles[m] = angle return sorted(points, key=angles.get) def faces(self): faces = [] for polyhedron in self.polyhedra: vertices = polyhedron.vertices() for constraint in polyhedron.constraints: face = [] for vertex in vertices: if constraint.subs(vertex.coordinates()) == 0: face.append(vertex) if len(face) >= 3: faces.append(face) return faces def _plot_2d(self, plot=None, **kwargs): import matplotlib.pyplot as plt from matplotlib.patches import Polygon if plot is None: fig = plt.figure() plot = fig.add_subplot(1, 1, 1) xmin, xmax = plot.get_xlim() ymin, ymax = plot.get_ylim() for polyhedron in self.polyhedra: vertices = polyhedron._sort_polygon_2d(polyhedron.vertices()) xys = [tuple(vertex.values()) for vertex in vertices] xs, ys = zip(*xys) xmin, xmax = min(xmin, float(min(xs))), max(xmax, float(max(xs))) ymin, ymax = min(ymin, float(min(ys))), max(ymax, float(max(ys))) plot.add_patch(Polygon(xys, closed=True, **kwargs)) plot.set_xlim(xmin, xmax) plot.set_ylim(ymin, ymax) return plot def _plot_3d(self, plot=None, **kwargs): import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d.art3d import Poly3DCollection if plot is None: fig = plt.figure() axes = Axes3D(fig) else: axes = plot xmin, xmax = axes.get_xlim() ymin, ymax = axes.get_ylim() zmin, zmax = axes.get_zlim() poly_xyzs = [] for vertices in self.faces(): vertices = self._sort_polygon_3d(vertices) vertices.append(vertices[0]) face_xyzs = [tuple(vertex.values()) for vertex in vertices] xs, ys, zs = zip(*face_xyzs) xmin, xmax = min(xmin, float(min(xs))), max(xmax, float(max(xs))) ymin, ymax = min(ymin, float(min(ys))), max(ymax, float(max(ys))) zmin, zmax = min(zmin, float(min(zs))), max(zmax, float(max(zs))) poly_xyzs.append(face_xyzs) collection = Poly3DCollection(poly_xyzs, **kwargs) axes.add_collection3d(collection) axes.set_xlim(xmin, xmax) axes.set_ylim(ymin, ymax) axes.set_zlim(zmin, zmax) return axes def plot(self, plot=None, **kwargs): """ Display plot of this set. """ if not self.isbounded(): raise ValueError('domain must be bounded') elif self.dimension == 2: return self._plot_2d(plot=plot, **kwargs) elif self.dimension == 3: return self._plot_3d(plot=plot, **kwargs) else: raise ValueError('polyhedron must be 2 or 3-dimensional') def __contains__(self, point): for polyhedron in self.polyhedra: if point in polyhedron: return True return False def subs(self, symbol, expression=None): polyhedra = [polyhedron.subs(symbol, expression) for polyhedron in self.polyhedra] return Domain(*polyhedra) @classmethod def _fromislset(cls, islset, symbols): from .polyhedra import Polyhedron islset = libisl.isl_set_remove_divs(islset) islbsets = islhelper.isl_set_basic_sets(islset) libisl.isl_set_free(islset) polyhedra = [] for islbset in islbsets: polyhedron = Polyhedron._fromislbasicset(islbset, symbols) polyhedra.append(polyhedron) if len(polyhedra) == 0: from .polyhedra import Empty return Empty elif len(polyhedra) == 1: return polyhedra[0] else: self = object().__new__(Domain) self._polyhedra = tuple(polyhedra) self._symbols = cls._xsymbols(polyhedra) self._dimension = len(self._symbols) return self @classmethod def _toislset(cls, polyhedra, symbols): polyhedron = polyhedra[0] islbset = polyhedron._toislbasicset(polyhedron.equalities, polyhedron.inequalities, symbols) islset1 = libisl.isl_set_from_basic_set(islbset) for polyhedron in polyhedra[1:]: islbset = polyhedron._toislbasicset(polyhedron.equalities, polyhedron.inequalities, symbols) islset2 = libisl.isl_set_from_basic_set(islbset) islset1 = libisl.isl_set_union(islset1, islset2) return islset1 @classmethod def _fromast(cls, node): from .polyhedra import Polyhedron if isinstance(node, ast.Module) and len(node.body) == 1: return cls._fromast(node.body[0]) elif isinstance(node, ast.Expr): return cls._fromast(node.value) elif isinstance(node, ast.UnaryOp): domain = cls._fromast(node.operand) if isinstance(node.operand, ast.invert): return Not(domain) elif isinstance(node, ast.BinOp): domain1 = cls._fromast(node.left) domain2 = cls._fromast(node.right) if isinstance(node.op, ast.BitAnd): return And(domain1, domain2) elif isinstance(node.op, ast.BitOr): return Or(domain1, domain2) elif isinstance(node, ast.Compare): equalities = [] inequalities = [] left = Expression._fromast(node.left) for i in range(len(node.ops)): op = node.ops[i] right = Expression._fromast(node.comparators[i]) if isinstance(op, ast.Lt): inequalities.append(right - left - 1) elif isinstance(op, ast.LtE): inequalities.append(right - left) elif isinstance(op, ast.Eq): equalities.append(left - right) elif isinstance(op, ast.GtE): inequalities.append(left - right) elif isinstance(op, ast.Gt): inequalities.append(left - right - 1) else: break left = right else: return Polyhedron(equalities, inequalities) raise SyntaxError('invalid syntax') _RE_BRACES = re.compile(r'^\{\s*|\s*\}$') _RE_EQ = re.compile(r'([^<=>])=([^<=>])') _RE_AND = re.compile(r'\band\b|,|&&|/\\|∧|∩') _RE_OR = re.compile(r'\bor\b|;|\|\||\\/|∨|∪') _RE_NOT = re.compile(r'\bnot\b|!|¬') _RE_NUM_VAR = Expression._RE_NUM_VAR _RE_OPERATORS = re.compile(r'(&|\||~)') @classmethod def fromstring(cls, string): # remove curly brackets string = cls._RE_BRACES.sub(r'', string) # replace '=' by '==' string = cls._RE_EQ.sub(r'\1==\2', string) # replace 'and', 'or', 'not' string = cls._RE_AND.sub(r' & ', string) string = cls._RE_OR.sub(r' | ', string) string = cls._RE_NOT.sub(r' ~', string) # add implicit multiplication operators, e.g. '5x' -> '5*x' string = cls._RE_NUM_VAR.sub(r'\1*\2', string) # add parentheses to force precedence tokens = cls._RE_OPERATORS.split(string) for i, token in enumerate(tokens): if i % 2 == 0: token = '({})'.format(token) tokens[i] = token string = ''.join(tokens) tree = ast.parse(string, 'eval') return cls._fromast(tree) def __repr__(self): assert len(self.polyhedra) >= 2 strings = [repr(polyhedron) for polyhedron in self.polyhedra] return 'Or({})'.format(', '.join(strings)) def _repr_latex_(self): strings = [] for polyhedron in self.polyhedra: strings.append('({})'.format(polyhedron._repr_latex_().strip('$'))) return '${}$'.format(' \\vee '.join(strings)) @classmethod def fromsympy(cls, expr): import sympy from .polyhedra import Lt, Le, Eq, Ne, Ge, Gt funcmap = { sympy.And: And, sympy.Or: Or, sympy.Not: Not, sympy.Lt: Lt, sympy.Le: Le, sympy.Eq: Eq, sympy.Ne: Ne, sympy.Ge: Ge, sympy.Gt: Gt, } if expr.func in funcmap: args = [Domain.fromsympy(arg) for arg in expr.args] return funcmap[expr.func](*args) elif isinstance(expr, sympy.Expr): return Expression.fromsympy(expr) raise ValueError('non-domain expression: {!r}'.format(expr)) def tosympy(self): import sympy polyhedra = [polyhedron.tosympy() for polyhedron in polyhedra] return sympy.Or(*polyhedra) def And(*domains): """ Return the intersection of two sets as a new set. """ if len(domains) == 0: from .polyhedra import Universe return Universe else: return domains[0].intersection(*domains[1:]) def Or(*domains): """ Return the union of sets as a new set. """ if len(domains) == 0: from .polyhedra import Empty return Empty else: return domains[0].union(*domains[1:]) def Not(domain): """ Returns the complement of this set. """ return ~domain