Exemple #1
0
    def model(self, X):
        _id = self._id
        K = self.K
        N, D = X.shape
        loc_locinit = self.param_init[f'loc_prior_loc_init_{_id}']
        loc_scaleinit = self.param_init[f'loc_prior_scale_init_{_id}']
        cov_diag_loc_init = self.param_init[f'cov_diag_prior_loc_init_{_id}']
        cov_diag_scale_init = self.param_init[f'cov_diag_prior_scale_init_{_id}']
        cov_factor_loc_init = self.param_init[f'cov_factor_prior_loc_init_{_id}']
        cov_factor_scale_init = self.param_init[f'cov_factor_prior_scale_init_{_id}']
        with pyro.plate(f'D_{_id}', D):
            loc_loc = pyro.param(f'loc_loc_prior_{_id}', loc_locinit)
            loc_scale = pyro.param(f'loc_scale_prior_{_id}', loc_scaleinit, constraint=constraints.positive)
            loc = pyro.sample(f'loc_{_id}', dist.LogNormal(loc_loc, loc_scale))

            cov_diag_loc = pyro.param(f'cov_diag_prior_loc_{_id}', cov_diag_loc_init)
            cov_diag_scale = pyro.param(f'cov_diag_prior_scale_{_id}', cov_diag_scale_init, constraint=constraints.positive)
            cov_diag = pyro.sample(f'cov_diag_{_id}', dist.LogNormal(cov_diag_loc, cov_diag_scale))
            cov_diag = cov_diag + jitter
            cov_factor = None
            with pyro.plate(f'K_{_id}', K):
                cov_factor_loc = pyro.param(f'cov_factor_prior_loc_{_id}', cov_factor_loc_init)
                cov_factor_scale = pyro.param(f'cov_factor_prior_scale_{_id}', cov_factor_scale_init, constraint=constraints.positive)
                cov_factor = pyro.sample(f'cov_factor_{_id}', dist.Normal(cov_factor_loc,cov_factor_scale))
            cov_factor = cov_factor.transpose(-2,-1)
        with pyro.plate(f'N_{_id}', size=N, subsample_size=self.batch_size) as ind:
            X = pyro.sample('obs', dist.LowRankMultivariateNormal(loc, cov_factor=cov_factor, cov_diag=cov_diag), obs=X.index_select(0, ind))
        return X
Exemple #2
0
    def model(self, X):
        _id = self._id
        N, D = X.shape
        global_shrinkage_prior_scale_init = self.param_init[f'global_shrinkage_prior_scale_init_{_id}']
        cov_diag_prior_loc_init = self.param_init[f'cov_diag_prior_loc_init_{_id}']
        cov_diag_prior_scale_init = self.param_init[f'cov_diag_prior_scale_init_{_id}']


        global_shrinkage_prior_scale = pyro.param(f'global_shrinkage_scale_prior_{_id}', global_shrinkage_prior_scale_init, constraint=constraints.positive)
        tau = pyro.sample(f'global_shrinkage_{_id}', dist.HalfNormal(global_shrinkage_prior_scale))
        
        b = pyro.sample('b', dist.InverseGamma(0.5,1./torch.ones(D)**2).to_event(1))
        lambdasquared = pyro.sample(f'local_shrinkage_{_id}', dist.InverseGamma(0.5,1./b).to_event(1))
        
        cov_diag_loc = pyro.param(f'cov_diag_prior_loc_{_id}', cov_diag_prior_loc_init)
        cov_diag_scale = pyro.param(f'cov_diag_prior_scale_{_id}', cov_diag_prior_scale_init, constraint=constraints.positive)
        cov_diag = pyro.sample(f'cov_diag_{_id}', dist.LogNormal(cov_diag_loc, cov_diag_scale).to_event(1))
        #cov_diag = cov_diag*torch.ones(D)
        cov_diag = cov_diag + jitter
        
        lambdasquared = lambdasquared.squeeze()
        if lambdasquared.dim() == 1:
            # outer product
            cov_factor_scale = torch.ger(torch.sqrt(lambdasquared),tau.repeat((tau.dim()-1)*(1,)+(D,)))
        else:
            # batch outer product
            cov_factor_scale = torch.einsum('bp, br->bpr', torch.sqrt(lambdasquared),tau.repeat((tau.dim()-1)*(1,)+(D,)))
        cov_factor = pyro.sample(f'cov_factor_{_id}', dist.Normal(0., cov_factor_scale).to_event(2))
        cov_factor = cov_factor.transpose(-2,-1)
        with pyro.plate(f'N_{_id}', size=N, subsample_size=self.batch_size, dim=-1) as ind:
            X = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(D), cov_factor=cov_factor, cov_diag=cov_diag), obs=X.index_select(0, ind))
        return X
