コード例 #1
0
ファイル: tensorial.py プロジェクト: kharyuk/vbtd
 def get_svds(self, weights=None, fast=True):
     U, S, V = [], [], []
     for k in range(self.d):
         if weights is None:
             if fast:
                 u, s, v = fast_svd_torch(
                     torch_utils.reshape_torch(
                         self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False
                     )*weights[k]
                 )
             else:
                 u, s, v = torch.svd(
                     torch_utils.reshape_torch(self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False)
                 )
         else:
             if fast:
                 u, s, v = fast_svd_torch(
                     torch_utils.reshape_torch(self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False)
                 )
             else:
                 u, s, v = torch.svd(
                     torch_utils.reshape_torch(self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False)
                 )
         u, s, v = u[:, :self.r[k+1]], s[:self.r[k+1]], v[:, :self.r[k+1]]
         U.append(u)
         S.append(s)
         V.append(v)
     return U, S, V
コード例 #2
0
ファイル: tensorial.py プロジェクト: kharyuk/vbtd
 def orthogonolize(self, last_core=True):
     for k in range(self.d):
         tmp = torch_utils.reshape_torch(self.cores[k].data, [self.r[k]*self.n[k], -1], use_batch=False)
         if k > 0:
             tmp = r.mm(tmp)
         if (k == self.d-1) and (not last_core):
             self.cores[k].data = torch_utils.reshape_torch(tmp, [self.r[k], self.n[k], -1], use_batch=False)
             continue
         q, r = torch.qr(tmp)
         self.cores[k].data = torch_utils.reshape_torch(q, [self.r[k], self.n[k], -1], use_batch=False)
コード例 #3
0
ファイル: tensorial.py プロジェクト: kharyuk/vbtd
 def forward(self, input=None, T=False, tensorize_output=False):
     assert (input is None) != self.sample_axis
     output = self.recover()
     if T:
         output = output.t()
     if self.sample_axis:
         # batch_size = input.shape[0]
         output = torch.einsum('ij,kj->ik', input, output)
     if tensorize_output:
         if T:
             return torch_utils.reshape_torch(output, self.r, use_batch=self.sample_axis)
         else:
             return torch_utils.reshape_torch(output, self.n, use_batch=self.sample_axis)
     return output
コード例 #4
0
ファイル: tensorial.py プロジェクト: kharyuk/vbtd
 def recover(self, weights=None):
     nrows = self.N
     if not self.sample_axis:
         nrows = nrows // self.n[0]
     output = self.cores[0].new_ones([1, 1])
     for k in range(self.d):
         if weights is None:
             output = output.mm(torch_utils.reshape_torch(self.cores[k], [self.r[k], -1]))
         else:
             output = output.mm(
                 torch_utils.reshape_torch(self.cores[k]*weights[k].view(1, self.n[k], 1), [self.r[k], -1])
             )
         output = torch_utils.reshape_torch(output, [-1, self.r[k]])
     if not self.sample_axis:
         output = torch_utils.flatten_torch(output, use_batch=False)
     else:
         output = torch_utils.reshape_torch(output, [self.N, self.r[-1]], use_batch=False)
     return output
コード例 #5
0
ファイル: tensorial.py プロジェクト: kharyuk/vbtd
 def inverse_batch(self, input_batch, tensorize_output=False, fast_svd=False):
     output = torch_utils.reshape_torch(input_batch, self.n)
     for k in range(self.d):
         if fast_svd:
             u, s, v = fast_svd_torch(
                 torch_utils.reshape_torch(
                     self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False
                 )
             )
         else:
             u, s, v = torch.svd(
                 torch_utils.reshape_torch(self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False)
             )
         output = torch.einsum(
             'ijk,jl->ilk',
             torch_utils.reshape_torch(output, [self.r[k]*self.n[k], -1]),
             (u/s).mm(v.t())
         )
     output = torch_utils.flatten_torch(output)
     return output
