コード例 #1
0
    def __call__(
        self,
        search_space: Optional[SearchSpace] = None,
        experiment: Optional[Experiment] = None,
        data: Optional[Data] = None,
        silently_filter_kwargs: bool = True,  # TODO[Lena]: default to False
        **kwargs: Any,
    ) -> ModelBridge:
        assert self.value in MODEL_KEY_TO_MODEL_SETUP, f"Unknown model {self.value}"
        # All model bridges require either a search space or an experiment.
        assert search_space or experiment, "Search space or experiment required."
        model_setup_info = MODEL_KEY_TO_MODEL_SETUP[self.value]
        model_class = model_setup_info.model_class
        bridge_class = model_setup_info.bridge_class
        if not silently_filter_kwargs:
            validate_kwarg_typing(  # TODO[Lena]: T46467254, pragma: no cover
                typed_callables=[model_class, bridge_class],
                search_space=search_space,
                experiment=experiment,
                data=data,
                **kwargs,
            )

        # Create model with consolidated arguments: defaults + passed in kwargs.
        model_kwargs = consolidate_kwargs(
            kwargs_iterable=[get_function_default_arguments(model_class), kwargs],
            keywords=get_function_argument_names(model_class),
        )
        model = model_class(**model_kwargs)

        # Create `ModelBridge`: defaults + standard kwargs + passed in kwargs.
        bridge_kwargs = consolidate_kwargs(
            kwargs_iterable=[
                get_function_default_arguments(bridge_class),
                model_setup_info.standard_bridge_kwargs,
                {"transforms": model_setup_info.transforms},
                kwargs,
            ],
            keywords=get_function_argument_names(
                function=bridge_class, omit=["experiment", "search_space", "data"]
            ),
        )

        # Create model bridge with the consolidated kwargs.
        model_bridge = bridge_class(
            search_space=search_space or not_none(experiment).search_space,
            experiment=experiment,
            data=data,
            model=model,
            **bridge_kwargs,
        )

        # Store all kwargs on model bridge, to be saved on generator run.
        model_bridge._set_kwargs_to_save(
            model_key=self.value,
            model_kwargs=_encode_callables_as_references(model_kwargs),
            bridge_kwargs=_encode_callables_as_references(bridge_kwargs),
        )
        return model_bridge
コード例 #2
0
    def __call__(
        self,
        search_space: Optional[SearchSpace] = None,
        experiment: Optional[Experiment] = None,
        data: Optional[Data] = None,
        silently_filter_kwargs: bool = True,  # TODO[Lena]: default to False
        **kwargs: Any,
    ) -> ModelBridge:
        assert self.value in MODEL_KEY_TO_MODEL_SETUP
        # All model bridges require either a search space or an experiment.
        assert search_space or experiment, "Search space or experiment required."
        model_setup_info = MODEL_KEY_TO_MODEL_SETUP[self.value]
        model_class = model_setup_info.model_class
        bridge_class = model_setup_info.bridge_class
        if not silently_filter_kwargs:
            validate_kwarg_typing(  # TODO[Lena]: T46467254, pragma: no cover
                typed_callables=[model_class, bridge_class],
                search_space=search_space,
                experiment=experiment,
                data=data,
                **kwargs,
            )

        # Create model with consolidated arguments: defaults + passed in kwargs.
        model_kwargs = consolidate_kwargs(
            kwargs_iterable=[get_function_default_arguments(model_class), kwargs],
            keywords=get_function_argument_names(model_class),
        )
        model = model_class(**model_kwargs)

        # Create `ModelBridge`: defaults + standard kwargs + passed in kwargs.
        bridge_kwargs = consolidate_kwargs(
            kwargs_iterable=[
                get_function_default_arguments(bridge_class),
                model_setup_info.standard_bridge_kwargs,
                {"transforms": model_setup_info.transforms},
                kwargs,
            ],
            keywords=get_function_argument_names(
                function=bridge_class, omit=["experiment", "search_space", "data"]
            ),
        )

        # Create model bridge with the consolidated kwargs.
        model_bridge = bridge_class(
            search_space=search_space or not_none(experiment).search_space,
            experiment=experiment,
            data=data,
            model=model,
            **bridge_kwargs,
        )

        # Temporarily ignore Botorch callable & torch-typed arguments, as those
        # are not serializable to JSON out-of-the-box. TODO[Lena]: T46527142
        if isinstance(model, TorchModel):
            model_kwargs = {kw: p for kw, p in model_kwargs.items() if not callable(p)}
            bridge_kwargs = {
                kw: p for kw, p in bridge_kwargs.items() if kw[:5] != "torch"
            }

        # Store all kwargs on model bridge, to be saved on generator run.
        model_bridge._set_kwargs_to_save(
            model_key=self.value, model_kwargs=model_kwargs, bridge_kwargs=bridge_kwargs
        )
        return model_bridge
