예제 #1
0
 def testFullFactorial(self):
     generator = FullFactorialGenerator()
     parameter_values = [[1, 2], ["foo", "bar"]]
     generated_points, weights, _ = generator.gen(
         n=-1, parameter_values=parameter_values, objective_weights=np.ones(1)
     )
     expected_points = [[1, "foo"], [1, "bar"], [2, "foo"], [2, "bar"]]
     self.assertEqual(generated_points, expected_points)
     self.assertEqual(weights, [1 for _ in range(len(expected_points))])
예제 #2
0
 def testFullFactorialFixedFeatures(self):
     generator = FullFactorialGenerator(max_cardinality=5, check_cardinality=True)
     parameter_values = [[1, 2], ["foo", "bar"]]
     generated_points, weights, _ = generator.gen(
         n=-1,
         parameter_values=parameter_values,
         objective_weights=np.ones(1),
         fixed_features={1: "foo"},
     )
     expected_points = [[1, "foo"], [2, "foo"]]
     self.assertEqual(generated_points, expected_points)
     self.assertEqual(weights, [1 for _ in range(len(expected_points))])
예제 #3
0
파일: factory.py 프로젝트: zorrock/Ax
def get_factorial(search_space: SearchSpace) -> DiscreteModelBridge:
    """Instantiates a factorial generator."""
    return DiscreteModelBridge(
        search_space=search_space,
        data=Data(),
        model=FullFactorialGenerator(),
        transforms=Discrete_X_trans,
    )
예제 #4
0
def get_modelbridge(mock_gen_arms,
                    mock_observations_from_data,
                    status_quo_name: Optional[str] = None) -> ModelBridge:
    exp = get_experiment()
    modelbridge = ModelBridge(
        search_space=get_search_space(),
        model=FullFactorialGenerator(),
        experiment=exp,
        data=get_data(),
        status_quo_name=status_quo_name,
    )
    modelbridge._predict = mock.MagicMock(
        "ax.modelbridge.base.ModelBridge._predict",
        autospec=True,
        return_value=[get_observation().data],
    )
    return modelbridge
예제 #5
0
    def testFullFactorialValidation(self):
        # Raise error because cardinality exceeds max cardinality
        generator = FullFactorialGenerator(max_cardinality=5, check_cardinality=True)
        parameter_values = [[1, 2], ["foo", "bar"], [True, False]]
        with self.assertRaises(ValueError):
            generated_points, weights, _ = generator.gen(
                n=-1, parameter_values=parameter_values, objective_weights=np.ones(1)
            )

        # Raise error because n != -1
        generator = FullFactorialGenerator()
        parameter_values = [[1, 2], ["foo", "bar"]]
        with self.assertRaises(ValueError):
            generated_points, weights, _ = generator.gen(
                n=5, parameter_values=parameter_values, objective_weights=np.ones(1)
            )