Ejemplo n.º 1
0
    def __call__(
        self,
        search_space: Optional[SearchSpace] = None,
        experiment: Optional[Experiment] = None,
        data: Optional[Data] = None,
        silently_filter_kwargs: bool = True,  # TODO[Lena]: default to False
        **kwargs: Any,
    ) -> ModelBridge:
        assert self.value in MODEL_KEY_TO_MODEL_SETUP, f"Unknown model {self.value}"
        # All model bridges require either a search space or an experiment.
        assert search_space or experiment, "Search space or experiment required."
        model_setup_info = MODEL_KEY_TO_MODEL_SETUP[self.value]
        model_class = model_setup_info.model_class
        bridge_class = model_setup_info.bridge_class
        if not silently_filter_kwargs:
            validate_kwarg_typing(  # TODO[Lena]: T46467254, pragma: no cover
                typed_callables=[model_class, bridge_class],
                search_space=search_space,
                experiment=experiment,
                data=data,
                **kwargs,
            )

        # Create model with consolidated arguments: defaults + passed in kwargs.
        model_kwargs = consolidate_kwargs(
            kwargs_iterable=[get_function_default_arguments(model_class), kwargs],
            keywords=get_function_argument_names(model_class),
        )
        model = model_class(**model_kwargs)

        # Create `ModelBridge`: defaults + standard kwargs + passed in kwargs.
        bridge_kwargs = consolidate_kwargs(
            kwargs_iterable=[
                get_function_default_arguments(bridge_class),
                model_setup_info.standard_bridge_kwargs,
                {"transforms": model_setup_info.transforms},
                kwargs,
            ],
            keywords=get_function_argument_names(
                function=bridge_class, omit=["experiment", "search_space", "data"]
            ),
        )

        # Create model bridge with the consolidated kwargs.
        model_bridge = bridge_class(
            search_space=search_space or not_none(experiment).search_space,
            experiment=experiment,
            data=data,
            model=model,
            **bridge_kwargs,
        )

        # Store all kwargs on model bridge, to be saved on generator run.
        model_bridge._set_kwargs_to_save(
            model_key=self.value,
            model_kwargs=_encode_callables_as_references(model_kwargs),
            bridge_kwargs=_encode_callables_as_references(bridge_kwargs),
        )
        return model_bridge
Ejemplo n.º 2
0
 def _get_model_kwargs(
         info: ModelSetup,
         kwargs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
     return consolidate_kwargs(
         [get_function_default_arguments(info.model_class), kwargs],
         keywords=get_function_argument_names(info.model_class),
     )
Ejemplo n.º 3
0
 def _get_bridge_kwargs(
     info: ModelSetup, kwargs: Optional[Dict[str, Any]] = None
 ) -> Dict[str, Any]:
     return consolidate_kwargs(
         [
             get_function_default_arguments(info.bridge_class),
             info.standard_bridge_kwargs,
             {"transforms": info.transforms},
             kwargs,
         ],
         keywords=get_function_argument_names(
             info.bridge_class, omit=["experiment", "search_space", "data"]
         ),
     )
Ejemplo n.º 4
0
    def __call__(
        self,
        search_space: Optional[SearchSpace] = None,
        experiment: Optional[Experiment] = None,
        data: Optional[Data] = None,
        silently_filter_kwargs: bool = True,  # TODO[Lena]: default to False
        **kwargs: Any,
    ) -> ModelBridge:
        assert self.value in MODEL_KEY_TO_MODEL_SETUP
        # All model bridges require either a search space or an experiment.
        assert search_space or experiment, "Search space or experiment required."
        model_setup_info = MODEL_KEY_TO_MODEL_SETUP[self.value]
        model_class = model_setup_info.model_class
        bridge_class = model_setup_info.bridge_class
        if not silently_filter_kwargs:
            validate_kwarg_typing(  # TODO[Lena]: T46467254, pragma: no cover
                typed_callables=[model_class, bridge_class],
                search_space=search_space,
                experiment=experiment,
                data=data,
                **kwargs,
            )

        # Create model with consolidated arguments: defaults + passed in kwargs.
        model_kwargs = consolidate_kwargs(
            kwargs_iterable=[get_function_default_arguments(model_class), kwargs],
            keywords=get_function_argument_names(model_class),
        )
        model = model_class(**model_kwargs)

        # Create `ModelBridge`: defaults + standard kwargs + passed in kwargs.
        bridge_kwargs = consolidate_kwargs(
            kwargs_iterable=[
                get_function_default_arguments(bridge_class),
                model_setup_info.standard_bridge_kwargs,
                {"transforms": model_setup_info.transforms},
                kwargs,
            ],
            keywords=get_function_argument_names(
                function=bridge_class, omit=["experiment", "search_space", "data"]
            ),
        )

        # Create model bridge with the consolidated kwargs.
        model_bridge = bridge_class(
            search_space=search_space or not_none(experiment).search_space,
            experiment=experiment,
            data=data,
            model=model,
            **bridge_kwargs,
        )

        # Temporarily ignore Botorch callable & torch-typed arguments, as those
        # are not serializable to JSON out-of-the-box. TODO[Lena]: T46527142
        if isinstance(model, TorchModel):
            model_kwargs = {kw: p for kw, p in model_kwargs.items() if not callable(p)}
            bridge_kwargs = {
                kw: p for kw, p in bridge_kwargs.items() if kw[:5] != "torch"
            }

        # Store all kwargs on model bridge, to be saved on generator run.
        model_bridge._set_kwargs_to_save(
            model_key=self.value, model_kwargs=model_kwargs, bridge_kwargs=bridge_kwargs
        )
        return model_bridge