Пример #1
0
 def forward(self, input=None, T=False, tensorize_output=False):
     assert (input is None) != self.sample_axis
     #assert (T and input is not None)
     if T:
         assert input is not None
     if input is None:
         output = self.cores[-1].new_ones([1, 1])
     else:
         if T:
             output = torch_utils.reshape_torch(input, self.n)
         else:
             output = input.clone()
     if T:
         for k in range(self.d):
             output = torch.einsum(
                 'ijk,jl->ilk',
                 torch_utils.reshape_torch(output, [self.r[k]*self.n[k], -1]),
                 torch_utils.reshape_torch(self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False)
             )
     else:
         for k in range(self.d-1, -1, -1):
             if (k == self.d-1) and (input is None):
                 continue
             output = torch.einsum(
                 'ijk,lj->ilk',
                 torch_utils.reshape_torch(output, [self.r[k+1], -1]),
                 torch_utils.reshape_torch(self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False)
             )
     if tensorize_output:
         output = torch_utils.reshape_torch(output, self.n)
     else:
         output = torch_utils.flatten_torch(output)
     return output
Пример #2
0
 def inverse_batch(self, input_batch):        
     output = torch_utils.flatten_torch(input_batch).mm(self.recover())
     M = input_batch.new_ones([self.R, self.R])
     for k in range(self.d):
         M *= self.factors[k].t().mm(self.factors[k])
     U, S, V = torch.svd(M)
     output = output.mm(U/S).mm(V.t())
     return output                
Пример #3
0
    def module_ppca_gm_means_sigma_guide(self, input_batch, epsilon):
        batch_size = input_batch.shape[0]
        if self.likelihood == 'normal':
            if self.group_isotropic:
                ppca_gm_sigma_p = pyro.param(f'ppca_gm_sigma_p',
                                             input_batch.new_ones(1, 1),
                                             constraint=constraints.positive)
                ppca_gm_sigma = pyro.sample(
                    f'ppca_gm_sigma',
                    dist.Delta(ppca_gm_sigma_p).independent(1))
            else:
                ppca_gm_sigma = input_batch.new_ones(1, 1)
                ppca_gm_sigma_list = []
                for i in range(self.d):
                    ppca_gm_sigma_p = pyro.param(
                        f'ppca_gm_sigma_{i}_p',
                        input_batch.new_ones(1, self.n[i]),
                        constraint=constraints.positive)
                    ppca_gm_sigma_list.append(
                        pyro.sample(
                            f'ppca_gm_sigma_{i}',
                            dist.Delta(ppca_gm_sigma_p).independent(1)))
                    ppca_gm_sigma = torch_utils.krp_cw_torch(
                        ppca_gm_sigma_list[i], ppca_gm_sigma, column=False)
        else:
            ppca_gm_sigma = input_batch.new_ones(1, 1)
            ppca_gm_sigma_list = [
                input_batch.new_ones(1, self.n[i]) for i in range(self.d)
            ]

        alpha_gm_p = pyro.param(
            f'alpha_gm_p', input_batch.new_ones([1, self.group_hidden_dim]))
        alpha_gm = pyro.sample(f'alpha_gm',
                               dist.Delta(alpha_gm_p).independent(1))

        if self.group_iterm is None:
            z_mu = self.group_term.linear_mapping.inverse_batch(input_batch)
        else:
            z_mu = self.group_iterm(torch_utils.flatten_torch(input_batch),
                                    T=True)
        if self.group_isotropic:
            zk_mean, zk_cov = self.group_term.get_posterior_gaussian_mean_covariance(
                input_batch,
                noise_sigma=ppca_gm_sigma[0]
                if ppca_gm_sigma is not None else input_batch.new_ones(1),
                z_mu=z_mu,
                z_sigma=alpha_gm[0])
        else:
            zk_mean, zk_cov = self.group_term.get_posterior_gaussian_mean_covariance(
                input_batch,
                noise_sigma=[x for x in ppca_gm_sigma_list],
                z_mu=z_mu,
                z_sigma=alpha_gm[0])
        ppca_gm_means = self.group_term(
            zk_mean + epsilon[:, :self.group_hidden_dim].mm(
                zk_cov.view(self.group_hidden_dim, self.group_hidden_dim)))
        return ppca_gm_means, ppca_gm_sigma
Пример #4
0
 def inverse_batch(self, input_batch, tensorize_output=False, fast_svd=False):
     output = input_batch.clone()
     if self.use_core:
         output_batch = self.core.clone()
         for k in range(self.d):
             output_batch = torch_utils.prodTenMat_torch(
                 output_batch,
                 self.factors[k].t().mm(self.factors[k]),
                 k+1,
                 1
             )
             output = torch_utils.prodTenMat_torch(
                 output,
                 self.factors[k],
                 k+1,
                 0
             )
         output = torch_utils.flatten_torch(output, use_batch=True)
         output = output.mm(torch_utils.reshape_torch(self.core, [-1, self.r0], use_batch=False))
         output_batch = torch_utils.reshape_torch(output_batch, [-1, self.r0], use_batch=False).t()
         output_batch = output_batch.mm(
             torch_utils.reshape_torch(self.core, [-1, self.r0], use_batch=False)
         )
         u, s, v = torch.svd(output_batch)
         output = output.mm(u/s).mm(v.t())
         return output
     
     for k in range(self.d):
         if fast_svd:
             u, s, v = torch_utils.fast_svd_torch(self.factors[k])
         else:
             u, s, v = torch.svd(self.factors[k])
             u, s, v = u[:, :self.r[k]], s[:self.r[k]], v[:, :self.r[k]]
         output = torch_utils.prodTenMat_torch(output, u/s, k+1, 0)
         output = torch_utils.prodTenMat_torch(output, v, k+1, 1)
     if tensorize_output and output.dim() == 2:
         output = torch_utils.reshape_torch(output, self.r)
     if not tensorize_output and output.dim() != 2:
         output = torch_utils.flatten_torch(output)
     return output
