コード例 #1
0
def test_distribution_shapes(
    distr: Distribution,
    expected_batch_shape: Tuple,
    expected_event_shape: Tuple,
):
    assert distr.batch_shape == expected_batch_shape
    assert distr.event_shape == expected_event_shape

    x = distr.sample()

    assert x.shape == distr.batch_shape + distr.event_shape

    loss = distr.loss(x)

    assert loss.shape == distr.batch_shape

    x1 = distr.sample(num_samples=1)

    assert x1.shape == (1, ) + distr.batch_shape + distr.event_shape

    x3 = distr.sample(num_samples=3)

    assert x3.shape == (3, ) + distr.batch_shape + distr.event_shape

    def has_quantile(d):
        return isinstance(d, (Uniform, Gaussian, Laplace))

    if (has_quantile(distr) or isinstance(distr, TransformedDistribution)
            and has_quantile(distr.base_distribution)):
        qs1 = distr.quantile(mx.nd.array([0.5]))
        assert qs1.shape == (1, ) + distr.batch_shape + distr.event_shape
コード例 #2
0
    def get_samples_for_loss(self, distr: Distribution) -> Tensor:
        """
        Get samples to compute the final loss. These are samples directly drawn from the given `distr` if coherence is
        not enforced yet; otherwise the drawn samples are reconciled.

        Parameters
        ----------
        distr
            Distribution instances

        Returns
        -------
        samples
            Tensor with shape (num_samples, batch_size, seq_len, target_dim)

        """
        samples = distr.sample_rep(num_samples=self.num_samples_for_loss,
                                   dtype="float32")

        # Determine which epoch we are currently in.
        self.batch_no += 1
        epoch_no = self.batch_no // self.num_batches_per_epoch + 1
        epoch_frac = epoch_no / self.epochs

        if (self.coherent_train_samples
                and epoch_frac > self.warmstart_epoch_frac):
            coherent_samples = self.reconcile_samples(samples)
            assert_shape(coherent_samples, samples.shape)
            return coherent_samples
        else:
            return samples
コード例 #3
0
    def loss(self, F, target: Tensor, distr: Distribution) -> Tensor:
        """
        Computes loss given the output of the network in the form of
        distribution. The loss is given by:

            `self.CRPS_weight` * `loss_CRPS` + `self.likelihood_weight` *
            `neg_likelihoods`,

         where
          * `loss_CRPS` is computed on the samples drawn from the predicted
            `distr` (optionally after reconciling them),
          * `neg_likelihoods` are either computed directly using the predicted
            `distr` or from the estimated distribution based on (coherent)
            samples, depending on the `sample_LH` flag.

        Parameters
        ----------
        F
        target
            Tensor with shape (batch_size, seq_len, target_dim)
        distr
            Distribution instances

        Returns
        -------
        Loss
            Tensor with shape (batch_size, seq_length, 1)
        """

        # Sample from the predicted distribution if we are computing CRPS loss
        # or likelihood using the distribution based on (coherent) samples.
        # Samples shape: (num_samples, batch_size, seq_len, target_dim)
        if self.sample_LH or (self.CRPS_weight > 0.0):
            samples = self.get_samples_for_loss(distr=distr)

        if self.sample_LH:
            # Estimate the distribution based on (coherent) samples.
            distr = LowrankMultivariateGaussian.fit(F, samples=samples, rank=0)

        neg_likelihoods = -distr.log_prob(target).expand_dims(axis=-1)

        loss_CRPS = F.zeros_like(neg_likelihoods)
        if self.CRPS_weight > 0.0:
            loss_CRPS = (
                EmpiricalDistribution(samples=samples, event_dim=1)
                .crps_univariate(x=target)
                .expand_dims(axis=-1)
            )

        return (
            self.CRPS_weight * loss_CRPS
            + self.likelihood_weight * neg_likelihoods
        )
コード例 #4
0
    def loss(self, F, target: Tensor, distr: Distribution) -> Tensor:
        """
        Returns negative log-likelihood of `target` under `distr`.

        Parameters
        ----------
        F
        target
            Tensor with shape (batch_size, seq_len, target_dim)
        distr
            Distribution instances

        Returns
        -------
        Loss
            Tensor with shape (batch_size, seq_length, 1)
        """
        # we sum the last axis to have the same shape for all likelihoods
        # (batch_size, subseq_length, 1)
        return -distr.log_prob(target).expand_dims(axis=-1)