コード例 #3
0
ファイル: test_kwargutils.py プロジェクト: proteanblank/Ax
    def test_validate_kwarg_typing(self):
        def typed_callable(arg1: int, arg2: str = None) -> None:
            pass

        def typed_callable_with_dict(arg3: int, arg4: Dict[str, int]) -> None:
            pass

        def typed_callable_valid(arg3: int, arg4: str = None) -> None:
            pass

        def typed_callable_dup_keyword(arg2: int, arg4: str = None) -> None:
            pass

        def typed_callable_with_callable(
                arg1: int, arg2: Callable[[int], Dict[str, int]]) -> None:
            pass

        def typed_callable_extra_arg(arg1: int, arg2: str, arg3: bool) -> None:
            pass

        # pass
        try:
            kwargs = {"arg1": 1, "arg2": "test", "arg3": 2}
            validate_kwarg_typing([typed_callable, typed_callable_valid],
                                  **kwargs)
        except Exception:
            self.assertTrue(False, "Exception raised on valid kwargs")

        # pass with complex data structure
        try:
            kwargs = {"arg1": 1, "arg2": "test", "arg3": 2, "arg4": {"k1": 1}}
            validate_kwarg_typing([typed_callable, typed_callable_with_dict],
                                  **kwargs)
        except Exception:
            self.assertTrue(False, "Exception raised on valid kwargs")

        # callable as arg (same arg count but diff type)
        try:
            kwargs = {"arg1": 1, "arg2": typed_callable}
            validate_kwarg_typing([typed_callable_with_callable], **kwargs)
        except Exception:
            self.assertTrue(False, "Exception raised on valid kwargs")

        # callable as arg (diff arg count)
        try:
            kwargs = {"arg1": 1, "arg2": typed_callable_extra_arg}
            validate_kwarg_typing([typed_callable_with_callable], **kwargs)
        except Exception:
            self.assertTrue(False, "Exception raised on valid kwargs")

        # kwargs contains extra keywords
        with self.assertRaises(ValueError):
            kwargs = {"arg1": 1, "arg2": "test", "arg3": 3, "arg5": 4}
            typed_callables = [typed_callable, typed_callable_valid]
            validate_kwarg_typing(typed_callables, **kwargs)

        # callables have duplicate keywords
        with patch.object(logger, "debug") as mock_debug:
            kwargs = {"arg1": 1, "arg2": "test", "arg4": "test_again"}
            typed_callables = [typed_callable, typed_callable_dup_keyword]
            validate_kwarg_typing(typed_callables, **kwargs)
            mock_debug.assert_called_once_with(
                f"`{typed_callables}` have duplicate keyword argument: arg2.")

        # mismatch types
        with patch.object(logger, "warning") as mock_warning:
            kwargs = {"arg1": 1, "arg2": "test", "arg3": "test_again"}
            typed_callables = [typed_callable, typed_callable_valid]
            validate_kwarg_typing(typed_callables, **kwargs)
            expected_message = (
                f"`{typed_callable_valid}` expected argument `arg3` to be of type"
                f" {type(1)}. Got test_again (type: {type('test_again')}).")
            mock_warning.assert_called_once_with(expected_message)

        # mismatch types with Dict
        with patch.object(logger, "warning") as mock_warning:
            str_dic = {"k1": "test"}
            kwargs = {"arg1": 1, "arg2": "test", "arg3": 2, "arg4": str_dic}
            typed_callables = [typed_callable, typed_callable_with_dict]
            validate_kwarg_typing(typed_callables, **kwargs)
            expected_message = (
                f"`{typed_callable_with_dict}` expected argument `arg4` to be of type"
                f" typing.Dict[str, int]. Got {str_dic} (type: {type(str_dic)})."
            )
            mock_warning.assert_called_once_with(expected_message)

        # mismatch types with callable as arg
        with patch.object(logger, "warning") as mock_warning:
            kwargs = {"arg1": 1, "arg2": "test_again"}
            typed_callables = [typed_callable_with_callable]
            validate_kwarg_typing(typed_callables, **kwargs)
            expected_message = (
                f"`{typed_callable_with_callable}` expected argument `arg2` to be of"
                f" type typing.Callable[[int], typing.Dict[str, int]]. "
                f"Got test_again (type: {type('test_again')}).")
            mock_warning.assert_called_once_with(expected_message)