Пример #5
0
    def module_ppca_gm_means_sigma(self, input_batch, epsilon):
        max_hidden_dim = max(self.hidden_dims)
        batch_size = input_batch.shape[0]
        ppca_means = input_batch.new_zeros(
            [self.K, batch_size, self.output_dim])
        if self.likelihood == 'normal':
            #with pyro.plate('ppca_sigma_plate', self.K):
            if self.group_isotropic:
                ppca_gm_sigma = pyro.sample(
                    f'ppca_gm_sigma',
                    dist.LogNormal(0, input_batch.new_ones(1,
                                                           1)).independent(1))
            else:
                ppca_gm_sigma_list = []
                ppca_gm_sigma = input_batch.new_ones(1, 1)
                for i in range(self.d):
                    ppca_gm_sigma_list.append(
                        pyro.sample(
                            f'ppca_gm_sigma_{i}',
                            dist.LogNormal(0, input_batch.new_ones(
                                1, self.n[i])).independent(1)))
                    ppca_gm_sigma = torch_utils.krp_cw_torch(
                        ppca_gm_sigma_list[i], ppca_gm_sigma, column=False)

        #with pyro.plate('alpha_plate', self.K):
        alpha_gm = pyro.sample(
            f'alpha_gm',
            dist.LogNormal(0, input_batch.new_ones([1, self.group_hidden_dim
                                                    ])).independent(1))

        if self.group_iterm is None:
            ppca_gm_means = self.group_term.multi_project(
                input_batch,
                remove_bias=False,
                tensorize=False  #fast_svd=False
            )
        else:
            ppca_gm_means = self.group_term(
                self.group_iterm(
                    torch_utils.flatten_torch(input_batch),
                    T=True  #, fast_svd=False
                ))

        ppca_gm_means += self.group_term(epsilon[:, :self.group_hidden_dim] *
                                         alpha_gm[:, :self.group_hidden_dim])

        if self.likelihood == 'bernoulli':
            return ppca_gm_means.sigmoid()
        if self.likelihood == 'normal':
            return ppca_gm_means, ppca_gm_sigma
        raise ValueError
Пример #6
0
 def recover(self, weights=None):
     if self.use_core:
         output = self.core.clone()
         for k in range(self.d):
             if weights is None:
                 output = torch_utils.prodTenMat_torch(output, self.factors[k], k, matrix_axis=1)
             else:
                 output = torch_utils.prodTenMat_torch(
                     output, self.factors[k]*weights[k].view(self.n[k], 1), k, matrix_axis=1
                 )
         if not self.sample_axis:
             output = torch_utils.flatten_torch(output, use_batch=False)
         else:
             output = torch_utils.flatten_torch(output.t(), use_batch=True).t()
     else:
         # very inefficient to do so
         output = self.factors[0].new_ones([1, 1])
         for k in range(self.d-1, -1, -1):
             if weights is None:
                 output = torch_utils.kron_torch(output, self.factors[k])
             else:
                 output = torch_utils.kron_torch(output, self.factors[k]*weights[k])
     return output
Пример #7
0
 def recover(self, weights=None):
     nrows = self.N
     if not self.sample_axis:
         nrows = nrows // self.n[0]
     output = self.cores[0].new_ones([1, 1])
     for k in range(self.d):
         if weights is None:
             output = output.mm(torch_utils.reshape_torch(self.cores[k], [self.r[k], -1]))
         else:
             output = output.mm(
                 torch_utils.reshape_torch(self.cores[k]*weights[k].view(1, self.n[k], 1), [self.r[k], -1])
             )
         output = torch_utils.reshape_torch(output, [-1, self.r[k]])
     if not self.sample_axis:
         output = torch_utils.flatten_torch(output, use_batch=False)
     else:
         output = torch_utils.reshape_torch(output, [self.N, self.r[-1]], use_batch=False)
     return output
Пример #8
0
 def recover(self, weights=None, tensorize_output=False):
     nrows = self.N
     if not self.sample_axis:
         nrows = nrows // self.n[0]
     output = self.factors[0].new_zeros(nrows, self.R)
     for r in range(self.R):
         tmp = self.factors[0].new_ones(1)
         for k in range(self.d-1, -1, -1):
             if (k == 0) and (not self.sample_axis):
                 break
             tmp = torch.ger(tmp, self.factors[k][:, r])
             tmp = tmp.view([-1]).contiguous()
             if weights is not None:
                 tmp *= weights[k][r]
         output[:, r] = tmp
     if not self.sample_axis:
         output = self.factors[0].mm(output.t())
         output = torch_utils.flatten_torch(output, use_batch=False)
     return output
Пример #9
0
 def inverse_batch(self, input_batch, tensorize_output=False, fast_svd=False):
     output = torch_utils.reshape_torch(input_batch, self.n)
     for k in range(self.d):
         if fast_svd:
             u, s, v = fast_svd_torch(
                 torch_utils.reshape_torch(
                     self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False
                 )
             )
         else:
             u, s, v = torch.svd(
                 torch_utils.reshape_torch(self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False)
             )
         output = torch.einsum(
             'ijk,jl->ilk',
             torch_utils.reshape_torch(output, [self.r[k]*self.n[k], -1]),
             (u/s).mm(v.t())
         )
     output = torch_utils.flatten_torch(output)
     return output
