Beispiel #1
0
    def loss(self, x, y, y_de, beta, lamda):

        latent, mu_latent, var_latent, \
        qmu5_1, qvar5_1, qmu4_2, qvar4_2, qmu4_1, qvar4_1, qmu3_2, qvar3_2, qmu3_1, qvar3_1, \
        qmu2_2, qvar2_2, qmu2_1, qvar2_1, qmu1_2, qvar1_2, qmu1_1, qvar1_1, \
        predict, predict_test, yh, \
        x_re, \
        pmu5_1, pvar5_1,pmu4_2, pvar4_2, pmu4_1, pvar4_1, pmu3_2, pvar3_2, pmu3_1, pvar3_1, \
        pmu2_2, pvar2_2, pmu2_1, pvar2_1, pmu1_2, pvar1_2, pmu1_1, pvar1_1 = self.lnet(x, y_de)

        rec = reconstruction_function(x_re, x)

        pm, pv = torch.zeros(mu_latent.shape).cuda(), torch.ones(
            var_latent.shape).cuda()
        # print("mu1", mu1)
        kl_latent = ut.kl_normal(mu_latent, var_latent, pm, pv, yh)
        kl5_1 = ut.kl_normal(qmu5_1, qvar5_1, pmu5_1, pvar5_1, 0)
        kl4_2 = ut.kl_normal(qmu4_2, qvar4_2, pmu4_2, pvar4_2, 0)
        kl4_1 = ut.kl_normal(qmu4_1, qvar4_1, pmu4_1, pvar4_1, 0)
        kl3_2 = ut.kl_normal(qmu3_2, qvar3_2, pmu3_2, pvar3_2, 0)
        kl3_1 = ut.kl_normal(qmu3_1, qvar3_1, pmu3_1, pvar3_1, 0)
        kl2_2 = ut.kl_normal(qmu2_2, qvar2_2, pmu2_2, pvar2_2, 0)
        kl2_1 = ut.kl_normal(qmu2_1, qvar2_1, pmu2_1, pvar2_1, 0)
        kl1_2 = ut.kl_normal(qmu1_2, qvar1_2, pmu1_2, pvar1_2, 0)
        kl1_1 = ut.kl_normal(qmu1_1, qvar1_1, pmu1_1, pvar1_1, 0)

        kl = beta * torch.mean(kl_latent + kl5_1 + kl4_2 + kl4_1 + kl3_2 +
                               kl3_1 + kl2_2 + kl2_1 + kl1_2 + kl1_1)

        ce = nllloss(predict, y)

        nelbo = rec + kl + lamda * ce
        # nelbo = rec
        return nelbo, mu_latent, predict, predict_test, x_re, rec, kl, lamda * ce
