def test_multivariate_gaussian() -> None: num_samples = 2000 dim = 2 mu = np.arange(0, dim) / float(dim) L_diag = np.ones((dim, )) L_low = 0.1 * np.ones((dim, dim)) * np.tri(dim, k=-1) L = np.diag(L_diag) + L_low Sigma = L.dot(L.transpose()) distr = MultivariateGaussian(mu=mx.nd.array(mu), L=mx.nd.array(L)) samples = distr.sample(num_samples) mu_hat, L_hat = maximum_likelihood_estimate_sgd( MultivariateGaussianOutput(dim=dim), samples, init_biases= None, # todo we would need to rework biases a bit to use it in the multivariate case hybridize=False, learning_rate=PositiveFloat(0.01), num_epochs=PositiveInt(10), ) distr = MultivariateGaussian(mu=mx.nd.array([mu_hat]), L=mx.nd.array([L_hat])) Sigma_hat = distr.variance[0].asnumpy() assert np.allclose( mu_hat, mu, atol=0.1, rtol=0.1), f"mu did not match: mu = {mu}, mu_hat = {mu_hat}" assert np.allclose( Sigma_hat, Sigma, atol=0.1, rtol=0.1 ), f"Sigma did not match: sigma = {Sigma}, sigma_hat = {Sigma_hat}"
def sample(self, num_samples: Optional[int] = None, scale: Optional[Tensor] = None) -> Tensor: r""" Generates samples from the LDS: p(z_1, z_2, \ldots, z_{`seq_length`}). Parameters ---------- num_samples Number of samples to generate scale Scale of each sequence in x, shape (batch_size, output_dim) Returns ------- Tensor Samples, shape (num_samples, batch_size, seq_length, output_dim) """ F = self.F # Note on shapes: here we work with tensors of the following shape # in each time step t: (num_samples, batch_size, dim, dim), # where dim can be obs_dim or latent_dim or a constant 1 to facilitate # generalized matrix multiplication (gemm2) # Sample observation noise for all time steps # noise_std: (batch_size, seq_length, obs_dim, 1) noise_std = F.stack(*self.noise_std, axis=1).expand_dims(axis=-1) # samples_eps_obs[t]: (num_samples, batch_size, obs_dim, 1) samples_eps_obs = (Gaussian(noise_std.zeros_like(), noise_std).sample(num_samples).split( axis=-3, num_outputs=self.seq_length, squeeze_axis=True)) # Sample standard normal for all time steps # samples_eps_std_normal[t]: (num_samples, batch_size, obs_dim, 1) samples_std_normal = (Gaussian( noise_std.zeros_like(), noise_std.ones_like()).sample(num_samples).split( axis=-3, num_outputs=self.seq_length, squeeze_axis=True)) # Sample the prior state. # samples_lat_state: (num_samples, batch_size, latent_dim, 1) # The prior covariance is observed to be slightly negative definite whenever there is # excessive zero padding at the beginning of the time series. # We add positive tolerance to the diagonal to avoid numerical issues. # Note that `jitter_cholesky` adds positive tolerance only if the decomposition without jitter fails. state = MultivariateGaussian( self.prior_mean, jitter_cholesky(F, self.prior_cov, self.latent_dim, float_type=np.float32), ) samples_lat_state = state.sample(num_samples).expand_dims(axis=-1) samples_seq = [] for t in range(self.seq_length): # Expand all coefficients to include samples in axis 0 # emission_coeff_t: (num_samples, batch_size, obs_dim, latent_dim) # transition_coeff_t: # (num_samples, batch_size, latent_dim, latent_dim) # innovation_coeff_t: (num_samples, batch_size, 1, latent_dim) emission_coeff_t, transition_coeff_t, innovation_coeff_t = [ _broadcast_param(coeff, axes=[0], sizes=[num_samples]) if num_samples is not None else coeff for coeff in [ self.emission_coeff[t], self.transition_coeff[t], self.innovation_coeff[t], ] ] # Expand residuals as well # residual_t: (num_samples, batch_size, obs_dim, 1) residual_t = (_broadcast_param( self.residuals[t].expand_dims(axis=-1), axes=[0], sizes=[num_samples], ) if num_samples is not None else self.residuals[t].expand_dims( axis=-1)) # (num_samples, batch_size, 1, obs_dim) samples_t = (F.linalg_gemm2(emission_coeff_t, samples_lat_state) + residual_t + samples_eps_obs[t]) samples_t = (samples_t.swapaxes(dim1=2, dim2=3) if num_samples is not None else samples_t.swapaxes(dim1=1, dim2=2)) samples_seq.append(samples_t) # sample next state: (num_samples, batch_size, latent_dim, 1) samples_lat_state = F.linalg_gemm2( transition_coeff_t, samples_lat_state) + F.linalg_gemm2( innovation_coeff_t, samples_std_normal[t], transpose_a=True) # (num_samples, batch_size, seq_length, obs_dim) samples = F.concat(*samples_seq, dim=-2) return (samples if scale is None else F.broadcast_mul( samples, scale.expand_dims(axis=1).expand_dims( axis=0) if num_samples is not None else scale.expand_dims( axis=1), ))
def sample_marginals(self, num_samples: Optional[int] = None, scale: Optional[Tensor] = None) -> Tensor: r""" Generates samples from the marginals p(z_t), t = 1, \ldots, `seq_length`. Parameters ---------- num_samples Number of samples to generate scale Scale of each sequence in x, shape (batch_size, output_dim) Returns ------- Tensor Samples, shape (num_samples, batch_size, seq_length, output_dim) """ F = self.F state_mean = self.prior_mean.expand_dims(axis=-1) state_cov = self.prior_cov output_mean_seq = [] output_cov_seq = [] for t in range(self.seq_length): # compute and store observation mean at time t output_mean = F.linalg_gemm2( self.emission_coeff[t], state_mean) + self.residuals[t].expand_dims(axis=-1) output_mean_seq.append(output_mean) # compute and store observation cov at time t output_cov = F.linalg_gemm2( self.emission_coeff[t], F.linalg_gemm2( state_cov, self.emission_coeff[t], transpose_b=True), ) + make_nd_diag(F=F, x=self.noise_std[t] * self.noise_std[t], d=self.output_dim) output_cov_seq.append(output_cov.expand_dims(axis=1)) state_mean = F.linalg_gemm2(self.transition_coeff[t], state_mean) state_cov = F.linalg_gemm2( self.transition_coeff[t], F.linalg_gemm2( state_cov, self.transition_coeff[t], transpose_b=True), ) + F.linalg_gemm2( self.innovation_coeff[t], self.innovation_coeff[t], transpose_a=True, ) output_mean = F.concat(*output_mean_seq, dim=1) output_cov = F.concat(*output_cov_seq, dim=1) L = F.linalg_potrf(output_cov) output_distribution = MultivariateGaussian(output_mean, L) samples = output_distribution.sample(num_samples=num_samples) return (samples if scale is None else F.broadcast_mul( samples, scale.expand_dims(axis=1)))
def sample( self, num_samples: Optional[int] = None, scale: Optional[Tensor] = None ) -> Tensor: r""" Generates samples from the LDS: p(z_1, z_2, \ldots, z_{`seq_length`}). Parameters ---------- num_samples Number of samples to generate scale Scale of each sequence in x, shape (batch_size, output_dim) Returns ------- Tensor Samples, shape (num_samples, batch_size, seq_length, output_dim) """ F = self.F # Note on shapes: here we work with tensors of the following shape # in each time step t: (num_samples, batch_size, dim, dim), # where dim can be obs_dim or latent_dim or a constant 1 to facilitate # generalized matrix multiplication (gemm2) # Sample observation noise for all time steps # noise_std: (batch_size, seq_length, obs_dim, 1) noise_std = F.stack(*self.noise_std, axis=1).expand_dims(axis=-1) # samples_eps_obs[t]: (num_samples, batch_size, obs_dim, 1) samples_eps_obs = ( Gaussian(noise_std.zeros_like(), noise_std) .sample(num_samples) .split(axis=2, num_outputs=self.seq_length, squeeze_axis=True) ) # Sample standard normal for all time steps # samples_eps_std_normal[t]: (num_samples, batch_size, obs_dim, 1) samples_std_normal = ( Gaussian(noise_std.zeros_like(), noise_std.ones_like()) .sample(num_samples) .split(axis=2, num_outputs=self.seq_length, squeeze_axis=True) ) # Sample the prior state. # samples_lat_state: (num_samples, batch_size, latent_dim, 1) state = MultivariateGaussian( self.prior_mean, F.linalg_potrf(self.prior_cov) ) samples_lat_state = state.sample(num_samples).expand_dims(axis=-1) samples_seq = [] for t in range(self.seq_length): # Expand all coefficients to include samples in axis 0 # emission_coeff_t: (num_samples, batch_size, obs_dim, latent_dim) # transition_coeff_t: # (num_samples, batch_size, latent_dim, latent_dim) # innovation_coeff_t: (num_samples, batch_size, 1, latent_dim) emission_coeff_t, transition_coeff_t, innovation_coeff_t = [ _broadcast_param(coeff, axes=[0], sizes=[num_samples]) for coeff in [ self.emission_coeff[t], self.transition_coeff[t], self.innovation_coeff[t], ] ] # Expand residuals as well # residual_t: (num_samples, batch_size, obs_dim, 1) residual_t = _broadcast_param( self.residuals[t].expand_dims(axis=-1), axes=[0], sizes=[num_samples], ) # (num_samples, batch_size, 1, obs_dim) samples_t = ( F.linalg_gemm2(emission_coeff_t, samples_lat_state) + residual_t + samples_eps_obs[t] ).swapaxes(dim1=2, dim2=3) samples_seq.append(samples_t) # sample next state: (num_samples, batch_size, latent_dim, 1) samples_lat_state = F.linalg_gemm2( transition_coeff_t, samples_lat_state ) + F.linalg_gemm2( innovation_coeff_t, samples_std_normal[t], transpose_a=True ) # (num_samples, batch_size, seq_length, obs_dim) samples = F.concat(*samples_seq, dim=2) return ( samples if scale is None else F.broadcast_mul(samples, scale.expand_dims(axis=1)) )