def check_standardization( Y: Tensor, atol_mean: float = 1e-2, atol_std: float = 1e-2, raise_on_fail: bool = False, ) -> None: r"""Check that tensor is standardized (zero mean, unit variance). Args: Y: The input tensor of shape `batch_shape x n x m`. Typically the train targets of a model. Standardization is checked across the `n`-dimension. atol_mean: The tolerance for the mean check. atol_std: The tolerance for the std check. raise_on_fail: If True, raise an exception instead of a warning. """ with torch.no_grad(): Ymean, Ystd = torch.mean(Y, dim=-2), torch.std(Y, dim=-2) if torch.abs(Ymean).max() > atol_mean or torch.abs(Ystd - 1).max() > atol_std: msg = ( "Input data is not standardized. Please consider scaling the " "input to zero mean and unit variance.") if raise_on_fail: raise InputDataError(msg) warnings.warn(msg, InputDataWarning)
def check_min_max_scaling(X: Tensor, strict: bool = False, atol: float = 1e-2, raise_on_fail: bool = False) -> None: r"""Check that tensor is normalized to the unit cube. Args: X: A `batch_shape x n x d` input tensor. Typically the training inputs of a model. strict: If True, require `X` to be scaled to the unit cube (rather than just to be contained within the unit cube). atol: The tolerance for the boundary check. Only used if `strict=True`. raise_on_fail: If True, raise an exception instead of a warning. """ with torch.no_grad(): Xmin, Xmax = torch.min(X, dim=-1)[0], torch.max(X, dim=-1)[0] msg = None if strict and max(torch.abs(Xmin).max(), torch.abs(Xmax - 1).max()) > atol: msg = "scaled" if torch.any(Xmin < -atol) or torch.any(Xmax > 1 + atol): msg = "contained" if msg is not None: msg = (f"Input data is not {msg} to the unit cube. " "Please consider min-max scaling the input data.") if raise_on_fail: raise InputDataError(msg) warnings.warn(msg, InputDataWarning)
def check_no_nans(Z: Tensor) -> None: r"""Check that tensor does not contain NaN values. Raises an InputDataError if `Z` contains NaN values. Args: Z: The input tensor. """ if torch.any(torch.isnan(Z)).item(): raise InputDataError("Input data contains NaN values.")
def validate_input_scaling( train_X: Tensor, train_Y: Tensor, train_Yvar: Optional[Tensor] = None, raise_on_fail: bool = False, ignore_X_dims: Optional[List[int]] = None, ) -> None: r"""Helper function to validate input data to models. Args: train_X: A `n x d` or `batch_shape x n x d` (batch mode) tensor of training features. train_Y: A `n x m` or `batch_shape x n x m` (batch mode) tensor of training observations. train_Yvar: A `batch_shape x n x m` or `batch_shape x n x m` (batch mode) tensor of observed measurement noise. raise_on_fail: If True, raise an error instead of emitting a warning (only for normalization/standardization checks, an error is always raised if NaN values are present). ignore_X_dims: For this subset of dimensions from `{1, ..., d}`, ignore the min-max scaling check. This function is typically called inside the constructor of standard BoTorch models. It validates the following: (i) none of the inputs contain NaN values (ii) the training data (`train_X`) is normalized to the unit cube for all dimensions except those in `ignore_X_dims`. (iii) the training targets (`train_Y`) are standardized (zero mean, unit var) No checks (other than the NaN check) are performed for observed variances (`train_Yvar`) at this point. """ if settings.validate_input_scaling.off(): return check_no_nans(train_X) check_no_nans(train_Y) if train_Yvar is not None: check_no_nans(train_Yvar) if torch.any(train_Yvar < 0): raise InputDataError("Input data contains negative variances.") check_min_max_scaling(X=train_X, raise_on_fail=raise_on_fail, ignore_dims=ignore_X_dims) check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)
def check_min_max_scaling( X: Tensor, strict: bool = False, atol: float = 1e-2, raise_on_fail: bool = False, ignore_dims: Optional[List[int]] = None, ) -> None: r"""Check that tensor is normalized to the unit cube. Args: X: A `batch_shape x n x d` input tensor. Typically the training inputs of a model. strict: If True, require `X` to be scaled to the unit cube (rather than just to be contained within the unit cube). atol: The tolerance for the boundary check. Only used if `strict=True`. raise_on_fail: If True, raise an exception instead of a warning. ignore_dims: Subset of dimensions where the min-max scaling check is omitted. """ ignore_dims = ignore_dims or [] check_dims = list(set(range(X.shape[-1])) - set(ignore_dims)) if len(check_dims) == 0: return None with torch.no_grad(): X_check = X[..., check_dims] Xmin = torch.min(X_check, dim=-1).values Xmax = torch.max(X_check, dim=-1).values msg = None if strict and max(torch.abs(Xmin).max(), torch.abs(Xmax - 1).max()) > atol: msg = "scaled" if torch.any(Xmin < -atol) or torch.any(Xmax > 1 + atol): msg = "contained" if msg is not None: msg = (f"Input data is not {msg} to the unit cube. " "Please consider min-max scaling the input data.") if raise_on_fail: raise InputDataError(msg) warnings.warn(msg, InputDataWarning)