Ejemplo n.º 1
0
 def run(self, x):
     x=Variable(x)
     #the action space is continuous
     u=self(x)
     sigma2=torch.exp(self.logstd_raw)*self.outputid
     d=MultivariateNormal(u, sigma2)
     action=d.sample()
     log_prob=d.log_prob(action)
     return action, log_prob
Ejemplo n.º 2
0
 def run(self, x):
     x = Variable(x)
     #the action space is continuous
     u = self(x)
     sigma2 = torch.exp(self.logstd_raw) * self.outputid
     d = MultivariateNormal(u, sigma2)
     action = d.sample()
     log_prob = d.log_prob(action)
     return action, log_prob
Ejemplo n.º 3
0
 def run(self, x):
     x=Variable(x)
     u, logstd=self(x)
     sigma2=torch.exp(2*logstd)*self.outputid
     d=MultivariateNormal(u, sigma2) #might want to use N Gaussian instead
     action=d.sample()
     log_prob=d.log_prob(action)
     self.history_of_log_probs.append(log_prob)
     return action, log_prob
Ejemplo n.º 4
0
 def reparameterize(self, logvar):
     std = torch.exp(0.5*logvar)
     eps = torch.randn_like(std)
     mu = torch.zeros(len(std[0])).cuda()
     covar_mat = torch.eye(len(std[0])).cuda()
     covar_mat = std*covar_mat
     m = MultivariateNormal(mu,covar_mat)
     log_prob_a = m.log_prob(eps.mul(std))
     return eps.mul(std), log_prob_a
Ejemplo n.º 5
0
 def run(self, x):
     x=Variable(Tensor(x))
     #the action space is continuous
     u=self(x)
     sigma2=torch.exp(self.logstd_raw)*self.outputid
     d=MultivariateNormal(u, sigma2)
     action=d.sample()
     self.history_of_log_probs.append(d.log_prob(action))
     return action
Ejemplo n.º 6
0
    def fit(self, X):
        N, D = X.shape

        # step1: initialization
        indices = torch.randperm(N)[:self.n_components]

        mus = X[indices]
        sigmas = get_spd(self.n_components, D)
        pis = torch.ones(1, self.n_components) * 1 / self.n_components

        normals = [
            MultivariateNormal(mus[i], sigmas[i])
            for i in range(self.n_components)
        ]
        p = torch.stack([normal.log_prob(X).exp() for normal in normals], 1)

        temp = pis * p

        prev_ll = temp.sum(1).log().sum()

        while True:
            print(prev_ll.item())
            # E step:
            gamma = temp / temp.sum(1, keepdim=True)

            # M step:
            Nk = gamma.sum(0, keepdim=True).t()

            mus = (gamma.t() @ X) / Nk

            diff = (X[:, None, :] - mus[None, :, :])[..., None]
            sigmas = (gamma[..., None, None] *
                      diff @ diff.transpose(-2, -1)).sum(0) / Nk[..., None]

            pis = (Nk / N).t()

            # evaluate the log likelihood
            normals = [
                MultivariateNormal(mus[i], sigmas[i])
                for i in range(self.n_components)
            ]
            p = torch.stack([normal.log_prob(X).exp() for normal in normals],
                            1)

            temp = pis * p

            ll = temp.sum(1).log().sum()

            if abs(ll - prev_ll) < 1e-4:
                break
            else:
                prev_ll = ll

        self.mus = mus
        self.sigmas = sigmas
        self.pis = pis
Ejemplo n.º 7
0
 def log_prob_forward_model(self, state_tensor, action_tensor,
                            next_state_tensor):
     # calculate qdd from current state and next state
     means = self.predict_forward_model_deterministic(
         state_tensor, action_tensor)
     multivariate_normal = MultivariateNormal(means,
                                              covariance_matrix=torch.diag(
                                                  self.std.pow(2)))
     log_prob = (multivariate_normal.log_prob(next_state_tensor))
     return log_prob
Ejemplo n.º 8
0
 def get_log_prob_action(self, goals, observations, actions):
     latent_and_observation = torch.cat([goals, observations], dim=2)
     action_means, action_stds = torch.split(
         self.action_decoder(latent_and_observation),
         self.action_dim,
         dim=2)
     m = MultivariateNormal(action_means, (10 * action_stds**2 + 0.001) *
                            torch.eye(self.action_dim))
     log_prob_actions = m.log_prob(actions)
     return log_prob_actions