コード例 #6
0
ファイル: tensorial.py プロジェクト: kharyuk/vbtd
 def inverse_batch(self, input_batch, tensorize_output=False, fast_svd=False):
     output = input_batch.clone()
     if self.use_core:
         output_batch = self.core.clone()
         for k in range(self.d):
             output_batch = torch_utils.prodTenMat_torch(
                 output_batch,
                 self.factors[k].t().mm(self.factors[k]),
                 k+1,
                 1
             )
             output = torch_utils.prodTenMat_torch(
                 output,
                 self.factors[k],
                 k+1,
                 0
             )
         output = torch_utils.flatten_torch(output, use_batch=True)
         output = output.mm(torch_utils.reshape_torch(self.core, [-1, self.r0], use_batch=False))
         output_batch = torch_utils.reshape_torch(output_batch, [-1, self.r0], use_batch=False).t()
         output_batch = output_batch.mm(
             torch_utils.reshape_torch(self.core, [-1, self.r0], use_batch=False)
         )
         u, s, v = torch.svd(output_batch)
         output = output.mm(u/s).mm(v.t())
         return output
     
     for k in range(self.d):
         if fast_svd:
             u, s, v = torch_utils.fast_svd_torch(self.factors[k])
         else:
             u, s, v = torch.svd(self.factors[k])
             u, s, v = u[:, :self.r[k]], s[:self.r[k]], v[:, :self.r[k]]
         output = torch_utils.prodTenMat_torch(output, u/s, k+1, 0)
         output = torch_utils.prodTenMat_torch(output, v, k+1, 1)
     if tensorize_output and output.dim() == 2:
         output = torch_utils.reshape_torch(output, self.r)
     if not tensorize_output and output.dim() != 2:
         output = torch_utils.flatten_torch(output)
     return output
コード例 #7
0
ファイル: vbtd.py プロジェクト: kharyuk/vbtd
 def orthogonolize_k_factors(self, mode):
     if isinstance(mode, int):
         mode = [mode]
     for m in mode:
         for k in range(self.K):
             if isinstance(self.terms[k].linear_mapping,
                           tensorial.TTTensor):
                 tmp = self.terms[k].linear_mapping.cores[m].permute(
                     [1, 0, 2])
                 tmp = torch_utils.reshape_torch(tmp, [tmp.shape[0], -1],
                                                 order='F',
                                                 use_batch=False)
                 uk, _, _ = torch_utils.fast_svd_torch(tmp)
             else:
                 uk, _, _ = torch_utils.fast_svd_torch(
                     self.terms[k].linear_mapping.factors[m])
             for l in range(k + 1, self.K):
                 if isinstance(self.terms[l].linear_mapping,
                               tensorial.TTTensor):
                     tmp = self.terms[l].linear_mapping.cores[m].permute(
                         [1, 0, 2])
                     tmp_shape = list(tmp.shape)
                     tmp = torch_utils.reshape_torch(tmp,
                                                     [tmp_shape[0], -1],
                                                     order='F',
                                                     use_batch=False)
                     tmp -= torch.mm(uk, torch.mm(uk.t(), tmp))
                     tmp = torch_utils.reshape_torch(tmp,
                                                     tmp_shape,
                                                     order='F',
                                                     use_batch=False)
                     self.terms[l].linear_mapping.cores[
                         m].data = tmp.permute([1, 0, 2])
                 else:
                     self.terms[l].linear_mapping.factors[m].data = (
                         self.terms[l].linear_mapping.factors[m] - torch.mm(
                             uk,
                             torch.mm(
                                 uk.t(),
                                 self.terms[l].linear_mapping.factors[m])))
コード例 #8
0
ファイル: tensorial.py プロジェクト: kharyuk/vbtd
 def forward(self, input=None, T=False, tensorize_output=False):
     assert (input is None) != self.sample_axis
     #assert (T and input is not None)
     if T:
         assert input is not None
     if input is None:
         output = self.cores[-1].new_ones([1, 1])
     else:
         if T:
             output = torch_utils.reshape_torch(input, self.n)
         else:
             output = input.clone()
     if T:
         for k in range(self.d):
             output = torch.einsum(
                 'ijk,jl->ilk',
                 torch_utils.reshape_torch(output, [self.r[k]*self.n[k], -1]),
                 torch_utils.reshape_torch(self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False)
             )
     else:
         for k in range(self.d-1, -1, -1):
             if (k == self.d-1) and (input is None):
                 continue
             output = torch.einsum(
                 'ijk,lj->ilk',
                 torch_utils.reshape_torch(output, [self.r[k+1], -1]),
                 torch_utils.reshape_torch(self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False)
             )
     if tensorize_output:
         output = torch_utils.reshape_torch(output, self.n)
     else:
         output = torch_utils.flatten_torch(output)
     return output