Пример #10
0
 def forward(self, input=None, T=False, tensorize_output=False):
     assert (input is None) != self.sample_axis
     # batch_size = input.shape[0]
     if self.use_core:
         output = self.core.clone()
     else:
         if T:
             output = torch_utils.reshape_torch(input, self.n)
         else:
             output = torch_utils.reshape_torch(input, self.r)
     offset = int((not self.use_core) and self.sample_axis)
     for k in range(self.d):
         output = torch_utils.prodTenMat_torch(
             output, self.factors[k], k+offset, matrix_axis=int(not T)
         )
     if self.use_core and self.sample_axis:
         output = torch_utils.prodTenMat_torch(output, input, self.d, matrix_axis=1)
         permutation = [self.d] + list(range(self.d))
         output = output.permute(permutation)
     if not tensorize_output:
         return torch_utils.flatten_torch(output, use_batch=self.sample_axis)
     return output
Пример #11
0
 def recover(self, weights=None):
     nrows = self.N
     if not self.sample_axis:
         nrows = nrows // self.n[0]
     output = self.factors[0].new_ones([1, self.R])
     for k in range(self.d-1, -1, -1):
         if (k == 0) and (not self.sample_axis):
             break
         if k < self.P:
             if weights is None:
                 output = torch_utils.krp_cw_torch(output, self.factors[k])
             else:
                 output = torch_utils.krp_cw_torch(output, self.factors[k]*weights[k].view(self.n[k], 1))
         else:
             if weights is None:
                 output = torch_utils.krp_cw_torch(output, self.factors[k])
             else:
                 output = torch_utils.krp_cw_torch(output, self.factors[k]*weights[k].view(self.n[k], 1))
             if k == self.P:
                 output = torch.repeat_interleave(output, torch.tensor(self.L, device=self.factors[0].device), dim=1)
     if not self.sample_axis:
         output = self.factors[0].mm(output.t())
         output = torch_utils.flatten_torch(output, use_batch=False)
     return output
Пример #12
0
    def measure_principal_angles(self, input_batch, mode, fast=True):
        batch_size = input_batch.shape[0]
        mu_k = []
        for k in range(self.K):
            tmp = self.terms[k].get_bias(tensorize_output=False,
                                         use_batch=True)
            if tmp is None:
                tmp = input_batch.new_zeros(1, self.output_dim)
            mu_k.append(tmp)

        mu_k = torch.cat(mu_k)

        if self.group_term is not None:
            mu_g = self.group_term.get_bias(tensorize_output=False,
                                            use_batch=True)
            if mu_g is None:
                mu_g = input_batch.new_zeros(1, self.output_dim)
            projected_batch = self.group_term.multi_project(input_batch,
                                                            remove_bias=False,
                                                            tensorize=False)
            projected_mu = mu_k + mu_g
            projected_mu -= self.group_term.multi_project(
                torch_utils.reshape_torch(projected_mu, self.n),
                remove_bias=False,
                tensorize=False)
        #'''
        input_batch = torch_utils.flatten_torch(input_batch)
        output_angles = input_batch.new_zeros([batch_size, self.K])
        if fast:
            chi = 1.2
        for k in range(self.K):
            if fast:
                Uk, _, _ = torch_utils.fast_svd_torch(
                    #self.terms[k].linear_mapping.factors[mode]
                    self.terms[k].get_sources(self.source_mode),
                    chi=chi)
            else:
                #Uk, _, _ = torch.svd(self.terms[k].linear_mapping.factors[mode])
                Uk, _, _ = torch.svd(self.terms[k].get_sources(
                    self.source_mode))
                Uk = Uk[:, :self.terms[k].linear_mapping.r[mode]]
            Uk = Uk.t()
            '''
            if self.group_term is not None:
                current_batch = input_batch - projected_batch - projected_mu[k:k+1, :]
            else:
                current_batch = input_batch - mu_k[k:k+1, :]
            '''
            current_batch = torch_utils.reshape_torch(input_batch,
                                                      self.n)  # current
            current_batch = torch_utils.swapaxes_torch(current_batch, 1,
                                                       mode + 1)
            current_batch = torch_utils.reshape_torch(current_batch,
                                                      [self.n[mode], -1])
            tmp_r = current_batch.shape[-1]
            for i in range(batch_size):
                if fast:
                    u, _, _ = torch_utils.fast_svd_torch(current_batch[i],
                                                         chi=chi)
                else:
                    u, _, _ = torch.svd(current_batch[i])
                _, s, _ = torch.svd(Uk.mm(u[:, :tmp_r]))
                output_angles[i, k] = s[0]
        return output_angles