Ejemplo n.º 9
0
    def kl_divergence_prior_post(self, prior, post):
        mu1, softplus1 = prior["lambdas1"], prior["lambdas2"]
        mu2, softplus2 = post["lambdas1"], post["lambdas2"]

        stds1 = self._softplus_to_std(softplus1)
        stds2 = self._softplus_to_std(softplus2)
        q1 = MultivariateNormal(loc=mu1, scale_tril=torch.diag_embed(stds1))
        q2 = MultivariateNormal(loc=mu2, scale_tril=torch.diag_embed(stds2))
        return torch.distributions.kl.kl_divergence(
            q2, q1)  # KL(post||prior), note ordering matters
Ejemplo n.º 10
0
 def compute_log_prob_goals(self, observations, goals):
     pen_vars_slice = self.pen_vars_slice
     goal_means, goal_stds = torch.split(self.goal_decoder(observations),
                                         self.goal_dim,
                                         dim=2)
     m = MultivariateNormal(
         goal_means, (goal_stds**2 + 0.01) *
         torch.eye(self.goal_dim))  # squaring stds so as to be positive
     log_prob_goals = m.log_prob(goals)
     return log_prob_goals
Ejemplo n.º 11
0
 def init_VL_sampler(self):
     from torch.distributions.multivariate_normal import MultivariateNormal as MVN
     view_mvn_path = self.cfgs.get('view_mvn_path', 'checkpoints/view_light/view_mvn.pth')
     light_mvn_path = self.cfgs.get('light_mvn_path', 'checkpoints/view_light/light_mvn.pth')
     view_mvn = torch.load(view_mvn_path)
     light_mvn = torch.load(light_mvn_path)
     self.view_mean = view_mvn['mean'].cuda()
     self.light_mean = light_mvn['mean'].cuda()
     self.view_mvn = MVN(view_mvn['mean'].cuda(), view_mvn['cov'].cuda())
     self.light_mvn = MVN(light_mvn['mean'].cuda(), light_mvn['cov'].cuda())
Ejemplo n.º 12
0
 def run(self, x):
     x = Variable(x)
     u, logstd = self(x)
     sigma2 = torch.exp(2 * logstd) * self.outputid
     d = MultivariateNormal(u,
                            sigma2)  #might want to use N Gaussian instead
     action = d.sample()
     log_prob = d.log_prob(action)
     self.history_of_log_probs.append(log_prob)
     return action, log_prob
Ejemplo n.º 13
0
def gauss_sample(n_sample, dim):
	z = MultivariateNormal(torch.zeros(dim), torch.eye(dim))
	sampled_z = z.sample((n_sample,))

	plt.figure(figsize = (5,5))
	plt.xlim([-4, 4])
	plt.ylim([-4, 4])
	plt.scatter(sampled_z[:,0], sampled_z[:,1], s=15)
	plt.savefig('../outputs/gauss_repara.png')
	return z
Ejemplo n.º 14
0
    def forward(self, x):

        b,c,w,h = x.size()
        #Step1: embedding for each local point.
        st = time.perf_counter()
        for i in range(1000):
            x_embedded = self.embedding(x)
        # print("x_embedded = self.embedding(x): {}".format(time.perf_counter() - st))
        time1 = time.perf_counter() - st

        # Step2: Distribution
        # TODO: Learn a local point for each channel.
        # st = time.perf_counter()
        st = time.perf_counter()
        for i in range(1000):
            multiNorm = MultivariateNormal(loc=self.normal_loc,scale_tril=(self.normal_scal).diag_embed())
        # print("multiNorm = MultivariateNormal: {}".format(time.perf_counter() - st))
        time2 = time.perf_counter() - st



        st = time.perf_counter()
        for i in range(1000):
            localtion_map = self.get_location_mask(x,b,w,h,self.local_num)
        # print("localtion_map = self.get_location_mask: {}".format(time.perf_counter() - st))
        time3 = time.perf_counter() - st


        st = time.perf_counter()
        for i in range(1000):
            pdf = multiNorm.log_prob(localtion_map*self.position_scal).exp()
        # print("pdf = multiNorm.log_prob: {}".format(time.perf_counter() - st))
        time4 = time.perf_counter() - st


        #Step3: Value embedding
        st = time.perf_counter()
        for i in range(1000):
            x_value = x.expand(self.local_num,b,c,w,h).reshape(self.local_num*b,c,w,h)
            x_value = self.value_embed(x_value).reshape(self.local_num,b,c,w,h).permute(1,2,3,4,0)
        # print("Value embedding: {}".format(time.perf_counter() - st))
        time5 = time.perf_counter() - st


        #Step4: embeded_Value X possibility_density
        st = time.perf_counter()
        for i in range(1000):
            increment = (x_value*pdf.unsqueeze(dim=1)).mean(dim=-1)
        # print("increment: {}".format(time.perf_counter() - st))
        time6 = time.perf_counter() - st
        timelist = torch.Tensor([time1,time2,time3,time4,time5,time6])
        print(timelist/min(timelist))

        print("================NEXT channel: {}=============================".format(self.channel))
        return x+increment