コード例 #9
0
ファイル: tensorial.py プロジェクト: kharyuk/vbtd
 def forward(self, input=None, T=False, tensorize_output=False):
     assert (input is None) != self.sample_axis
     # batch_size = input.shape[0]
     if self.use_core:
         output = self.core.clone()
     else:
         if T:
             output = torch_utils.reshape_torch(input, self.n)
         else:
             output = torch_utils.reshape_torch(input, self.r)
     offset = int((not self.use_core) and self.sample_axis)
     for k in range(self.d):
         output = torch_utils.prodTenMat_torch(
             output, self.factors[k], k+offset, matrix_axis=int(not T)
         )
     if self.use_core and self.sample_axis:
         output = torch_utils.prodTenMat_torch(output, input, self.d, matrix_axis=1)
         permutation = [self.d] + list(range(self.d))
         output = output.permute(permutation)
     if not tensorize_output:
         return torch_utils.flatten_torch(output, use_batch=self.sample_axis)
     return output
コード例 #10
0
ファイル: vbtd.py プロジェクト: kharyuk/vbtd
 def isolate_group_factors(self, mode):
     assert self.group_term is not None
     if isinstance(mode, int):
         mode = [mode]
     for m in mode:
         if isinstance(self.group_term.linear_mapping, tensorial.TTTensor):
             tmp = self.group_term.linear_mapping.cores[m].permute(
                 [1, 0, 2])
             tmp = torch_utils.reshape_torch(tmp, [tmp.shape[0], -1],
                                             order='F',
                                             use_batch=False)
             u, _, _ = torch_utils.fast_svd_torch(tmp)
         else:
             u, _, _ = torch_utils.fast_svd_torch(
                 self.group_term.linear_mapping.factors[m])
         for k in range(self.K):
             if isinstance(self.terms[k].linear_mapping,
                           tensorial.TTTensor):
                 tmp = self.terms[k].linear_mapping.cores[m].permute(
                     [1, 0, 2])
                 tmp_shape = list(tmp.shape)
                 tmp = torch_utils.reshape_torch(tmp, [tmp_shape[0], -1],
                                                 order='F',
                                                 use_batch=False)
                 tmp -= torch.mm(u, torch.mm(u.t(), tmp))
                 tmp = torch_utils.reshape_torch(tmp,
                                                 tmp_shape,
                                                 order='F',
                                                 use_batch=False)
                 self.terms[k].linear_mapping.cores[m].data = tmp.permute(
                     [1, 0, 2])
             else:
                 self.terms[k].linear_mapping.factors[m].data = (
                     self.terms[k].linear_mapping.factors[m] - torch.mm(
                         u,
                         torch.mm(u.t(),
                                  self.terms[k].linear_mapping.factors[m])))
コード例 #11
0
ファイル: tensorial.py プロジェクト: kharyuk/vbtd
 def get_bias(self, tensorize_output=True, use_batch=True):
     if self.bias is None:
         return None
     if callable(self.bias):
         result = self.bias()
     else:
         result = self.bias.clone()
     shape = []
     if tensorize_output:
         shape += self.linear_mapping.n
     else:
         shape += [-1]
     if use_batch:
         shape = [1] + shape
     if len(shape) != result.dim():
         result = torch_utils.reshape_torch(result, shape, use_batch=False)
     return result
