def model(self, X): _id = self._id K = self.K N, D = X.shape loc_locinit = self.param_init[f'loc_prior_loc_init_{_id}'] loc_scaleinit = self.param_init[f'loc_prior_scale_init_{_id}'] cov_diag_loc_init = self.param_init[f'cov_diag_prior_loc_init_{_id}'] cov_diag_scale_init = self.param_init[f'cov_diag_prior_scale_init_{_id}'] cov_factor_loc_init = self.param_init[f'cov_factor_prior_loc_init_{_id}'] cov_factor_scale_init = self.param_init[f'cov_factor_prior_scale_init_{_id}'] with pyro.plate(f'D_{_id}', D): loc_loc = pyro.param(f'loc_loc_prior_{_id}', loc_locinit) loc_scale = pyro.param(f'loc_scale_prior_{_id}', loc_scaleinit, constraint=constraints.positive) loc = pyro.sample(f'loc_{_id}', dist.LogNormal(loc_loc, loc_scale)) cov_diag_loc = pyro.param(f'cov_diag_prior_loc_{_id}', cov_diag_loc_init) cov_diag_scale = pyro.param(f'cov_diag_prior_scale_{_id}', cov_diag_scale_init, constraint=constraints.positive) cov_diag = pyro.sample(f'cov_diag_{_id}', dist.LogNormal(cov_diag_loc, cov_diag_scale)) cov_diag = cov_diag + jitter cov_factor = None with pyro.plate(f'K_{_id}', K): cov_factor_loc = pyro.param(f'cov_factor_prior_loc_{_id}', cov_factor_loc_init) cov_factor_scale = pyro.param(f'cov_factor_prior_scale_{_id}', cov_factor_scale_init, constraint=constraints.positive) cov_factor = pyro.sample(f'cov_factor_{_id}', dist.Normal(cov_factor_loc,cov_factor_scale)) cov_factor = cov_factor.transpose(-2,-1) with pyro.plate(f'N_{_id}', size=N, subsample_size=self.batch_size) as ind: X = pyro.sample('obs', dist.LowRankMultivariateNormal(loc, cov_factor=cov_factor, cov_diag=cov_diag), obs=X.index_select(0, ind)) return X
def model(self, X): _id = self._id N, D = X.shape global_shrinkage_prior_scale_init = self.param_init[f'global_shrinkage_prior_scale_init_{_id}'] cov_diag_prior_loc_init = self.param_init[f'cov_diag_prior_loc_init_{_id}'] cov_diag_prior_scale_init = self.param_init[f'cov_diag_prior_scale_init_{_id}'] global_shrinkage_prior_scale = pyro.param(f'global_shrinkage_scale_prior_{_id}', global_shrinkage_prior_scale_init, constraint=constraints.positive) tau = pyro.sample(f'global_shrinkage_{_id}', dist.HalfNormal(global_shrinkage_prior_scale)) b = pyro.sample('b', dist.InverseGamma(0.5,1./torch.ones(D)**2).to_event(1)) lambdasquared = pyro.sample(f'local_shrinkage_{_id}', dist.InverseGamma(0.5,1./b).to_event(1)) cov_diag_loc = pyro.param(f'cov_diag_prior_loc_{_id}', cov_diag_prior_loc_init) cov_diag_scale = pyro.param(f'cov_diag_prior_scale_{_id}', cov_diag_prior_scale_init, constraint=constraints.positive) cov_diag = pyro.sample(f'cov_diag_{_id}', dist.LogNormal(cov_diag_loc, cov_diag_scale).to_event(1)) #cov_diag = cov_diag*torch.ones(D) cov_diag = cov_diag + jitter lambdasquared = lambdasquared.squeeze() if lambdasquared.dim() == 1: # outer product cov_factor_scale = torch.ger(torch.sqrt(lambdasquared),tau.repeat((tau.dim()-1)*(1,)+(D,))) else: # batch outer product cov_factor_scale = torch.einsum('bp, br->bpr', torch.sqrt(lambdasquared),tau.repeat((tau.dim()-1)*(1,)+(D,))) cov_factor = pyro.sample(f'cov_factor_{_id}', dist.Normal(0., cov_factor_scale).to_event(2)) cov_factor = cov_factor.transpose(-2,-1) with pyro.plate(f'N_{_id}', size=N, subsample_size=self.batch_size, dim=-1) as ind: X = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(D), cov_factor=cov_factor, cov_diag=cov_diag), obs=X.index_select(0, ind)) return X
def guide(self, batch, subsample, full_size): num_time_steps, batch_size = batch.shape self.map_estimate("drift") group = self.group(match="state_[0-9]*") cov_diag = pyro.param( "state_cov_diag", lambda: torch.full(group.event_shape, 0.01), constraint=constraints.positive, ) cov_factor = pyro.param( "state_cov_factor", lambda: torch.randn(group.event_shape + (rank,)) * 0.01 ) if not hasattr(self, "nn"): self.nn = torch.nn.Linear( group.event_shape.numel(), group.event_shape.numel() ) self.nn.weight.data.fill_(1.0 / num_time_steps) self.nn.bias.data.fill_(-0.5) pyro.module("state_nn", self.nn) with self.plate("data", full_size, subsample=subsample): loc = self.nn(batch.t()) group.sample( "states", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag) )
def get_posterior(self, *args, **kwargs): """ Returns a LowRankMultivariateNormal posterior distribution. """ scale = self.scale cov_factor = self.cov_factor * scale.unsqueeze(-1) cov_diag = scale * scale return dist.LowRankMultivariateNormal(self.loc, cov_factor, cov_diag)
def get_posterior(self, *args, **kwargs): """ Returns a LowRankMultivariateNormal posterior distribution. """ loc = pyro.param("{}_loc".format(self.prefix), self._init_loc) factor = pyro.param("{}_cov_factor".format(self.prefix), lambda: loc.new_empty(self.latent_dim, self.rank).normal_(0, (0.5 / self.rank) ** 0.5)) diagonal = pyro.param("{}_cov_diag".format(self.prefix), lambda: loc.new_full((self.latent_dim,), 0.5), constraint=constraints.positive) return dist.LowRankMultivariateNormal(loc, factor, diagonal)
def get_posterior(self, *args, **kwargs): """ Returns a LowRankMultivariateNormal posterior distribution. """ loc = pyro.param("{}_loc".format(self.prefix), lambda: torch.zeros(self.latent_dim)) factor = pyro.param("{}_cov_factor".format(self.prefix), lambda: torch.randn(self.latent_dim, self.rank) * (0.5 / self.rank) ** 0.5) diagonal = pyro.param("{}_cov_diag".format(self.prefix), lambda: torch.ones(self.latent_dim) * 0.5, constraint=constraints.positive) return dist.LowRankMultivariateNormal(loc, factor, diagonal)
def guide(self, batch, subsample, full_size): self.map_estimate("drift") group = self.group(match="state_[0-9]*") cov_diag = pyro.param("state_cov_diag", lambda: torch.full(group.event_shape, 0.01), constraint=constraints.positive) cov_factor = pyro.param("state_cov_factor", lambda: torch.randn(group.event_shape + (rank,)) * 0.01) with self.plate("data", full_size, subsample=subsample): loc = pyro.param("state_loc", lambda: torch.full((full_size,) + group.event_shape, 0.5), event_dim=1) group.sample("states", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))
def sample_latent(self, *args, **kwargs): """ Samples the (single) multivariate normal latent used in the auto guide. """ loc = pyro.param("{}_loc".format(self.prefix), lambda: torch.zeros(self.latent_dim)) W_term = pyro.param("{}_W_term".format(self.prefix), lambda: torch.randn(self.rank, self.latent_dim) * (0.5 / self.rank) ** 0.5) D_term = pyro.param("{}_D_term".format(self.prefix), lambda: torch.ones(self.latent_dim) * 0.5, constraint=constraints.positive) return pyro.sample("_{}_latent".format(self.prefix), dist.LowRankMultivariateNormal(loc, W_term, D_term), infer={"is_auxiliary": True})
def model(self): self.set_mode("model") # W = (inv(Luu) @ Kuf).T # Qff = Kfu @ inv(Kuu) @ Kuf = W @ W.T # Fomulas for each approximation method are # DTC: y_cov = Qff + noise, trace_term = 0 # FITC: y_cov = Qff + diag(Kff - Qff) + noise, trace_term = 0 # VFE: y_cov = Qff + noise, trace_term = tr(Kff-Qff) / noise # y_cov = W @ W.T + D # trace_term is added into log_prob N = self.X.size(0) M = self.Xu.size(0) Kuu = self.kernel(self.Xu).contiguous() Kuu.view(-1)[::M + 1] += self.jitter # add jitter to the diagonal Luu = Kuu.cholesky() Kuf = self.kernel(self.Xu, self.X) W = Kuf.trtrs(Luu, upper=False)[0].t() D = self.noise.expand(N) if self.approx == "FITC" or self.approx == "VFE": Kffdiag = self.kernel(self.X, diag=True) Qffdiag = W.pow(2).sum(dim=-1) if self.approx == "FITC": D = D + Kffdiag - Qffdiag else: # approx = "VFE" trace_term = (Kffdiag - Qffdiag).sum() / self.noise trace_term = trace_term.clamp(min=0) zero_loc = self.X.new_zeros(N) f_loc = zero_loc + self.mean_function(self.X) if self.y is None: f_var = D + W.pow(2).sum(dim=-1) return f_loc, f_var else: if self.approx == "VFE": pyro.sample("trace_term", dist.Bernoulli(probs=torch.exp(-trace_term / 2.)), obs=trace_term.new_tensor(1.)) return pyro.sample("y", dist.LowRankMultivariateNormal(f_loc, W, D) .expand_by(self.y.shape[:-1]) .to_event(self.y.dim() - 1), obs=self.y)
def model(self): self.set_mode("model") Xu = self.get_param("Xu") noise = self.get_param("noise") # W = inv(Luu) @ Kuf # Qff = Kfu @ inv(Kuu) @ Kuf = W.T @ W # Fomulas for each approximation method are # DTC: y_cov = Qff + noise, trace_term = 0 # FITC: y_cov = Qff + diag(Kff - Qff) + noise, trace_term = 0 # VFE: y_cov = Qff + noise, trace_term = tr(Kff-Qff) / noise # y_cov = W.T @ W + D # trace_term is added into log_prob M = Xu.shape[0] Kuu = self.kernel(Xu) + torch.eye(M, out=Xu.new_empty(M, M)) * self.jitter Luu = Kuu.potrf(upper=False) Kuf = self.kernel(Xu, self.X) W = matrix_triangular_solve_compat(Kuf, Luu, upper=False) D = noise.expand(W.shape[1]) trace_term = 0 if self.approx == "FITC" or self.approx == "VFE": Kffdiag = self.kernel(self.X, diag=True) Qffdiag = W.pow(2).sum(dim=0) if self.approx == "FITC": D = D + Kffdiag - Qffdiag else: # approx = "VFE" trace_term += (Kffdiag - Qffdiag).sum() / noise zero_loc = self.X.new_zeros(self.X.shape[0]) f_loc = zero_loc + self.mean_function(self.X) if self.y is None: f_var = D + W.pow(2).sum(dim=0) return f_loc, f_var else: y_name = param_with_module_name(self.name, "y") return pyro.sample( y_name, dist.LowRankMultivariateNormal( f_loc, W, D, trace_term).expand_by( self.y.shape[:-1]).independent(self.y.dim() - 1), obs=self.y)
def projectedMixture(X, batch_size, prior_parameters): """ Covariances of all clusters are locked, we're just learning one covariance, mixture weights and means """ N, D = X.shape locloc, locscale, scaleloc, scalescale, component_logits_concentration, cov_factor_loc,cov_factor_scale = prior_parameters[0] C = locloc.shape[0] K = cov_factor_loc.shape[0] component_logits = pyro.sample('component_logits', dist.Dirichlet(component_logits_concentration)) with pyro.plate('D', D): cov_diag = pyro.sample('scale', dist.LogNormal(scaleloc, scalescale)) with pyro.plate('K', K): cov_factor = pyro.sample('cov_factor', dist.Normal(cov_factor_loc,cov_factor_scale)) cov_factor = cov_factor.transpose(0,1) with pyro.plate('C', C): locs = pyro.sample('locs', dist.Normal(locloc,locscale)) with pyro.plate('N', size=N, subsample_size=batch_size) as ind: assignment = pyro.sample('assignment', dist.Categorical(component_logits), infer={"enumerate": "parallel"}) X = pyro.sample('obs', dist.LowRankMultivariateNormal(locs.index_select(-2, assignment), cov_factor, cov_diag), obs=X.index_select(0, ind)) return X
def incrementalPpca(X, batch_size, prior_parameters): N, D = X.shape prior_parameters,_ = prior_parameters K, scaleloc, scalescale, cov_factor_loc, cov_factor_scale = prior_parameters cov_diag = pyro.sample('scale', dist.LogNormal(scaleloc, scalescale)) cov_diag = cov_diag*torch.ones(D) with pyro.plate('D', D): cov_factor = None if K > 1: with pyro.plate('K', K-1): cov_factor = pyro.sample('cov_factor', dist.Normal(cov_factor_loc[:K-1,:],cov_factor_scale[:K-1,:])) cov_factor_new = pyro.sample('cov_factor_new', dist.Normal(cov_factor_loc[-1,:],cov_factor_scale[-1,:])) cov_factor = torch.cat([cov_factor, torch.unsqueeze(cov_factor_new, dim=0)]) else: with pyro.plate('K', K): cov_factor = pyro.sample('cov_factor', dist.Normal(cov_factor_loc,cov_factor_scale)) cov_factor = cov_factor.transpose(0,1) with pyro.plate('N', size=N, subsample_size=batch_size) as ind: X = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(D), cov_factor=cov_factor, cov_diag=cov_diag), obs=X.index_select(0, ind)) return X
def zeroMeanFactor2(X, batch_size, prior_parameters): """ Parameters are K, locloc, locscale, scaleloc, scalescale, cov_factor_loc, cov_factor_scale NEED TO CHECK ALL SHAPES FOR SUPERFLUOUS SINGLETON DIMENSIONS """ N, D = X.shape K, scalelocinit, scalescaleinit, cov_factor_loc_init, cov_factor_scale_init = prior_parameters[0] with pyro.plate('D', D, dim=-1): #cov_diag_loc = pyro.param('scale_loc_prior', scalelocinit, constraint=constraints.positive) cov_diag_loc = pyro.param('scale_loc_prior', scalelocinit) cov_diag_scale = pyro.param('scale_scale_prior', scalescaleinit, constraint=constraints.positive) cov_diag = pyro.sample('scale', dist.LogNormal(cov_diag_loc, cov_diag_scale)) cov_diag = cov_diag*torch.ones(D) # sample variables cov_factor = None if K > 1: with pyro.plate('K', K-1, dim=-2): cov_factor_loc = pyro.param('cov_factor_prior_loc_{}'.format(K), cov_factor_loc_init[:K-1,:]) cov_factor_scale = pyro.param('cov_factor_prior_scale_{}'.format(K), cov_factor_scale_init[:K-1,:], constraint=constraints.positive) cov_factor = pyro.sample('cov_factor', dist.Normal(cov_factor_loc, cov_factor_scale)) cov_factor_new_loc = pyro.param('cov_factor_new_loc_prior_{}'.format(K), cov_factor_loc_init[-1,:]) cov_factor_new_scale = pyro.param('cov_factor_new_scale_prior_{}'.format(K), cov_factor_scale_init[-1,:], constraint=constraints.positive) cov_factor_new = pyro.sample('cov_factor_new', dist.Normal(cov_factor_new_loc,cov_factor_new_scale)) # when using pyro.infer.Predictive, cov_factor_new is somehow sampled as 2-d tensors instead of 1-d #print(cov_factor.shape) #print(cov_factor_new.shape) if cov_factor_new.dim() == cov_factor.dim(): cov_factor = torch.cat([cov_factor, cov_factor_new], dim=-2) #cov_factor = torch.cat([cov_factor, cov_factor_new], dim=1) else: cov_factor = torch.cat([cov_factor, torch.unsqueeze(cov_factor_new, dim=-2)], dim=-2) else: with pyro.plate('K', K): cov_factor_loc = pyro.param('cov_factor_prior_loc_{}'.format(K), cov_factor_loc_init) cov_factor_scale = pyro.param('cov_factor_prior_scale_{}'.format(K), cov_factor_scale_init, constraint=constraints.positive) cov_factor = pyro.sample('cov_factor', dist.Normal(cov_factor_loc,cov_factor_scale)) cov_factor = cov_factor.transpose(-2,-1) with pyro.plate('N', size=N, subsample_size=batch_size, dim=-1) as ind: X = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(D), cov_factor=cov_factor, cov_diag=cov_diag), obs=X.index_select(0, ind)) return X
def zeroMeanFactor(X, batch_size, prior_parameters): """ Parameters are K, locloc, locscale, scaleloc, scalescale, cov_factor_loc, cov_factor_scale """ N, D = X.shape K, scaleloc, scalescale, cov_factor_loc, cov_factor_scale = prior_parameters[0] with pyro.plate('D', D): cov_diag = pyro.sample('scale', dist.LogNormal(scaleloc, scalescale)) cov_factor = None if K > 1: with pyro.plate('K', K-1): cov_factor = pyro.sample('cov_factor', dist.Normal(cov_factor_loc[:K-1,:],cov_factor_scale[:K-1,:])) cov_factor_new = pyro.sample('cov_factor_new', dist.Normal(cov_factor_loc[-1,:],cov_factor_scale[-1,:])) cov_factor = torch.cat([cov_factor, torch.unsqueeze(cov_factor_new, dim=0)]) else: with pyro.plate('K', K): cov_factor = pyro.sample('cov_factor', dist.Normal(cov_factor_loc,cov_factor_scale)) cov_factor = cov_factor.transpose(0,1) loc = torch.zeros(D) with pyro.plate('N', size=N, subsample_size=batch_size) as ind: X = pyro.sample('obs', dist.LowRankMultivariateNormal(loc, cov_factor=cov_factor, cov_diag=cov_diag), obs=X.index_select(0, ind)) return X
def predict(self, x) -> dist.LowRankMultivariateNormal: mean, diag, factors = self(x) return dist.LowRankMultivariateNormal(mean, factors, diag)