Exemple #3
0
    def guide(self, batch, subsample, full_size):
        num_time_steps, batch_size = batch.shape
        self.map_estimate("drift")

        group = self.group(match="state_[0-9]*")
        cov_diag = pyro.param(
            "state_cov_diag",
            lambda: torch.full(group.event_shape, 0.01),
            constraint=constraints.positive,
        )
        cov_factor = pyro.param(
            "state_cov_factor", lambda: torch.randn(group.event_shape + (rank,)) * 0.01
        )

        if not hasattr(self, "nn"):
            self.nn = torch.nn.Linear(
                group.event_shape.numel(), group.event_shape.numel()
            )
            self.nn.weight.data.fill_(1.0 / num_time_steps)
            self.nn.bias.data.fill_(-0.5)
        pyro.module("state_nn", self.nn)
        with self.plate("data", full_size, subsample=subsample):
            loc = self.nn(batch.t())
            group.sample(
                "states", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag)
            )
Exemple #4
0
 def get_posterior(self, *args, **kwargs):
     """
     Returns a LowRankMultivariateNormal posterior distribution.
     """
     scale = self.scale
     cov_factor = self.cov_factor * scale.unsqueeze(-1)
     cov_diag = scale * scale
     return dist.LowRankMultivariateNormal(self.loc, cov_factor, cov_diag)
Exemple #5
0
 def get_posterior(self, *args, **kwargs):
     """
     Returns a LowRankMultivariateNormal posterior distribution.
     """
     loc = pyro.param("{}_loc".format(self.prefix), self._init_loc)
     factor = pyro.param("{}_cov_factor".format(self.prefix),
                         lambda: loc.new_empty(self.latent_dim, self.rank).normal_(0, (0.5 / self.rank) ** 0.5))
     diagonal = pyro.param("{}_cov_diag".format(self.prefix),
                           lambda: loc.new_full((self.latent_dim,), 0.5),
                           constraint=constraints.positive)
     return dist.LowRankMultivariateNormal(loc, factor, diagonal)
Exemple #6
0
 def get_posterior(self, *args, **kwargs):
     """
     Returns a LowRankMultivariateNormal posterior distribution.
     """
     loc = pyro.param("{}_loc".format(self.prefix),
                      lambda: torch.zeros(self.latent_dim))
     factor = pyro.param("{}_cov_factor".format(self.prefix),
                         lambda: torch.randn(self.latent_dim, self.rank) * (0.5 / self.rank) ** 0.5)
     diagonal = pyro.param("{}_cov_diag".format(self.prefix),
                           lambda: torch.ones(self.latent_dim) * 0.5,
                           constraint=constraints.positive)
     return dist.LowRankMultivariateNormal(loc, factor, diagonal)
Exemple #7
0
 def guide(self, batch, subsample, full_size):
     self.map_estimate("drift")
     group = self.group(match="state_[0-9]*")
     cov_diag = pyro.param("state_cov_diag",
                           lambda: torch.full(group.event_shape, 0.01),
                           constraint=constraints.positive)
     cov_factor = pyro.param("state_cov_factor",
                             lambda: torch.randn(group.event_shape + (rank,)) * 0.01)
     with self.plate("data", full_size, subsample=subsample):
         loc = pyro.param("state_loc",
                          lambda: torch.full((full_size,) + group.event_shape, 0.5),
                          event_dim=1)
         group.sample("states", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))