コード例 #12
0
ファイル: tensorial.py プロジェクト: kharyuk/vbtd
 def get_sources(self, mode):
     if isinstance(mode, int):
         mode = [mode]
     else:
         assert (np.diff(mode) == 1).all()
         #if isinstance(self.linear_mapping, TTTensor):
         #    assert (np.diff(mode) == 1).all()
         #elif isinstance(self.linear_mapping, (TuckerTensor, CPTensor, LROTensor)):
         #    assert (np.diff(mode) > 0).all()
     if isinstance(self.linear_mapping, CPTensor):
         result = self.linear_mapping.factors[0].new_ones([1, self.linear_mapping.R])
     elif isinstance(self.linear_mapping, LROTensor):
         result = self.linear_mapping.factors[0].new_ones([1, self.linear_mapping.M])
     elif isinstance(self.linear_mapping, TuckerTensor):
         result = self.linear_mapping.factors[0].new_ones([1, 1])
         # use_core case?
     elif isinstance(self.linear_mapping, TTTensor):
         result = self.linear_mapping.cores[0].new_ones([1, 1, self.linear_mapping.r[0]])
     else:
         raise ValueError
     for i in range(len(mode)):
         m = mode[i]
         assert 0 <= m < self.d
         if isinstance(self.linear_mapping, CPTensor):
             result = torch_utils.krp_cw_torch(self.linear_mapping.factors[m], result)
         elif isinstance(self.linear_mapping, LROTensor):
             tmp = self.linear_mapping.factors[m]
             if m >= self.linear_mapping.P:
                 tmp = torch.repeat_interleave(
                     tmp, torch.tensor(self.linear_mapping.L, device=self.factors[0].device), dim=1
                 )
             result = torch_utils.krp_cw_torch(tmp, result)
         elif isinstance(self.linear_mapping, TTTensor):
             result = torch.einsum('ijk,klm->ijlm', result, self.linear_mapping.cores[m])
             r1, n1, n2, r2 = result.shape
             result = torch_utils.reshape_torch(result, [r1, n1*n2, r2], use_batch=False)
         elif isinstance(self.linear_mapping, TuckerTensor):
             tmp = self.linear_mapping.factors[m]
             result = torch_utils.kron_torch(tmp, result)
         else:
             raise ValueError
     if isinstance(self.linear_mapping, TTTensor):
         result = torch_utils.swapunfold_torch(result, 1, use_batch=False)
     return result
コード例 #13
0
ファイル: tensorial.py プロジェクト: kharyuk/vbtd
 def get_svds(self, weights=None, fast=True):
     if fast:
         chi = 1.2
     core_flag = self.core is not None
     if core_flag:
         assert self.r0 is not None
         G = self.core.clone()
     U, S, V = [], [], []
     for k in range(self.d):
         if fast and (self.n[k] / self.r[k] >= chi):
             _, s, v = torch.svd(torch.mm(self.factors[k].t(), self.factors[k]))
             s = s.sqrt()
             u = self.factors[k].mm(v/s)
         else:
             u, s, v = torch.svd(self.factors[k])
         U.append(u[:, :self.r[k]])
         S.append(s[:self.r[k]].view(-1, 1))
         V.append(v[:, :self.r[k]])
         if core_flag:
             G = torch_utils.prodTenMat_torch(G, V[k]*S[k].t(), k, 0)
     if core_flag:
         permutation = [self.d] + list(range(self.d))
         G = G.permute(permutation).contiguous()
         rp = int(np.prod(self.r))
         tmp = torch_utils.reshape_torch(G, [self.r0, -1], use_batch=False)
         if fast and (self.r0 / rp >= chi):
             P, L, _ = torch.svd(tmp.t().mm(tmp))
             L = L.sqrt()
             Q = tmp.mm(P/L)
         elif fast and (rp / self.r0 >= chi):
             Q, L, _ = torch.svd(tmp.mm(tmp.t()))
             L = L.sqrt()
             P = torch.mm((Q / L).t(), tmp)
         else:
             P, L, Q = torch.svd(tmp.t())
             r = min(rp, self.r0)
             P, L, Q = P[:, :r], L[:r].view(-1, 1), Q[:, :r]
         return U, P, L, Q, core_flag
     return U, S, V, self.r, core_flag
