예제 #1
0
 def test_sampler_good_starting_point(self):
     """Test running sampler with valid starting point"""
     input_path = os.path.join(TEST_DIR, '2d.txt')
     constraint = Constraint(input_path)
     sampler = Sampler(constraint)
     results = sampler.sample(100)
     self.assertEqual(len(results), 100)
     self.assertTrue(all(constraint.apply(pt) for pt in results))
예제 #2
0
 def test_sampler_bad_starting_point(self):
     """Test running sampler with invalid starting point, where
     the sampler has to find an edge first"""
     input_path = os.path.join(TEST_DIR, '2d-badstart.txt')
     constraint = Constraint(input_path)
     sampler = Sampler(constraint)
     results = sampler.sample(100)
     self.assertEqual(len(results), 100)
     self.assertTrue(all(constraint.apply(pt) for pt in results))
예제 #3
0
def main(input_file, output_file, n_results):
    """Runs the sampling algorithm on the problem defined in the
    constraint input file INPUT_FILE and outputs N_RESULTS number
    of sampled points to OUTPUT_FILE.
    """
    constraint = Constraint(input_file)
    sampler = Sampler(constraint)
    samples = sampler.sample(int(n_results))

    with open(output_file, 'w') as f:
        for point in samples:
            vector = " ".join(str(v) for v in point)
            f.write(vector + '\n')
예제 #4
0
    def test_is_valid_point(self):
        """Test if points lie in the unit hypercube and satisfy constraints"""
        input_path = os.path.join(TEST_DIR, '2d.txt')
        constraint = Constraint(input_path)
        sampler = Sampler(constraint)

        # Satisfies constraints
        self.assertTrue(sampler._is_valid_point(np.array([0.5, 0.5])))
        # Does not satisfy constraints
        self.assertFalse(sampler._is_valid_point(np.array([0.1, 0.1])))
        # Outside cube
        self.assertFalse(sampler._is_valid_point(np.array([-0.1, 0.1])))
        self.assertFalse(sampler._is_valid_point(np.array([0.1, -0.1])))
        self.assertFalse(sampler._is_valid_point(np.array([1.1, 0.1])))
        self.assertFalse(sampler._is_valid_point(np.array([0.1, 1.1])))