Пример #13
0
    def module_ppca_means_sigmas_weights_guide(self,
                                               input_batch,
                                               epsilon,
                                               expert=False,
                                               xi_greedy=None,
                                               highlight_peak=None):
        # zk = zk_mean ##+ eps*alpha_k, eps \sim N(0, I)
        # ppca_mean = Wk zk, ppca_sigma = sigma_k^2
        max_hidden_dim = max(self.hidden_dims)
        batch_size = input_batch.shape[0]

        gamma = input_batch.new_zeros(batch_size, self.K)
        if expert:
            output_angles = self.measure_principal_angles(
                input_batch, mode=self.source_mode)
        else:
            pi_p = pyro.param('pi_p',
                              input_batch.new_ones(self.K) / self.K,
                              constraint=constraints.positive)
            output_angles = pyro.sample(
                'pi',
                dist.Dirichlet(pi_p),
            ).log()
        if highlight_peak is not None:
            output_angles = highlight_peak * output_angles
        output_angles = output_angles.log_softmax(dim=-1)

        if xi_greedy is not None:
            phi = dist.LogNormal(input_batch.new_zeros([batch_size, self.K]),
                                 input_batch.new_ones(
                                     [batch_size,
                                      self.K])).to_event(1).sample()
            output_angles = (1 - xi_greedy) * output_angles + xi_greedy * phi

        #output_angles /= output_angles.sum(dim=-1, keepdim=True)
        #output_angles = output_angles.log()
        ppca_means = input_batch.new_zeros(
            [self.K, batch_size, self.output_dim])
        if self.likelihood == 'normal':
            with pyro.plate('ppca_sigma_plate', self.K):
                if self.terms_isotropic:
                    ppca_sigmas_p = pyro.param(
                        f'ppca_sigmas_p',
                        input_batch.new_ones(self.K, 1),
                        #constraint=constraints.interval(1e-6, 10.)
                        constraint=constraints.positive)
                    ppca_sigmas = pyro.sample(
                        f'ppca_sigmas',
                        dist.Delta(ppca_sigmas_p).independent(1)
                        #dist.LogNormal(0, ppca_sigmas_p)#.independent(1)
                    )
                else:
                    ppca_sigmas = input_batch.new_ones(self.K, 1)
                    ppca_sigmas_list = []
                    for i in range(self.d):
                        ppca_sigmas_p = pyro.param(
                            f'ppca_sigmas_{i}_p',
                            input_batch.new_ones(self.K, self.n[i]),
                            #constraint=constraints.interval(1e-6, 10.)
                            constraint=constraints.positive)
                        ppca_sigmas_list.append(
                            pyro.sample(
                                f'ppca_sigmas_{i}',
                                dist.Delta(ppca_sigmas_p).independent(1)
                                #dist.LogNormal(0, ppca_sigmas_p)#.independent(1)
                            ))
                        ppca_sigmas = torch_utils.krp_cw_torch(
                            ppca_sigmas_list[i], ppca_sigmas, column=False)
                #'''
        else:
            ppca_sigmas = None
            ppca_sigmas_list = [
                input_batch.new_ones(self.K, self.n[i]) for i in range(self.d)
            ]
        with pyro.plate('alpha_plate', self.K):
            alpha_p = pyro.param(f'alpha_p',
                                 input_batch.new_ones([self.K,
                                                       max_hidden_dim]),
                                 constraint=constraints.positive
                                 #constraint=constraints.interval(1e-6, 10.)
                                 )
            alpha = pyro.sample(f'alpha',
                                dist.Delta(alpha_p).independent(1)
                                #dist.LogNormal(0, alpha_p).independent(1)
                                )  #'''
            #alpha = input_batch.new_ones([self.K, max_hidden_dim])

        for k in range(self.K):
            if self.iterms is None:
                z_mu = self.terms[k].linear_mapping.inverse_batch(input_batch)
            else:
                z_mu = self.iterms[k](torch_utils.flatten_torch(input_batch),
                                      T=True)
            if self.terms_isotropic:
                zk_mean, zk_cov = self.terms[
                    k].get_posterior_gaussian_mean_covariance(
                        input_batch,
                        noise_sigma=ppca_sigmas[k]
                        if ppca_sigmas is not None else 1,
                        z_mu=z_mu,
                        z_sigma=alpha[k])
            else:
                zk_mean, zk_cov = self.terms[
                    k].get_posterior_gaussian_mean_covariance(
                        input_batch,
                        noise_sigma=[x[k] for x in ppca_sigmas_list],
                        z_mu=z_mu,
                        z_sigma=alpha[k])
            ppca_means[k, :, :] = self.terms[k](
                zk_mean + epsilon[:, :self.hidden_dims[k]].mm(
                    zk_cov.view(self.hidden_dims[k], self.hidden_dims[k])))
            if self.likelihood == 'bernoulli':
                gamma[:, k] = dist.Bernoulli(
                    ppca_means[k].sigmoid(),
                    validate_args=False).to_event(1).log_prob(
                        torch_utils.flatten_torch(input_batch))
            elif self.likelihood == 'normal':
                gamma[:, k] = dist.Normal(
                    loc=ppca_means[k],
                    scale=ppca_sigmas[k]).to_event(1).log_prob(
                        torch_utils.flatten_torch(input_batch))
            else:
                raise ValueError
        gamma = (output_angles + gamma).softmax(dim=-1)
        #gamma_l = 0.999
        #gamma = gamma_l*gamma+(1.-gamma_l)*np.ones([1, self.K])/self.K
        '''
        pps = pyro.get_param_store()
        Nk = gamma.sum(dim=0)
        tmp1 = input_batch.new_zeros([self.K, max_hidden_dim])
        tmp2 = input_batch.new_zeros(self.K)
        for k in range(self.K):
            tmp1[k] = (
                gamma[:, k:k+1]*(
                    zk_mean + epsilon[:, :self.hidden_dims[k]].mm(
                        zk_cov.view(self.hidden_dims[k], self.hidden_dims[k])
                    )**2.
                )
            ).sum(dim=0) / Nk[k]
            tmp2[k] = (
                (gamma[:, k]*torch.norm(torch_utils.flatten_torch(input_batch) - ppca_means[k], dim=1)**2.).sum()
            ) / Nk[k]
        pname = 'alpha_p'
        pps.replace_param(pname, tmp1, pps[pname])
        pname = 'ppca_sigmas_p'
        pps.replace_param(pname, tmp2, pps[pname])
        '''
        return gamma, ppca_means, ppca_sigmas