コード例 #14
0
ファイル: vbtd.py プロジェクト: kharyuk/vbtd
    def model_with(self,
                   input_batch,
                   labels=None,
                   subsample_size=None,
                   subsample=None,
                   xi_greedy=None,
                   highlight_peak=None,
                   normalize=False,
                   isolate_group=None,
                   orthogonolize_terms=None,
                   expert=False):
        pyro.module('terms', self.terms)
        if self.iterms is not None:
            pyro.module('iterms', self.iterms)
        batch_size = input_batch.shape[0]

        if labels is not None:
            assert len(labels) == batch_size
        if subsample_size is None:
            subsample_size = batch_size
            if subsample is None:
                subsample = torch.arange(subsample_size)
        max_hidden_dim = max(self.hidden_dims)
        if self.group_term is not None:
            max_hidden_dim = max(max_hidden_dim, self.group_hidden_dim)
        with pyro.plate('epsilon_plate', subsample_size):
            epsilon = pyro.sample(
                'epsilon',
                dist.Normal(
                    input_batch.new_zeros(subsample_size, max_hidden_dim),
                    1.).independent(1))
        if self.group_term is not None:
            pyro.module('group_term', self.group_term)
            if self.likelihood == 'bernoulli':
                ppca_gm_means = self.module_ppca_gm_means_sigma(
                    input_batch[subsample],
                    epsilon  #[subsample]
                )
            elif self.likelihood == 'normal':
                ppca_gm_means, ppca_gm_sigma = self.module_ppca_gm_means_sigma(
                    input_batch[subsample],
                    epsilon  #[subsample]
                )
            else:
                raise ValueError

        if self.likelihood == 'bernoulli':
            pi, ppca_means = self.module_ppca_means_sigmas_weights(
                input_batch[subsample],
                epsilon,  #[subsample],
                expert=expert,
                highlight_peak=highlight_peak)
        elif self.likelihood == 'normal':
            pi, ppca_means, ppca_sigmas = self.module_ppca_means_sigmas_weights(
                input_batch[subsample],
                epsilon,  #[subsample],
                expert=expert,
                highlight_peak=highlight_peak)
        else:
            raise ValueError

        #print(ppca_means)
        with pyro.plate(f'samples',
                        batch_size,
                        subsample_size=subsample_size,
                        subsample=subsample,
                        device=input_batch.device) as i:
            #print(i, ppca_means.shape, ppca_sigmas.shape, ppca_gm_means.shape, ppca_gm_sigma.shape)
            assignments = pyro.sample('assignments', dist.Categorical(pi))
            if self.likelihood == 'normal':
                if assignments.dim() == 1:
                    if self.group_term is None:
                        pyro.sample(
                            f'obs',
                            dist.Normal(
                                ppca_means[assignments,
                                           torch.arange(subsample_size), :],
                                ppca_sigmas[assignments]
                                #).independent(1),
                            ).to_event(1),
                            obs=torch_utils.flatten_torch(
                                input_batch[i]
                            )  #*output_angles[:, k].view(-1, 1)
                        )
                    else:
                        pyro.sample(
                            f'obs',
                            dist.Normal(
                                (ppca_means + torch_utils.reshape_torch(
                                    ppca_gm_means, [1, subsample_size, -1],
                                    use_batch=False)
                                 )[assignments,
                                   torch.arange(batch_size), :],
                                (torch_utils.reshape_torch(ppca_sigmas,
                                                           [self.K, -1],
                                                           use_batch=False) +
                                 ppca_gm_sigma[0])[assignments]  #.view(-1, 1)
                            ).independent(1),  #to_event(1),
                            obs=torch_utils.flatten_torch(
                                input_batch[i]
                            )  #*output_angles[:, k].view(-1, 1)
                        )
                else:
                    if self.group_term is None:
                        pyro.sample(
                            f'obs',
                            dist.Normal(ppca_means[assignments, :, :][:, 0],
                                        ppca_sigmas[assignments].view(
                                            self.K, 1, -1)
                                        #).independent(1),
                                        ).to_event(1),
                            obs=torch_utils.flatten_torch(
                                input_batch[i]
                            )  #*output_angles[:, k].view(-1, 1)
                        )
                    else:
                        pyro.sample(
                            f'obs',
                            dist.Normal(
                                (ppca_means + torch_utils.reshape_torch(
                                    ppca_gm_means, [1, subsample_size, -1],
                                    use_batch=False))[assignments, :, :][:, 0],
                                torch_utils.reshape_torch(
                                    (ppca_sigmas.view(self.K, -1) +
                                     ppca_gm_sigma)[assignments],
                                    [self.K, 1, -1],
                                    use_batch=False)).independent(
                                        1),  #to_event(1),
                            obs=torch_utils.flatten_torch(
                                input_batch[i]
                            )  #*output_angles[:, k].view(-1, 1)
                        )
            elif self.likelihood == 'bernoulli':
                if assignments.dim() == 1:
                    if self.group_term is None:
                        pyro.sample(
                            f'obs',
                            dist.Bernoulli(
                                ppca_means[assignments,
                                           torch.arange(subsample_size), :],
                                validate_args=False).to_event(
                                    1),  #to_event(1),
                            obs=torch_utils.flatten_torch(
                                input_batch[i]
                            )  #*output_angles[:, k].view(-1, 1)
                        )
                    else:
                        pyro.sample(
                            f'obs',
                            dist.Bernoulli(
                                (ppca_means + torch_utils.reshape_torch(
                                    ppca_gm_means, [1, subsample_size, -1],
                                    use_batch=False)
                                 )[assignments,
                                   torch.arange(batch_size), :],
                                validate_args=False).to_event(
                                    1),  #to_event(1),
                            obs=torch_utils.flatten_torch(
                                input_batch[i]
                            )  #*output_angles[:, k].view(-1, 1)
                        )
                else:
                    if self.group_term is None:
                        pyro.sample(
                            f'obs',
                            dist.Bernoulli(ppca_means[assignments, :, :][:, 0],
                                           validate_args=False).to_event(
                                               1),  #to_event(1),
                            obs=torch_utils.flatten_torch(
                                input_batch[i]
                            )  #*output_angles[:, k].view(-1, 1)
                        )
                    else:
                        pyro.sample(
                            f'obs',
                            dist.Bernoulli(
                                (ppca_means + torch_utils.reshape_torch(
                                    ppca_gm_means, [1, subsample_size, -1],
                                    use_batch=False))[assignments, :, :][:, 0],
                                validate_args=False).to_event(
                                    1),  #to_event(1),
                            obs=torch_utils.flatten_torch(
                                input_batch[i]
                            )  #*output_angles[:, k].view(-1, 1)
                        )
            else:
                raise ValueError