Ejemplo n.º 15
0
def main(recon_model, dyn_model, T, K, N, H, img_initial, img_goal, resz_act, step_i, KL):
    for t in range(T):
        print("***** Start Step {}".format(t))
        if t==0:
            img_cur = img_initial
        #Initialize Q with uniform distribution 
        mean = None
        cov = None
        mean_tmp = None
        cov_tmp = None
        converge = False   
        iter_count = 0     
        while not converge:
            imgs_recon, sample_actions = generate_next_pred_state_in_n_step(recon_model, dyn_model, img_cur, N, H, mean, cov)
            #Calculate binary cross entropy loss for predicted image and goal image 
            loss = loss_function_img(imgs_recon, img_goal, N)
            #Select K action sequences with lowest loss 
            loss_index = torch.argsort(loss)
            sorted_sample_actions = sample_actions[loss_index]
            #Fit multivariate gaussian distribution to K samples 
            #(see how to fit algorithm: 
            #https://stackoverflow.com/questions/27230824/fit-multivariate-gaussian-distribution-to-a-given-dataset) 
            mean = torch.mean(sorted_sample_actions[:K], dim=0).type(torch.DoubleTensor)
            cov = torch.from_numpy(np.cov(sorted_sample_actions[:K], rowvar=0)).type(torch.DoubleTensor)
            # iteration is based on convergence of Q
            if det(cov) == 0 or cov_tmp == None:
                mean_tmp = mean
                cov_tmp = cov
                continue
            else:
                if det(cov_tmp)==0:
                    mean_tmp = mean
                    cov_tmp = cov 
                    continue   
                else:            
                    p = MultivariateNormal(mean, cov)
                    q = MultivariateNormal(mean_tmp, cov_tmp)
                if kl_divergence(p, q) < KL: 
                    converge = True
                mean_tmp = mean
                cov_tmp = cov    
            
            print("***** At action time step {}, iteration {} *****".format(t, iter_count))
            iter_count += 1    

        #Execute action a{t}* with lowest loss 
        action_best = sorted_sample_actions[0] 
        action_loss = ((action_best.detach().cpu().numpy()-resz_act[:4])**2).mean(axis=None)
        #Observe new image I{t+1} 
        img_cur = generate_next_pred_state(recon_model, dyn_model, img_cur, action_best)
        img_loss = F.binary_cross_entropy(img_cur.view(-1, 2500), img_goal.view(-1, 2500), reduction='mean')
        print("***** Generate Next Predicted Image {}*****".format(t+1))

    print("***** End Planning *****")
    return action_loss, img_loss.detach().cpu().numpy()
Ejemplo n.º 16
0
def simulator(theta):
    N_samples = theta.shape[0]

    x = torch.zeros(N_samples, conj_model.N, dim)

    for i in range(N_samples):
        model_tmp = MultivariateNormal(theta[i], conj_model.model.covariance_matrix)
        x[i, :, :] = model_tmp.rsample(sample_shape=(conj_model.N,))

    # return calc_summary_stats(x), theta #/math.sqrt(5) # div with std of prior to nomarlize data
    return func.flatten(x)