Пример #14
0
    def model_with(self,
                   input_batch,
                   labels=None,
                   subsample_size=None,
                   subsample=None,
                   xi_greedy=None,
                   highlight_peak=None,
                   normalize=False,
                   isolate_group=None,
                   orthogonolize_terms=None,
                   expert=False):
        pyro.module('terms', self.terms)
        if self.iterms is not None:
            pyro.module('iterms', self.iterms)
        batch_size = input_batch.shape[0]

        if labels is not None:
            assert len(labels) == batch_size
        if subsample_size is None:
            subsample_size = batch_size
            if subsample is None:
                subsample = torch.arange(subsample_size)
        max_hidden_dim = max(self.hidden_dims)
        if self.group_term is not None:
            max_hidden_dim = max(max_hidden_dim, self.group_hidden_dim)
        with pyro.plate('epsilon_plate', subsample_size):
            epsilon = pyro.sample(
                'epsilon',
                dist.Normal(
                    input_batch.new_zeros(subsample_size, max_hidden_dim),
                    1.).independent(1))
        if self.group_term is not None:
            pyro.module('group_term', self.group_term)
            if self.likelihood == 'bernoulli':
                ppca_gm_means = self.module_ppca_gm_means_sigma(
                    input_batch[subsample],
                    epsilon  #[subsample]
                )
            elif self.likelihood == 'normal':
                ppca_gm_means, ppca_gm_sigma = self.module_ppca_gm_means_sigma(
                    input_batch[subsample],
                    epsilon  #[subsample]
                )
            else:
                raise ValueError

        if self.likelihood == 'bernoulli':
            pi, ppca_means = self.module_ppca_means_sigmas_weights(
                input_batch[subsample],
                epsilon,  #[subsample],
                expert=expert,
                highlight_peak=highlight_peak)
        elif self.likelihood == 'normal':
            pi, ppca_means, ppca_sigmas = self.module_ppca_means_sigmas_weights(
                input_batch[subsample],
                epsilon,  #[subsample],
                expert=expert,
                highlight_peak=highlight_peak)
        else:
            raise ValueError

        #print(ppca_means)
        with pyro.plate(f'samples',
                        batch_size,
                        subsample_size=subsample_size,
                        subsample=subsample,
                        device=input_batch.device) as i:
            #print(i, ppca_means.shape, ppca_sigmas.shape, ppca_gm_means.shape, ppca_gm_sigma.shape)
            assignments = pyro.sample('assignments', dist.Categorical(pi))
            if self.likelihood == 'normal':
                if assignments.dim() == 1:
                    if self.group_term is None:
                        pyro.sample(
                            f'obs',
                            dist.Normal(
                                ppca_means[assignments,
                                           torch.arange(subsample_size), :],
                                ppca_sigmas[assignments]
                                #).independent(1),
                            ).to_event(1),
                            obs=torch_utils.flatten_torch(
                                input_batch[i]
                            )  #*output_angles[:, k].view(-1, 1)
                        )
                    else:
                        pyro.sample(
                            f'obs',
                            dist.Normal(
                                (ppca_means + torch_utils.reshape_torch(
                                    ppca_gm_means, [1, subsample_size, -1],
                                    use_batch=False)
                                 )[assignments,
                                   torch.arange(batch_size), :],
                                (torch_utils.reshape_torch(ppca_sigmas,
                                                           [self.K, -1],
                                                           use_batch=False) +
                                 ppca_gm_sigma[0])[assignments]  #.view(-1, 1)
                            ).independent(1),  #to_event(1),
                            obs=torch_utils.flatten_torch(
                                input_batch[i]
                            )  #*output_angles[:, k].view(-1, 1)
                        )
                else:
                    if self.group_term is None:
                        pyro.sample(
                            f'obs',
                            dist.Normal(ppca_means[assignments, :, :][:, 0],
                                        ppca_sigmas[assignments].view(
                                            self.K, 1, -1)
                                        #).independent(1),
                                        ).to_event(1),
                            obs=torch_utils.flatten_torch(
                                input_batch[i]
                            )  #*output_angles[:, k].view(-1, 1)
                        )
                    else:
                        pyro.sample(
                            f'obs',
                            dist.Normal(
                                (ppca_means + torch_utils.reshape_torch(
                                    ppca_gm_means, [1, subsample_size, -1],
                                    use_batch=False))[assignments, :, :][:, 0],
                                torch_utils.reshape_torch(
                                    (ppca_sigmas.view(self.K, -1) +
                                     ppca_gm_sigma)[assignments],
                                    [self.K, 1, -1],
                                    use_batch=False)).independent(
                                        1),  #to_event(1),
                            obs=torch_utils.flatten_torch(
                                input_batch[i]
                            )  #*output_angles[:, k].view(-1, 1)
                        )
            elif self.likelihood == 'bernoulli':
                if assignments.dim() == 1:
                    if self.group_term is None:
                        pyro.sample(
                            f'obs',
                            dist.Bernoulli(
                                ppca_means[assignments,
                                           torch.arange(subsample_size), :],
                                validate_args=False).to_event(
                                    1),  #to_event(1),
                            obs=torch_utils.flatten_torch(
                                input_batch[i]
                            )  #*output_angles[:, k].view(-1, 1)
                        )
                    else:
                        pyro.sample(
                            f'obs',
                            dist.Bernoulli(
                                (ppca_means + torch_utils.reshape_torch(
                                    ppca_gm_means, [1, subsample_size, -1],
                                    use_batch=False)
                                 )[assignments,
                                   torch.arange(batch_size), :],
                                validate_args=False).to_event(
                                    1),  #to_event(1),
                            obs=torch_utils.flatten_torch(
                                input_batch[i]
                            )  #*output_angles[:, k].view(-1, 1)
                        )
                else:
                    if self.group_term is None:
                        pyro.sample(
                            f'obs',
                            dist.Bernoulli(ppca_means[assignments, :, :][:, 0],
                                           validate_args=False).to_event(
                                               1),  #to_event(1),
                            obs=torch_utils.flatten_torch(
                                input_batch[i]
                            )  #*output_angles[:, k].view(-1, 1)
                        )
                    else:
                        pyro.sample(
                            f'obs',
                            dist.Bernoulli(
                                (ppca_means + torch_utils.reshape_torch(
                                    ppca_gm_means, [1, subsample_size, -1],
                                    use_batch=False))[assignments, :, :][:, 0],
                                validate_args=False).to_event(
                                    1),  #to_event(1),
                            obs=torch_utils.flatten_torch(
                                input_batch[i]
                            )  #*output_angles[:, k].view(-1, 1)
                        )
            else:
                raise ValueError