コード例 #15
0
ファイル: tensorial.py プロジェクト: kharyuk/vbtd
 def normalize(self):
     for k in range(self.d):
         tmp = torch_utils.reshape_torch(self.cores[k].data, [self.r[k]*self.n[k], -1], use_batch=False)
         tmp = tmp / torch.norm(tmp, p='fro', dim=0)
         self.cores[k].data = torch_utils.reshape_torch(tmp, [self.r[k], self.n[k], -1], use_batch=False)
コード例 #16
0
ファイル: tensorial.py プロジェクト: kharyuk/vbtd
 def get_posterior_gaussian_mean_covariance(self, x_batch, noise_sigma=1, z_mu=0., z_sigma=1):
     if self.bias is not None:
         if callable(self.bias):
             output_mean = x_batch - torch_utils.reshape_torch(self.bias(), [1]+self.n, use_batch=False)
         else:
             output_mean = x_batch - torch_utils.reshape_torch(self.bias, [1]+self.n, use_batch=False)
         output_mean = torch_utils.reshape_torch(output_mean, self.linear_mapping.n)
     else:
         output_mean = torch_utils.reshape_torch(x_batch, self.linear_mapping.n)
     #output_mean = torch.mean(output_mean, dim=0, keepdim=True)
     if isinstance(self.linear_mapping, TuckerTensor):
         if not isinstance(noise_sigma, list):
             svds = self.linear_mapping.get_svds()
         else:
             svds = self.linear_mapping.get_svds(weights=[x.sqrt() for x in noise_sigma])
         if svds[-1]:
             U, P, L, Q, _ = svds
             S_cov = Q*L.t()
             for k in range(self.d):
                 output_mean = torch_utils.prodTenMat_torch(output_mean, U[k], k+1, 0)
             output_mean = torch_utils.flatten_torch(output_mean)
             #output_mean = torch.mm(output_mean, P*L.t())
             output_mean = output_mean.mm(L*P.t())
             #output_mean = torch.mm(output_mean, Q.t())
             output_mean = output_mean.mm(Q)
         else:
             U, S, V, shapes_s, _ = svds
             S_cov = x_batch.new_ones([1, 1])
             for k in range(self.d):
                 S_cov = torch_utils.kron_torch(S_cov, V[k]*S[k].t())
                 output_mean = torch_utils.prodTenMat_torch(output_mean, U[k]*S[k].t(), k+1, 0)
                 output_mean = torch_utils.prodTenMat_torch(output_mean, V[k], k+1, 1)
     elif (
         isinstance(self.linear_mapping, CPTensor) or
         isinstance(self.linear_mapping, LROTensor)
     ):
         if not isinstance(noise_sigma, list):
             U, S, V = self.linear_mapping.get_svds(coupled=True)
         else:
             U, S, V = self.linear_mapping.get_svds(weights=[x.sqrt() for x in noise_sigma], coupled=True)
         S_cov = V*S.t()
         output_mean = torch_utils.flatten_torch(output_mean)
         output_mean = output_mean.mm(U*S)
         output_mean = output_mean.mm(V.t())
     elif isinstance(self.linear_mapping, TTTensor):
         #S_cov = x_batch.new_ones([1, 1])
         for k in range(self.d):
             shape = [self.n[k], self.linear_mapping.r[k+1]]
             tmp = self.linear_mapping.cores[k]
             if isinstance(noise_sigma, list):
                 tmp = tmp * noise_sigma[k].sqrt().view(1, -1, 1)
             if k > 0:
                 tmp = torch.einsum('ij,iab,jac->bc', S_cov, tmp, tmp)
             else:
                 tmp = torch.einsum('aib,aic->bc', tmp, tmp)
             #tmp = torch_utils.reshape_torch(tmp, shape, use_batch=False)
             #E, V = torch.eig(tmp, eigenvectors=True)
             #S_cov = (V*E[:, :1].t()).mm(V.t())
             u, s, v = torch.svd(tmp)
             S_cov = (u/s).mm(v.t())
             
             shape = [self.linear_mapping.r[k]*self.n[k], -1]
             tmp = self.linear_mapping.cores[k]
             output_mean = torch_utils.reshape_torch(output_mean, shape)
             output_mean = torch.einsum(
                 'ijk,jl->ilk', output_mean, torch_utils.reshape_torch(tmp, shape, use_batch=False)
             )
     else:
         raise ValueError
     if not isinstance(noise_sigma, list):
         try:
             S_cov = S_cov/np.sqrt(noise_sigma)
         except:
             S_cov = S_cov/noise_sigma.sqrt()
     if not isinstance(self.linear_mapping, TTTensor):
         S_cov = S_cov.mm(S_cov.t())
     n = S_cov.shape[0]
     mask = torch.eye(n, n, device=x_batch.device).byte()
     S_cov[mask] += 1./z_sigma
     u, s, v = torch.svd(S_cov)
     S_cov = (u/s).mm(v.t())
     #E, V = torch.eig(S_cov, eigenvectors=True)
     #S_cov = (V/E[:, :1].t()).mm(V.t())
     output_mean = torch_utils.flatten_torch(output_mean)
     output_mean = output_mean + z_mu / z_sigma ###
     output_mean = output_mean.mm(S_cov)
     S_cov = S_cov.unsqueeze(0)
     return output_mean, S_cov