Exemple #8
0
 def sample_latent(self, *args, **kwargs):
     """
     Samples the (single) multivariate normal latent used in the auto guide.
     """
     loc = pyro.param("{}_loc".format(self.prefix),
                      lambda: torch.zeros(self.latent_dim))
     W_term = pyro.param("{}_W_term".format(self.prefix),
                         lambda: torch.randn(self.rank, self.latent_dim) * (0.5 / self.rank) ** 0.5)
     D_term = pyro.param("{}_D_term".format(self.prefix),
                         lambda: torch.ones(self.latent_dim) * 0.5,
                         constraint=constraints.positive)
     return pyro.sample("_{}_latent".format(self.prefix),
                        dist.LowRankMultivariateNormal(loc, W_term, D_term),
                        infer={"is_auxiliary": True})
Exemple #9
0
    def model(self):
        self.set_mode("model")

        # W = (inv(Luu) @ Kuf).T
        # Qff = Kfu @ inv(Kuu) @ Kuf = W @ W.T
        # Fomulas for each approximation method are
        # DTC:  y_cov = Qff + noise,                   trace_term = 0
        # FITC: y_cov = Qff + diag(Kff - Qff) + noise, trace_term = 0
        # VFE:  y_cov = Qff + noise,                   trace_term = tr(Kff-Qff) / noise
        # y_cov = W @ W.T + D
        # trace_term is added into log_prob

        N = self.X.size(0)
        M = self.Xu.size(0)
        Kuu = self.kernel(self.Xu).contiguous()
        Kuu.view(-1)[::M + 1] += self.jitter  # add jitter to the diagonal
        Luu = Kuu.cholesky()
        Kuf = self.kernel(self.Xu, self.X)
        W = Kuf.trtrs(Luu, upper=False)[0].t()

        D = self.noise.expand(N)
        if self.approx == "FITC" or self.approx == "VFE":
            Kffdiag = self.kernel(self.X, diag=True)
            Qffdiag = W.pow(2).sum(dim=-1)
            if self.approx == "FITC":
                D = D + Kffdiag - Qffdiag
            else:  # approx = "VFE"
                trace_term = (Kffdiag - Qffdiag).sum() / self.noise
                trace_term = trace_term.clamp(min=0)

        zero_loc = self.X.new_zeros(N)
        f_loc = zero_loc + self.mean_function(self.X)
        if self.y is None:
            f_var = D + W.pow(2).sum(dim=-1)
            return f_loc, f_var
        else:
            if self.approx == "VFE":
                pyro.sample("trace_term", dist.Bernoulli(probs=torch.exp(-trace_term / 2.)),
                            obs=trace_term.new_tensor(1.))

            return pyro.sample("y",
                               dist.LowRankMultivariateNormal(f_loc, W, D)
                                   .expand_by(self.y.shape[:-1])
                                   .to_event(self.y.dim() - 1),
                               obs=self.y)
Exemple #10
0
    def model(self):
        self.set_mode("model")

        Xu = self.get_param("Xu")
        noise = self.get_param("noise")

        # W = inv(Luu) @ Kuf
        # Qff = Kfu @ inv(Kuu) @ Kuf = W.T @ W
        # Fomulas for each approximation method are
        # DTC:  y_cov = Qff + noise,                   trace_term = 0
        # FITC: y_cov = Qff + diag(Kff - Qff) + noise, trace_term = 0
        # VFE:  y_cov = Qff + noise,                   trace_term = tr(Kff-Qff) / noise
        # y_cov = W.T @ W + D
        # trace_term is added into log_prob

        M = Xu.shape[0]
        Kuu = self.kernel(Xu) + torch.eye(M, out=Xu.new_empty(M,
                                                              M)) * self.jitter
        Luu = Kuu.potrf(upper=False)
        Kuf = self.kernel(Xu, self.X)
        W = matrix_triangular_solve_compat(Kuf, Luu, upper=False)

        D = noise.expand(W.shape[1])
        trace_term = 0
        if self.approx == "FITC" or self.approx == "VFE":
            Kffdiag = self.kernel(self.X, diag=True)
            Qffdiag = W.pow(2).sum(dim=0)
            if self.approx == "FITC":
                D = D + Kffdiag - Qffdiag
            else:  # approx = "VFE"
                trace_term += (Kffdiag - Qffdiag).sum() / noise

        zero_loc = self.X.new_zeros(self.X.shape[0])
        f_loc = zero_loc + self.mean_function(self.X)
        if self.y is None:
            f_var = D + W.pow(2).sum(dim=0)
            return f_loc, f_var
        else:
            y_name = param_with_module_name(self.name, "y")
            return pyro.sample(
                y_name,
                dist.LowRankMultivariateNormal(
                    f_loc, W, D, trace_term).expand_by(
                        self.y.shape[:-1]).independent(self.y.dim() - 1),
                obs=self.y)
