def get_GPKG( experiment: Experiment, data: Data, search_space: Optional[SearchSpace] = None, cost_intercept: float = 0.01, dtype: torch.dtype = torch.double, device: torch.device = DEFAULT_TORCH_DEVICE, transforms: List[Type[Transform]] = Cont_X_trans + Y_trans, transform_configs: Optional[Dict[str, TConfig]] = None, **kwargs: Any, ) -> TorchModelBridge: """Instantiates a GP model that generates points with KG.""" if search_space is None: search_space = experiment.search_space if data.df.empty: # pragma: no cover raise ValueError("GP+KG BotorchModel requires non-empty data.") inputs = { "search_space": search_space, "experiment": experiment, "data": data, "cost_intercept": cost_intercept, "torch_dtype": dtype, "torch_device": device, "transforms": transforms, "transform_configs": transform_configs, } if any(p.is_fidelity for k, p in experiment.parameters.items()): inputs["linear_truncated"] = kwargs.get("linear_truncated", True) return checked_cast(TorchModelBridge, Models.GPKG(**inputs)) # pyre-ignore: [16]
def get_GPKG( experiment: Experiment, data: Data, search_space: Optional[SearchSpace] = None, cost_intercept: float = 0.01, dtype: torch.dtype = torch.double, device: torch.device = DEFAULT_TORCH_DEVICE, transforms: List[Type[Transform]] = Cont_X_trans + Y_trans, winsorization_limits: Optional[Tuple[Optional[float], Optional[float]]] = None, **kwargs: Any, ) -> TorchModelBridge: """Instantiates a GP model that generates points with KG.""" if search_space is None: search_space = experiment.search_space if data.df.empty: # pragma: no cover raise ValueError("GP+KG BotorchModel requires non-empty data.") transform_configs = {} if winsorization_limits is not None: transform_configs["Winsorize"] = { "winsorization_lower": winsorization_limits[0] or 0.0, "winsorization_upper": winsorization_limits[1] or 0.0, } return checked_cast( TorchModelBridge, Models.GPKG( search_space=search_space, experiment=experiment, data=data, cost_intercept=cost_intercept, linear_truncated=kwargs.get("linear_truncated", True), torch_dtype=dtype, torch_device=device, transforms=transforms, transform_configs=transform_configs, ), )