Ejemplo n.º 17
0
 def sample_weights(self, store=False):
     try:
         m = MultivariateNormal(self.mu, self.sig2)
         w = m.sample()
     except:
         print('Using np.random.multivariate_normal')
         w = torch.from_numpy(
             np.random.multivariate_normal(
                 self.mu.reshape(-1).numpy(), self.sig2.numpy())).float()
     if store: self.w = w
     return w
Ejemplo n.º 18
0
 def select_action(self, state, deterministic, reparameterize=False):
     mu, std = self.forward(state)
     dist = MultivariateNormal(loc=mu, scale_tril=torch.diag_embed(std))
     if deterministic:
         action = mu  # (bsize, action_dim)
     else:
         if reparameterize:
             action = dist.rsample()  # (bsize, action_dim)
         else:
             action = dist.sample()  # (bsize, action_dim)
     return action, dist
Ejemplo n.º 19
0
def select_state(mu, deterministic=False):
    """
    Select Δs_t from Multivariate normal with mean mu
    """
    if deterministic:
        return mu
    else:
        shape = mu.shape
        mu = mu.view(-1)
        gauss = MultivariateNormal(mu.view(-1), torch.eye(mu.shape[0]))
        return gauss.sample().view(shape)
Ejemplo n.º 20
0
    def conditional_sample(self,
                           cond_val,
                           sample_shape=torch.Size([]),
                           cond_idx=None,
                           sample_idx=None):
        """
        Draw samples conditioning on cond_val.

        Args:
            cond_val (torch.Tensor): conditional values. Should be a 1D tensor.
            sample_shape (torch.Size): same as in 
                `Distribution.sample(sample_shape=torch.Size([]))`.
            cond_idx (torch.LongTensor): indices that correspond to cond_val.
                If None, use the last m dimensions, where m is the length of cond_val.
            sample_idx (torch.LongTensor): indices to sample from. If None, sample 
                from all remaining dimensions.

        Returns:
            Generates a sample_shape shaped sample or sample_shape shaped batch of 
                samples if the distribution parameters are batched.
        """
        m, n = *cond_val.shape, *self.event_shape

        if cond_idx is None:
            cond_idx = torch.arange(n - m, n)
        if sample_idx is None:
            sample_idx = torch.tensor(
                [i for i in range(n) if i not in set(cond_idx.tolist())])

        assert (len(cond_idx) == m and len(sample_idx) + len(cond_idx) <= n
                and not set(cond_idx.tolist()) & set(sample_idx.tolist()))

        cov_00 = self.covariance_matrix.index_select(
            dim=0, index=sample_idx).index_select(dim=1, index=sample_idx)
        cov_01 = self.covariance_matrix.index_select(
            dim=0, index=sample_idx).index_select(dim=1, index=cond_idx)
        cov_10 = self.covariance_matrix.index_select(
            dim=0, index=cond_idx).index_select(dim=1, index=sample_idx)
        cov_11 = self.covariance_matrix.index_select(
            dim=0, index=cond_idx).index_select(dim=1, index=cond_idx)

        cond_val_nscale = _standard_normal_quantile(
            cond_val)  # Phi^{-1}(u_cond)
        reg_coeff, _ = torch.solve(cov_10,
                                   cov_11)  # Sigma_{11}^{-1} Sigma_{10}
        cond_mu = torch.mv(reg_coeff.t(), cond_val_nscale)
        cond_sigma = cov_00 - torch.mm(cov_01, reg_coeff)
        cond_normal = MultivariateNormal(loc=cond_mu,
                                         covariance_matrix=cond_sigma)

        samples_nscale = cond_normal.sample(sample_shape)
        samples_uscale = _standard_normal_cdf(samples_nscale)

        return samples_uscale
Ejemplo n.º 21
0
 def compute_noise(self, observations):
     noise_stds = self.noise_decoder(observations)
     #noise_means, noise_stds = torch.split(self.noise_decoder(observations), 1, dim=2)
     m = MultivariateNormal(
         torch.zeros_like(noise_stds), (noise_stds**2 + 0.01) *
         torch.eye(self.action_dim))  # squaring stds so as to be positive
     noise = m.sample()
     #noise = torch.clamp(noise, -1, 1)
     log_prob_noise = m.log_prob(noise)
     #noise = torch.tanh(noise)
     return noise, log_prob_noise
