def destandardizing_net(batch_t: Tensor, min_std: float = 1e-7) -> nn.Module: """Net that de-standardizes the output so the NN can learn the standardized target. Args: batch_t: Batched tensor from which mean and std deviation (across first dimension) are computed. min_std: Minimum value of the standard deviation to use when z-scoring to avoid division by zero. Returns: Neural network module for z-scoring """ is_valid_t, *_ = handle_invalid_x(batch_t, True) t_mean = torch.mean(batch_t[is_valid_t], dim=0) if len(batch_t > 1): t_std = torch.std(batch_t[is_valid_t], dim=0) t_std[t_std < min_std] = min_std else: t_std = 1 logging.warning( """Using a one-dimensional batch will instantiate a Standardize transform with (mean, std) parameters which are not representative of the data. We allow this behavior because you might be loading a pre-trained. If this is not the case, please be sure to use a larger batch.""") return Destandardize(t_mean, t_std)
def append_simulations(self, theta: Tensor, x: Tensor) -> "RestrictionEstimator": r""" Store parameters and simulation outputs to use them for training later. Data ar stored as entries in lists for each type of variable (parameter/data). Args: theta: Parameter sets. x: Simulation outputs. Returns: `RestrictionEstimator` object (returned so that this function is chainable). """ validate_theta_and_x(theta, x) if self._valid_or_invalid_criterion == "nan": label, _, _ = handle_invalid_x(x) else: label = self._valid_or_invalid_criterion(x) label = label.long() if self._data_round_index: self._data_round_index.append(self._data_round_index[-1] + 1) else: self._data_round_index.append(0) self._theta_roundwise.append(theta) self._x_roundwise.append(x) self._label_roundwise.append(label) return self
def test_handle_invalid_x(x_shape, set_seed): x = torch.rand(x_shape) x[x < 0.1] = float("nan") x[x > 0.9] = float("inf") x_is_valid, *_ = handle_invalid_x(x, exclude_invalid_x=True) assert torch.isfinite(x[x_is_valid]).all()