Exemple #11
0
def projectedMixture(X, batch_size, prior_parameters):
    """
    Covariances of all clusters are locked, we're just learning one covariance, mixture weights and means
    """
    N, D = X.shape
    locloc, locscale, scaleloc, scalescale, component_logits_concentration, cov_factor_loc,cov_factor_scale = prior_parameters[0]
    C = locloc.shape[0]
    K = cov_factor_loc.shape[0]
    component_logits = pyro.sample('component_logits', dist.Dirichlet(component_logits_concentration))
    with pyro.plate('D', D):
        cov_diag = pyro.sample('scale', dist.LogNormal(scaleloc, scalescale))
        with pyro.plate('K', K):
            cov_factor = pyro.sample('cov_factor', dist.Normal(cov_factor_loc,cov_factor_scale))
        cov_factor = cov_factor.transpose(0,1)
        with pyro.plate('C', C):
            locs = pyro.sample('locs', dist.Normal(locloc,locscale))
    with pyro.plate('N', size=N, subsample_size=batch_size) as ind:
        assignment = pyro.sample('assignment', dist.Categorical(component_logits), infer={"enumerate": "parallel"})
        X = pyro.sample('obs', dist.LowRankMultivariateNormal(locs.index_select(-2, assignment), cov_factor, cov_diag), obs=X.index_select(0, ind))
    return X
Exemple #12
0
def incrementalPpca(X, batch_size, prior_parameters):
    N, D = X.shape
    prior_parameters,_ = prior_parameters
    K, scaleloc, scalescale, cov_factor_loc, cov_factor_scale = prior_parameters
    cov_diag = pyro.sample('scale', dist.LogNormal(scaleloc, scalescale))
    cov_diag = cov_diag*torch.ones(D)
    with pyro.plate('D', D):
        cov_factor = None
        if K > 1:
            with pyro.plate('K', K-1):
                cov_factor = pyro.sample('cov_factor', dist.Normal(cov_factor_loc[:K-1,:],cov_factor_scale[:K-1,:]))
            cov_factor_new = pyro.sample('cov_factor_new', dist.Normal(cov_factor_loc[-1,:],cov_factor_scale[-1,:]))
            cov_factor = torch.cat([cov_factor, torch.unsqueeze(cov_factor_new, dim=0)])
        else:
            with pyro.plate('K', K):
                cov_factor = pyro.sample('cov_factor', dist.Normal(cov_factor_loc,cov_factor_scale))
        cov_factor = cov_factor.transpose(0,1)
    with pyro.plate('N', size=N, subsample_size=batch_size) as ind:
        X = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(D), cov_factor=cov_factor, cov_diag=cov_diag), obs=X.index_select(0, ind))
    return X
