def get_simulations(self, starting_round: int = 0 ) -> Tuple[Tensor, Tensor, Tensor]: r""" Return all $(\theta, x, label)$ pairs that have been passed to this object. The label had been inferred from the `valid_or_invalid_criterion`. """ theta = get_simulations_since_round(self._theta_roundwise, self._data_round_index, starting_round) x = get_simulations_since_round(self._x_roundwise, self._data_round_index, starting_round) label = get_simulations_since_round(self._label_roundwise, self._data_round_index, starting_round) return theta, x, label
def get_simulations( self, starting_round: int = 0, exclude_invalid_x: bool = True, warn_on_invalid: bool = True, ) -> Tuple[Tensor, Tensor, Tensor]: r""" Returns all $\theta$, $x$, and prior_masks from rounds >= `starting_round`. If requested, do not return invalid data. Args: starting_round: The earliest round to return samples from (we start counting from zero). exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` during training. warn_on_invalid: Whether to give out a warning if invalid simulations were found. Returns: Parameters, simulation outputs, prior masks. """ theta = get_simulations_since_round( self._theta_roundwise, self._data_round_index, starting_round ) x = get_simulations_since_round( self._x_roundwise, self._data_round_index, starting_round ) prior_masks = get_simulations_since_round( self._prior_masks, self._data_round_index, starting_round ) # Check for NaNs in simulations. is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x) # Check for problematic z-scoring warn_if_zscoring_changes_data(x) if warn_on_invalid: warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x) warn_on_invalid_x_for_snpec_leakage( num_nans, num_infs, exclude_invalid_x, type(self).__name__, self._round ) return theta[is_valid_x], x[is_valid_x], prior_masks[is_valid_x]