Пример #15
0
 def inverse_batch(self, input_batch, fast_svd=True):
     U, S, V = self.get_svds(coupled=True, fast=fast_svd)
     output = torch_utils.flatten_torch(input_batch).mm(U/S).mm(V.t())
     return output
Пример #16
0
    def guide_with(self,
                   input_batch,
                   labels=None,
                   subsample_size=None,
                   subsample=None,
                   xi_greedy=None,
                   highlight_peak=None,
                   normalize=False,
                   isolate_group=None,
                   orthogonolize_terms=None,
                   expert=False):
        pyro.module('terms', self.terms)
        if self.iterms is not None:
            pyro.module('iterms', self.iterms)
        batch_size = input_batch.shape[0]
        if labels is not None:
            assert len(labels) == batch_size
        if subsample_size is None:
            subsample_size = batch_size
            if subsample is None:
                subsample = torch.arange(subsample_size)

        max_hidden_dim = max(self.hidden_dims)
        if self.group_term is not None:
            max_hidden_dim = max(max_hidden_dim, self.group_hidden_dim)
        with pyro.plate('epsilon_plate', subsample_size):
            epsilon = pyro.sample(
                'epsilon',
                dist.Normal(
                    input_batch.new_zeros(subsample_size, max_hidden_dim),
                    1.).independent(1))

        if self.group_term is not None:
            pyro.module('group_term', self.group_term)
            if self.training and (isolate_group is not None):
                self.isolate_group_factors(mode=isolate_group)
            ppca_gm_means, ppca_gm_sigma = self.module_ppca_gm_means_sigma_guide(
                input_batch[subsample],
                epsilon  #[subsample]
            )
        if self.training and (orthogonolize_terms is not None):
            self.orthogonolize_k_factors(mode=orthogonolize_terms)
        if self.training and normalize:
            with torch.no_grad():
                self.normalize_parameters()
        if self.group_term is not None:
            gamma, ppca_means, ppca_sigmas = self.module_ppca_means_sigmas_weights_guide(
                torch_utils.flatten_torch(input_batch)[subsample] -
                dist.Normal(ppca_gm_means, ppca_gm_sigma).sample(),
                epsilon,  #[subsample]
                expert=expert,
                xi_greedy=xi_greedy,
                highlight_peak=highlight_peak)
        else:
            gamma, ppca_means, ppca_sigmas = self.module_ppca_means_sigmas_weights_guide(
                input_batch[subsample],
                epsilon,  #[subsample],
                expert=expert,
                xi_greedy=xi_greedy,
                highlight_peak=highlight_peak)
        with pyro.plate(f'samples',
                        batch_size,
                        subsample_size=subsample_size,
                        subsample=subsample,
                        device=input_batch.device):
            assignments = pyro.sample('assignments', dist.Categorical(gamma))
