def test_distribution_shapes( distr: Distribution, expected_batch_shape: Tuple, expected_event_shape: Tuple, ): assert distr.batch_shape == expected_batch_shape assert distr.event_shape == expected_event_shape x = distr.sample() assert x.shape == distr.batch_shape + distr.event_shape loss = distr.loss(x) assert loss.shape == distr.batch_shape x1 = distr.sample(num_samples=1) assert x1.shape == (1, ) + distr.batch_shape + distr.event_shape x3 = distr.sample(num_samples=3) assert x3.shape == (3, ) + distr.batch_shape + distr.event_shape def has_quantile(d): return isinstance(d, (Uniform, Gaussian, Laplace)) if (has_quantile(distr) or isinstance(distr, TransformedDistribution) and has_quantile(distr.base_distribution)): qs1 = distr.quantile(mx.nd.array([0.5])) assert qs1.shape == (1, ) + distr.batch_shape + distr.event_shape
def get_samples_for_loss(self, distr: Distribution) -> Tensor: """ Get samples to compute the final loss. These are samples directly drawn from the given `distr` if coherence is not enforced yet; otherwise the drawn samples are reconciled. Parameters ---------- distr Distribution instances Returns ------- samples Tensor with shape (num_samples, batch_size, seq_len, target_dim) """ samples = distr.sample_rep(num_samples=self.num_samples_for_loss, dtype="float32") # Determine which epoch we are currently in. self.batch_no += 1 epoch_no = self.batch_no // self.num_batches_per_epoch + 1 epoch_frac = epoch_no / self.epochs if (self.coherent_train_samples and epoch_frac > self.warmstart_epoch_frac): coherent_samples = self.reconcile_samples(samples) assert_shape(coherent_samples, samples.shape) return coherent_samples else: return samples
def loss(self, F, target: Tensor, distr: Distribution) -> Tensor: """ Computes loss given the output of the network in the form of distribution. The loss is given by: `self.CRPS_weight` * `loss_CRPS` + `self.likelihood_weight` * `neg_likelihoods`, where * `loss_CRPS` is computed on the samples drawn from the predicted `distr` (optionally after reconciling them), * `neg_likelihoods` are either computed directly using the predicted `distr` or from the estimated distribution based on (coherent) samples, depending on the `sample_LH` flag. Parameters ---------- F target Tensor with shape (batch_size, seq_len, target_dim) distr Distribution instances Returns ------- Loss Tensor with shape (batch_size, seq_length, 1) """ # Sample from the predicted distribution if we are computing CRPS loss # or likelihood using the distribution based on (coherent) samples. # Samples shape: (num_samples, batch_size, seq_len, target_dim) if self.sample_LH or (self.CRPS_weight > 0.0): samples = self.get_samples_for_loss(distr=distr) if self.sample_LH: # Estimate the distribution based on (coherent) samples. distr = LowrankMultivariateGaussian.fit(F, samples=samples, rank=0) neg_likelihoods = -distr.log_prob(target).expand_dims(axis=-1) loss_CRPS = F.zeros_like(neg_likelihoods) if self.CRPS_weight > 0.0: loss_CRPS = ( EmpiricalDistribution(samples=samples, event_dim=1) .crps_univariate(x=target) .expand_dims(axis=-1) ) return ( self.CRPS_weight * loss_CRPS + self.likelihood_weight * neg_likelihoods )
def loss(self, F, target: Tensor, distr: Distribution) -> Tensor: """ Returns negative log-likelihood of `target` under `distr`. Parameters ---------- F target Tensor with shape (batch_size, seq_len, target_dim) distr Distribution instances Returns ------- Loss Tensor with shape (batch_size, seq_length, 1) """ # we sum the last axis to have the same shape for all likelihoods # (batch_size, subseq_length, 1) return -distr.log_prob(target).expand_dims(axis=-1)