예제 #1
0
파일: vbtd.py 프로젝트: kharyuk/vbtd
    def module_ppca_gm_means_sigma_guide(self, input_batch, epsilon):
        batch_size = input_batch.shape[0]
        if self.likelihood == 'normal':
            if self.group_isotropic:
                ppca_gm_sigma_p = pyro.param(f'ppca_gm_sigma_p',
                                             input_batch.new_ones(1, 1),
                                             constraint=constraints.positive)
                ppca_gm_sigma = pyro.sample(
                    f'ppca_gm_sigma',
                    dist.Delta(ppca_gm_sigma_p).independent(1))
            else:
                ppca_gm_sigma = input_batch.new_ones(1, 1)
                ppca_gm_sigma_list = []
                for i in range(self.d):
                    ppca_gm_sigma_p = pyro.param(
                        f'ppca_gm_sigma_{i}_p',
                        input_batch.new_ones(1, self.n[i]),
                        constraint=constraints.positive)
                    ppca_gm_sigma_list.append(
                        pyro.sample(
                            f'ppca_gm_sigma_{i}',
                            dist.Delta(ppca_gm_sigma_p).independent(1)))
                    ppca_gm_sigma = torch_utils.krp_cw_torch(
                        ppca_gm_sigma_list[i], ppca_gm_sigma, column=False)
        else:
            ppca_gm_sigma = input_batch.new_ones(1, 1)
            ppca_gm_sigma_list = [
                input_batch.new_ones(1, self.n[i]) for i in range(self.d)
            ]

        alpha_gm_p = pyro.param(
            f'alpha_gm_p', input_batch.new_ones([1, self.group_hidden_dim]))
        alpha_gm = pyro.sample(f'alpha_gm',
                               dist.Delta(alpha_gm_p).independent(1))

        if self.group_iterm is None:
            z_mu = self.group_term.linear_mapping.inverse_batch(input_batch)
        else:
            z_mu = self.group_iterm(torch_utils.flatten_torch(input_batch),
                                    T=True)
        if self.group_isotropic:
            zk_mean, zk_cov = self.group_term.get_posterior_gaussian_mean_covariance(
                input_batch,
                noise_sigma=ppca_gm_sigma[0]
                if ppca_gm_sigma is not None else input_batch.new_ones(1),
                z_mu=z_mu,
                z_sigma=alpha_gm[0])
        else:
            zk_mean, zk_cov = self.group_term.get_posterior_gaussian_mean_covariance(
                input_batch,
                noise_sigma=[x for x in ppca_gm_sigma_list],
                z_mu=z_mu,
                z_sigma=alpha_gm[0])
        ppca_gm_means = self.group_term(
            zk_mean + epsilon[:, :self.group_hidden_dim].mm(
                zk_cov.view(self.group_hidden_dim, self.group_hidden_dim)))
        return ppca_gm_means, ppca_gm_sigma