Пример #17
0
 def get_posterior_gaussian_mean_covariance(self, x_batch, noise_sigma=1, z_mu=0., z_sigma=1):
     if self.bias is not None:
         if callable(self.bias):
             output_mean = x_batch - torch_utils.reshape_torch(self.bias(), [1]+self.n, use_batch=False)
         else:
             output_mean = x_batch - torch_utils.reshape_torch(self.bias, [1]+self.n, use_batch=False)
         output_mean = torch_utils.reshape_torch(output_mean, self.linear_mapping.n)
     else:
         output_mean = torch_utils.reshape_torch(x_batch, self.linear_mapping.n)
     #output_mean = torch.mean(output_mean, dim=0, keepdim=True)
     if isinstance(self.linear_mapping, TuckerTensor):
         if not isinstance(noise_sigma, list):
             svds = self.linear_mapping.get_svds()
         else:
             svds = self.linear_mapping.get_svds(weights=[x.sqrt() for x in noise_sigma])
         if svds[-1]:
             U, P, L, Q, _ = svds
             S_cov = Q*L.t()
             for k in range(self.d):
                 output_mean = torch_utils.prodTenMat_torch(output_mean, U[k], k+1, 0)
             output_mean = torch_utils.flatten_torch(output_mean)
             #output_mean = torch.mm(output_mean, P*L.t())
             output_mean = output_mean.mm(L*P.t())
             #output_mean = torch.mm(output_mean, Q.t())
             output_mean = output_mean.mm(Q)
         else:
             U, S, V, shapes_s, _ = svds
             S_cov = x_batch.new_ones([1, 1])
             for k in range(self.d):
                 S_cov = torch_utils.kron_torch(S_cov, V[k]*S[k].t())
                 output_mean = torch_utils.prodTenMat_torch(output_mean, U[k]*S[k].t(), k+1, 0)
                 output_mean = torch_utils.prodTenMat_torch(output_mean, V[k], k+1, 1)
     elif (
         isinstance(self.linear_mapping, CPTensor) or
         isinstance(self.linear_mapping, LROTensor)
     ):
         if not isinstance(noise_sigma, list):
             U, S, V = self.linear_mapping.get_svds(coupled=True)
         else:
             U, S, V = self.linear_mapping.get_svds(weights=[x.sqrt() for x in noise_sigma], coupled=True)
         S_cov = V*S.t()
         output_mean = torch_utils.flatten_torch(output_mean)
         output_mean = output_mean.mm(U*S)
         output_mean = output_mean.mm(V.t())
     elif isinstance(self.linear_mapping, TTTensor):
         #S_cov = x_batch.new_ones([1, 1])
         for k in range(self.d):
             shape = [self.n[k], self.linear_mapping.r[k+1]]
             tmp = self.linear_mapping.cores[k]
             if isinstance(noise_sigma, list):
                 tmp = tmp * noise_sigma[k].sqrt().view(1, -1, 1)
             if k > 0:
                 tmp = torch.einsum('ij,iab,jac->bc', S_cov, tmp, tmp)
             else:
                 tmp = torch.einsum('aib,aic->bc', tmp, tmp)
             #tmp = torch_utils.reshape_torch(tmp, shape, use_batch=False)
             #E, V = torch.eig(tmp, eigenvectors=True)
             #S_cov = (V*E[:, :1].t()).mm(V.t())
             u, s, v = torch.svd(tmp)
             S_cov = (u/s).mm(v.t())
             
             shape = [self.linear_mapping.r[k]*self.n[k], -1]
             tmp = self.linear_mapping.cores[k]
             output_mean = torch_utils.reshape_torch(output_mean, shape)
             output_mean = torch.einsum(
                 'ijk,jl->ilk', output_mean, torch_utils.reshape_torch(tmp, shape, use_batch=False)
             )
     else:
         raise ValueError
     if not isinstance(noise_sigma, list):
         try:
             S_cov = S_cov/np.sqrt(noise_sigma)
         except:
             S_cov = S_cov/noise_sigma.sqrt()
     if not isinstance(self.linear_mapping, TTTensor):
         S_cov = S_cov.mm(S_cov.t())
     n = S_cov.shape[0]
     mask = torch.eye(n, n, device=x_batch.device).byte()
     S_cov[mask] += 1./z_sigma
     u, s, v = torch.svd(S_cov)
     S_cov = (u/s).mm(v.t())
     #E, V = torch.eig(S_cov, eigenvectors=True)
     #S_cov = (V/E[:, :1].t()).mm(V.t())
     output_mean = torch_utils.flatten_torch(output_mean)
     output_mean = output_mean + z_mu / z_sigma ###
     output_mean = output_mean.mm(S_cov)
     S_cov = S_cov.unsqueeze(0)
     return output_mean, S_cov
Пример #18
0
 def multi_project(self, input_batch, remove_bias=True, tensorize=False):
     if remove_bias and (self.bias is not None):
         if callable(self.bias):
             output_batch = input_batch - torch_utils.reshape_torch(self.bias(), [1]+self.n, use_batch=False)
         else:
             output_batch = input_batch - torch_utils.reshape_torch(self.bias, [1]+self.n, use_batch=False)
         output_batch = torch_utils.reshape_torch(output_batch, self.linear_mapping.n)
     else:
         output_batch = torch_utils.reshape_torch(input_batch, self.linear_mapping.n)
         
     if isinstance(self.linear_mapping, TuckerTensor):
         svds = self.linear_mapping.get_svds()
         if svds[-1]:
             U, P, _, _, _ = svds
             for k in range(self.d):
                 output_batch = torch_utils.prodTenMat_torch(output_batch, U[k], k+1, 0)
             output_batch = torch_utils.flatten_torch(output_batch)
             output_batch = output_batch.mm(P.t())
             output_batch = output_batch.mm(P)
             output_batch = torch_utils.reshape_torch(output_batch, self.linear_mapping.r)
             for k in range(self.d):
                 output_batch = torch_utils.prodTenMat_torch(output_batch, U[k], k+1, 1)
         else:
             U, _, _, _, _ = svds
             for k in range(self.d):
                 output_batch = torch_utils.prodTenMat_torch(output_batch, U[k], k+1, 0)
                 output_batch = torch_utils.prodTenMat_torch(output_batch, U[k], k+1, 1)
     elif (
         isinstance(self.linear_mapping, CPTensor) or
         isinstance(self.linear_mapping, LROTensor)
     ):
         U, _, _ = self.linear_mapping.get_svds(coupled=True)
         output_batch = torch_utils.flatten_torch(output_batch)
         output_batch = output_batch.mm(U)
         output_batch = output_batch.mm(U.t())
     elif isinstance(self.linear_mapping, TTTensor):
         orth_list = []
         output_batch = input_batch.clone()
         for k in range(self.d):
             if k > 0:
                 tmp = torch_utils.prodTenMat_torch(
                     self.linear_mapping.cores[k],
                     tmp,
                     0,
                     1
                 )
             else:
                 tmp = self.linear_mapping.cores[k]
             tmp = torch_utils.reshape_torch(tmp, [-1, self.linear_mapping.r[k+1]], use_batch=False)
             tmp = torch_utils.reshape_torch(
                 self.linear_mapping.cores[k], [-1, self.linear_mapping.r[k+1]], use_batch=False
             )
             u, s, v = torch.svd(tmp)
             orth_list.append(u[:, :self.linear_mapping.r[k+1]])
             tmp = s*(v[:, :self.linear_mapping.r[k+1]].t())
             output_batch = torch_utils.reshape_torch(
                 output_batch,
                 [self.linear_mapping.r[k]*self.linear_mapping.n[k], -1],
                 use_batch=True
             )
             output_batch = torch_utils.prodTenMat_torch(output_batch, u, 1, 0)
         u, s, v = torch.svd(tmp)
         output_batch = output_batch.squeeze(2).mm(u).mm(u.t()) ##### ??? 
         for k in range(self.d-1, -1, -1):
             if k == self.d-1:
                 output_batch = output_batch.mm(orth_list[k].t())
             else:
                 output_batch = torch_utils.prodTenMat_torch(output_batch, orth_list[k], self.d-k, 1)
             output_batch = torch_utils.reshape_torch(
                 output_batch, self.n[k:]+[self.linear_mapping.r[k]], use_batch=True
             )
     else:
         raise ValueError
     if (tensorize) and (output_batch.dim() == 2):
         return torch_utils.reshape_torch(output_batch, self.linear_mapping.n)
     if (not tensorize) and (output_batch.dim() > 2):
         return torch_utils.flatten_torch(output_batch)
     return output_batch
