Raise ValueError when sampling an empty Domain
authorVivien Maisonneuve <v.maisonneuve@gmail.com>
Sat, 5 Jul 2014 14:53:14 +0000 (16:53 +0200)
committerVivien Maisonneuve <v.maisonneuve@gmail.com>
Sat, 5 Jul 2014 14:53:14 +0000 (16:53 +0200)
pypol/domains.py
pypol/tests/test_domains.py

index 6088e32..1ffed45 100644 (file)
@@ -168,17 +168,17 @@ class Domain:
         return Domain._fromislset(islset, dims)
 
     def sample(self):
         return Domain._fromislset(islset, dims)
 
     def sample(self):
-        from .polyhedra import Polyhedron
         islset = self._toislset(self.polyhedra, self.symbols)
         islpoint = libisl.isl_set_sample_point(islset)
         islset = self._toislset(self.polyhedra, self.symbols)
         islpoint = libisl.isl_set_sample_point(islset)
+        if bool(libisl.isl_point_is_void(islpoint)):
+            libisl.isl_point_free(islpoint)
+            raise ValueError('domain must be non-empty')
         point = {}
         for index, symbol in enumerate(self.symbols):
             coordinate = libisl.isl_point_get_coordinate_val(islpoint,
                 libisl.isl_dim_set, index)
             coordinate = islhelper.isl_val_to_int(coordinate)
             point[symbol] = coordinate
         point = {}
         for index, symbol in enumerate(self.symbols):
             coordinate = libisl.isl_point_get_coordinate_val(islpoint,
                 libisl.isl_dim_set, index)
             coordinate = islhelper.isl_val_to_int(coordinate)
             point[symbol] = coordinate
-        if bool(libisl.isl_point_is_void(islpoint)):
-            point = None
         libisl.isl_point_free(islpoint)
         return point
 
         libisl.isl_point_free(islpoint)
         return point
 
index e4f996c..755547e 100644 (file)
@@ -104,7 +104,8 @@ class TestDomain(unittest.TestCase):
 
     def test_sample(self):
         self.assertEqual(self.square6.sample(), {Symbol('x'): 1, Symbol('y'): 3})
 
     def test_sample(self):
         self.assertEqual(self.square6.sample(), {Symbol('x'): 1, Symbol('y'): 3})
-        self.assertEqual(self.empty.sample(), None)
+        with self.assertRaises(ValueError):
+            self.empty.sample()
         self.assertEqual(self.universe.sample(), {})
 
     def test_intersection(self):
         self.assertEqual(self.universe.sample(), {})
 
     def test_intersection(self):