def testVerifyConfig(self): unspecified_model_config = configs.CalibratedLatticeEnsembleConfig( feature_configs=copy.deepcopy(unspecified_feature_configs), lattices='random', num_lattices=3, lattice_rank=2, separate_calibrators=True, output_initialization=[-1.0, 1.0]) with self.assertRaisesRegex( ValueError, 'Lattices are not fully specified for ensemble config.'): premade_lib.verify_config(unspecified_model_config) premade_lib.set_random_lattice_ensemble(unspecified_model_config) with self.assertRaisesRegex( ValueError, 'Element 0 for list/tuple 0 for feature categorical monotonicity is ' 'not an index: 0.0'): premade_lib.verify_config(unspecified_model_config) fixed_feature_configs = copy.deepcopy(unspecified_feature_configs) premade_lib.set_categorical_monotonicities(fixed_feature_configs) unspecified_model_config.feature_configs = fixed_feature_configs premade_lib.verify_config(unspecified_model_config) specified_model_config = configs.CalibratedLatticeEnsembleConfig( feature_configs=copy.deepcopy(specified_feature_configs), lattices=[['numerical_1', 'categorical'], ['numerical_2', 'categorical']], num_lattices=2, lattice_rank=2, separate_calibrators=True, output_initialization=[-1.0, 1.0]) premade_lib.verify_config(specified_model_config)
def testSetRandomLattices(self): random_model_config = configs.CalibratedLatticeEnsembleConfig( feature_configs=copy.deepcopy(unspecified_feature_configs), lattices='random', num_lattices=3, lattice_rank=2, separate_calibrators=True, output_initialization=[-1.0, 1.0]) premade_lib.set_random_lattice_ensemble(random_model_config) self.assertLen(random_model_config.lattices, 3) self.assertListEqual( [2, 2, 2], [len(lattice) for lattice in random_model_config.lattices]) specified_model_config = configs.CalibratedLatticeEnsembleConfig( feature_configs=copy.deepcopy(specified_feature_configs), lattices=[['numerical_1', 'categorical'], ['numerical_2', 'categorical']], num_lattices=2, lattice_rank=2, separate_calibrators=True, output_initialization=[-1.0, 1.0]) with self.assertRaisesRegex( ValueError, 'model_config.lattices must be set to \'random\'.'): premade_lib.set_random_lattice_ensemble(specified_model_config)