def get_simulations(self,
                        starting_round: int = 0
                        ) -> Tuple[Tensor, Tensor, Tensor]:
        r"""
        Return all $(\theta, x, label)$ pairs that have been passed to this object.

        The label had been inferred from the `valid_or_invalid_criterion`.
        """
        theta = get_simulations_since_round(self._theta_roundwise,
                                            self._data_round_index,
                                            starting_round)
        x = get_simulations_since_round(self._x_roundwise,
                                        self._data_round_index, starting_round)
        label = get_simulations_since_round(self._label_roundwise,
                                            self._data_round_index,
                                            starting_round)

        return theta, x, label
Esempio n. 2
0
    def get_simulations(
        self,
        starting_round: int = 0,
        exclude_invalid_x: bool = True,
        warn_on_invalid: bool = True,
    ) -> Tuple[Tensor, Tensor, Tensor]:
        r"""
        Returns all $\theta$, $x$, and prior_masks from rounds >= `starting_round`.

        If requested, do not return invalid data.

        Args:
            starting_round: The earliest round to return samples from (we start counting
                from zero).
            exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
                during training.
            warn_on_invalid: Whether to give out a warning if invalid simulations were
                found.

        Returns: Parameters, simulation outputs, prior masks.
        """

        theta = get_simulations_since_round(
            self._theta_roundwise, self._data_round_index, starting_round
        )
        x = get_simulations_since_round(
            self._x_roundwise, self._data_round_index, starting_round
        )
        prior_masks = get_simulations_since_round(
            self._prior_masks, self._data_round_index, starting_round
        )

        # Check for NaNs in simulations.
        is_valid_x, num_nans, num_infs = handle_invalid_x(x, exclude_invalid_x)
        # Check for problematic z-scoring
        warn_if_zscoring_changes_data(x)
        if warn_on_invalid:
            warn_on_invalid_x(num_nans, num_infs, exclude_invalid_x)
            warn_on_invalid_x_for_snpec_leakage(
                num_nans, num_infs, exclude_invalid_x, type(self).__name__, self._round
            )

        return theta[is_valid_x], x[is_valid_x], prior_masks[is_valid_x]