Ejemplo n.º 1
0
def get_GPKG(
    experiment: Experiment,
    data: Data,
    search_space: Optional[SearchSpace] = None,
    cost_intercept: float = 0.01,
    dtype: torch.dtype = torch.double,
    device: torch.device = DEFAULT_TORCH_DEVICE,
    transforms: List[Type[Transform]] = Cont_X_trans + Y_trans,
    transform_configs: Optional[Dict[str, TConfig]] = None,
    **kwargs: Any,
) -> TorchModelBridge:
    """Instantiates a GP model that generates points with KG."""
    if search_space is None:
        search_space = experiment.search_space
    if data.df.empty:  # pragma: no cover
        raise ValueError("GP+KG BotorchModel requires non-empty data.")

    inputs = {
        "search_space": search_space,
        "experiment": experiment,
        "data": data,
        "cost_intercept": cost_intercept,
        "torch_dtype": dtype,
        "torch_device": device,
        "transforms": transforms,
        "transform_configs": transform_configs,
    }

    if any(p.is_fidelity for k, p in experiment.parameters.items()):
        inputs["linear_truncated"] = kwargs.get("linear_truncated", True)
    return checked_cast(TorchModelBridge, Models.GPKG(**inputs))  # pyre-ignore: [16]
Ejemplo n.º 2
0
def get_GPKG(
    experiment: Experiment,
    data: Data,
    search_space: Optional[SearchSpace] = None,
    cost_intercept: float = 0.01,
    dtype: torch.dtype = torch.double,
    device: torch.device = DEFAULT_TORCH_DEVICE,
    transforms: List[Type[Transform]] = Cont_X_trans + Y_trans,
    winsorization_limits: Optional[Tuple[Optional[float],
                                         Optional[float]]] = None,
    **kwargs: Any,
) -> TorchModelBridge:
    """Instantiates a GP model that generates points with KG."""
    if search_space is None:
        search_space = experiment.search_space
    if data.df.empty:  # pragma: no cover
        raise ValueError("GP+KG BotorchModel requires non-empty data.")
    transform_configs = {}
    if winsorization_limits is not None:
        transform_configs["Winsorize"] = {
            "winsorization_lower": winsorization_limits[0] or 0.0,
            "winsorization_upper": winsorization_limits[1] or 0.0,
        }
    return checked_cast(
        TorchModelBridge,
        Models.GPKG(
            search_space=search_space,
            experiment=experiment,
            data=data,
            cost_intercept=cost_intercept,
            linear_truncated=kwargs.get("linear_truncated", True),
            torch_dtype=dtype,
            torch_device=device,
            transforms=transforms,
            transform_configs=transform_configs,
        ),
    )