def module_ppca_gm_means_sigma_guide(self, input_batch, epsilon): batch_size = input_batch.shape[0] if self.likelihood == 'normal': if self.group_isotropic: ppca_gm_sigma_p = pyro.param(f'ppca_gm_sigma_p', input_batch.new_ones(1, 1), constraint=constraints.positive) ppca_gm_sigma = pyro.sample( f'ppca_gm_sigma', dist.Delta(ppca_gm_sigma_p).independent(1)) else: ppca_gm_sigma = input_batch.new_ones(1, 1) ppca_gm_sigma_list = [] for i in range(self.d): ppca_gm_sigma_p = pyro.param( f'ppca_gm_sigma_{i}_p', input_batch.new_ones(1, self.n[i]), constraint=constraints.positive) ppca_gm_sigma_list.append( pyro.sample( f'ppca_gm_sigma_{i}', dist.Delta(ppca_gm_sigma_p).independent(1))) ppca_gm_sigma = torch_utils.krp_cw_torch( ppca_gm_sigma_list[i], ppca_gm_sigma, column=False) else: ppca_gm_sigma = input_batch.new_ones(1, 1) ppca_gm_sigma_list = [ input_batch.new_ones(1, self.n[i]) for i in range(self.d) ] alpha_gm_p = pyro.param( f'alpha_gm_p', input_batch.new_ones([1, self.group_hidden_dim])) alpha_gm = pyro.sample(f'alpha_gm', dist.Delta(alpha_gm_p).independent(1)) if self.group_iterm is None: z_mu = self.group_term.linear_mapping.inverse_batch(input_batch) else: z_mu = self.group_iterm(torch_utils.flatten_torch(input_batch), T=True) if self.group_isotropic: zk_mean, zk_cov = self.group_term.get_posterior_gaussian_mean_covariance( input_batch, noise_sigma=ppca_gm_sigma[0] if ppca_gm_sigma is not None else input_batch.new_ones(1), z_mu=z_mu, z_sigma=alpha_gm[0]) else: zk_mean, zk_cov = self.group_term.get_posterior_gaussian_mean_covariance( input_batch, noise_sigma=[x for x in ppca_gm_sigma_list], z_mu=z_mu, z_sigma=alpha_gm[0]) ppca_gm_means = self.group_term( zk_mean + epsilon[:, :self.group_hidden_dim].mm( zk_cov.view(self.group_hidden_dim, self.group_hidden_dim))) return ppca_gm_means, ppca_gm_sigma
def get_sources(self, mode): if isinstance(mode, int): mode = [mode] else: assert (np.diff(mode) == 1).all() #if isinstance(self.linear_mapping, TTTensor): # assert (np.diff(mode) == 1).all() #elif isinstance(self.linear_mapping, (TuckerTensor, CPTensor, LROTensor)): # assert (np.diff(mode) > 0).all() if isinstance(self.linear_mapping, CPTensor): result = self.linear_mapping.factors[0].new_ones([1, self.linear_mapping.R]) elif isinstance(self.linear_mapping, LROTensor): result = self.linear_mapping.factors[0].new_ones([1, self.linear_mapping.M]) elif isinstance(self.linear_mapping, TuckerTensor): result = self.linear_mapping.factors[0].new_ones([1, 1]) # use_core case? elif isinstance(self.linear_mapping, TTTensor): result = self.linear_mapping.cores[0].new_ones([1, 1, self.linear_mapping.r[0]]) else: raise ValueError for i in range(len(mode)): m = mode[i] assert 0 <= m < self.d if isinstance(self.linear_mapping, CPTensor): result = torch_utils.krp_cw_torch(self.linear_mapping.factors[m], result) elif isinstance(self.linear_mapping, LROTensor): tmp = self.linear_mapping.factors[m] if m >= self.linear_mapping.P: tmp = torch.repeat_interleave( tmp, torch.tensor(self.linear_mapping.L, device=self.factors[0].device), dim=1 ) result = torch_utils.krp_cw_torch(tmp, result) elif isinstance(self.linear_mapping, TTTensor): result = torch.einsum('ijk,klm->ijlm', result, self.linear_mapping.cores[m]) r1, n1, n2, r2 = result.shape result = torch_utils.reshape_torch(result, [r1, n1*n2, r2], use_batch=False) elif isinstance(self.linear_mapping, TuckerTensor): tmp = self.linear_mapping.factors[m] result = torch_utils.kron_torch(tmp, result) else: raise ValueError if isinstance(self.linear_mapping, TTTensor): result = torch_utils.swapunfold_torch(result, 1, use_batch=False) return result
def module_ppca_gm_means_sigma(self, input_batch, epsilon): max_hidden_dim = max(self.hidden_dims) batch_size = input_batch.shape[0] ppca_means = input_batch.new_zeros( [self.K, batch_size, self.output_dim]) if self.likelihood == 'normal': #with pyro.plate('ppca_sigma_plate', self.K): if self.group_isotropic: ppca_gm_sigma = pyro.sample( f'ppca_gm_sigma', dist.LogNormal(0, input_batch.new_ones(1, 1)).independent(1)) else: ppca_gm_sigma_list = [] ppca_gm_sigma = input_batch.new_ones(1, 1) for i in range(self.d): ppca_gm_sigma_list.append( pyro.sample( f'ppca_gm_sigma_{i}', dist.LogNormal(0, input_batch.new_ones( 1, self.n[i])).independent(1))) ppca_gm_sigma = torch_utils.krp_cw_torch( ppca_gm_sigma_list[i], ppca_gm_sigma, column=False) #with pyro.plate('alpha_plate', self.K): alpha_gm = pyro.sample( f'alpha_gm', dist.LogNormal(0, input_batch.new_ones([1, self.group_hidden_dim ])).independent(1)) if self.group_iterm is None: ppca_gm_means = self.group_term.multi_project( input_batch, remove_bias=False, tensorize=False #fast_svd=False ) else: ppca_gm_means = self.group_term( self.group_iterm( torch_utils.flatten_torch(input_batch), T=True #, fast_svd=False )) ppca_gm_means += self.group_term(epsilon[:, :self.group_hidden_dim] * alpha_gm[:, :self.group_hidden_dim]) if self.likelihood == 'bernoulli': return ppca_gm_means.sigmoid() if self.likelihood == 'normal': return ppca_gm_means, ppca_gm_sigma raise ValueError
def recover(self, weights=None): nrows = self.N if not self.sample_axis: nrows = nrows // self.n[0] output = self.factors[0].new_ones([1, self.R]) for k in range(self.d-1, -1, -1): if (k == 0) and (not self.sample_axis): break if k < self.P: if weights is None: output = torch_utils.krp_cw_torch(output, self.factors[k]) else: output = torch_utils.krp_cw_torch(output, self.factors[k]*weights[k].view(self.n[k], 1)) else: if weights is None: output = torch_utils.krp_cw_torch(output, self.factors[k]) else: output = torch_utils.krp_cw_torch(output, self.factors[k]*weights[k].view(self.n[k], 1)) if k == self.P: output = torch.repeat_interleave(output, torch.tensor(self.L, device=self.factors[0].device), dim=1) if not self.sample_axis: output = self.factors[0].mm(output.t()) output = torch_utils.flatten_torch(output, use_batch=False) return output
def module_ppca_means_sigmas_weights_guide(self, input_batch, epsilon, expert=False, xi_greedy=None, highlight_peak=None): # zk = zk_mean ##+ eps*alpha_k, eps \sim N(0, I) # ppca_mean = Wk zk, ppca_sigma = sigma_k^2 max_hidden_dim = max(self.hidden_dims) batch_size = input_batch.shape[0] gamma = input_batch.new_zeros(batch_size, self.K) if expert: output_angles = self.measure_principal_angles( input_batch, mode=self.source_mode) else: pi_p = pyro.param('pi_p', input_batch.new_ones(self.K) / self.K, constraint=constraints.positive) output_angles = pyro.sample( 'pi', dist.Dirichlet(pi_p), ).log() if highlight_peak is not None: output_angles = highlight_peak * output_angles output_angles = output_angles.log_softmax(dim=-1) if xi_greedy is not None: phi = dist.LogNormal(input_batch.new_zeros([batch_size, self.K]), input_batch.new_ones( [batch_size, self.K])).to_event(1).sample() output_angles = (1 - xi_greedy) * output_angles + xi_greedy * phi #output_angles /= output_angles.sum(dim=-1, keepdim=True) #output_angles = output_angles.log() ppca_means = input_batch.new_zeros( [self.K, batch_size, self.output_dim]) if self.likelihood == 'normal': with pyro.plate('ppca_sigma_plate', self.K): if self.terms_isotropic: ppca_sigmas_p = pyro.param( f'ppca_sigmas_p', input_batch.new_ones(self.K, 1), #constraint=constraints.interval(1e-6, 10.) constraint=constraints.positive) ppca_sigmas = pyro.sample( f'ppca_sigmas', dist.Delta(ppca_sigmas_p).independent(1) #dist.LogNormal(0, ppca_sigmas_p)#.independent(1) ) else: ppca_sigmas = input_batch.new_ones(self.K, 1) ppca_sigmas_list = [] for i in range(self.d): ppca_sigmas_p = pyro.param( f'ppca_sigmas_{i}_p', input_batch.new_ones(self.K, self.n[i]), #constraint=constraints.interval(1e-6, 10.) constraint=constraints.positive) ppca_sigmas_list.append( pyro.sample( f'ppca_sigmas_{i}', dist.Delta(ppca_sigmas_p).independent(1) #dist.LogNormal(0, ppca_sigmas_p)#.independent(1) )) ppca_sigmas = torch_utils.krp_cw_torch( ppca_sigmas_list[i], ppca_sigmas, column=False) #''' else: ppca_sigmas = None ppca_sigmas_list = [ input_batch.new_ones(self.K, self.n[i]) for i in range(self.d) ] with pyro.plate('alpha_plate', self.K): alpha_p = pyro.param(f'alpha_p', input_batch.new_ones([self.K, max_hidden_dim]), constraint=constraints.positive #constraint=constraints.interval(1e-6, 10.) ) alpha = pyro.sample(f'alpha', dist.Delta(alpha_p).independent(1) #dist.LogNormal(0, alpha_p).independent(1) ) #''' #alpha = input_batch.new_ones([self.K, max_hidden_dim]) for k in range(self.K): if self.iterms is None: z_mu = self.terms[k].linear_mapping.inverse_batch(input_batch) else: z_mu = self.iterms[k](torch_utils.flatten_torch(input_batch), T=True) if self.terms_isotropic: zk_mean, zk_cov = self.terms[ k].get_posterior_gaussian_mean_covariance( input_batch, noise_sigma=ppca_sigmas[k] if ppca_sigmas is not None else 1, z_mu=z_mu, z_sigma=alpha[k]) else: zk_mean, zk_cov = self.terms[ k].get_posterior_gaussian_mean_covariance( input_batch, noise_sigma=[x[k] for x in ppca_sigmas_list], z_mu=z_mu, z_sigma=alpha[k]) ppca_means[k, :, :] = self.terms[k]( zk_mean + epsilon[:, :self.hidden_dims[k]].mm( zk_cov.view(self.hidden_dims[k], self.hidden_dims[k]))) if self.likelihood == 'bernoulli': gamma[:, k] = dist.Bernoulli( ppca_means[k].sigmoid(), validate_args=False).to_event(1).log_prob( torch_utils.flatten_torch(input_batch)) elif self.likelihood == 'normal': gamma[:, k] = dist.Normal( loc=ppca_means[k], scale=ppca_sigmas[k]).to_event(1).log_prob( torch_utils.flatten_torch(input_batch)) else: raise ValueError gamma = (output_angles + gamma).softmax(dim=-1) #gamma_l = 0.999 #gamma = gamma_l*gamma+(1.-gamma_l)*np.ones([1, self.K])/self.K ''' pps = pyro.get_param_store() Nk = gamma.sum(dim=0) tmp1 = input_batch.new_zeros([self.K, max_hidden_dim]) tmp2 = input_batch.new_zeros(self.K) for k in range(self.K): tmp1[k] = ( gamma[:, k:k+1]*( zk_mean + epsilon[:, :self.hidden_dims[k]].mm( zk_cov.view(self.hidden_dims[k], self.hidden_dims[k]) )**2. ) ).sum(dim=0) / Nk[k] tmp2[k] = ( (gamma[:, k]*torch.norm(torch_utils.flatten_torch(input_batch) - ppca_means[k], dim=1)**2.).sum() ) / Nk[k] pname = 'alpha_p' pps.replace_param(pname, tmp1, pps[pname]) pname = 'ppca_sigmas_p' pps.replace_param(pname, tmp2, pps[pname]) ''' return gamma, ppca_means, ppca_sigmas
def module_ppca_means_sigmas_weights(self, input_batch, epsilon, expert=False, highlight_peak=None): # zk = zk_mean ##+ eps*alpha_k, eps \sim N(0, I) # ppca_mean = Wk zk, ppca_sigma = sigma_k^2 max_hidden_dim = max(self.hidden_dims) batch_size = input_batch.shape[0] if expert: pi = self.measure_principal_angles(input_batch, mode=self.source_mode) if highlight_peak is not None: pi = highlight_peak * pi pi = pi.softmax(dim=-1) else: pi = pyro.sample( 'pi', dist.Dirichlet(input_batch.new_ones(self.K) / self.K)) if highlight_peak is not None: pi = highlight_peak * pi.log() #dim=-1) pi = pi.softmax(dim=-1) #zk_mean = input_batch.new_zeros([batch_size, self.K, max_hidden_dim]) ppca_means = input_batch.new_zeros( [self.K, batch_size, self.output_dim]) if self.likelihood == 'normal': with pyro.plate('ppca_sigma_plate', self.K): ''' ppca_sigmas_p = pyro.param( f'ppca_sigmas_p', input_batch.new_ones(self.K), #constraint=constraints.interval(1e-3, 2.) constraint=constraints.positive )''' if self.terms_isotropic: ppca_sigmas = pyro.sample( f'ppca_sigmas', #dist.Delta(ppca_sigmas_p)#.independent(1) #dist.LogNormal(0, ppca_sigmas_p)#.independent(1) dist.LogNormal(0, input_batch.new_ones(self.K, 1)).independent(1) #dist.Delta(input_batch.new_ones(self.K, 1)).independent(1) ) else: ppca_sigmas_list = [] ppca_sigmas = input_batch.new_ones(self.K, 1) for i in range(self.d): ppca_sigmas_list.append( pyro.sample( f'ppca_sigmas_{i}', #dist.Delta(ppca_sigmas_p)#.independent(1) #dist.LogNormal(0, ppca_sigmas_p)#.independent(1) dist.LogNormal( 0, input_batch.new_ones( self.K, self.n[i])).independent(1) #dist.Delta(input_batch.new_ones(self.K, self.n[i])).independent(1) )) ppca_sigmas = torch_utils.krp_cw_torch( ppca_sigmas_list[i], ppca_sigmas, column=False) ''' alpha_p = pyro.param( f'alpha_p', input_batch.new_ones([self.K, max_hidden_dim]), constraint=constraints.positive )''' with pyro.plate('alpha_plate', self.K): alpha = pyro.sample( f'alpha', #dist.LogNormal(0, alpha_p).independent(1) #dist.Delta(alpha_p).independent(1) dist.LogNormal(0, input_batch.new_ones([self.K, max_hidden_dim ])).independent(1) #dist.Delta(input_batch.new_ones([self.K, max_hidden_dim])).independent(1) ) #''' #alpha = input_batch.new_ones([self.K, max_hidden_dim]) for k in range(self.K): #zk_mean[:, k, :self.hidden_dims[k]] = self.terms[k].linear_mapping.inverse_batch(input_batch, fast=True) #zk_mean[:, k, :self.hidden_dims[k]] += epsilon[:, :self.hidden_dims[k]]*alpha[k:k+1, :self.hidden_dims[k]] if self.iterms is None: #zk_mean = self.terms[k].linear_mapping.inverse_batch(input_batch) ppca_means[k, :, :] += self.terms[k].multi_project( input_batch, remove_bias=False, tensorize=False) else: #zk_mean = self.iterms[k](torch_utils.flatten_torch(input_batch), T=True) ppca_means[k, :, :] += self.terms[k](self.iterms[k]( torch_utils.flatten_torch(input_batch), T=True)) ppca_means[k, :, :] += self.terms[k]( epsilon[:, :self.hidden_dims[k]] * alpha[k:k + 1, :self.hidden_dims[k]]) if self.likelihood == 'bernoulli': return pi, ppca_means.sigmoid() if self.likelihood == 'normal': return pi, ppca_means, ppca_sigmas raise ValueError