Beispiel #2
0
 def forward(self, inputs):
     ret = {}
     x, y_class = inputs['x'], inputs['y_class']
     m, v = self.encoder(x)
     flow_loss = 0.0
     if self.use_deterministic_encoder:
         y = self.decoder(m)
         kl_loss = torch.zeros(1)
     elif self.use_flow:
         z = ut.sample_gaussian(m, v)
         decoder_input = z if not self.use_encoding_in_decoder else \
         torch.cat((z,m),dim=-1) #BUGBUG: Ideally the encodings before passing to mu and sigma should be here.
         # decoder_input, log_jacobians = self.flow(decoder_input)
         # flow_loss = self.bound(decoder_input, log_jacobians)
         flow_decoder_input = torch.zeros_like(decoder_input)
         for i in range(self.z_dim):
             flow = self.flow[i]
             single_input = decoder_input[:, i].unsqueeze(1)
             single_output, log_jacobians = flow(single_input)
             flow_decoder_input[:, i] = single_output.squeeze(1)
             flow_loss += self.bound(single_output, log_jacobians)
         flow_loss /= self.z_dim
         y = self.decoder(flow_decoder_input)
         #compute KL divergence loss :
         p_m = self.z_prior[0].expand(m.size())
         p_v = self.z_prior[1].expand(v.size())
         kl_loss = ut.kl_normal(m, v, p_m, p_v)
     else:
         z = ut.sample_gaussian(m, v)
         decoder_input = z if not self.use_encoding_in_decoder else \
         torch.cat((z,m),dim=-1) #BUGBUG: Ideally the encodings before passing to mu and sigma should be here.
         y = self.decoder(decoder_input)
         #compute KL divergence loss :
         p_m = self.z_prior[0].expand(m.size())
         p_v = self.z_prior[1].expand(v.size())
         kl_loss = ut.kl_normal(m, v, p_m, p_v)
     #compute reconstruction loss
     if self.loss_type is 'chamfer':
         x_reconst = CD_loss(y, x)
     # mean or sum
     if self.loss_sum_mean == "mean":
         x_reconst = x_reconst.mean()
         kl_loss = kl_loss.mean()
     else:
         x_reconst = x_reconst.sum()
         kl_loss = kl_loss.sum()
     nelbo = x_reconst + kl_loss + flow_loss
     ret = {
         'nelbo': nelbo,
         'kl_loss': kl_loss,
         'x_reconst': x_reconst,
         'flow_loss': flow_loss
     }
     # classifer network
     mv = torch.cat((m, v), dim=1)
     y_logits = self.z_classifer(mv)
     z_cl_loss = self.z_classifer.cross_entropy_loss(y_logits, y_class)
     ret['z_cl_loss'] = z_cl_loss
     return ret
Beispiel #3
0
 def forward(self, inputs):
     x,y_class = inputs['x'], inputs['y_class']
     m, v = self.encoder(x)
     if self.use_deterministic_encoder:
         y = self.decoder(m)
         kl_loss = torch.zeros(1)
     else:
         z =  ut.sample_gaussian(m,v).to(device)
         y = self.decoder(z)
         #compute KL divergence loss :
         p_m = self.z_prior[0].expand(m.size())
         p_v = self.z_prior[1].expand(v.size())
         kl_loss = ut.kl_normal(m,v,p_m,p_v)
     #compute reconstruction loss 
     if self.loss_type is 'chamfer':
         x_reconst = CD_loss(y,x)
     
     x_reconst = x_reconst.mean()
     kl_loss = kl_loss.mean()
     #compute classifers
     y_logits = self.z_classifer(z)
     cl_loss = self.z_classifer.cross_entropy_loss(y_logits,y_class)
     nelbo = x_reconst + kl_loss 
     loss = nelbo + cl_loss
     ret = {'loss':loss, 'nelbo':nelbo, 'kl_loss':kl_loss, 'x_reconst':x_reconst, 'cl_loss':cl_loss}
     return ret
    def _vae_loss(self):
        kl = kl_normal(self.q_z_mean, self.q_z_log_var)
        # tf.summary.scalar('kl divergence', kl)

        # Bernoulli reconstruction
        reconstruction = tf.reduce_sum(
            self.p_x.log_prob(slim.flatten(self.x)), 1)
        # tf.summary.scalar('reconstruction', reconstruction)

        # Mean-squared error reconstruction
        # d = (slim.flatten(self.input_) - self.logits)
        # d2 = tf.multiply(d, d) * 2.0
        # reconstruction = -tf.reduce_sum(d2, 1)

        elbo = reconstruction - self.beta * kl
        return tf.reduce_mean(-elbo)
Beispiel #5
0
    def _vae_loss(self):
        # TODO: should the KL divergences be weighted to factor in the
        # number of variables. For isntance kl_normal uses reduce_sum,
        # should it be divded by the number of normal variables
        discrete_kl = kl_categorical(self.q_category)
        normal_kl = kl_normal(self.q_z_mean, self.q_z_log_var)
        # reconstruction = tf.reduce_sum(
        #    self.p_x.log_prob(slim.flatten(self.input_)), 1)

        d = (slim.flatten(self.input_) - self.logits)
        d2 = tf.multiply(d, d) * 2.0
        reconstruction = -tf.reduce_sum(d2, 1)

        self.discrete_kl = discrete_kl
        self.normal_kl = normal_kl
        self.reconstruction = reconstruction

        elbo = reconstruction - discrete_kl - normal_kl
        return tf.reduce_mean(-elbo)
