def get_svds(self, weights=None, fast=True): U, S, V = [], [], [] for k in range(self.d): if weights is None: if fast: u, s, v = fast_svd_torch( torch_utils.reshape_torch( self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False )*weights[k] ) else: u, s, v = torch.svd( torch_utils.reshape_torch(self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False) ) else: if fast: u, s, v = fast_svd_torch( torch_utils.reshape_torch(self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False) ) else: u, s, v = torch.svd( torch_utils.reshape_torch(self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False) ) u, s, v = u[:, :self.r[k+1]], s[:self.r[k+1]], v[:, :self.r[k+1]] U.append(u) S.append(s) V.append(v) return U, S, V
def orthogonolize(self, last_core=True): for k in range(self.d): tmp = torch_utils.reshape_torch(self.cores[k].data, [self.r[k]*self.n[k], -1], use_batch=False) if k > 0: tmp = r.mm(tmp) if (k == self.d-1) and (not last_core): self.cores[k].data = torch_utils.reshape_torch(tmp, [self.r[k], self.n[k], -1], use_batch=False) continue q, r = torch.qr(tmp) self.cores[k].data = torch_utils.reshape_torch(q, [self.r[k], self.n[k], -1], use_batch=False)
def forward(self, input=None, T=False, tensorize_output=False): assert (input is None) != self.sample_axis output = self.recover() if T: output = output.t() if self.sample_axis: # batch_size = input.shape[0] output = torch.einsum('ij,kj->ik', input, output) if tensorize_output: if T: return torch_utils.reshape_torch(output, self.r, use_batch=self.sample_axis) else: return torch_utils.reshape_torch(output, self.n, use_batch=self.sample_axis) return output
def recover(self, weights=None): nrows = self.N if not self.sample_axis: nrows = nrows // self.n[0] output = self.cores[0].new_ones([1, 1]) for k in range(self.d): if weights is None: output = output.mm(torch_utils.reshape_torch(self.cores[k], [self.r[k], -1])) else: output = output.mm( torch_utils.reshape_torch(self.cores[k]*weights[k].view(1, self.n[k], 1), [self.r[k], -1]) ) output = torch_utils.reshape_torch(output, [-1, self.r[k]]) if not self.sample_axis: output = torch_utils.flatten_torch(output, use_batch=False) else: output = torch_utils.reshape_torch(output, [self.N, self.r[-1]], use_batch=False) return output
def inverse_batch(self, input_batch, tensorize_output=False, fast_svd=False): output = torch_utils.reshape_torch(input_batch, self.n) for k in range(self.d): if fast_svd: u, s, v = fast_svd_torch( torch_utils.reshape_torch( self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False ) ) else: u, s, v = torch.svd( torch_utils.reshape_torch(self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False) ) output = torch.einsum( 'ijk,jl->ilk', torch_utils.reshape_torch(output, [self.r[k]*self.n[k], -1]), (u/s).mm(v.t()) ) output = torch_utils.flatten_torch(output) return output
def inverse_batch(self, input_batch, tensorize_output=False, fast_svd=False): output = input_batch.clone() if self.use_core: output_batch = self.core.clone() for k in range(self.d): output_batch = torch_utils.prodTenMat_torch( output_batch, self.factors[k].t().mm(self.factors[k]), k+1, 1 ) output = torch_utils.prodTenMat_torch( output, self.factors[k], k+1, 0 ) output = torch_utils.flatten_torch(output, use_batch=True) output = output.mm(torch_utils.reshape_torch(self.core, [-1, self.r0], use_batch=False)) output_batch = torch_utils.reshape_torch(output_batch, [-1, self.r0], use_batch=False).t() output_batch = output_batch.mm( torch_utils.reshape_torch(self.core, [-1, self.r0], use_batch=False) ) u, s, v = torch.svd(output_batch) output = output.mm(u/s).mm(v.t()) return output for k in range(self.d): if fast_svd: u, s, v = torch_utils.fast_svd_torch(self.factors[k]) else: u, s, v = torch.svd(self.factors[k]) u, s, v = u[:, :self.r[k]], s[:self.r[k]], v[:, :self.r[k]] output = torch_utils.prodTenMat_torch(output, u/s, k+1, 0) output = torch_utils.prodTenMat_torch(output, v, k+1, 1) if tensorize_output and output.dim() == 2: output = torch_utils.reshape_torch(output, self.r) if not tensorize_output and output.dim() != 2: output = torch_utils.flatten_torch(output) return output
def orthogonolize_k_factors(self, mode): if isinstance(mode, int): mode = [mode] for m in mode: for k in range(self.K): if isinstance(self.terms[k].linear_mapping, tensorial.TTTensor): tmp = self.terms[k].linear_mapping.cores[m].permute( [1, 0, 2]) tmp = torch_utils.reshape_torch(tmp, [tmp.shape[0], -1], order='F', use_batch=False) uk, _, _ = torch_utils.fast_svd_torch(tmp) else: uk, _, _ = torch_utils.fast_svd_torch( self.terms[k].linear_mapping.factors[m]) for l in range(k + 1, self.K): if isinstance(self.terms[l].linear_mapping, tensorial.TTTensor): tmp = self.terms[l].linear_mapping.cores[m].permute( [1, 0, 2]) tmp_shape = list(tmp.shape) tmp = torch_utils.reshape_torch(tmp, [tmp_shape[0], -1], order='F', use_batch=False) tmp -= torch.mm(uk, torch.mm(uk.t(), tmp)) tmp = torch_utils.reshape_torch(tmp, tmp_shape, order='F', use_batch=False) self.terms[l].linear_mapping.cores[ m].data = tmp.permute([1, 0, 2]) else: self.terms[l].linear_mapping.factors[m].data = ( self.terms[l].linear_mapping.factors[m] - torch.mm( uk, torch.mm( uk.t(), self.terms[l].linear_mapping.factors[m])))
def forward(self, input=None, T=False, tensorize_output=False): assert (input is None) != self.sample_axis #assert (T and input is not None) if T: assert input is not None if input is None: output = self.cores[-1].new_ones([1, 1]) else: if T: output = torch_utils.reshape_torch(input, self.n) else: output = input.clone() if T: for k in range(self.d): output = torch.einsum( 'ijk,jl->ilk', torch_utils.reshape_torch(output, [self.r[k]*self.n[k], -1]), torch_utils.reshape_torch(self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False) ) else: for k in range(self.d-1, -1, -1): if (k == self.d-1) and (input is None): continue output = torch.einsum( 'ijk,lj->ilk', torch_utils.reshape_torch(output, [self.r[k+1], -1]), torch_utils.reshape_torch(self.cores[k], [self.r[k]*self.n[k], -1], use_batch=False) ) if tensorize_output: output = torch_utils.reshape_torch(output, self.n) else: output = torch_utils.flatten_torch(output) return output
def forward(self, input=None, T=False, tensorize_output=False): assert (input is None) != self.sample_axis # batch_size = input.shape[0] if self.use_core: output = self.core.clone() else: if T: output = torch_utils.reshape_torch(input, self.n) else: output = torch_utils.reshape_torch(input, self.r) offset = int((not self.use_core) and self.sample_axis) for k in range(self.d): output = torch_utils.prodTenMat_torch( output, self.factors[k], k+offset, matrix_axis=int(not T) ) if self.use_core and self.sample_axis: output = torch_utils.prodTenMat_torch(output, input, self.d, matrix_axis=1) permutation = [self.d] + list(range(self.d)) output = output.permute(permutation) if not tensorize_output: return torch_utils.flatten_torch(output, use_batch=self.sample_axis) return output
def isolate_group_factors(self, mode): assert self.group_term is not None if isinstance(mode, int): mode = [mode] for m in mode: if isinstance(self.group_term.linear_mapping, tensorial.TTTensor): tmp = self.group_term.linear_mapping.cores[m].permute( [1, 0, 2]) tmp = torch_utils.reshape_torch(tmp, [tmp.shape[0], -1], order='F', use_batch=False) u, _, _ = torch_utils.fast_svd_torch(tmp) else: u, _, _ = torch_utils.fast_svd_torch( self.group_term.linear_mapping.factors[m]) for k in range(self.K): if isinstance(self.terms[k].linear_mapping, tensorial.TTTensor): tmp = self.terms[k].linear_mapping.cores[m].permute( [1, 0, 2]) tmp_shape = list(tmp.shape) tmp = torch_utils.reshape_torch(tmp, [tmp_shape[0], -1], order='F', use_batch=False) tmp -= torch.mm(u, torch.mm(u.t(), tmp)) tmp = torch_utils.reshape_torch(tmp, tmp_shape, order='F', use_batch=False) self.terms[k].linear_mapping.cores[m].data = tmp.permute( [1, 0, 2]) else: self.terms[k].linear_mapping.factors[m].data = ( self.terms[k].linear_mapping.factors[m] - torch.mm( u, torch.mm(u.t(), self.terms[k].linear_mapping.factors[m])))
def get_bias(self, tensorize_output=True, use_batch=True): if self.bias is None: return None if callable(self.bias): result = self.bias() else: result = self.bias.clone() shape = [] if tensorize_output: shape += self.linear_mapping.n else: shape += [-1] if use_batch: shape = [1] + shape if len(shape) != result.dim(): result = torch_utils.reshape_torch(result, shape, use_batch=False) return result
def get_sources(self, mode): if isinstance(mode, int): mode = [mode] else: assert (np.diff(mode) == 1).all() #if isinstance(self.linear_mapping, TTTensor): # assert (np.diff(mode) == 1).all() #elif isinstance(self.linear_mapping, (TuckerTensor, CPTensor, LROTensor)): # assert (np.diff(mode) > 0).all() if isinstance(self.linear_mapping, CPTensor): result = self.linear_mapping.factors[0].new_ones([1, self.linear_mapping.R]) elif isinstance(self.linear_mapping, LROTensor): result = self.linear_mapping.factors[0].new_ones([1, self.linear_mapping.M]) elif isinstance(self.linear_mapping, TuckerTensor): result = self.linear_mapping.factors[0].new_ones([1, 1]) # use_core case? elif isinstance(self.linear_mapping, TTTensor): result = self.linear_mapping.cores[0].new_ones([1, 1, self.linear_mapping.r[0]]) else: raise ValueError for i in range(len(mode)): m = mode[i] assert 0 <= m < self.d if isinstance(self.linear_mapping, CPTensor): result = torch_utils.krp_cw_torch(self.linear_mapping.factors[m], result) elif isinstance(self.linear_mapping, LROTensor): tmp = self.linear_mapping.factors[m] if m >= self.linear_mapping.P: tmp = torch.repeat_interleave( tmp, torch.tensor(self.linear_mapping.L, device=self.factors[0].device), dim=1 ) result = torch_utils.krp_cw_torch(tmp, result) elif isinstance(self.linear_mapping, TTTensor): result = torch.einsum('ijk,klm->ijlm', result, self.linear_mapping.cores[m]) r1, n1, n2, r2 = result.shape result = torch_utils.reshape_torch(result, [r1, n1*n2, r2], use_batch=False) elif isinstance(self.linear_mapping, TuckerTensor): tmp = self.linear_mapping.factors[m] result = torch_utils.kron_torch(tmp, result) else: raise ValueError if isinstance(self.linear_mapping, TTTensor): result = torch_utils.swapunfold_torch(result, 1, use_batch=False) return result
def get_svds(self, weights=None, fast=True): if fast: chi = 1.2 core_flag = self.core is not None if core_flag: assert self.r0 is not None G = self.core.clone() U, S, V = [], [], [] for k in range(self.d): if fast and (self.n[k] / self.r[k] >= chi): _, s, v = torch.svd(torch.mm(self.factors[k].t(), self.factors[k])) s = s.sqrt() u = self.factors[k].mm(v/s) else: u, s, v = torch.svd(self.factors[k]) U.append(u[:, :self.r[k]]) S.append(s[:self.r[k]].view(-1, 1)) V.append(v[:, :self.r[k]]) if core_flag: G = torch_utils.prodTenMat_torch(G, V[k]*S[k].t(), k, 0) if core_flag: permutation = [self.d] + list(range(self.d)) G = G.permute(permutation).contiguous() rp = int(np.prod(self.r)) tmp = torch_utils.reshape_torch(G, [self.r0, -1], use_batch=False) if fast and (self.r0 / rp >= chi): P, L, _ = torch.svd(tmp.t().mm(tmp)) L = L.sqrt() Q = tmp.mm(P/L) elif fast and (rp / self.r0 >= chi): Q, L, _ = torch.svd(tmp.mm(tmp.t())) L = L.sqrt() P = torch.mm((Q / L).t(), tmp) else: P, L, Q = torch.svd(tmp.t()) r = min(rp, self.r0) P, L, Q = P[:, :r], L[:r].view(-1, 1), Q[:, :r] return U, P, L, Q, core_flag return U, S, V, self.r, core_flag
def model_with(self, input_batch, labels=None, subsample_size=None, subsample=None, xi_greedy=None, highlight_peak=None, normalize=False, isolate_group=None, orthogonolize_terms=None, expert=False): pyro.module('terms', self.terms) if self.iterms is not None: pyro.module('iterms', self.iterms) batch_size = input_batch.shape[0] if labels is not None: assert len(labels) == batch_size if subsample_size is None: subsample_size = batch_size if subsample is None: subsample = torch.arange(subsample_size) max_hidden_dim = max(self.hidden_dims) if self.group_term is not None: max_hidden_dim = max(max_hidden_dim, self.group_hidden_dim) with pyro.plate('epsilon_plate', subsample_size): epsilon = pyro.sample( 'epsilon', dist.Normal( input_batch.new_zeros(subsample_size, max_hidden_dim), 1.).independent(1)) if self.group_term is not None: pyro.module('group_term', self.group_term) if self.likelihood == 'bernoulli': ppca_gm_means = self.module_ppca_gm_means_sigma( input_batch[subsample], epsilon #[subsample] ) elif self.likelihood == 'normal': ppca_gm_means, ppca_gm_sigma = self.module_ppca_gm_means_sigma( input_batch[subsample], epsilon #[subsample] ) else: raise ValueError if self.likelihood == 'bernoulli': pi, ppca_means = self.module_ppca_means_sigmas_weights( input_batch[subsample], epsilon, #[subsample], expert=expert, highlight_peak=highlight_peak) elif self.likelihood == 'normal': pi, ppca_means, ppca_sigmas = self.module_ppca_means_sigmas_weights( input_batch[subsample], epsilon, #[subsample], expert=expert, highlight_peak=highlight_peak) else: raise ValueError #print(ppca_means) with pyro.plate(f'samples', batch_size, subsample_size=subsample_size, subsample=subsample, device=input_batch.device) as i: #print(i, ppca_means.shape, ppca_sigmas.shape, ppca_gm_means.shape, ppca_gm_sigma.shape) assignments = pyro.sample('assignments', dist.Categorical(pi)) if self.likelihood == 'normal': if assignments.dim() == 1: if self.group_term is None: pyro.sample( f'obs', dist.Normal( ppca_means[assignments, torch.arange(subsample_size), :], ppca_sigmas[assignments] #).independent(1), ).to_event(1), obs=torch_utils.flatten_torch( input_batch[i] ) #*output_angles[:, k].view(-1, 1) ) else: pyro.sample( f'obs', dist.Normal( (ppca_means + torch_utils.reshape_torch( ppca_gm_means, [1, subsample_size, -1], use_batch=False) )[assignments, torch.arange(batch_size), :], (torch_utils.reshape_torch(ppca_sigmas, [self.K, -1], use_batch=False) + ppca_gm_sigma[0])[assignments] #.view(-1, 1) ).independent(1), #to_event(1), obs=torch_utils.flatten_torch( input_batch[i] ) #*output_angles[:, k].view(-1, 1) ) else: if self.group_term is None: pyro.sample( f'obs', dist.Normal(ppca_means[assignments, :, :][:, 0], ppca_sigmas[assignments].view( self.K, 1, -1) #).independent(1), ).to_event(1), obs=torch_utils.flatten_torch( input_batch[i] ) #*output_angles[:, k].view(-1, 1) ) else: pyro.sample( f'obs', dist.Normal( (ppca_means + torch_utils.reshape_torch( ppca_gm_means, [1, subsample_size, -1], use_batch=False))[assignments, :, :][:, 0], torch_utils.reshape_torch( (ppca_sigmas.view(self.K, -1) + ppca_gm_sigma)[assignments], [self.K, 1, -1], use_batch=False)).independent( 1), #to_event(1), obs=torch_utils.flatten_torch( input_batch[i] ) #*output_angles[:, k].view(-1, 1) ) elif self.likelihood == 'bernoulli': if assignments.dim() == 1: if self.group_term is None: pyro.sample( f'obs', dist.Bernoulli( ppca_means[assignments, torch.arange(subsample_size), :], validate_args=False).to_event( 1), #to_event(1), obs=torch_utils.flatten_torch( input_batch[i] ) #*output_angles[:, k].view(-1, 1) ) else: pyro.sample( f'obs', dist.Bernoulli( (ppca_means + torch_utils.reshape_torch( ppca_gm_means, [1, subsample_size, -1], use_batch=False) )[assignments, torch.arange(batch_size), :], validate_args=False).to_event( 1), #to_event(1), obs=torch_utils.flatten_torch( input_batch[i] ) #*output_angles[:, k].view(-1, 1) ) else: if self.group_term is None: pyro.sample( f'obs', dist.Bernoulli(ppca_means[assignments, :, :][:, 0], validate_args=False).to_event( 1), #to_event(1), obs=torch_utils.flatten_torch( input_batch[i] ) #*output_angles[:, k].view(-1, 1) ) else: pyro.sample( f'obs', dist.Bernoulli( (ppca_means + torch_utils.reshape_torch( ppca_gm_means, [1, subsample_size, -1], use_batch=False))[assignments, :, :][:, 0], validate_args=False).to_event( 1), #to_event(1), obs=torch_utils.flatten_torch( input_batch[i] ) #*output_angles[:, k].view(-1, 1) ) else: raise ValueError
def normalize(self): for k in range(self.d): tmp = torch_utils.reshape_torch(self.cores[k].data, [self.r[k]*self.n[k], -1], use_batch=False) tmp = tmp / torch.norm(tmp, p='fro', dim=0) self.cores[k].data = torch_utils.reshape_torch(tmp, [self.r[k], self.n[k], -1], use_batch=False)
def get_posterior_gaussian_mean_covariance(self, x_batch, noise_sigma=1, z_mu=0., z_sigma=1): if self.bias is not None: if callable(self.bias): output_mean = x_batch - torch_utils.reshape_torch(self.bias(), [1]+self.n, use_batch=False) else: output_mean = x_batch - torch_utils.reshape_torch(self.bias, [1]+self.n, use_batch=False) output_mean = torch_utils.reshape_torch(output_mean, self.linear_mapping.n) else: output_mean = torch_utils.reshape_torch(x_batch, self.linear_mapping.n) #output_mean = torch.mean(output_mean, dim=0, keepdim=True) if isinstance(self.linear_mapping, TuckerTensor): if not isinstance(noise_sigma, list): svds = self.linear_mapping.get_svds() else: svds = self.linear_mapping.get_svds(weights=[x.sqrt() for x in noise_sigma]) if svds[-1]: U, P, L, Q, _ = svds S_cov = Q*L.t() for k in range(self.d): output_mean = torch_utils.prodTenMat_torch(output_mean, U[k], k+1, 0) output_mean = torch_utils.flatten_torch(output_mean) #output_mean = torch.mm(output_mean, P*L.t()) output_mean = output_mean.mm(L*P.t()) #output_mean = torch.mm(output_mean, Q.t()) output_mean = output_mean.mm(Q) else: U, S, V, shapes_s, _ = svds S_cov = x_batch.new_ones([1, 1]) for k in range(self.d): S_cov = torch_utils.kron_torch(S_cov, V[k]*S[k].t()) output_mean = torch_utils.prodTenMat_torch(output_mean, U[k]*S[k].t(), k+1, 0) output_mean = torch_utils.prodTenMat_torch(output_mean, V[k], k+1, 1) elif ( isinstance(self.linear_mapping, CPTensor) or isinstance(self.linear_mapping, LROTensor) ): if not isinstance(noise_sigma, list): U, S, V = self.linear_mapping.get_svds(coupled=True) else: U, S, V = self.linear_mapping.get_svds(weights=[x.sqrt() for x in noise_sigma], coupled=True) S_cov = V*S.t() output_mean = torch_utils.flatten_torch(output_mean) output_mean = output_mean.mm(U*S) output_mean = output_mean.mm(V.t()) elif isinstance(self.linear_mapping, TTTensor): #S_cov = x_batch.new_ones([1, 1]) for k in range(self.d): shape = [self.n[k], self.linear_mapping.r[k+1]] tmp = self.linear_mapping.cores[k] if isinstance(noise_sigma, list): tmp = tmp * noise_sigma[k].sqrt().view(1, -1, 1) if k > 0: tmp = torch.einsum('ij,iab,jac->bc', S_cov, tmp, tmp) else: tmp = torch.einsum('aib,aic->bc', tmp, tmp) #tmp = torch_utils.reshape_torch(tmp, shape, use_batch=False) #E, V = torch.eig(tmp, eigenvectors=True) #S_cov = (V*E[:, :1].t()).mm(V.t()) u, s, v = torch.svd(tmp) S_cov = (u/s).mm(v.t()) shape = [self.linear_mapping.r[k]*self.n[k], -1] tmp = self.linear_mapping.cores[k] output_mean = torch_utils.reshape_torch(output_mean, shape) output_mean = torch.einsum( 'ijk,jl->ilk', output_mean, torch_utils.reshape_torch(tmp, shape, use_batch=False) ) else: raise ValueError if not isinstance(noise_sigma, list): try: S_cov = S_cov/np.sqrt(noise_sigma) except: S_cov = S_cov/noise_sigma.sqrt() if not isinstance(self.linear_mapping, TTTensor): S_cov = S_cov.mm(S_cov.t()) n = S_cov.shape[0] mask = torch.eye(n, n, device=x_batch.device).byte() S_cov[mask] += 1./z_sigma u, s, v = torch.svd(S_cov) S_cov = (u/s).mm(v.t()) #E, V = torch.eig(S_cov, eigenvectors=True) #S_cov = (V/E[:, :1].t()).mm(V.t()) output_mean = torch_utils.flatten_torch(output_mean) output_mean = output_mean + z_mu / z_sigma ### output_mean = output_mean.mm(S_cov) S_cov = S_cov.unsqueeze(0) return output_mean, S_cov
def multi_project(self, input_batch, remove_bias=True, tensorize=False): if remove_bias and (self.bias is not None): if callable(self.bias): output_batch = input_batch - torch_utils.reshape_torch(self.bias(), [1]+self.n, use_batch=False) else: output_batch = input_batch - torch_utils.reshape_torch(self.bias, [1]+self.n, use_batch=False) output_batch = torch_utils.reshape_torch(output_batch, self.linear_mapping.n) else: output_batch = torch_utils.reshape_torch(input_batch, self.linear_mapping.n) if isinstance(self.linear_mapping, TuckerTensor): svds = self.linear_mapping.get_svds() if svds[-1]: U, P, _, _, _ = svds for k in range(self.d): output_batch = torch_utils.prodTenMat_torch(output_batch, U[k], k+1, 0) output_batch = torch_utils.flatten_torch(output_batch) output_batch = output_batch.mm(P.t()) output_batch = output_batch.mm(P) output_batch = torch_utils.reshape_torch(output_batch, self.linear_mapping.r) for k in range(self.d): output_batch = torch_utils.prodTenMat_torch(output_batch, U[k], k+1, 1) else: U, _, _, _, _ = svds for k in range(self.d): output_batch = torch_utils.prodTenMat_torch(output_batch, U[k], k+1, 0) output_batch = torch_utils.prodTenMat_torch(output_batch, U[k], k+1, 1) elif ( isinstance(self.linear_mapping, CPTensor) or isinstance(self.linear_mapping, LROTensor) ): U, _, _ = self.linear_mapping.get_svds(coupled=True) output_batch = torch_utils.flatten_torch(output_batch) output_batch = output_batch.mm(U) output_batch = output_batch.mm(U.t()) elif isinstance(self.linear_mapping, TTTensor): orth_list = [] output_batch = input_batch.clone() for k in range(self.d): if k > 0: tmp = torch_utils.prodTenMat_torch( self.linear_mapping.cores[k], tmp, 0, 1 ) else: tmp = self.linear_mapping.cores[k] tmp = torch_utils.reshape_torch(tmp, [-1, self.linear_mapping.r[k+1]], use_batch=False) tmp = torch_utils.reshape_torch( self.linear_mapping.cores[k], [-1, self.linear_mapping.r[k+1]], use_batch=False ) u, s, v = torch.svd(tmp) orth_list.append(u[:, :self.linear_mapping.r[k+1]]) tmp = s*(v[:, :self.linear_mapping.r[k+1]].t()) output_batch = torch_utils.reshape_torch( output_batch, [self.linear_mapping.r[k]*self.linear_mapping.n[k], -1], use_batch=True ) output_batch = torch_utils.prodTenMat_torch(output_batch, u, 1, 0) u, s, v = torch.svd(tmp) output_batch = output_batch.squeeze(2).mm(u).mm(u.t()) ##### ??? for k in range(self.d-1, -1, -1): if k == self.d-1: output_batch = output_batch.mm(orth_list[k].t()) else: output_batch = torch_utils.prodTenMat_torch(output_batch, orth_list[k], self.d-k, 1) output_batch = torch_utils.reshape_torch( output_batch, self.n[k:]+[self.linear_mapping.r[k]], use_batch=True ) else: raise ValueError if (tensorize) and (output_batch.dim() == 2): return torch_utils.reshape_torch(output_batch, self.linear_mapping.n) if (not tensorize) and (output_batch.dim() > 2): return torch_utils.flatten_torch(output_batch) return output_batch
def measure_principal_angles(self, input_batch, mode, fast=True): batch_size = input_batch.shape[0] mu_k = [] for k in range(self.K): tmp = self.terms[k].get_bias(tensorize_output=False, use_batch=True) if tmp is None: tmp = input_batch.new_zeros(1, self.output_dim) mu_k.append(tmp) mu_k = torch.cat(mu_k) if self.group_term is not None: mu_g = self.group_term.get_bias(tensorize_output=False, use_batch=True) if mu_g is None: mu_g = input_batch.new_zeros(1, self.output_dim) projected_batch = self.group_term.multi_project(input_batch, remove_bias=False, tensorize=False) projected_mu = mu_k + mu_g projected_mu -= self.group_term.multi_project( torch_utils.reshape_torch(projected_mu, self.n), remove_bias=False, tensorize=False) #''' input_batch = torch_utils.flatten_torch(input_batch) output_angles = input_batch.new_zeros([batch_size, self.K]) if fast: chi = 1.2 for k in range(self.K): if fast: Uk, _, _ = torch_utils.fast_svd_torch( #self.terms[k].linear_mapping.factors[mode] self.terms[k].get_sources(self.source_mode), chi=chi) else: #Uk, _, _ = torch.svd(self.terms[k].linear_mapping.factors[mode]) Uk, _, _ = torch.svd(self.terms[k].get_sources( self.source_mode)) Uk = Uk[:, :self.terms[k].linear_mapping.r[mode]] Uk = Uk.t() ''' if self.group_term is not None: current_batch = input_batch - projected_batch - projected_mu[k:k+1, :] else: current_batch = input_batch - mu_k[k:k+1, :] ''' current_batch = torch_utils.reshape_torch(input_batch, self.n) # current current_batch = torch_utils.swapaxes_torch(current_batch, 1, mode + 1) current_batch = torch_utils.reshape_torch(current_batch, [self.n[mode], -1]) tmp_r = current_batch.shape[-1] for i in range(batch_size): if fast: u, _, _ = torch_utils.fast_svd_torch(current_batch[i], chi=chi) else: u, _, _ = torch.svd(current_batch[i]) _, s, _ = torch.svd(Uk.mm(u[:, :tmp_r])) output_angles[i, k] = s[0] return output_angles