コード例 #17
0
ファイル: tensorial.py プロジェクト: kharyuk/vbtd
 def multi_project(self, input_batch, remove_bias=True, tensorize=False):
     if remove_bias and (self.bias is not None):
         if callable(self.bias):
             output_batch = input_batch - torch_utils.reshape_torch(self.bias(), [1]+self.n, use_batch=False)
         else:
             output_batch = input_batch - torch_utils.reshape_torch(self.bias, [1]+self.n, use_batch=False)
         output_batch = torch_utils.reshape_torch(output_batch, self.linear_mapping.n)
     else:
         output_batch = torch_utils.reshape_torch(input_batch, self.linear_mapping.n)
         
     if isinstance(self.linear_mapping, TuckerTensor):
         svds = self.linear_mapping.get_svds()
         if svds[-1]:
             U, P, _, _, _ = svds
             for k in range(self.d):
                 output_batch = torch_utils.prodTenMat_torch(output_batch, U[k], k+1, 0)
             output_batch = torch_utils.flatten_torch(output_batch)
             output_batch = output_batch.mm(P.t())
             output_batch = output_batch.mm(P)
             output_batch = torch_utils.reshape_torch(output_batch, self.linear_mapping.r)
             for k in range(self.d):
                 output_batch = torch_utils.prodTenMat_torch(output_batch, U[k], k+1, 1)
         else:
             U, _, _, _, _ = svds
             for k in range(self.d):
                 output_batch = torch_utils.prodTenMat_torch(output_batch, U[k], k+1, 0)
                 output_batch = torch_utils.prodTenMat_torch(output_batch, U[k], k+1, 1)
     elif (
         isinstance(self.linear_mapping, CPTensor) or
         isinstance(self.linear_mapping, LROTensor)
     ):
         U, _, _ = self.linear_mapping.get_svds(coupled=True)
         output_batch = torch_utils.flatten_torch(output_batch)
         output_batch = output_batch.mm(U)
         output_batch = output_batch.mm(U.t())
     elif isinstance(self.linear_mapping, TTTensor):
         orth_list = []
         output_batch = input_batch.clone()
         for k in range(self.d):
             if k > 0:
                 tmp = torch_utils.prodTenMat_torch(
                     self.linear_mapping.cores[k],
                     tmp,
                     0,
                     1
                 )
             else:
                 tmp = self.linear_mapping.cores[k]
             tmp = torch_utils.reshape_torch(tmp, [-1, self.linear_mapping.r[k+1]], use_batch=False)
             tmp = torch_utils.reshape_torch(
                 self.linear_mapping.cores[k], [-1, self.linear_mapping.r[k+1]], use_batch=False
             )
             u, s, v = torch.svd(tmp)
             orth_list.append(u[:, :self.linear_mapping.r[k+1]])
             tmp = s*(v[:, :self.linear_mapping.r[k+1]].t())
             output_batch = torch_utils.reshape_torch(
                 output_batch,
                 [self.linear_mapping.r[k]*self.linear_mapping.n[k], -1],
                 use_batch=True
             )
             output_batch = torch_utils.prodTenMat_torch(output_batch, u, 1, 0)
         u, s, v = torch.svd(tmp)
         output_batch = output_batch.squeeze(2).mm(u).mm(u.t()) ##### ??? 
         for k in range(self.d-1, -1, -1):
             if k == self.d-1:
                 output_batch = output_batch.mm(orth_list[k].t())
             else:
                 output_batch = torch_utils.prodTenMat_torch(output_batch, orth_list[k], self.d-k, 1)
             output_batch = torch_utils.reshape_torch(
                 output_batch, self.n[k:]+[self.linear_mapping.r[k]], use_batch=True
             )
     else:
         raise ValueError
     if (tensorize) and (output_batch.dim() == 2):
         return torch_utils.reshape_torch(output_batch, self.linear_mapping.n)
     if (not tensorize) and (output_batch.dim() > 2):
         return torch_utils.flatten_torch(output_batch)
     return output_batch