Ejemplo n.º 22
0
def predict_KFAC_sampling(model,
                          test_loader,
                          M_W_post,
                          M_b_post,
                          U_post,
                          V_post,
                          B_post,
                          n_samples,
                          timing=False,
                          verbose=False,
                          cuda=False):
    py = []
    max_len = len(test_loader)
    if timing:
        time_sum = 0

    for batch_idx, (x, y) in enumerate(test_loader):

        if cuda:
            x, y = x.cuda(), y.cuda()

        phi = model.features(x).detach()

        mu_pred = phi @ M_W_post + M_b_post
        Cov_pred = torch.diag(phi @ V_post @ phi.t()).reshape(
            -1, 1, 1) * U_post.unsqueeze(0) + B_post.unsqueeze(0)

        post_pred = MultivariateNormal(mu_pred, Cov_pred)

        # MC-integral
        t0 = time.time()
        py_ = 0

        for _ in range(n_samples):
            f_s = post_pred.rsample()
            py_ += torch.softmax(f_s, 1)

        py_ /= n_samples
        py_ = py_.detach()

        py.append(py_)
        t1 = time.time()
        if timing:
            time_sum += (t1 - t0)

        if verbose:
            print("Batch: {}/{}".format(batch_idx, max_len))

    if timing:
        print("time used for sampling with {} samples: {}".format(
            n_samples, time_sum))

    return torch.cat(py, dim=0)
Ejemplo n.º 23
0
def predict_diagonal_sampling(model,
                              test_loader,
                              M_W_post,
                              M_b_post,
                              C_W_post,
                              C_b_post,
                              n_samples,
                              verbose=False,
                              cuda=False,
                              timing=False):
    py = []
    max_len = len(test_loader)
    if timing:
        time_sum = 0

    for batch_idx, (x, y) in enumerate(test_loader):

        if cuda:
            x, y = x.cuda(), y.cuda()

        phi = model.features(x)

        mu, Sigma = get_Gaussian_output(phi, M_W_post, M_b_post, C_W_post,
                                        C_b_post)
        #print("mu size: ", mu.size())
        #print("sigma size: ", Sigma.size())

        post_pred = MultivariateNormal(mu, Sigma)

        # MC-integral
        t0 = time.time()
        py_ = 0

        for _ in range(n_samples):
            f_s = post_pred.rsample()
            py_ += torch.softmax(f_s, 1)

        py_ /= n_samples
        py_ = py_.detach()

        py.append(py_)
        t1 = time.time()
        if timing:
            time_sum += (t1 - t0)

        if verbose:
            print("Batch: {}/{}".format(batch_idx, max_len))

    if timing:
        print("time used for sampling with {} samples: {}".format(
            n_samples, time_sum))

    return torch.cat(py, dim=0)
Ejemplo n.º 24
0
    def spectral_init(self, model):
        U = {}
        for l, m in enumerate(model.modules()):
            if isinstance(m, nn.Linear) or isinstance(
                    m, nn.Conv2d) or isinstance(m, nn.Conv1d):
                N = m.weight.shape[0]
                U_l = MultivariateNormal(torch.zeros(N), torch.eye(N)).sample()
                U_l = U_l.cuda() if self.args['device'] == torch.device(
                    'cuda') else U_l
                U[l] = U_l

        return U
Ejemplo n.º 25
0
def select_mj(mu, sigma, deterministic=False):
    """
    Select Δs_t or action from Multivariate normal with mean mu and cov_matrix
    """
    if deterministic:
        return mu
    else:
        shape = mu.shape
        mu = mu.view(-1)
        sigma = sigma.view(-1)
        gauss = MultivariateNormal(mu, torch.diag(sigma))
        return gauss.sample().view(shape)
Ejemplo n.º 26
0
    def _generate(self, input):
        mean = torch.tensor([input[0], input[1]])
        scale = 1.0
        s_1 = input[2]**2
        s_2 = input[3]**2
        rho = input[4].tanh()
        covariance = torch.tensor([[scale * s_1**2, scale * rho * s_1 * s_2],
                                   [scale * rho * s_1 * s_2, scale * s_2**2]])
        normal = Normal(mean, covariance)
        x_out = normal.sample(torch.Size([4])).view(1, -1)

        return x_out
