def __call__( self, search_space: Optional[SearchSpace] = None, experiment: Optional[Experiment] = None, data: Optional[Data] = None, silently_filter_kwargs: bool = True, # TODO[Lena]: default to False **kwargs: Any, ) -> ModelBridge: assert self.value in MODEL_KEY_TO_MODEL_SETUP, f"Unknown model {self.value}" # All model bridges require either a search space or an experiment. assert search_space or experiment, "Search space or experiment required." model_setup_info = MODEL_KEY_TO_MODEL_SETUP[self.value] model_class = model_setup_info.model_class bridge_class = model_setup_info.bridge_class if not silently_filter_kwargs: validate_kwarg_typing( # TODO[Lena]: T46467254, pragma: no cover typed_callables=[model_class, bridge_class], search_space=search_space, experiment=experiment, data=data, **kwargs, ) # Create model with consolidated arguments: defaults + passed in kwargs. model_kwargs = consolidate_kwargs( kwargs_iterable=[get_function_default_arguments(model_class), kwargs], keywords=get_function_argument_names(model_class), ) model = model_class(**model_kwargs) # Create `ModelBridge`: defaults + standard kwargs + passed in kwargs. bridge_kwargs = consolidate_kwargs( kwargs_iterable=[ get_function_default_arguments(bridge_class), model_setup_info.standard_bridge_kwargs, {"transforms": model_setup_info.transforms}, kwargs, ], keywords=get_function_argument_names( function=bridge_class, omit=["experiment", "search_space", "data"] ), ) # Create model bridge with the consolidated kwargs. model_bridge = bridge_class( search_space=search_space or not_none(experiment).search_space, experiment=experiment, data=data, model=model, **bridge_kwargs, ) # Store all kwargs on model bridge, to be saved on generator run. model_bridge._set_kwargs_to_save( model_key=self.value, model_kwargs=_encode_callables_as_references(model_kwargs), bridge_kwargs=_encode_callables_as_references(bridge_kwargs), ) return model_bridge
def __call__( self, search_space: Optional[SearchSpace] = None, experiment: Optional[Experiment] = None, data: Optional[Data] = None, silently_filter_kwargs: bool = True, # TODO[Lena]: default to False **kwargs: Any, ) -> ModelBridge: assert self.value in MODEL_KEY_TO_MODEL_SETUP # All model bridges require either a search space or an experiment. assert search_space or experiment, "Search space or experiment required." model_setup_info = MODEL_KEY_TO_MODEL_SETUP[self.value] model_class = model_setup_info.model_class bridge_class = model_setup_info.bridge_class if not silently_filter_kwargs: validate_kwarg_typing( # TODO[Lena]: T46467254, pragma: no cover typed_callables=[model_class, bridge_class], search_space=search_space, experiment=experiment, data=data, **kwargs, ) # Create model with consolidated arguments: defaults + passed in kwargs. model_kwargs = consolidate_kwargs( kwargs_iterable=[get_function_default_arguments(model_class), kwargs], keywords=get_function_argument_names(model_class), ) model = model_class(**model_kwargs) # Create `ModelBridge`: defaults + standard kwargs + passed in kwargs. bridge_kwargs = consolidate_kwargs( kwargs_iterable=[ get_function_default_arguments(bridge_class), model_setup_info.standard_bridge_kwargs, {"transforms": model_setup_info.transforms}, kwargs, ], keywords=get_function_argument_names( function=bridge_class, omit=["experiment", "search_space", "data"] ), ) # Create model bridge with the consolidated kwargs. model_bridge = bridge_class( search_space=search_space or not_none(experiment).search_space, experiment=experiment, data=data, model=model, **bridge_kwargs, ) # Temporarily ignore Botorch callable & torch-typed arguments, as those # are not serializable to JSON out-of-the-box. TODO[Lena]: T46527142 if isinstance(model, TorchModel): model_kwargs = {kw: p for kw, p in model_kwargs.items() if not callable(p)} bridge_kwargs = { kw: p for kw, p in bridge_kwargs.items() if kw[:5] != "torch" } # Store all kwargs on model bridge, to be saved on generator run. model_bridge._set_kwargs_to_save( model_key=self.value, model_kwargs=model_kwargs, bridge_kwargs=bridge_kwargs ) return model_bridge
def test_validate_kwarg_typing(self): def typed_callable(arg1: int, arg2: str = None) -> None: pass def typed_callable_with_dict(arg3: int, arg4: Dict[str, int]) -> None: pass def typed_callable_valid(arg3: int, arg4: str = None) -> None: pass def typed_callable_dup_keyword(arg2: int, arg4: str = None) -> None: pass def typed_callable_with_callable( arg1: int, arg2: Callable[[int], Dict[str, int]]) -> None: pass def typed_callable_extra_arg(arg1: int, arg2: str, arg3: bool) -> None: pass # pass try: kwargs = {"arg1": 1, "arg2": "test", "arg3": 2} validate_kwarg_typing([typed_callable, typed_callable_valid], **kwargs) except Exception: self.assertTrue(False, "Exception raised on valid kwargs") # pass with complex data structure try: kwargs = {"arg1": 1, "arg2": "test", "arg3": 2, "arg4": {"k1": 1}} validate_kwarg_typing([typed_callable, typed_callable_with_dict], **kwargs) except Exception: self.assertTrue(False, "Exception raised on valid kwargs") # callable as arg (same arg count but diff type) try: kwargs = {"arg1": 1, "arg2": typed_callable} validate_kwarg_typing([typed_callable_with_callable], **kwargs) except Exception: self.assertTrue(False, "Exception raised on valid kwargs") # callable as arg (diff arg count) try: kwargs = {"arg1": 1, "arg2": typed_callable_extra_arg} validate_kwarg_typing([typed_callable_with_callable], **kwargs) except Exception: self.assertTrue(False, "Exception raised on valid kwargs") # kwargs contains extra keywords with self.assertRaises(ValueError): kwargs = {"arg1": 1, "arg2": "test", "arg3": 3, "arg5": 4} typed_callables = [typed_callable, typed_callable_valid] validate_kwarg_typing(typed_callables, **kwargs) # callables have duplicate keywords with patch.object(logger, "debug") as mock_debug: kwargs = {"arg1": 1, "arg2": "test", "arg4": "test_again"} typed_callables = [typed_callable, typed_callable_dup_keyword] validate_kwarg_typing(typed_callables, **kwargs) mock_debug.assert_called_once_with( f"`{typed_callables}` have duplicate keyword argument: arg2.") # mismatch types with patch.object(logger, "warning") as mock_warning: kwargs = {"arg1": 1, "arg2": "test", "arg3": "test_again"} typed_callables = [typed_callable, typed_callable_valid] validate_kwarg_typing(typed_callables, **kwargs) expected_message = ( f"`{typed_callable_valid}` expected argument `arg3` to be of type" f" {type(1)}. Got test_again (type: {type('test_again')}).") mock_warning.assert_called_once_with(expected_message) # mismatch types with Dict with patch.object(logger, "warning") as mock_warning: str_dic = {"k1": "test"} kwargs = {"arg1": 1, "arg2": "test", "arg3": 2, "arg4": str_dic} typed_callables = [typed_callable, typed_callable_with_dict] validate_kwarg_typing(typed_callables, **kwargs) expected_message = ( f"`{typed_callable_with_dict}` expected argument `arg4` to be of type" f" typing.Dict[str, int]. Got {str_dic} (type: {type(str_dic)})." ) mock_warning.assert_called_once_with(expected_message) # mismatch types with callable as arg with patch.object(logger, "warning") as mock_warning: kwargs = {"arg1": 1, "arg2": "test_again"} typed_callables = [typed_callable_with_callable] validate_kwarg_typing(typed_callables, **kwargs) expected_message = ( f"`{typed_callable_with_callable}` expected argument `arg2` to be of" f" type typing.Callable[[int], typing.Dict[str, int]]. " f"Got test_again (type: {type('test_again')}).") mock_warning.assert_called_once_with(expected_message)