コード例 #18
0
ファイル: vbtd.py プロジェクト: kharyuk/vbtd
    def measure_principal_angles(self, input_batch, mode, fast=True):
        batch_size = input_batch.shape[0]
        mu_k = []
        for k in range(self.K):
            tmp = self.terms[k].get_bias(tensorize_output=False,
                                         use_batch=True)
            if tmp is None:
                tmp = input_batch.new_zeros(1, self.output_dim)
            mu_k.append(tmp)

        mu_k = torch.cat(mu_k)

        if self.group_term is not None:
            mu_g = self.group_term.get_bias(tensorize_output=False,
                                            use_batch=True)
            if mu_g is None:
                mu_g = input_batch.new_zeros(1, self.output_dim)
            projected_batch = self.group_term.multi_project(input_batch,
                                                            remove_bias=False,
                                                            tensorize=False)
            projected_mu = mu_k + mu_g
            projected_mu -= self.group_term.multi_project(
                torch_utils.reshape_torch(projected_mu, self.n),
                remove_bias=False,
                tensorize=False)
        #'''
        input_batch = torch_utils.flatten_torch(input_batch)
        output_angles = input_batch.new_zeros([batch_size, self.K])
        if fast:
            chi = 1.2
        for k in range(self.K):
            if fast:
                Uk, _, _ = torch_utils.fast_svd_torch(
                    #self.terms[k].linear_mapping.factors[mode]
                    self.terms[k].get_sources(self.source_mode),
                    chi=chi)
            else:
                #Uk, _, _ = torch.svd(self.terms[k].linear_mapping.factors[mode])
                Uk, _, _ = torch.svd(self.terms[k].get_sources(
                    self.source_mode))
                Uk = Uk[:, :self.terms[k].linear_mapping.r[mode]]
            Uk = Uk.t()
            '''
            if self.group_term is not None:
                current_batch = input_batch - projected_batch - projected_mu[k:k+1, :]
            else:
                current_batch = input_batch - mu_k[k:k+1, :]
            '''
            current_batch = torch_utils.reshape_torch(input_batch,
                                                      self.n)  # current
            current_batch = torch_utils.swapaxes_torch(current_batch, 1,
                                                       mode + 1)
            current_batch = torch_utils.reshape_torch(current_batch,
                                                      [self.n[mode], -1])
            tmp_r = current_batch.shape[-1]
            for i in range(batch_size):
                if fast:
                    u, _, _ = torch_utils.fast_svd_torch(current_batch[i],
                                                         chi=chi)
                else:
                    u, _, _ = torch.svd(current_batch[i])
                _, s, _ = torch.svd(Uk.mm(u[:, :tmp_r]))
                output_angles[i, k] = s[0]
        return output_angles