def __init__(self, name, **kwargs): self.name = name metadata = load_primitive(name) self.primitive = import_object(metadata['primitive']) self._fit = metadata.get('fit', dict()) self.fit_args = self._fit.get('args', []) self.fit_method = self._fit.get('method') self._produce = metadata['produce'] self.produce_args = self._produce['args'] self.produce_output = self._produce['output'] self.produce_method = self._produce.get('method') self._class = bool(self.produce_method) hyperparameters = metadata.get('hyperparameters', dict()) init_params, fit_params, produce_params = self._extract_params(kwargs, hyperparameters) self._hyperparameters = init_params self._fit_params = fit_params self._produce_params = produce_params self._tunable = self._get_tunable(hyperparameters, init_params) default = { name: param['default'] for name, param in self._tunable.items() # TODO: support undefined defaults } self.set_hyperparameters(default)
def test_load_primitive_success(): primitive = {'name': 'temp.primitive', 'primitive': 'temp.primitive'} with tempfile.TemporaryDirectory() as tempdir: primitives.add_primitives_path(tempdir) primitive_path = os.path.join(tempdir, 'temp.primitive.json') with open(primitive_path, 'w') as primitive_file: json.dump(primitive, primitive_file, indent=4) loaded = primitives.load_primitive('temp.primitive') assert primitive == loaded
def test_load_primitive_value_error(): with pytest.raises(ValueError): primitives.load_primitive('invalid.primitive')