예제 #2
0
파일: tensorial.py 프로젝트: kharyuk/vbtd
 def get_sources(self, mode):
     if isinstance(mode, int):
         mode = [mode]
     else:
         assert (np.diff(mode) == 1).all()
         #if isinstance(self.linear_mapping, TTTensor):
         #    assert (np.diff(mode) == 1).all()
         #elif isinstance(self.linear_mapping, (TuckerTensor, CPTensor, LROTensor)):
         #    assert (np.diff(mode) > 0).all()
     if isinstance(self.linear_mapping, CPTensor):
         result = self.linear_mapping.factors[0].new_ones([1, self.linear_mapping.R])
     elif isinstance(self.linear_mapping, LROTensor):
         result = self.linear_mapping.factors[0].new_ones([1, self.linear_mapping.M])
     elif isinstance(self.linear_mapping, TuckerTensor):
         result = self.linear_mapping.factors[0].new_ones([1, 1])
         # use_core case?
     elif isinstance(self.linear_mapping, TTTensor):
         result = self.linear_mapping.cores[0].new_ones([1, 1, self.linear_mapping.r[0]])
     else:
         raise ValueError
     for i in range(len(mode)):
         m = mode[i]
         assert 0 <= m < self.d
         if isinstance(self.linear_mapping, CPTensor):
             result = torch_utils.krp_cw_torch(self.linear_mapping.factors[m], result)
         elif isinstance(self.linear_mapping, LROTensor):
             tmp = self.linear_mapping.factors[m]
             if m >= self.linear_mapping.P:
                 tmp = torch.repeat_interleave(
                     tmp, torch.tensor(self.linear_mapping.L, device=self.factors[0].device), dim=1
                 )
             result = torch_utils.krp_cw_torch(tmp, result)
         elif isinstance(self.linear_mapping, TTTensor):
             result = torch.einsum('ijk,klm->ijlm', result, self.linear_mapping.cores[m])
             r1, n1, n2, r2 = result.shape
             result = torch_utils.reshape_torch(result, [r1, n1*n2, r2], use_batch=False)
         elif isinstance(self.linear_mapping, TuckerTensor):
             tmp = self.linear_mapping.factors[m]
             result = torch_utils.kron_torch(tmp, result)
         else:
             raise ValueError
     if isinstance(self.linear_mapping, TTTensor):
         result = torch_utils.swapunfold_torch(result, 1, use_batch=False)
     return result
예제 #3
0
파일: vbtd.py 프로젝트: kharyuk/vbtd
    def module_ppca_gm_means_sigma(self, input_batch, epsilon):
        max_hidden_dim = max(self.hidden_dims)
        batch_size = input_batch.shape[0]
        ppca_means = input_batch.new_zeros(
            [self.K, batch_size, self.output_dim])
        if self.likelihood == 'normal':
            #with pyro.plate('ppca_sigma_plate', self.K):
            if self.group_isotropic:
                ppca_gm_sigma = pyro.sample(
                    f'ppca_gm_sigma',
                    dist.LogNormal(0, input_batch.new_ones(1,
                                                           1)).independent(1))
            else:
                ppca_gm_sigma_list = []
                ppca_gm_sigma = input_batch.new_ones(1, 1)
                for i in range(self.d):
                    ppca_gm_sigma_list.append(
                        pyro.sample(
                            f'ppca_gm_sigma_{i}',
                            dist.LogNormal(0, input_batch.new_ones(
                                1, self.n[i])).independent(1)))
                    ppca_gm_sigma = torch_utils.krp_cw_torch(
                        ppca_gm_sigma_list[i], ppca_gm_sigma, column=False)

        #with pyro.plate('alpha_plate', self.K):
        alpha_gm = pyro.sample(
            f'alpha_gm',
            dist.LogNormal(0, input_batch.new_ones([1, self.group_hidden_dim
                                                    ])).independent(1))

        if self.group_iterm is None:
            ppca_gm_means = self.group_term.multi_project(
                input_batch,
                remove_bias=False,
                tensorize=False  #fast_svd=False
            )
        else:
            ppca_gm_means = self.group_term(
                self.group_iterm(
                    torch_utils.flatten_torch(input_batch),
                    T=True  #, fast_svd=False
                ))

        ppca_gm_means += self.group_term(epsilon[:, :self.group_hidden_dim] *
                                         alpha_gm[:, :self.group_hidden_dim])

        if self.likelihood == 'bernoulli':
            return ppca_gm_means.sigmoid()
        if self.likelihood == 'normal':
            return ppca_gm_means, ppca_gm_sigma
        raise ValueError