Beispiel #6
0
def compute_loss(batch, grid, mask, z_params_full, z_params_masked, h, w,
                 decoder):
    ## compute loss
    z_full = sample_z(z_params_full)  # size bsize * hidden
    z_full = z_full.unsqueeze(1).expand(-1, h * w, -1)

    # resize context to have one context per input coordinate
    grid_input = grid.view(1, h * w, -1).expand(batch.size(0), -1, -1)
    target_input = torch.cat([z_full, grid_input], dim=-1)

    reconstructed_image_mean, reconstructed_image_variance = decoder(
        target_input)  # bsize,h*w,1
    reconstruction_loss = -(
        log_normal(x=batch.view(batch.size(0), 3, h * w).transpose(1, 2),
                   m=reconstructed_image_mean,
                   v=reconstructed_image_variance) *
        (1 - mask.view(-1, h * w))).sum(dim=1).mean()

    kl_loss = kl_normal(z_params_full, z_params_masked).mean()
    return reconstruction_loss, kl_loss, reconstructed_image_mean, reconstructed_image_variance
def compute_loss(batch, grid, mask, z_params_full, z_params_masked, h, w,
                 decoder):
    ## compute loss
    z_full = sample_z(z_params_full)  # size bsize * hidden
    z_full = z_full.unsqueeze(1).expand(-1, h * w, -1)

    # resize context to have one context per input coordinate
    grid_input = grid.view(1, h * w, -1).expand(batch.size(0), -1, -1)
    target_input = torch.cat([z_full, grid_input], dim=-1)

    reconstructed_image = decoder(target_input)  # bsize,h*w,1

    reconstruction_loss = (
        F.binary_cross_entropy(reconstructed_image,
                               batch.view(batch.size(0), h * w, 1),
                               reduction='none') *
        (1 - mask.view(-1, h * w, 1))).sum(dim=1).mean()

    kl_loss = kl_normal(z_params_full, z_params_masked).mean()

    return reconstruction_loss, kl_loss, reconstructed_image
    def loss(self, x, y, y_de, beta, lamda, args):

        latent, mu_latent, var_latent, \
        qmu5_1, qvar5_1, qmu4_2, qvar4_2, qmu4_1, qvar4_1, qmu3_2, qvar3_2, qmu3_1, qvar3_1, \
        qmu2_2, qvar2_2, qmu2_1, qvar2_1, qmu1_2, qvar1_2, qmu1_1, qvar1_1, \
        predict, predict_test, yh, \
        x_re, \
        pmu5_1, pvar5_1,pmu4_2, pvar4_2, pmu4_1, pvar4_1, pmu3_2, pvar3_2, pmu3_1, pvar3_1, \
        pmu2_2, pvar2_2, pmu2_1, pvar2_1, pmu1_2, pvar1_2, pmu1_1, pvar1_1 = self.lnet(x, y_de, args)

        rec = reconstruction_function(x_re, x)

        # split z and y if encode_z
        if args.encode_z:
            z_latent_mu, y_latent_mu = mu_latent.split([args.encode_z, 32],
                                                       dim=1)
            z_latent_var, y_latent_var = var_latent.split([args.encode_z, 32],
                                                          dim=1)
            pm_z, pv_z = torch.zeros(z_latent_mu.shape).cuda(), torch.ones(
                z_latent_var.shape).cuda()
        else:
            y_latent_mu = mu_latent
            y_latent_var = var_latent

        pm, pv = torch.zeros(y_latent_mu.shape).cuda(), torch.ones(
            y_latent_var.shape).cuda()
        # print("mu1", mu1)
        kl_latent = ut.kl_normal(y_latent_mu, y_latent_var, pm, pv, yh)
        kl5_1 = ut.kl_normal(qmu5_1, qvar5_1, pmu5_1, pvar5_1, 0)
        kl4_2 = ut.kl_normal(qmu4_2, qvar4_2, pmu4_2, pvar4_2, 0)
        kl4_1 = ut.kl_normal(qmu4_1, qvar4_1, pmu4_1, pvar4_1, 0)
        kl3_2 = ut.kl_normal(qmu3_2, qvar3_2, pmu3_2, pvar3_2, 0)
        kl3_1 = ut.kl_normal(qmu3_1, qvar3_1, pmu3_1, pvar3_1, 0)
        kl2_2 = ut.kl_normal(qmu2_2, qvar2_2, pmu2_2, pvar2_2, 0)
        kl2_1 = ut.kl_normal(qmu2_1, qvar2_1, pmu2_1, pvar2_1, 0)
        kl1_2 = ut.kl_normal(qmu1_2, qvar1_2, pmu1_2, pvar1_2, 0)
        kl1_1 = ut.kl_normal(qmu1_1, qvar1_1, pmu1_1, pvar1_1, 0)

        kl_all = kl_latent + kl5_1 + kl4_2 + kl4_1 + kl3_2 + kl3_1 + kl2_2 + kl2_1 + kl1_2 + kl1_1

        if args.encode_z:
            kl_all += args.beta_z * ut.kl_normal(z_latent_mu, z_latent_var,
                                                 pm_z, pv_z, 0)

        kl = beta * torch.mean(kl_all)

        ce = nllloss(predict, y)

        nelbo = rec + kl + lamda * ce

        if args.contrastive_loss:
            contra_loss = self.contra_loss
            nelbo += contra_loss
        # nelbo = rec
        return nelbo, y_latent_mu, predict, predict_test, x_re, rec, kl, lamda * ce
    def loss_calc(self, sampled_batched):

        # input data
        image = self.alpha_vision * sampled_batched["image"].to(self.device)
        force = self.alpha_force * sampled_batched["force"].to(self.device)
        proprio = self.alpha_proprio * sampled_batched["proprio"].to(
            self.device)
        depth = self.alpha_depth * sampled_batched["depth"].to(
            self.device).transpose(1, 3).transpose(2, 3)

        action = sampled_batched["action"].to(self.device)

        contact_label = sampled_batched["contact_next"].to(self.device)
        optical_flow_label = sampled_batched["flow"].to(self.device)
        optical_flow_mask_label = sampled_batched["flow_mask"].to(self.device)

        # unpaired data for sampled point
        unpaired_image = self.alpha_vision * sampled_batched[
            "unpaired_image"].to(self.device)
        unpaired_force = self.alpha_force * sampled_batched[
            "unpaired_force"].to(self.device)
        unpaired_proprio = self.alpha_proprio * sampled_batched[
            "unpaired_proprio"].to(self.device)
        unpaired_depth = self.alpha_depth * sampled_batched[
            "unpaired_depth"].to(self.device).transpose(1, 3).transpose(2, 3)

        # labels to predict
        gt_ee_pos_delta = sampled_batched["ee_yaw_next"].to(self.device)

        if self.deterministic:
            paired_out, contact_out, flow2, optical_flow2_mask, ee_delta_out, mm_feat = self.model(
                image, force, proprio, depth, action)
            kl = torch.tensor([0]).to(self.device).type(torch.cuda.FloatTensor)
        else:
            paired_out, contact_out, flow2, optical_flow2_mask, ee_delta_out, mm_feat, mu_z, var_z, mu_prior, var_prior = self.model(
                image, force, proprio, depth, action)
            kl = self.alpha_kl * torch.mean(
                kl_normal(mu_z, var_z, mu_prior.squeeze(0),
                          var_prior.squeeze(0)))

        flow_loss = self.alpha_optical_flow * realEPE(
            flow2, optical_flow_label, self.device)

        # Scene flow losses

        b, _, h, w = optical_flow_label.size()

        optical_flow_mask = nn.functional.upsample(optical_flow2_mask,
                                                   size=(h, w),
                                                   mode="bilinear")

        flow_mask_loss = self.alpha_optical_flow_mask * self.loss_optical_flow_mask(
            optical_flow_mask, optical_flow_mask_label)

        contact_loss = self.alpha_contact * self.loss_contact_next(
            contact_out, contact_label)

        ee_delta_loss = self.alpha_ee_fut * self.loss_ee_pos(
            ee_delta_out, gt_ee_pos_delta)

        paired_loss = self.alpha_pair * self.loss_is_paired(
            paired_out,
            torch.ones(paired_out.size(0), 1).to(self.device))

        unpaired_total_losses = self.model(unpaired_image, unpaired_force,
                                           unpaired_proprio, unpaired_depth,
                                           action)
        unpaired_out = unpaired_total_losses[0]
        unpaired_loss = self.alpha_pair * self.loss_is_paired(
            unpaired_out,
            torch.zeros(unpaired_out.size(0), 1).to(self.device))

        loss = (contact_loss + paired_loss + unpaired_loss + ee_delta_loss +
                kl + flow_loss + flow_mask_loss)

        contact_pred = nn.Sigmoid()(contact_out).detach()
        contact_accuracy = compute_accuracy(contact_pred,
                                            contact_label.detach())

        paired_pred = nn.Sigmoid()(paired_out).detach()
        paired_accuracy = compute_accuracy(
            paired_pred,
            torch.ones(paired_pred.size()[0], 1, device=self.device))

        unpaired_pred = nn.Sigmoid()(unpaired_out).detach()
        unpaired_accuracy = compute_accuracy(
            unpaired_pred,
            torch.zeros(unpaired_pred.size()[0], 1, device=self.device))

        is_paired_accuracy = (paired_accuracy + unpaired_accuracy) / 2.0

        # logging
        is_paired_loss = paired_loss + unpaired_loss

        return (
            loss,
            mm_feat,
            (
                flow_loss,
                contact_loss,
                is_paired_loss,
                contact_accuracy,
                is_paired_accuracy,
                ee_delta_loss,
                kl,
            ),
            (flow2, optical_flow_label, image),
        )
        #encoding. dim: batch, dim_z
        qm, qv = ut.gaussian_parameters(dict_model['EncoderVAE'](set_rep),
                                        dim=1)
        #sample z
        z = ut.sample_gaussian(qm, qv, device=device)  #batch_size, dim_z
        #z to rep
        rep_m, rep_v = ut.gaussian_parameters(dict_model['LatentToRep'](z))

        log_likelihood = ut.log_normal(set_rep, rep_m, rep_v)  #dim: batch
        lb_1 = log_likelihood.mean()  #scalar

        #KL divergence
        m = z_prior_m.expand(P, dim_z)
        v = z_prior_v.expand(P, dim_z)
        lb_2 = -ut.kl_normal(qm, qv, m, v).mean()  #scalar

        loss = -1 * (lb_1 + lb_2)
        loss.backward()
        optimizer.step()

        #reconstruct and plot
        if i % iter_rec == 0:
            X_rec = dict_model['DecoderAE'](ut.sample_gaussian(
                rep_m, rep_v, device=device)).reshape(P, -1, 3)
            for j in range(5):
                fig = ut.plot_3d_point_cloud(
                    X_hold[j, :, :].detach().cpu().numpy(),
                    X_rec[j, :, :].detach().cpu().numpy())
                fig.savefig(dir_fig + args.name + '_iter_' + str(i) + '_pic_' +
                            str(j + 1) + '.png')