Пример #19
0
    def module_ppca_means_sigmas_weights(self,
                                         input_batch,
                                         epsilon,
                                         expert=False,
                                         highlight_peak=None):
        # zk = zk_mean ##+ eps*alpha_k, eps \sim N(0, I)
        # ppca_mean = Wk zk, ppca_sigma = sigma_k^2
        max_hidden_dim = max(self.hidden_dims)
        batch_size = input_batch.shape[0]
        if expert:
            pi = self.measure_principal_angles(input_batch,
                                               mode=self.source_mode)
            if highlight_peak is not None:
                pi = highlight_peak * pi
            pi = pi.softmax(dim=-1)
        else:
            pi = pyro.sample(
                'pi', dist.Dirichlet(input_batch.new_ones(self.K) / self.K))
            if highlight_peak is not None:
                pi = highlight_peak * pi.log()  #dim=-1)
                pi = pi.softmax(dim=-1)

        #zk_mean = input_batch.new_zeros([batch_size, self.K, max_hidden_dim])
        ppca_means = input_batch.new_zeros(
            [self.K, batch_size, self.output_dim])
        if self.likelihood == 'normal':
            with pyro.plate('ppca_sigma_plate', self.K):
                '''
                ppca_sigmas_p = pyro.param(
                    f'ppca_sigmas_p',
                    input_batch.new_ones(self.K),
                    #constraint=constraints.interval(1e-3, 2.)
                    constraint=constraints.positive
                )'''
                if self.terms_isotropic:
                    ppca_sigmas = pyro.sample(
                        f'ppca_sigmas',
                        #dist.Delta(ppca_sigmas_p)#.independent(1)
                        #dist.LogNormal(0, ppca_sigmas_p)#.independent(1)
                        dist.LogNormal(0,
                                       input_batch.new_ones(self.K,
                                                            1)).independent(1)
                        #dist.Delta(input_batch.new_ones(self.K, 1)).independent(1)
                    )
                else:
                    ppca_sigmas_list = []
                    ppca_sigmas = input_batch.new_ones(self.K, 1)
                    for i in range(self.d):
                        ppca_sigmas_list.append(
                            pyro.sample(
                                f'ppca_sigmas_{i}',
                                #dist.Delta(ppca_sigmas_p)#.independent(1)
                                #dist.LogNormal(0, ppca_sigmas_p)#.independent(1)
                                dist.LogNormal(
                                    0, input_batch.new_ones(
                                        self.K, self.n[i])).independent(1)
                                #dist.Delta(input_batch.new_ones(self.K, self.n[i])).independent(1)
                            ))
                        ppca_sigmas = torch_utils.krp_cw_torch(
                            ppca_sigmas_list[i], ppca_sigmas, column=False)
        '''
        alpha_p = pyro.param(
            f'alpha_p',
            input_batch.new_ones([self.K, max_hidden_dim]),
            constraint=constraints.positive
        )'''
        with pyro.plate('alpha_plate', self.K):
            alpha = pyro.sample(
                f'alpha',
                #dist.LogNormal(0, alpha_p).independent(1)
                #dist.Delta(alpha_p).independent(1)
                dist.LogNormal(0,
                               input_batch.new_ones([self.K, max_hidden_dim
                                                     ])).independent(1)
                #dist.Delta(input_batch.new_ones([self.K, max_hidden_dim])).independent(1)
            )  #'''
        #alpha = input_batch.new_ones([self.K, max_hidden_dim])
        for k in range(self.K):
            #zk_mean[:, k, :self.hidden_dims[k]] = self.terms[k].linear_mapping.inverse_batch(input_batch, fast=True)
            #zk_mean[:, k, :self.hidden_dims[k]] += epsilon[:, :self.hidden_dims[k]]*alpha[k:k+1, :self.hidden_dims[k]]
            if self.iterms is None:
                #zk_mean = self.terms[k].linear_mapping.inverse_batch(input_batch)
                ppca_means[k, :, :] += self.terms[k].multi_project(
                    input_batch, remove_bias=False, tensorize=False)
            else:
                #zk_mean = self.iterms[k](torch_utils.flatten_torch(input_batch), T=True)
                ppca_means[k, :, :] += self.terms[k](self.iterms[k](
                    torch_utils.flatten_torch(input_batch), T=True))

            ppca_means[k, :, :] += self.terms[k](
                epsilon[:, :self.hidden_dims[k]] *
                alpha[k:k + 1, :self.hidden_dims[k]])
        if self.likelihood == 'bernoulli':
            return pi, ppca_means.sigmoid()
        if self.likelihood == 'normal':
            return pi, ppca_means, ppca_sigmas
        raise ValueError