예제 #4
0
파일: tensorial.py 프로젝트: kharyuk/vbtd
 def recover(self, weights=None):
     nrows = self.N
     if not self.sample_axis:
         nrows = nrows // self.n[0]
     output = self.factors[0].new_ones([1, self.R])
     for k in range(self.d-1, -1, -1):
         if (k == 0) and (not self.sample_axis):
             break
         if k < self.P:
             if weights is None:
                 output = torch_utils.krp_cw_torch(output, self.factors[k])
             else:
                 output = torch_utils.krp_cw_torch(output, self.factors[k]*weights[k].view(self.n[k], 1))
         else:
             if weights is None:
                 output = torch_utils.krp_cw_torch(output, self.factors[k])
             else:
                 output = torch_utils.krp_cw_torch(output, self.factors[k]*weights[k].view(self.n[k], 1))
             if k == self.P:
                 output = torch.repeat_interleave(output, torch.tensor(self.L, device=self.factors[0].device), dim=1)
     if not self.sample_axis:
         output = self.factors[0].mm(output.t())
         output = torch_utils.flatten_torch(output, use_batch=False)
     return output
예제 #5
0
파일: vbtd.py 프로젝트: kharyuk/vbtd
    def module_ppca_means_sigmas_weights_guide(self,
                                               input_batch,
                                               epsilon,
                                               expert=False,
                                               xi_greedy=None,
                                               highlight_peak=None):
        # zk = zk_mean ##+ eps*alpha_k, eps \sim N(0, I)
        # ppca_mean = Wk zk, ppca_sigma = sigma_k^2
        max_hidden_dim = max(self.hidden_dims)
        batch_size = input_batch.shape[0]

        gamma = input_batch.new_zeros(batch_size, self.K)
        if expert:
            output_angles = self.measure_principal_angles(
                input_batch, mode=self.source_mode)
        else:
            pi_p = pyro.param('pi_p',
                              input_batch.new_ones(self.K) / self.K,
                              constraint=constraints.positive)
            output_angles = pyro.sample(
                'pi',
                dist.Dirichlet(pi_p),
            ).log()
        if highlight_peak is not None:
            output_angles = highlight_peak * output_angles
        output_angles = output_angles.log_softmax(dim=-1)

        if xi_greedy is not None:
            phi = dist.LogNormal(input_batch.new_zeros([batch_size, self.K]),
                                 input_batch.new_ones(
                                     [batch_size,
                                      self.K])).to_event(1).sample()
            output_angles = (1 - xi_greedy) * output_angles + xi_greedy * phi

        #output_angles /= output_angles.sum(dim=-1, keepdim=True)
        #output_angles = output_angles.log()
        ppca_means = input_batch.new_zeros(
            [self.K, batch_size, self.output_dim])
        if self.likelihood == 'normal':
            with pyro.plate('ppca_sigma_plate', self.K):
                if self.terms_isotropic:
                    ppca_sigmas_p = pyro.param(
                        f'ppca_sigmas_p',
                        input_batch.new_ones(self.K, 1),
                        #constraint=constraints.interval(1e-6, 10.)
                        constraint=constraints.positive)
                    ppca_sigmas = pyro.sample(
                        f'ppca_sigmas',
                        dist.Delta(ppca_sigmas_p).independent(1)
                        #dist.LogNormal(0, ppca_sigmas_p)#.independent(1)
                    )
                else:
                    ppca_sigmas = input_batch.new_ones(self.K, 1)
                    ppca_sigmas_list = []
                    for i in range(self.d):
                        ppca_sigmas_p = pyro.param(
                            f'ppca_sigmas_{i}_p',
                            input_batch.new_ones(self.K, self.n[i]),
                            #constraint=constraints.interval(1e-6, 10.)
                            constraint=constraints.positive)
                        ppca_sigmas_list.append(
                            pyro.sample(
                                f'ppca_sigmas_{i}',
                                dist.Delta(ppca_sigmas_p).independent(1)
                                #dist.LogNormal(0, ppca_sigmas_p)#.independent(1)
                            ))
                        ppca_sigmas = torch_utils.krp_cw_torch(
                            ppca_sigmas_list[i], ppca_sigmas, column=False)
                #'''
        else:
            ppca_sigmas = None
            ppca_sigmas_list = [
                input_batch.new_ones(self.K, self.n[i]) for i in range(self.d)
            ]
        with pyro.plate('alpha_plate', self.K):
            alpha_p = pyro.param(f'alpha_p',
                                 input_batch.new_ones([self.K,
                                                       max_hidden_dim]),
                                 constraint=constraints.positive
                                 #constraint=constraints.interval(1e-6, 10.)
                                 )
            alpha = pyro.sample(f'alpha',
                                dist.Delta(alpha_p).independent(1)
                                #dist.LogNormal(0, alpha_p).independent(1)
                                )  #'''
            #alpha = input_batch.new_ones([self.K, max_hidden_dim])

        for k in range(self.K):
            if self.iterms is None:
                z_mu = self.terms[k].linear_mapping.inverse_batch(input_batch)
            else:
                z_mu = self.iterms[k](torch_utils.flatten_torch(input_batch),
                                      T=True)
            if self.terms_isotropic:
                zk_mean, zk_cov = self.terms[
                    k].get_posterior_gaussian_mean_covariance(
                        input_batch,
                        noise_sigma=ppca_sigmas[k]
                        if ppca_sigmas is not None else 1,
                        z_mu=z_mu,
                        z_sigma=alpha[k])
            else:
                zk_mean, zk_cov = self.terms[
                    k].get_posterior_gaussian_mean_covariance(
                        input_batch,
                        noise_sigma=[x[k] for x in ppca_sigmas_list],
                        z_mu=z_mu,
                        z_sigma=alpha[k])
            ppca_means[k, :, :] = self.terms[k](
                zk_mean + epsilon[:, :self.hidden_dims[k]].mm(
                    zk_cov.view(self.hidden_dims[k], self.hidden_dims[k])))
            if self.likelihood == 'bernoulli':
                gamma[:, k] = dist.Bernoulli(
                    ppca_means[k].sigmoid(),
                    validate_args=False).to_event(1).log_prob(
                        torch_utils.flatten_torch(input_batch))
            elif self.likelihood == 'normal':
                gamma[:, k] = dist.Normal(
                    loc=ppca_means[k],
                    scale=ppca_sigmas[k]).to_event(1).log_prob(
                        torch_utils.flatten_torch(input_batch))
            else:
                raise ValueError
        gamma = (output_angles + gamma).softmax(dim=-1)
        #gamma_l = 0.999
        #gamma = gamma_l*gamma+(1.-gamma_l)*np.ones([1, self.K])/self.K
        '''
        pps = pyro.get_param_store()
        Nk = gamma.sum(dim=0)
        tmp1 = input_batch.new_zeros([self.K, max_hidden_dim])
        tmp2 = input_batch.new_zeros(self.K)
        for k in range(self.K):
            tmp1[k] = (
                gamma[:, k:k+1]*(
                    zk_mean + epsilon[:, :self.hidden_dims[k]].mm(
                        zk_cov.view(self.hidden_dims[k], self.hidden_dims[k])
                    )**2.
                )
            ).sum(dim=0) / Nk[k]
            tmp2[k] = (
                (gamma[:, k]*torch.norm(torch_utils.flatten_torch(input_batch) - ppca_means[k], dim=1)**2.).sum()
            ) / Nk[k]
        pname = 'alpha_p'
        pps.replace_param(pname, tmp1, pps[pname])
        pname = 'ppca_sigmas_p'
        pps.replace_param(pname, tmp2, pps[pname])
        '''
        return gamma, ppca_means, ppca_sigmas
