def _zero_mean_forward(self, x): if not isinstance(x, tuple): x_mean = x x_var = None else: x_mean = x[0] x_var = x[1] y_mean = F.linear(x_mean, torch.zeros_like(self.W).t()) + self.bias W_var = self._get_var(self.W_logvar) bias_var = self._get_var(self.bias_logvar) if x_var is None: xx = x_mean * x_mean y_var = torch.diag_embed(F.linear(xx, W_var.t()) + bias_var) else: y_var = compute_linear_var(x_mean, x_var, torch.zeros_like(self.W), W_var, self.bias, bias_var) if self.deterministic: return y_mean, y_var else: dst = MultivariateNormal(loc=y_mean, covariance_matrix=y_var) sample = dst.rsample() return sample, None
def _mc_forward(self, x): if isinstance(x, tuple): x_mean = x[0] x_var = x[1] else: x_mean = x if self.zero_mean: lrt_mean = 0.0 else: lrt_mean = F.linear(x_mean, self.W) if self.bias is not None: lrt_mean = lrt_mean + self.bias sigma2 = torch.exp(self.log_alpha) * self.W * self.W if self.permute_sigma: sigma2 = sigma2.view(-1)[torch.randperm( self.in_features * self.out_features).cuda()].view( self.out_features, self.in_features) if x_var is None: x_var = torch.diag_embed(x_mean * x_mean) lrt_cov = compute_linear_var(x_mean, x_var, self.W.t(), sigma2.t()) dst = MultivariateNormal(lrt_mean, covariance_matrix=lrt_cov) return dst.rsample(), None
def forward(self, *state_args, deterministic=True): x = super(Policy, self).forward(*state_args) mean, log_std = torch.split(x, x.shape[1] // 2, dim=1) log_std = self.std_clamp(log_std) if deterministic: action = torch.tanh(mean) log_prob = torch.zeros(log_std.shape[0]).unsqueeze_(-1) else: std = log_std.exp() normal = MultivariateNormal(mean, torch.diag_embed(std.pow(2))) action_base = normal.rsample() log_prob = normal.log_prob(action_base) log_prob.unsqueeze_(-1) action = torch.tanh(action_base) action_bound_compensation = torch.log(1. - action.pow(2) + np.finfo(float).eps).sum( dim=1, keepdim=True) log_prob.sub_(action_bound_compensation) return action, log_prob
def forward(self, matrix, rets=None, **kwargs): """Perform forward pass. Only accepts keyword arguments to avoid ambiguity. Parameters ---------- matrix : torch.Tensor Of shape (n_samples, n_assets, n_assets) representing the square of the covariance matrix if `self.square=True` else the covariance matrix itself. rets : torch.Tensor or None Of shape (n_samples, n_assets) representing expected returns (or whatever the feature extractor decided to encode). Note that `NCO` and `AnalyticalMarkowitz` allow for `rets=None` (using only minimum variance). kwargs : dict All additional input arguments the `self.allocator` needs to perform forward pass. Returns ------- weights : torch.Tensor Of shape (n_samples, n_assets) representing the optimal weights. """ if self.random_state is not None: torch.manual_seed(self.random_state) n_samples, n_assets, _ = matrix.shape dtype, device = matrix.dtype, matrix.device n_draws = self.n_draws or n_assets # make sure that if None then we have the same N=M covmat = matrix @ matrix if self.sqrt else matrix dist_rets = torch.zeros(n_samples, n_assets, dtype=dtype, device=device) if rets is None else rets dist = MultivariateNormal(loc=dist_rets, covariance_matrix=covmat) portfolios = [] # n_portfolios elements of (n_samples, n_assets) for _ in range(self.n_portfolios): draws = dist.rsample((n_draws,)) # (n_draws, n_samples, n_assets) rets_ = draws.mean(dim=0) if rets is not None else None # (n_samples, n_assets) covmat_ = CovarianceMatrix(sqrt=self.uses_sqrt)(draws.permute(1, 0, 2)) # (n_samples, n_assets, ...) if isinstance(self.allocator, (AnalyticalMarkowitz, NCO)): portfolio = self.allocator(covmat=covmat_, rets=rets_) elif isinstance(self.allocator, NumericalMarkowitz): gamma = kwargs['gamma'] alpha = kwargs['alpha'] portfolio = self.allocator(rets_, covmat_, gamma, alpha) portfolios.append(portfolio) portfolios_t = torch.stack(portfolios, dim=0) # (n_portfolios, n_samples, n_assets) return portfolios_t.mean(dim=0)
def select_action(self, x): mu, cov = self.forward(x) tril = self.reshape_output(mu, cov) dist = MultivariateNormal(mu, scale_tril=tril) if self.pwd: action = dist.rsample() else: action = dist.sample() log_prob = dist.log_prob(action) entropy = dist.entropy() return action, log_prob, entropy
def forward(self, x, n_samples, reparam=True, squeeze=True): q_m = self.mean_encoder(x) l_mat = self.var_encoder q_v = l_mat.matmul(l_mat.T) variational_dist = MultivariateNormal(loc=q_m, scale_tril=l_mat) if squeeze and n_samples == 1: sample_shape = [] else: sample_shape = (n_samples, ) if reparam: latent = variational_dist.rsample(sample_shape=sample_shape) else: latent = variational_dist.sample(sample_shape=sample_shape) return dict(q_m=q_m, q_v=q_v, latent=latent)
def test_log_prob(self): loc = torch.ones(self.d) wdw = self.W @ torch.diag(self.D) @ self.W.t() sI = self.s2 * self.Id sigma = sI + wdw dist2 = MultivariateNormal(loc, covariance_matrix=sigma) samples = dist2.rsample([10000]) exp_logp = dist2.log_prob(samples) dist1 = MultivariateNormalFactorIdentity(loc, self.s2, self.D, self.W) res_logp = dist1.log_prob(samples) self.assertAlmostEqual(float(exp_logp.mean()), float(res_logp.mean()), places=3)
class TanhNormal(Distribution): def __init__(self, loc, scale): super().__init__() self.normal = MultivariateNormal(loc, scale) def sample(self): return torch.tanh(self.normal.sample()) def rsample(self): return torch.tanh(self.normal.rsample()) # Calculates log probability of value using the change-of-variables technique (uses log1p = log(1 + x) for extra numerical stability) def log_prob(self, value): inv_value = (torch.log1p(value) - torch.log1p(-value)) / 2 # artanh(y) return self.normal.log_prob(inv_value) - torch.log1p( -value.pow(2) + 1e-6) # log p(f^-1(y)) + log |det(J(f^-1(y)))| @property def mean(self): return torch.tanh(self.normal.mean)
def loss(self, Y, dT, T, switch=None, K_zz_inv=None, K_xz=None, qmu=None, qs=None, anneal=1.): # Calculating covariance T_induce, induce_idx = self.get_T_induce(T) if K_zz_inv is None: K_zz_inv = self.calc_K_inv(T_induce) if K_xz is None: K_xz = self.calc_K_xz(T_induce, T) # GP loss if qmu is not None and qs is not None: q_u = MultivariateNormal(qmu, torch.diag(qs.exp())) else: mu, log_var = self.x_model(Y, dT, induce_idx, switch) q_u = MultivariateNormal(mu.squeeze(), torch.diag(log_var.squeeze().exp())) # sparse GP KL u = q_u.rsample().squeeze() p_u = MultivariateNormal(torch.zeros(u.shape[0], dtype=torch.float64), precision_matrix=K_zz_inv) kl = kl_divergence(q_u, p_u) # HMM loss X = torch.mv(K_xz, torch.mv(K_zz_inv, u)) log_pi0, log_pi, log_ab = self.calc_params(X) ll = self.log_like_dT(dT, log_ab) + self.likelihood.mixture_prob(Y) loss_hmm = -1. * hmmnorm_cython(log_pi0, log_pi.contiguous(), ll.contiguous()) loss = loss_hmm + anneal * kl return loss
def _mcvi_forward(self, x): W_var = self._get_var(self.W_logvar) bias_var = self._get_var(self.bias_logvar) if self.certain: x_mean = x x_var = None else: x_mean = x[0] x_var = x[1] y_mean = F.linear(x_mean, self.W.t()) + self.bias if self.certain or not self.deterministic: xx = x_mean * x_mean y_var = torch.diag_embed(F.linear(xx, W_var.t()) + bias_var) else: y_var = compute_linear_var(x_mean, x_var, self.W, W_var, self.bias, bias_var) dst = MultivariateNormal(loc=y_mean, covariance_matrix=y_var) sample = dst.rsample() return sample, None
def _loss_em_mc_efficient( self, past_targets: [Sequence[torch.Tensor], torch.Tensor], past_controls: Optional[Union[Sequence[ControlInputs], ControlInputs]] = None, ) -> torch.Tensor: """ Monte Carlo loss as computed in KVAE paper. Can be computed more efficiently if no missing data (no imputation), by batching some things along time-axis. """ past_controls = self._expand_particle_dim(past_controls) n_batch = len(past_targets[0]) # A) SSM related distributions: # A1) smoothing. latents_smoothed = self._smooth_efficient( past_targets=past_targets, past_controls=past_controls, return_time_tensor=True, ) state_smoothed_dist = MultivariateNormal( loc=latents_smoothed.variables.m, covariance_matrix=latents_smoothed.variables.V, ) x = state_smoothed_dist.rsample() gls_params = latents_smoothed.gls_params # A2) prior && posterior transition distribution. prior_dist = self.state_prior_model( None, batch_shape_to_prepend=(self.n_particle, n_batch)) # # A, B, R are already 0:T-1. transition_dist = MultivariateNormal( loc=matvec(gls_params.A[:-1], x[:-1]) + (matvec(gls_params.B[:-1], past_controls.state[:-1]) if gls_params.B is not None else 0.0), covariance_matrix=gls_params.R[:-1], ) # A3) posterior predictive (auxiliary) distribution. auxiliary_predictive_dist = MultivariateNormal( loc=matvec(gls_params.C, x) + (matvec(gls_params.D, past_controls.target) if gls_params.D is not None else 0.0), covariance_matrix=gls_params.Q, ) # A4) SSM related losses l_prior = (-prior_dist.log_prob(x[0:1]).sum(dim=(0, 1)) / self.n_particle) # time and particle dim l_transition = (-transition_dist.log_prob(x[1:]).sum(dim=(0, 1)) / self.n_particle) # time and particle dim l_auxiliary = (-auxiliary_predictive_dist.log_prob( latents_smoothed.variables.auxiliary).sum(dim=(0, 1)) / self.n_particle) # time and particle dim l_entropy = ( state_smoothed_dist.log_prob(x).sum(dim=(0, 1)) # negative entropy / self.n_particle) # time and particle dim # B) VAE related distributions # B1) inv_measurement_dist already obtained from smoothing (as we dont want to re-compute) # B2) measurement (decoder) distribution # transpose TPBF -> PTBF to broadcast log_prob of y (TBF) correctly z_particle_first = latents_smoothed.variables.auxiliary.transpose(0, 1) measurement_dist = self.measurement_model(z_particle_first) # B3) VAE related losses l_measurement = ( -measurement_dist.log_prob(past_targets).sum(dim=(0, 1)) / self.n_particle) # time and particle dim auxiliary_variational_dist = MultivariateNormal( loc=latents_smoothed.variables.m_auxiliary_variational, covariance_matrix=latents_smoothed.variables. V_auxiliary_variational, ) l_inv_measurement = ( auxiliary_variational_dist.log_prob(z_particle_first).sum( dim=(0, 1)) / self.n_particle) # time and particle dim assert all(t.shape == l_prior.shape for t in ( l_prior, l_transition, l_auxiliary, l_measurement, l_inv_measurement, )) l_total = (self.reconstruction_weight * l_measurement + l_inv_measurement + l_auxiliary + l_prior + l_transition + l_entropy) return l_total
def sample_activations(x, n_samples): x_mean, x_var = x[0], x[1] sampler = MultivariateNormal(loc=x_mean, covariance_matrix=x_var) samples = sampler.rsample([n_samples]) return samples
class GaussianTorchDistribution(TorchDistribution): def __init__(self, mu, chol_flat, use_cuda): super().__init__(use_cuda) self._dim = mu.shape[0] self._mu = nn.Parameter(torch.as_tensor(mu, dtype=torch.float32), requires_grad=True) self._chol_flat = nn.Parameter(torch.as_tensor(chol_flat, dtype=torch.float32), requires_grad=True) self.distribution_t = MultivariateNormal( self._mu, scale_tril=self.to_tril_matrix(self._chol_flat, self._dim)) def __copy__(self): return GaussianTorchDistribution(self._mu, self._chol_flat, self.use_cuda) def __deepcopy__(self, memodict=None): return GaussianTorchDistribution(copy.deepcopy(self._mu), copy.deepcopy(self._chol_flat), self.use_cuda) @staticmethod def to_tril_matrix(chol_flat, dim): if isinstance(chol_flat, np.ndarray): chol = np.zeros((dim, dim)) exp_fun = np.exp else: chol = torch.zeros((dim, dim)) exp_fun = torch.exp d1, d2 = np.diag_indices(dim) chol[d1, d2] += exp_fun(chol_flat[0:dim]) ld1, ld2 = np.tril_indices(dim, k=-1) chol[ld1, ld2] += chol_flat[dim:] return chol @staticmethod def flatten_matrix(mat, tril=False): if not tril: mat = scpla.cholesky(mat, lower=True) dim = mat.shape[0] d1, d2 = np.diag_indices(dim) ld1, ld2 = np.tril_indices(dim, k=-1) return np.concatenate((np.log(mat[d1, d2]), mat[ld1, ld2])) def entropy_t(self): return self.distribution_t.entropy() def mean_t(self): return self.distribution_t.mean def log_pdf_t(self, x): return self.distribution_t.log_prob(x) def sample(self): return self.distribution_t.rsample() def covariance_matrix(self): return self.distribution_t.covariance_matrix.detach().numpy() def set_weights(self, weights): set_weights([self._mu], weights[0:self._dim], self._use_cuda) set_weights([self._chol_flat], weights[self._dim:], self._use_cuda) # This is important - otherwise the changes will not be reflected! self.distribution_t = MultivariateNormal( self._mu, scale_tril=self.to_tril_matrix(self._chol_flat, self._dim)) def get_weights(self): mu_weights = get_weights([self._mu]) chol_flat_weights = get_weights([self._chol_flat]) return np.concatenate([mu_weights, chol_flat_weights]) def parameters(self): return [self._mu, self._chol_flat]
def mvnrnd(mu, sigma, sample_shape=()): d = MultivariateNormal(loc=mu, covariance_matrix=sigma) return d.rsample(sample_shape)
def forward(self, x, mask, num_particles=4): log_hat_p_acc = torch.zeros(x.size(1)).to(device) # (batch_size, ) kl_acc = torch.zeros(x.size(1)).to(device) # (batch_size, ) h = Variable( torch.zeros(self.n_layers, x.size(1) * num_particles, self.h_dim)).to(device) c = Variable( torch.zeros(self.n_layers, x.size(1) * num_particles, self.h_dim)).to(device) # with torch.autograd.set_detect_anomaly(True): for t in range(x.size(0)): # VRNN Cell xts = x[t].repeat((1, num_particles)).reshape( (x.size(1) * num_particles, -1)) phi_x_ts = self.phi_x( xts) # [batch_size, num_particles, embed_size] enc_t = self.enc( torch.cat([phi_x_ts, h[-1]], 1)) # [batch_size, num_particles, embed_size] enc_mean_t = self.enc_mean( enc_t) # [batch_size, num_particles, latent_size] enc_std_t = self.enc_std( enc_t) # [batch_size, num_particles, latent_size] encoder_dist = MultivariateNormal( enc_mean_t, scale_tril=torch.diag_embed(enc_std_t)) prior_t = self.prior(h[-1]) prior_mean_t = self.prior_mean(prior_t) prior_std_t = self.prior_std(prior_t) prior_dist = MultivariateNormal( prior_mean_t, scale_tril=torch.diag_embed(prior_std_t)) z_t_is = encoder_dist.rsample( ) # reparametrizable # [batch_size * seq_len, latent_size] phi_z_ts = self.phi_z(z_t_is) dec_t = self.dec(torch.cat([phi_z_ts, h[-1]], 1)) dec_mean_t = self.dec_mean(dec_t) decoder_dist = Bernoulli(probs=dec_mean_t) prior_logprob_ti = prior_dist.log_prob(z_t_is.detach()) encoder_logprob_ti = encoder_dist.log_prob(z_t_is.detach()) decoder_logprob_ti = decoder_dist.log_prob(xts).sum(-1) # recurrence _, (h, c) = self.rnn( torch.cat([phi_x_ts, phi_z_ts], 1).unsqueeze(0), (h, c)) kl = torch.distributions.kl_divergence(encoder_dist, prior_dist) kl_acc += kl.mean(-1) * mask[t] nll = self._nll_bernoulli(dec_mean_t, xts) log_alpha_ti = -(nll + kl) # log_alpha_ti = prior_logprob_ti + decoder_logprob_ti - encoder_logprob_ti # [batch_size, ] log_alpha_ti = log_alpha_ti.reshape( x.size(1), -1) # [batch_size, num_particles] log_alpha_ti = log_alpha_ti * mask[t][ None].T # [batch_size, num_particles] * [batch_size, 1] # hat_p = torch.exp(logweight_acc + log_alpha_ti) # [batch_size, num_particles] # log_hat_p = torch.exp(logweight_acc).sum(-1)) log_hat_p = torch.logsumexp( log_alpha_ti.clone(), dim=-1) - math.log(float(num_particles)) log_hat_p_acc += log_hat_p * mask[t] # logweight_acc *= (1. - should_resample_tiled.reshape(x.size(1), num_particles).float()) iwae_bound = torch.sum(log_hat_p_acc) # kl_acc = kl_acc.mean(-1) # kl = torch.mean(kl_acc.reshape(x.size(1), -1), dim=-1) # return fivo_loss, kld_loss, nll_loss, \ # (all_enc_mean, all_enc_std), \ # (all_dec_mean, all_dec_std), \ # log_hat_ps return -iwae_bound, log_hat_p_acc, _, kl_acc, log_hat_p_acc
def logprob_w_cov_gaussian_posterior(self, input, sample_size=128, z=None, std=None): # init batch_size = input.size(0) input = input.view(batch_size, self.input_dim) assert sample_size >= 2 * self.z_dim ''' get z and pseudo log q(newz|x) ''' z, newz = [], [] #cov_qz, rv_z = [], [] logposterior = [] inp = self.encode._forward_inp(input).detach() #for i in range(sample_size): for i in range(batch_size): _inp = inp[i:i + 1, :].expand(sample_size, inp.size(1)) _nos = self.encode._forward_nos(batch_size=sample_size, std=std, device=input.device).detach() _z = self.encode._forward_all(_inp, _nos) # ssz x zdim z += [_z.detach().unsqueeze(0)] z = torch.cat(z, dim=0) # bsz x ssz x zdim mu_qz = torch.mean(z, dim=1) # bsz x zdim for i in range(batch_size): _cov_qz = get_covmat(z[i, :, :]) _rv_z = MultivariateNormal(mu_qz[i], _cov_qz) _newz = _rv_z.rsample(torch.Size([1, sample_size])) _logposterior = _rv_z.log_prob(_newz) #cov_qz += [_cov_qz.unsqueeze(0)] #rv_z += [_rv_z] newz += [_newz] logposterior += [_logposterior] #cov_qz = torch.cat(cov_qz, dim=0) # bsz x zdim x zdim newz = torch.cat(newz, dim=0) # bsz x ssz x zdim logposterior = torch.cat(logposterior, dim=0) # bsz x ssz ''' get log p(z) ''' # get prior (as unit normal dist) mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) logprior = logprob_gaussian(mu_pz, logvar_pz, newz, do_unsqueeze=False, do_mean=False) logprior = torch.sum(logprior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz ''' get log p(x|z) ''' # decode logit_x = [] #for i in range(sample_size): for i in range(batch_size): _, _logit_x = self.decode(newz[i, :, :]) # ssz x zdim logit_x += [_logit_x.detach().unsqueeze(0)] logit_x = torch.cat(logit_x, dim=0) # bsz x ssz x input_dim _input = input.unsqueeze(1).expand( batch_size, sample_size, self.input_dim) # bsz x ssz x input_dim loglikelihood = -F.binary_cross_entropy_with_logits( logit_x, _input, reduction='none') loglikelihood = torch.sum(loglikelihood, dim=2) # bsz x ssz ''' get log p(x|z)p(z)/q(z|x) ''' logprob = loglikelihood + logprior - logposterior # bsz x ssz logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) rprob = (logprob - logprob_max).exp() # relative prob logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 # return return logprob.mean()
def forward(self, x, mask, num_particles=4): logweight_acc = torch.zeros(x.size(1), num_particles).to( device) # (batch_size, num_particles) log_hat_p_acc = torch.zeros(x.size(1)).to(device) # (batch_size, ) log_hat_p_iwae_acc = torch.zeros(x.size(1)).to(device) kl_acc = torch.zeros(x.size(1)).to(device) # (batch_size, ) # [0, 1, 2, 3, 4, 5, 6, 7, ... ] noresampleidxs = torch.arange(x.size(1) * num_particles).to(device) h = Variable( torch.zeros(self.n_layers, x.size(1) * num_particles, self.h_dim)).to(device) c = Variable( torch.zeros(self.n_layers, x.size(1) * num_particles, self.h_dim)).to(device) # with torch.autograd.set_detect_anomaly(True): for t in range(x.size(0)): # VRNN Cell xts = x[t].repeat((1, num_particles)).reshape( (x.size(1) * num_particles, -1)) phi_x_ts = self.phi_x( xts) # [batch_size * num_particle, embed_size] enc_t = self.enc(torch.cat([phi_x_ts, h[-1]], 1)) enc_mean_t = self.enc_mean(enc_t) enc_std_t = self.enc_std(enc_t) encoder_dist = MultivariateNormal( enc_mean_t, scale_tril=torch.diag_embed(enc_std_t)) prior_t = self.prior(h[-1]) prior_mean_t = self.prior_mean(prior_t) prior_std_t = self.prior_std(prior_t) prior_dist = MultivariateNormal( prior_mean_t, scale_tril=torch.diag_embed(prior_std_t)) z_t_is = encoder_dist.rsample( ) # reparametrizable # [batch_size * seq_len, latent_size] phi_z_ts = self.phi_z(z_t_is) dec_t = self.dec(torch.cat([phi_z_ts, h[-1]], 1)) dec_mean_t = self.dec_mean(dec_t) decoder_dist = Bernoulli(probs=dec_mean_t) prior_logprob_ti = prior_dist.log_prob(z_t_is.detach()) + 1e-7 encoder_logprob_ti = encoder_dist.log_prob(z_t_is.detach()) + 1e-7 decoder_logprob_ti = decoder_dist.log_prob(xts).sum(-1) + 1e-7 # recurrence _, (h, c) = self.rnn( torch.cat([phi_x_ts, phi_z_ts], 1).unsqueeze(0), (h, c)) kl = torch.distributions.kl_divergence(encoder_dist, prior_dist) kl_acc += kl.mean(-1) * mask[t] nll = self._nll_bernoulli(dec_mean_t, xts) # log_alpha_ti = prior_logprob_ti + decoder_logprob_ti - encoder_logprob_ti # [batch_size, ] log_alpha_ti = -(nll + kl) log_alpha_ti = log_alpha_ti.reshape( x.size(1), -1) # [batch_size, num_particles] log_alpha_ti = log_alpha_ti * mask[t][ None].T # [batch_size, num_particles] * [batch_size, 1] # hat_p = torch.exp(logweight_acc + log_alpha_ti) # [batch_size, num_particles] logweight_acc += log_alpha_ti # Add resampling procedure here # ess = 1. / (torch.exp(logweight_acc) ** 2).sum(-1) # [batch_size, ] # logess = torch.log(1. / (torch.exp(logweight_acc) ** 2).sum(-1) ) logess_num = 2 * torch.logsumexp(logweight_acc, dim=-1) logess_denom = torch.logsumexp(2 * logweight_acc, dim=-1) logess = logess_num - logess_denom if not self.use_resampling_gradient: resample_dist = Categorical( logits=logweight_acc.reshape(x.size(1), num_particles)) resampled_idxs = resample_dist.sample([num_particles]).T # [0, 0, 0, 0, 4, 4, 4, 4, ... ] sample_offset = torch.arange(x.size(1)).repeat([ num_particles, 1 ]).T.reshape(-1).to(device) * num_particles resampled_idxs = resampled_idxs.reshape(-1) + sample_offset should_resample = logess <= torch.log( torch.ones_like(logess).to(device) * num_particles / 2.0) should_resample = should_resample & mask[t].bool() should_resample_tiled = should_resample.repeat( [num_particles, 1]).T.reshape(-1) new_idxs = torch.where(should_resample_tiled, resampled_idxs, noresampleidxs) h[-1] = h[-1][new_idxs] c[-1] = c[-1][new_idxs] log_hat_p = torch.logsumexp(logweight_acc.clone(), dim=-1) - math.log( float(num_particles)) log_hat_p_acc += log_hat_p * should_resample.float() logweight_acc *= (1. - should_resample_tiled.reshape( x.size(1), num_particles).float()) else: # raise NotImplementedError resample_dist = RelaxedOneHotCategorical( logits=logweight_acc.reshape(x.size(1), num_particles), temperature=0.1) resampled_onehot_relaxedidxs = resample_dist.rsample( [num_particles]).permute(1, 0, 2) #.reshape(-1, num_particles) should_resample = logess <= torch.log( torch.ones_like(logess).to(device) * num_particles / 2.0) should_resample = should_resample & mask[t].bool() should_resample_tiled = should_resample.repeat( [num_particles, 1]).T.reshape(-1) # noresample_onehot = torch.eye(x.size(1) * num_particles) for batch_idx in range(x.size(1)): if should_resample[batch_idx]: # cur_slice = (batch_idx * x.size(1) * num_particles) : (batch_idx * x.size(1) * num_particles + x.size(1) * num_particles) h[-1][(batch_idx * num_particles) : (batch_idx * num_particles + num_particles)] = \ resampled_onehot_relaxedidxs[batch_idx] @ h[-1][(batch_idx * num_particles) : (batch_idx * num_particles + num_particles)].clone() c[-1][(batch_idx * num_particles) : (batch_idx * num_particles + num_particles)] = \ resampled_onehot_relaxedidxs[batch_idx] @ c[-1][(batch_idx * num_particles) : (batch_idx * num_particles + num_particles)].clone() log_hat_p = torch.logsumexp(logweight_acc.clone(), dim=-1) - math.log( float(num_particles)) log_hat_p_acc += log_hat_p * should_resample.float() logweight_acc *= (1. - should_resample_tiled.reshape( x.size(1), num_particles).float()) log_hat_p_iwae_acc += ( torch.logsumexp(log_alpha_ti.detach(), dim=-1) - math.log(float(num_particles))) * mask[t] #computing losses # kld_loss /= self.num_zs # nll_loss /= self.num_zs log_hat_p_acc += torch.logsumexp(logweight_acc, dim=-1) - math.log( float(num_particles)) fivo_bound = torch.sum(log_hat_p_acc) # kl = torch.mean(kl_acc.reshape(x.size(1), -1), dim=-1) # return fivo_loss, kld_loss, nll_loss, \ # (all_enc_mean, all_enc_std), \ # (all_dec_mean, all_dec_std), \ # log_hat_ps return -fivo_bound, log_hat_p_acc, logweight_acc, kl_acc, log_hat_p_iwae_acc
def _loss_em_mc( self, past_targets: [Sequence[torch.Tensor], torch.Tensor], past_controls: Optional[Union[Sequence[ControlInputs], ControlInputs]] = None, past_targets_is_observed: Optional[Union[Sequence[torch.Tensor], torch.Tensor]] = None, ) -> torch.Tensor: """" Monte Carlo loss as computed in KVAE paper """ n_batch = len(past_targets[0]) past_controls = self._expand_particle_dim(past_controls) # A) SSM related distributions: # A1) smoothing. latents_smoothed = self.smooth( past_targets=past_targets, past_controls=past_controls, past_targets_is_observed=past_targets_is_observed, ) m = torch.stack([l.variables.m for l in latents_smoothed]) V = torch.stack([l.variables.V for l in latents_smoothed]) z = torch.stack([l.variables.auxiliary for l in latents_smoothed]) state_smoothed_dist = MultivariateNormal(loc=m, covariance_matrix=V) x = state_smoothed_dist.rsample() A = torch.stack([l.gls_params.A for l in latents_smoothed]) C = torch.stack([l.gls_params.C for l in latents_smoothed]) LR = torch.stack([l.gls_params.LR for l in latents_smoothed]) LQ = torch.stack([l.gls_params.LQ for l in latents_smoothed]) if latents_smoothed[0].gls_params.B is not None: B = torch.stack([l.gls_params.B for l in latents_smoothed]) else: B = None if latents_smoothed[0].gls_params.D is not None: D = torch.stack([l.gls_params.D for l in latents_smoothed]) else: D = None # A2) prior && posterior transition distribution. prior_dist = self.state_prior_model( None, batch_shape_to_prepend=(self.n_particle, n_batch)) # # A, B, R are already 0:T-1. transition_dist = MultivariateNormal( loc=matvec(A[:-1], x[:-1]) + (matvec( B[:-1], past_controls.state[:-1]) if B is not None else 0.0), scale_tril=LR[:-1], ) # A3) posterior predictive (auxiliary) distribution. auxiliary_predictive_dist = MultivariateNormal( loc=matvec(C, x) + (matvec(D, past_controls.target) if D is not None else 0.0), scale_tril=LQ, ) # A4) SSM related losses # mean over particle dim, sum over time (after masking), leave batch dim l_prior = -prior_dist.log_prob(x[0:1]).mean(dim=1).sum(dim=0) l_transition = -transition_dist.log_prob(x[1:]).mean(dim=1).sum(dim=0) l_entropy = state_smoothed_dist.log_prob(x).mean(dim=1).sum(dim=0) _l_aux_timewise = -auxiliary_predictive_dist.log_prob(z).mean(dim=1) if past_targets_is_observed is not None: _l_aux_timewise = _l_aux_timewise * past_targets_is_observed l_auxiliary = _l_aux_timewise.sum(dim=0) # B) VAE related distributions # B1) inv_measurement_dist already obtained from smoothing (as we dont want to re-compute) # B2) measurement (decoder) distribution # transpose TPBF -> PTBF to broadcast log_prob of y (TBF) correctly z_particle_first = z.transpose(0, 1) measurement_dist = self.measurement_model(z_particle_first) # B3) VAE related losses # We use z_particle_first for correct broadcasting -> dim=0 is particle. _l_meas_timewise = -measurement_dist.log_prob(past_targets).mean(dim=0) if past_targets_is_observed is not None: _l_meas_timewise = _l_meas_timewise * past_targets_is_observed l_measurement = _l_meas_timewise.sum(dim=0) auxiliary_variational_dist = MultivariateNormal( loc=torch.stack([ l.variables.m_auxiliary_variational for l in latents_smoothed ]), covariance_matrix=torch.stack([ l.variables.V_auxiliary_variational for l in latents_smoothed ]), ) _l_variational_timewise = auxiliary_variational_dist.log_prob( z_particle_first).mean(dim=0) # again dim=0 is particle dim here. if past_targets_is_observed is not None: _l_variational_timewise = (_l_variational_timewise * past_targets_is_observed) l_inv_measurement = _l_variational_timewise.sum(dim=0) assert all(t.shape == l_prior.shape for t in ( l_prior, l_transition, l_auxiliary, l_measurement, l_inv_measurement, )) l_total = (self.reconstruction_weight * l_measurement + l_inv_measurement + l_auxiliary + l_prior + l_transition + l_entropy) return l_total
def logprob_w_cov_gaussian_posterior(self, input, sample_size=128, z=None, std=None): # init batch_size = input.size(0) input = input.view(batch_size, self.input_dim) assert sample_size >= 2*self.z_dim #assert int(math.sqrt(sample_size))**2 == sample_size ''' get z and pseudo log q(newz|x) ''' #z, newz = [], [] #logposterior = [] #inp = self.encode._forward_inp(input).detach() #for i in range(batch_size): # _inp = inp[i:i+1, :].expand(sample_size, inp.size(1)) # _nos = self.encode._forward_nos(sample_size, std=std, device=input.device).detach() # _z = self.encode._forward_all(_inp, _nos) # ssz x zdim # z += [_z.detach().unsqueeze(0)] #z = torch.cat(z, dim=0) # bsz x ssz x zdim #_nz = int(math.sqrt(sample_size)) _, _, _, _, z, _, _, _, _ = self.encode._forward(input, std=std, nz=sample_size) # bsz x ssz x zdim newz = [] logposterior = [] eye = torch.eye(self.z_dim, device=z.device) mu_qz = torch.mean(z, dim=1) # bsz x zdim for i in range(batch_size): _cov_qz = get_covmat(z[i, :, :]) + 1e-5*eye _rv_z = MultivariateNormal(mu_qz[i], _cov_qz) _newz = _rv_z.rsample(torch.Size([1, sample_size])) _logposterior = _rv_z.log_prob(_newz) newz += [_newz] logposterior += [_logposterior] newz = torch.cat(newz, dim=0) # bsz x ssz x zdim logposterior = torch.cat(logposterior, dim=0) # bsz x ssz ''' get log p(z) ''' # get prior (as unit normal dist) mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) logprior = logprob_gaussian(mu_pz, logvar_pz, newz, do_unsqueeze=False, do_mean=False) logprior = torch.sum(logprior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz ''' get log p(x|z) ''' # decode mu_x, logvar_x = [], [] #for i in range(sample_size): for i in range(batch_size): _, _mu_x, _logvar_x = self.decode(newz[i, :, :]) mu_x += [_mu_x.detach().unsqueeze(0)] logvar_x += [_logvar_x.detach().unsqueeze(0)] mu_x = torch.cat(mu_x, dim=0) # bsz x ssz x input_dim logvar_x = torch.cat(logvar_x, dim=0) # bsz x ssz x input_dim _input = input.unsqueeze(1).expand(batch_size, sample_size, self.input_dim) # bsz x ssz x input_dim loglikelihood = logprob_gaussian(mu_x, logvar_x, _input, do_unsqueeze=False, do_mean=False) loglikelihood = torch.sum(loglikelihood, dim=2) # bsz x ssz ''' get log p(x|z)p(z)/q(z|x) ''' logprob = loglikelihood + logprior - logposterior # bsz x ssz logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) rprob = (logprob - logprob_max).exp() # relative prob logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 # return return logprob.mean()
def filter_step( self, lats_tm1: (LatentsKVAE, None), tar_t: torch.Tensor, ctrl_t: ControlInputs, tar_is_obs_t: Optional[torch.Tensor] = None, ) -> LatentsKVAE: is_initial_step = lats_tm1 is None if tar_is_obs_t is None: tar_is_obs_t = torch.ones( tar_t.shape[:-1], dtype=tar_t.dtype, device=tar_t.device, ) # 1) Initial step must prepare previous latents with prior and learnt z. if is_initial_step: n_particle, n_batch = self.n_particle, len(tar_t) state_prior = self.state_prior_model( None, batch_shape_to_prepend=(n_particle, n_batch), ) z_init = self.z_initial[None, None].repeat(n_particle, n_batch, 1) lats_tm1 = LatentsKVAE( variables=GLSVariablesKVAE( m=state_prior.loc, V=state_prior.covariance_matrix, Cov=None, x=None, auxiliary=z_init, rnn_state=None, m_auxiliary_variational=None, V_auxiliary_variational=None, ), gls_params=None, ) # 2) Compute GLS params rnn_state_t, rnn_output_t = self.compute_deterministic_switch_step( rnn_input=lats_tm1.variables.auxiliary, rnn_prev_state=lats_tm1.variables.rnn_state, ) gls_params_t = self.gls_base_parameters( switch=rnn_output_t, controls=ctrl_t, ) # Perform filter step: # 3) Prediction Step: Only for t > 0 and using previous GLS params. # (In KVAE, they do first update then prediction step.) if is_initial_step: mp, Vp, = lats_tm1.variables.m, lats_tm1.variables.V else: mp, Vp = filter_forward_prediction_step( m=lats_tm1.variables.m, V=lats_tm1.variables.V, R=lats_tm1.gls_params.R, A=lats_tm1.gls_params.A, b=lats_tm1.gls_params.b, ) # 4) Update step # 4a) Observed data: Infer pseudo-obs by encoding obs && Bayes update auxiliary_variational_dist_t = self.encoder(tar_t) z_infer_t = auxiliary_variational_dist_t.rsample([self.n_particle]) m_infer_t, V_infer_t = filter_forward_measurement_step( y=z_infer_t, m=mp, V=Vp, Q=gls_params_t.Q, C=gls_params_t.C, d=gls_params_t.d, ) # 4b) Choice: inferred / predicted m, V for observed / missing data. is_filtered = tar_is_obs_t[None, :].repeat(self.n_particle, 1).byte() replace_m_fw = is_filtered[:, :, None].repeat(1, 1, mp.shape[2]) replace_V_fw = is_filtered[:, :, None, None].repeat( 1, 1, Vp.shape[2], Vp.shape[3], ) assert replace_m_fw.shape == m_infer_t.shape == mp.shape assert replace_V_fw.shape == V_infer_t.shape == Vp.shape m_t = torch.where(replace_m_fw, m_infer_t, mp) V_t = torch.where(replace_V_fw, V_infer_t, Vp) # 4c) Missing Data: Predict pseudo-observations && No Bayes update mpz_t, Vpz_t = filter_forward_predictive_distribution( m=m_t, # posterior predictive or one-step-predictive (if missing) V=V_t, Q=gls_params_t.Q, C=gls_params_t.C, d=gls_params_t.d, ) auxiliary_predictive_dist_t = MultivariateNormal( loc=mpz_t, covariance_matrix=Vpz_t, ) z_gen_t = auxiliary_predictive_dist_t.rsample() # 4d) Choice: inferred / predicted z for observed / missing data. # One-step predictive if missing and inferred from encoder otherwise. replace_z = is_filtered[:, :, None].repeat(1, 1, z_gen_t.shape[2]) z_t = torch.where(replace_z, z_infer_t, z_gen_t) # 5) Put result in Latents object, used in next iteration lats_t = LatentsKVAE( variables=GLSVariablesKVAE( m=m_t, V=V_t, Cov=None, x=None, auxiliary=z_t, rnn_state=rnn_state_t, m_auxiliary_variational=auxiliary_variational_dist_t.loc, V_auxiliary_variational=auxiliary_variational_dist_t. covariance_matrix, ), gls_params=gls_params_t, ) return lats_t