Ejemplo n.º 27
0
def true_5gaussians_probs(on_mani, twodim):
    scale = 3
    centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (0, 0)]
    centers = torch.tensor([(scale * x, scale * y) for x, y in centers])

    prob = 0
    for c in centers:
        loc = torch.ones_like(twodim) * torch.tensor(c)
        distr = MultivariateNormal(loc, .25 * torch.eye(2))
        prob += torch.exp(distr.log_prob(twodim))
    prob /= len(centers)
    return prob
Ejemplo n.º 28
0
    def sample(self, means, samples=1):
        x = []

        with torch.no_grad():
            means = means.view(-1, self.dimensionality)
            mean_samples = torch.Size([samples])
            for mean in means:
                normal = MultivariateNormalDistribution(mean, self.sigma)
                x.append(normal.sample(mean_samples).view(-1, samples, self.dimensionality))
            x = torch.cat(x, dim=0).squeeze()

        return x
Ejemplo n.º 29
0
def sample(pi, sigma, mu):
    """Draw samples from a MoG.
    # Original implementation
    categorical = Categorical(pi)
    pis = list(categorical.sample().data)
    sample = Variable(sigma.data.new(sigma.size(0), sigma.size(2)).normal_())
    for i, idx in enumerate(pis):
        sample[i] = sample[i].mul(sigma[i,idx]).add(mu[i,idx])
    return sample
    """
    ######################
    # new implementation #
    ######################
    categorical = Categorical(pi)
    pis = list(categorical.sample().data)
    #print('len of pis', len(pis))
    #print('pis', pis)
    print('size of sigma = ', sigma.size())
    print('size of mu = ', mu.size())
    D = mu.size(-1)
    samples = torch.zeros([len(pi), D])
    sigma_cpu_all = sigma.detach().cpu()
    mu_cpu_all = mu.detach().cpu()
    for i, idx in enumerate(pis):
        #print('i = {}'.format(i))
        sigma_cpu = sigma_cpu_all[i, idx]
        precision_mat_diag_pos = torch.matmul(sigma_cpu,
                                              torch.transpose(sigma_cpu, 0, 1))
        mu_cpu = mu_cpu_all[i, idx]
        #precision_mat = sigma[i, idx] + torch.transpose(sigma[i, idx], 0, 1)
        diagonal_mat = torch.tensor(np.zeros([D, D]))
        #precision_mat_diag_pos np.fill_diagonal_(diagonal_mat, 1e-7)
        precision_mat_diag_pos += diagonal_mat.fill_diagonal_(
            1)  # add small positive value
        #precision_mat_diag_pos = precision_mat + diagonal_mat.fill_diagonal_(1 - torch.min(torch.diagonal(precision_mat)).detach().cpu().numpy())
        #print('precision_mat = ', precision_mat_diag_pos)
        #print(precision_mat_diag_pos)
        #print(mu_cpu)
        try:
            #print('precision_mat = ', precision_mat_diag_pos)
            MVN = MultivariateNormal(loc=mu_cpu,
                                     precision_matrix=precision_mat_diag_pos)
            draw_sample = MVN.rsample()
        except:
            print(
                "Ops, your covariance matrix is very unfortunately singular, assign loss of test_loss to avoid counting"
            )
            draw_sample = -999 * torch.ones([1, D])
        #print('sample size = ', draw_sample.size())
        samples[i, :] = draw_sample
    #print('samples', samples.size())
    return samples
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.set_num_threads(75)
    n, d, m = args.n, args.d, args.m

    labels_list = []
    data_list = []
    means_list = []
    covs_list = []

    l = 1

    for i in range(m):

        mean = torch.randn(d) * 2
        covs = torch.randn((d, d))
        covs = covs @ covs.t() + l * torch.eye(d)

        means_list.append(mean)
        covs_list.append(covs)

        sampler = MultivariateNormal(mean, covs)

        data = sampler.sample(sample_shape=(torch.Size([n // m])))
        labels = torch.zeros((n // m, m))
        labels[:, i] = 1.0
        labels_list.append(labels)
        data_list.append(data)

    data = torch.cat(data_list, dim=0)
    labels = torch.cat(labels_list, dim=0)
    means = torch.stack(means_list)
    covs = torch.stack(covs_list)

    print('data shape', data.shape)
    print('labels shape', labels.shape)

    path = os.path.join(
        args.save_prefix,
        'multimodal_gaussian_' + str(n) + '_' + str(d) + '_' + str(m))
    os.makedirs(path)

    data_path = os.path.join(path, 'data.npy')
    labels_path = os.path.join(path, 'labels.npy')
    means_path = os.path.join(path, 'means.npy')
    covs_path = os.path.join(path, 'covs.npy')

    np.save(data_path, data.numpy())
    np.save(labels_path, labels.numpy())
    np.save(means_path, means.numpy())
    np.save(covs_path, covs.numpy())
Ejemplo n.º 31
0
    def forward(self, x, S):
        x = x.view(-1, self.x_dim)
        bsz = x.size(0)

        ### get w and \alpha and L(\theta)
        mu, logvar = self.encoder(x)
        q_phi = Normal(loc=mu, scale=torch.exp(0.5 * logvar))
        z_q = q_phi.rsample((S, ))
        recon_batch = self.decoder(z_q)
        x_dist = Bernoulli(logits=recon_batch)
        log_lik = x_dist.log_prob(x).sum(-1)
        log_prior = self.prior.log_prob(z_q).sum(-1)
        log_q = q_phi.log_prob(z_q).sum(-1)
        log_w = log_lik + log_prior - log_q
        tmp_alpha = torch.logsumexp(log_w, dim=0).unsqueeze(0)
        alpha = torch.exp(log_w - tmp_alpha).detach()
        if self.version == 'v1':
            p_loss = -alpha * (log_lik + log_prior)

        ### get moment-matched proposal
        mu_r = alpha.unsqueeze(2) * z_q
        mu_r = mu_r.sum(0).detach()
        z_minus_mu_r = z_q - mu_r.unsqueeze(0)
        reshaped_diff = z_minus_mu_r.view(S * bsz, -1, 1)
        reshaped_diff_t = reshaped_diff.permute(0, 2, 1)
        outer = torch.bmm(reshaped_diff, reshaped_diff_t)
        outer = outer.view(S, bsz, self.z_dim, self.z_dim)
        Sigma_r = outer.mean(0) * S / (S - 1)
        Sigma_r = Sigma_r + torch.eye(self.z_dim).to(device) * 1e-6  ## ridging

        ### get v, \beta, and L(\phi)
        L = torch.cholesky(Sigma_r)
        r_phi = MultivariateNormal(loc=mu_r, scale_tril=L)

        z = r_phi.rsample((S, ))
        z_r = z.detach()
        recon_batch_r = self.decoder(z_r)
        x_dist_r = Bernoulli(logits=recon_batch_r)
        log_lik_r = x_dist_r.log_prob(x).sum(-1)
        log_prior_r = self.prior.log_prob(z_r).sum(-1)
        log_r = r_phi.log_prob(z_r)
        log_v = log_lik_r + log_prior_r - log_r
        tmp_beta = torch.logsumexp(log_v, dim=0).unsqueeze(0)
        beta = torch.exp(log_v - tmp_beta).detach()
        log_q = q_phi.log_prob(z_r).sum(-1)
        q_loss = -beta * log_q

        if self.version == 'v2':
            p_loss = -beta * (log_lik_r + log_prior_r)

        rem_loss = torch.sum(q_loss + p_loss, 0).sum()
        return rem_loss
Ejemplo n.º 32
0
    def generate_x(self, N=25):
        """
        Sample, using you VAE: sample z from prior and decode it 
        :param N: number of samples
        :return: X (N, inp_size)
        """

        m = MultivariateNormal(torch.zeros(self.z_dim + self.w_dim),
                               torch.eye(self.z_dim + self.w_dim))
        z = m.sample(sample_shape=torch.Size([N]))

        X, _ = self.p_x(z.cuda())
        return X
Ejemplo n.º 33
0
    def generate_x(self, N=25, device=torch.device("cpu")):
        """
        Sample, using you VAE: sample z from prior and decode it 
        :param N: number of samples
        :return: X (N, inp_size)
        """

        m = MultivariateNormal(torch.zeros(self.hid_dim),
                               torch.eye(self.hid_dim))
        z = m.sample(sample_shape=torch.Size([N]))

        X, _ = self.p_x(z.to(device))
        return X