예제 #6
0
파일: vbtd.py 프로젝트: kharyuk/vbtd
    def module_ppca_means_sigmas_weights(self,
                                         input_batch,
                                         epsilon,
                                         expert=False,
                                         highlight_peak=None):
        # zk = zk_mean ##+ eps*alpha_k, eps \sim N(0, I)
        # ppca_mean = Wk zk, ppca_sigma = sigma_k^2
        max_hidden_dim = max(self.hidden_dims)
        batch_size = input_batch.shape[0]
        if expert:
            pi = self.measure_principal_angles(input_batch,
                                               mode=self.source_mode)
            if highlight_peak is not None:
                pi = highlight_peak * pi
            pi = pi.softmax(dim=-1)
        else:
            pi = pyro.sample(
                'pi', dist.Dirichlet(input_batch.new_ones(self.K) / self.K))
            if highlight_peak is not None:
                pi = highlight_peak * pi.log()  #dim=-1)
                pi = pi.softmax(dim=-1)

        #zk_mean = input_batch.new_zeros([batch_size, self.K, max_hidden_dim])
        ppca_means = input_batch.new_zeros(
            [self.K, batch_size, self.output_dim])
        if self.likelihood == 'normal':
            with pyro.plate('ppca_sigma_plate', self.K):
                '''
                ppca_sigmas_p = pyro.param(
                    f'ppca_sigmas_p',
                    input_batch.new_ones(self.K),
                    #constraint=constraints.interval(1e-3, 2.)
                    constraint=constraints.positive
                )'''
                if self.terms_isotropic:
                    ppca_sigmas = pyro.sample(
                        f'ppca_sigmas',
                        #dist.Delta(ppca_sigmas_p)#.independent(1)
                        #dist.LogNormal(0, ppca_sigmas_p)#.independent(1)
                        dist.LogNormal(0,
                                       input_batch.new_ones(self.K,
                                                            1)).independent(1)
                        #dist.Delta(input_batch.new_ones(self.K, 1)).independent(1)
                    )
                else:
                    ppca_sigmas_list = []
                    ppca_sigmas = input_batch.new_ones(self.K, 1)
                    for i in range(self.d):
                        ppca_sigmas_list.append(
                            pyro.sample(
                                f'ppca_sigmas_{i}',
                                #dist.Delta(ppca_sigmas_p)#.independent(1)
                                #dist.LogNormal(0, ppca_sigmas_p)#.independent(1)
                                dist.LogNormal(
                                    0, input_batch.new_ones(
                                        self.K, self.n[i])).independent(1)
                                #dist.Delta(input_batch.new_ones(self.K, self.n[i])).independent(1)
                            ))
                        ppca_sigmas = torch_utils.krp_cw_torch(
                            ppca_sigmas_list[i], ppca_sigmas, column=False)
        '''
        alpha_p = pyro.param(
            f'alpha_p',
            input_batch.new_ones([self.K, max_hidden_dim]),
            constraint=constraints.positive
        )'''
        with pyro.plate('alpha_plate', self.K):
            alpha = pyro.sample(
                f'alpha',
                #dist.LogNormal(0, alpha_p).independent(1)
                #dist.Delta(alpha_p).independent(1)
                dist.LogNormal(0,
                               input_batch.new_ones([self.K, max_hidden_dim
                                                     ])).independent(1)
                #dist.Delta(input_batch.new_ones([self.K, max_hidden_dim])).independent(1)
            )  #'''
        #alpha = input_batch.new_ones([self.K, max_hidden_dim])
        for k in range(self.K):
            #zk_mean[:, k, :self.hidden_dims[k]] = self.terms[k].linear_mapping.inverse_batch(input_batch, fast=True)
            #zk_mean[:, k, :self.hidden_dims[k]] += epsilon[:, :self.hidden_dims[k]]*alpha[k:k+1, :self.hidden_dims[k]]
            if self.iterms is None:
                #zk_mean = self.terms[k].linear_mapping.inverse_batch(input_batch)
                ppca_means[k, :, :] += self.terms[k].multi_project(
                    input_batch, remove_bias=False, tensorize=False)
            else:
                #zk_mean = self.iterms[k](torch_utils.flatten_torch(input_batch), T=True)
                ppca_means[k, :, :] += self.terms[k](self.iterms[k](
                    torch_utils.flatten_torch(input_batch), T=True))

            ppca_means[k, :, :] += self.terms[k](
                epsilon[:, :self.hidden_dims[k]] *
                alpha[k:k + 1, :self.hidden_dims[k]])
        if self.likelihood == 'bernoulli':
            return pi, ppca_means.sigmoid()
        if self.likelihood == 'normal':
            return pi, ppca_means, ppca_sigmas
        raise ValueError