Exemple #13
0
def zeroMeanFactor2(X, batch_size, prior_parameters):
    """
    Parameters are K, locloc, locscale, scaleloc, scalescale, cov_factor_loc, cov_factor_scale

    NEED TO CHECK ALL SHAPES FOR SUPERFLUOUS SINGLETON DIMENSIONS
    """
    N, D = X.shape
    K, scalelocinit, scalescaleinit, cov_factor_loc_init, cov_factor_scale_init = prior_parameters[0]
    with pyro.plate('D', D, dim=-1):
        #cov_diag_loc = pyro.param('scale_loc_prior', scalelocinit, constraint=constraints.positive)
        cov_diag_loc = pyro.param('scale_loc_prior', scalelocinit)
        cov_diag_scale = pyro.param('scale_scale_prior', scalescaleinit, constraint=constraints.positive)
        cov_diag = pyro.sample('scale', dist.LogNormal(cov_diag_loc, cov_diag_scale))
        cov_diag = cov_diag*torch.ones(D)
        # sample variables
        cov_factor = None
        if K > 1:
            with pyro.plate('K', K-1, dim=-2):
                cov_factor_loc = pyro.param('cov_factor_prior_loc_{}'.format(K), cov_factor_loc_init[:K-1,:])
                cov_factor_scale = pyro.param('cov_factor_prior_scale_{}'.format(K), cov_factor_scale_init[:K-1,:], constraint=constraints.positive)
                cov_factor = pyro.sample('cov_factor', dist.Normal(cov_factor_loc, cov_factor_scale))
            cov_factor_new_loc = pyro.param('cov_factor_new_loc_prior_{}'.format(K), cov_factor_loc_init[-1,:])
            cov_factor_new_scale = pyro.param('cov_factor_new_scale_prior_{}'.format(K), cov_factor_scale_init[-1,:], constraint=constraints.positive)
            cov_factor_new = pyro.sample('cov_factor_new', dist.Normal(cov_factor_new_loc,cov_factor_new_scale))
            # when using pyro.infer.Predictive, cov_factor_new is somehow sampled as 2-d tensors instead of 1-d
            #print(cov_factor.shape)
            #print(cov_factor_new.shape)
            if cov_factor_new.dim() == cov_factor.dim():
                cov_factor = torch.cat([cov_factor, cov_factor_new], dim=-2)
                #cov_factor = torch.cat([cov_factor, cov_factor_new], dim=1)
            else:
                cov_factor = torch.cat([cov_factor, torch.unsqueeze(cov_factor_new, dim=-2)], dim=-2)
        else:
            with pyro.plate('K', K):
                cov_factor_loc = pyro.param('cov_factor_prior_loc_{}'.format(K), cov_factor_loc_init)
                cov_factor_scale = pyro.param('cov_factor_prior_scale_{}'.format(K), cov_factor_scale_init, constraint=constraints.positive)
                cov_factor = pyro.sample('cov_factor', dist.Normal(cov_factor_loc,cov_factor_scale))
        cov_factor = cov_factor.transpose(-2,-1)
    with pyro.plate('N', size=N, subsample_size=batch_size, dim=-1) as ind:
        X = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(D), cov_factor=cov_factor, cov_diag=cov_diag), obs=X.index_select(0, ind))
    return X
Exemple #14
0
def zeroMeanFactor(X, batch_size, prior_parameters):
    """
    Parameters are K, locloc, locscale, scaleloc, scalescale, cov_factor_loc, cov_factor_scale
    """
    N, D = X.shape
    K, scaleloc, scalescale, cov_factor_loc, cov_factor_scale = prior_parameters[0]
    with pyro.plate('D', D):
        cov_diag = pyro.sample('scale', dist.LogNormal(scaleloc, scalescale))
        cov_factor = None
        if K > 1:
            with pyro.plate('K', K-1):
                cov_factor = pyro.sample('cov_factor', dist.Normal(cov_factor_loc[:K-1,:],cov_factor_scale[:K-1,:]))
            cov_factor_new = pyro.sample('cov_factor_new', dist.Normal(cov_factor_loc[-1,:],cov_factor_scale[-1,:]))
            cov_factor = torch.cat([cov_factor, torch.unsqueeze(cov_factor_new, dim=0)])
        else:
            with pyro.plate('K', K):
                cov_factor = pyro.sample('cov_factor', dist.Normal(cov_factor_loc,cov_factor_scale))
        cov_factor = cov_factor.transpose(0,1)
        loc = torch.zeros(D)
    with pyro.plate('N', size=N, subsample_size=batch_size) as ind:
        X = pyro.sample('obs', dist.LowRankMultivariateNormal(loc, cov_factor=cov_factor, cov_diag=cov_diag), obs=X.index_select(0, ind))
    return X
 def predict(self, x) -> dist.LowRankMultivariateNormal:
     mean, diag, factors = self(x)
     return dist.LowRankMultivariateNormal(mean, factors, diag)