Пример #1
0
    def train(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device)
        zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device)

        out = False
        while not out:
            for x_true1, x_true2 in self.data_loader:
                self.global_iter += 1
                self.pbar.update(1)

                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                vae_recon_loss = recon_loss(x_true1, x_recon)
                vae_kld = kl_divergence(mu, logvar)

                D_z = self.D(z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

                vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss

                self.optim_VAE.zero_grad()
                vae_loss.backward(retain_graph=True)
                self.optim_VAE.step()

                x_true2 = x_true2.to(self.device)
                z_prime = self.VAE(x_true2, no_dec=True)
                z_pperm = permute_dims(z_prime).detach()
                D_z_pperm = self.D(z_pperm)
                D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones))

                self.optim_D.zero_grad()
                D_tc_loss.backward()
                self.optim_D.step()

                if self.global_iter%self.print_iter == 0:
                    self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format(
                        self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item()))

                if self.global_iter%self.ckpt_save_iter == 0:
                    self.save_checkpoint(self.global_iter)

                if self.global_iter >= self.max_iter:
                    out = True
                    break

        self.pbar.write("[Training Finished]")
        self.pbar.close()
Пример #2
0
    def train(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device)
        zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device)

        out = False
        while not out:
            for x_true1, x_true2 in self.data_loader:
                self.global_iter += 1
                self.pbar.update(1)

                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                vae_recon_loss = recon_loss(x_true1, x_recon)
                vae_kld = kl_divergence(mu, logvar,self.r)
                H_r = entropy(self.r)

                D_z = self.D(self.r*z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

                vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss + self.etaS*self.r.abs().sum() + self.etaH*H_r

                self.optim_VAE.zero_grad()
                vae_loss.backward(retain_graph=True)
                self.optim_VAE.step()

                self.optim_r.zero_grad()
                vae_loss.backward(retain_graph=True)
                self.optim_r.step()

                x_true2 = x_true2.to(self.device)
                z_prime = self.VAE(x_true2, no_dec=True)
                z_pperm = permute_dims(z_prime).detach()
                D_z_pperm = self.D(self.r*z_pperm)
                D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones))

                self.optim_D.zero_grad()
                D_tc_loss.backward()
                self.optim_D.step()


                if self.global_iter%self.print_iter == 0:
                    self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format(
                        self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item()))

                if self.global_iter%self.ckpt_save_iter == 0:
                    self.save_checkpoint(self.global_iter)

                if self.viz_on and (self.global_iter%self.viz_ll_iter == 0):
                    soft_D_z = F.softmax(D_z, 1)[:, :1].detach()
                    soft_D_z_pperm = F.softmax(D_z_pperm, 1)[:, :1].detach()
                    D_acc = ((soft_D_z >= 0.5).sum() + (soft_D_z_pperm < 0.5).sum()).float()
                    D_acc /= 2*self.batch_size
                    self.line_gather.insert(iter=self.global_iter,
                                            soft_D_z=soft_D_z.mean().item(),
                                            soft_D_z_pperm=soft_D_z_pperm.mean().item(),
                                            recon=vae_recon_loss.item(),
                                            kld=vae_kld.item(),
                                            acc=D_acc.item(),
                                            r_distribute=self.r.data.cpu())

                if self.viz_on and (self.global_iter%self.viz_la_iter == 0):
                    self.visualize_line()
                    self.line_gather.flush()

                if self.viz_on and (self.global_iter%self.viz_ra_iter == 0):
                    self.image_gather.insert(true=x_true1.data.cpu(),
                                             recon=F.sigmoid(x_recon).data.cpu())
                    self.visualize_recon()
                    self.image_gather.flush()

                if self.viz_on and (self.global_iter%self.viz_ta_iter == 0):
                    if self.dataset.lower() == '3dchairs':
                        self.visualize_traverse(limit=2, inter=0.5)
                    else:
                        self.visualize_traverse(limit=3, inter=2/3)

                if self.global_iter >= self.max_iter:
                    out = True
                    break

        self.pbar.write("[Training Finished]")
        self.pbar.close()
Пример #3
0
    def train(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device)
        zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device)
        metrics = []
        out = False
        while not out:
            for x_true1, x_true2 in self.data_loader:
                self.global_iter += 1
                self.pbar.update(1)

                self.optim_VAE.step()
                self.optim_D.step()

                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                vae_recon_loss = recon_loss(x_true1, x_recon)
                vae_kld = kl_divergence(mu, logvar)

                D_z = self.D(z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

                vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss

                self.optim_VAE.zero_grad()
                vae_loss.backward(retain_graph=True)
                #self.optim_VAE.step()

                x_true2 = x_true2.to(self.device)
                z_prime = self.VAE(x_true2, no_dec=True)
                z_pperm = permute_dims(z_prime).detach()
                D_z_pperm = self.D(z_pperm)
                D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones))

                self.optim_D.zero_grad()
                D_tc_loss.backward()
                #self.optim_D.step()


                # Saving the training metrics
                if self.global_iter % 100 == 0:
                    metrics.append({'its':self.global_iter,
                        'vae_loss': vae_loss.detach().to(torch.device("cpu")).item(),
                        'D_loss': D_tc_loss.detach().to(torch.device("cpu")).item(),
                        'recon_loss':vae_recon_loss.detach().to(torch.device("cpu")).item(),
                        'tc_loss': vae_tc_loss.detach().to(torch.device("cpu")).item()})

                # Saving the disentanglement metrics results
                if self.global_iter % 1500 == 0:
                    score = self.disentanglement_metric() 
                    metrics.append({'its':self.global_iter, 'metric_score': score})
                    self.net_mode(train=True) #To continue the training again

                if self.global_iter%self.print_iter == 0:
                    self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format(
                        self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item()))

                if self.global_iter%self.ckpt_save_iter == 0:
                    self.save_checkpoint(str(self.global_iter)+".pth")
                    self.save_metrics(metrics)
                    metrics = []

                if self.global_iter >= self.max_iter:
                    out = True
                    break

        self.pbar.write("[Training Finished]")
        self.pbar.close()
Пример #4
0
    def train(self):

        self.net_mode(train=True)

        ones = torch.ones(self.batch_size,
                          dtype=torch.long,
                          device=self.device)
        zeros = torch.zeros(self.batch_size,
                            dtype=torch.long,
                            device=self.device)

        epochs = int(np.ceil(self.steps) / len(self.dataloader))
        print("number of epochs {}".format(epochs))

        step = 0

        for e in range(epochs):
            #for e in range():

            for x_true1, x_true2 in self.dataloader:

                if step == 1: break

                step += 1

                # VAE
                x_true1 = x_true1.unsqueeze(1).to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)

                # Reconstruction and KL
                vae_recon_loss = recon_loss(x_true1, x_recon)
                vae_kl = kl_div(mu, logvar)

                # Total Correlation
                D_z = self.D(z)
                tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

                # Synergy term
                best_ai = self.D_syn(mu, logvar)
                best_ai_labels = torch.bernoulli(best_ai)

                # TODO Copy to an empty tensor

                mu[best_ai_labels == 0] = 0
                logvar_syn[best_ai_labels == 0] = 0

                # TODO For to KL

                for i in range(self.batch_size):
                    mu_syn_s = mu_syn[i][mu_syn[i] != 0]

                if len(mu_syn.size()) == 1:
                    syn_loss = kl_div_uni_dim(mu_syn, logvar_syn).mean()
                    # print("here")
                else:
                    syn_loss = kl_div(mu_syn, logvar_syn)

                # VAE loss
                vae_loss = vae_recon_loss + vae_kl + self.gamma * tc_loss + self.alpha * syn_loss

                # Optimise VAE
                self.optim_VAE.zero_grad()  #zero gradients the buffer, grads
                vae_loss.backward(
                    retain_graph=True)  # grad parameters are populated
                self.optim_VAE.step()  #Does the step

                # TODO Check the best greedy policy
                # Discriminator Syn
                real_seq = greedy_policy_Smax_discount(self.z_dim, mu, logvar,
                                                       0.8).detach
                d_syn_loss = recon_loss(real_seq, best_ai)

                # Optimise Discriminator Syn
                self.optim_D_syn.zero_grad(
                )  # set zeros all the gradients of VAE network
                d_syn_loss.backward(
                    retain_graph=True)  # backprop the gradients
                self.optim_D_syn.step(
                )  # Does the update in VAE network parameters

                # Discriminator TC
                x_true2 = x_true2.unsqueeze(1).to(self.device)
                z_prime = self.VAE(x_true2, decode=False)[3]
                z_perm = permute_dims(z_prime).detach(
                )  ## detaches the output from the graph. no gradient will be backproped along this variable.
                D_z_perm = self.D(z_perm)

                # Discriminator TC loss
                d_loss = 0.5 * (F.cross_entropy(D_z, zeros) +
                                F.cross_entropy(D_z_perm, ones))

                # Optimise Discriminator TC
                self.optim_D.zero_grad()
                d_loss.backward()
                self.optim_D.step()

                # Logging
                if step % self.args.log_interval == 0:

                    print("Step {}".format(step))
                    print("Recons. Loss = " + "{:.4f}".format(vae_recon_loss))
                    print("KL Loss = " + "{:.4f}".format(vae_kl))
                    print("TC Loss = " + "{:.4f}".format(tc_loss))
                    print("Syn Loss = " + "{:.4f}".format(syn_loss))
                    print("Factor VAE Loss = " + "{:.4f}".format(vae_loss))
                    print("D loss = " + "{:.4f}".format(d_loss))
                    print("best_ai {}".format(best_ai))
                    print("Syn loss {:.4f}".format(syn_loss))

                # Saving
                if not step % self.args.save_interval:
                    filename = 'traversal_' + str(step) + '.png'
                    filepath = os.path.join(self.args.output_dir, filename)
                    traverse(self.net_mode, self.VAE, self.test_imgs, filepath)
Пример #5
0
    def train(self):

        self.net_mode(train=True)

        ones = torch.ones(self.batch_size,
                          dtype=torch.long,
                          device=self.device)
        zeros = torch.zeros(self.batch_size,
                            dtype=torch.long,
                            device=self.device)

        epochs = int(np.ceil(self.steps) / len(self.dataloader))
        print("number of epochs {}".format(epochs))

        step = 0
        # dict of init opt weights
        #dict_init = {a: defaultdict(list) for a in range(10)}
        # dict of VAE opt weights
        #dict_VAE = {a:defaultdict(list) for a in range(10)}

        weights_names = [
            'encoder.2.weight', 'encoder.10.weight', 'decoder.0.weight',
            'decoder.7.weight', 'net.4.weight'
        ]

        dict_VAE = defaultdict(list)
        dict_weight = {a: [] for a in weights_names}

        for e in range(epochs):
            #for e in range():

            for x_true1, x_true2 in self.dataloader:

                #if step == 1: break

                step += 1
                """

                # TRACKING OF GRADS
                print("GRADS")
                for name, params in self.VAE.named_parameters():

                    if name == 'encoder.2.weight':
                        #size : 32,32,4,4
                        print("Grads: Before VAE optim step {}".format(step))
                        #if params.grad != None:
                        if step != 1:
                            if np.array_equal(dict_VAE[name], params.grad.numpy()) == False :
                            #if dict_VAE[name] != tuple(params.grad.numpy()):
                                print("Change in gradients {}".format(name))
                                #dict_init[step][name] = params.grad.numpy()
                                dict_VAE[name] = params.grad.numpy().copy()
                            else:
                                print("No change in gradients {}".format(name))
                            #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))

                        else:
                            print("name {}, params grad {}".format(name, params.grad))
                            #dict_init[step][name] = None
                            dict_VAE[name] = None

                    if name == 'encoder.10.weight':
                        #size : 32,32,4,4
                        #print("Before VAE optim  encoder step {}".format(step))
                        #if params.grad != None:
                        if step != 1:
                            if np.array_equal(dict_VAE[name], params.grad.numpy()) == False :
                            #if dict_VAE[name] != tuple(params.grad.numpy()):
                                print("Change in gradients {}".format(name))
                                #dict_init[step][name] = params.grad.numpy()
                                dict_VAE[name] = params.grad.numpy().copy()
                            else:
                                print("No change in gradients {}".format(name))
                            #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))

                        else:
                            print("name {}, params grad {}".format(name, params.grad))
                            #dict_init[step][name] = None
                            dict_VAE[name] = None

                    if name == 'decoder.0.weight':

                        #print("Before VAE optim  decoder step {}".format(step))
                        #if params.grad != None:
                        if step != 1:

                            if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            #if dict_VAE[name] != tuple(params.grad.numpy()):
                                print("Change in gradients {}".format(name))
                                dict_VAE[name] = params.grad.numpy().copy()
                            else:
                                print("No change in gradients {}".format(name))
                            #print("name {}, params grad {}".format(name, params.grad[:5, :2]))
                        else:
                            print("name {}, params grad {}".format(name, params.grad))
                            #dict_init[step][name] = None
                            dict_VAE[name] = None

                    if name == 'decoder.7.weight':

                        #print("Before VAE optim  decoder step {}".format(step))
                        #if params.grad != None:
                        if step != 1:

                            if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            #if dict_VAE[name] != tuple(params.grad.numpy()):
                                print("Change in gradients {}".format(name))
                                dict_VAE[name] = params.grad.numpy().copy()
                            else:
                                print("No change in gradients {}".format(name))
                            #print("name {}, params grad {}".format(name, params.grad[1, 1, :, :]))
                        else:
                            print("name {}, params grad {}".format(name, params.grad))
                            #dict_init[step][name] = None
                            dict_VAE[name] = None

                for name, params in self.D.named_parameters():

                    if name == 'net.4.weight':

                        #print("Before VAE optim  discrim step {}".format(step))
                        #if params.grad != None:
                        if step != 1:

                            if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            #if dict_VAE[name] != tuple(params.grad.numpy()):
                                print("Change in gradients {}".format(name))
                                dict_VAE[name] = params.grad.numpy().copy()
                            else:
                                print("No change in gradients {}".format(name))
                            #print("name {}, params grad {}".format(name, params.grad[1, 1, :, :]))
                        else:
                            print("name {}, params grad {}".format(name, params.grad))
                            #dict_init[step][name] = None
                            dict_VAE[name] = None
                        print()

                """

                # VAE
                x_true1 = x_true1.unsqueeze(1).to(self.device)
                #print("x_true1 size {}".format(x_true1.size()))

                x_recon, mu, logvar, z = self.VAE(x_true1)

                # Reconstruction and KL
                vae_recon_loss = recon_loss(x_true1, x_recon)
                #print("vae recon loss {}".format(vae_recon_loss))
                vae_kl = kl_div(mu, logvar)
                #print("vae kl loss {}".format(vae_kl))

                # Total Correlation
                D_z = self.D(z)
                #print("D_z size {}".format(D_z.size()))
                tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()
                #print("tc loss {}".format(tc_loss))

                # VAE loss
                vae_loss = vae_recon_loss + vae_kl + self.gamma * tc_loss
                #print("Total VAE loss {}".format(vae_loss))

                #print("Weights: Before VAE, step {}".format(step))
                #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :]))
                #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5]))

                # Optimise VAE
                self.optim_VAE.zero_grad()  #zero gradients the buffer, grads
                """
                print("after zero grad step {}".format(step))
                #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :]))
                #print("encoder.10.weight {}".format(self.optim_VAE.param_groups[0]['params'][10][:, :2]))
                #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5]))

                # check if the VAE is optimizing the encoder and decoder
                for name, params in self.VAE.named_parameters():

                    if name == 'encoder.2.weight':
                        # size : 32,32,4,4
                        if step == 1:
                            print("name {}, params grad {}".format(name, params.grad))
                        #else:
                            #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))

                    if name == 'encoder.10.weight':
                        # size : 32,32,4,4
                        if step == 1:
                            print("name {}, params grad {}".format(name, params.grad))
                        #else:
                            #print("name {}, params grad {}".format(name, params.grad[:, :2]))

                for name, params in self.D.named_parameters():

                    if name == 'net.4.weight':
                        # size : 32,32,4,4
                        if step == 1:

                            print("name {}, params grad {}".format(name, params.grad))
                        #else:

                            #print("name {}, params grad {}".format(name, params.grad[:5, :5]))

                """

                vae_loss.backward(
                    retain_graph=True)  # grad parameters are populated
                """
                print()
                print("after backward step {}".format(step))
                #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :]))
                #print("encoder.10.weight {}".format(self.optim_VAE.param_groups[0]['params'][10][:, :2]))
                #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5]))
                # check if the VAE is optimizing the encoder and decoder
                for name, params in self.VAE.named_parameters():
                    if name == 'encoder.2.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                    if name == 'encoder.10.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[:, :2]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                for name, params in self.D.named_parameters():
                    if name == 'net.4.weight':
                        # size : 1000,1000
                        #print("name {}, params grad {}".format(name, params.grad[:5, :5]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))"""

                self.optim_VAE.step()  #Does the step
                #print()
                #print("after VAE update step {}".format(step))
                #print("encoder.2.weight size {}".format(self.optim_VAE.param_groups[0]['params'][2].size()))
                #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0,0,:,:]))
                #print("encoder.10.weight {}".format(self.optim_VAE.param_groups[0]['params'][10][:, :2]))
                #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5]))
                """

                # check if the VAE is optimizing the encoder and decoder
                for name, params in self.VAE.named_parameters():
                    if name == 'encoder.2.weight':
                        #size : 32,32,4,4
                        print("After VAE optim step {}".format(step))
                        #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                        #if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))


                    if name == 'encoder.10.weight':
                        #size : 20, 128
                        #print("size of {}: {}".format(name, params.grad.size()))
                        #print("After VAE optim  encoder step {}".format(step))
                        #print("name {}, params grad {}".format(name, params.grad[:, :2]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                        #if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))


                    if name == 'decoder.0.weight':
                        #128,10
                        #print("After VAE optim  decoder linear step {}".format(step))
                        #print("size of {}: {}".format(name, params.grad.size()))
                        #print("name {}, params grad {}".format(name, params.grad[:3, :]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                        #if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))


                    if name == 'decoder.7.weight':
                        #print("After VAE optim  decoder step {}".format(step))
                        #print("name {}, params grad {}".format(name, params.grad[1, 1, :, :]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                        #if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))


                for name, params in self.D.named_parameters():

                    if name == 'net.4.weight':
                        #print("After VAE optim  discriminator step {}".format(step))
                        #print("name {}, params grad {}".format(name, params.grad[:5,:5]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                        #if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))
                        print()

                """

                #print()
                #print("Before Syn step {}".format(step))
                #print("encoder.2.weight size {}".format(self.optim_VAE.param_groups[0]['params'][2].size()))
                #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :]))
                #print("encoder.10.weight {}".format(self.optim_VAE.param_groups[0]['params'][10][:, :2]))
                #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5]))

                ##################
                #Synergy Max

                # Step 1: compute the argmax of D kl (q(ai | x(i)) || )
                best_ai = greedy_policy_Smax_discount(self.z_dim,
                                                      mu,
                                                      logvar,
                                                      alpha=self.omega)

                # Step 2: compute the Imax
                mu_syn = mu[:, best_ai]
                logvar_syn = logvar[:, best_ai]

                if len(mu_syn.size()) == 1:
                    I_max = kl_div_uni_dim(mu_syn, logvar_syn).mean()
                    # print("here")
                else:
                    I_max = kl_div(mu_syn, logvar_syn)

                #I_max1 = I_max_batch(best_ai, mu, logvar)
                #print("I_max step{}".format(I_max, step))

                # Step 3: Use it in the loss
                syn_loss = self.alpha * I_max
                #print("syn_loss step {}".format(syn_loss, step))

                # Step 4: Optimise Syn term
                self.optim_VAE.zero_grad(
                )  # set zeros all the gradients of VAE network
                """
                #print()
                print("after zeros Syn step {}".format(step))
                #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :]))
                #print("encoder.10.weight {}".format(self.optim_VAE.param_groups[0]['params'][10][:, :2]))
                #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5]))

                for name, params in self.VAE.named_parameters():
                    if name == 'encoder.2.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                    if name == 'encoder.10.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[:, :2]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                    if name == 'decoder.0.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[:, :2]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                    if name == 'decoder.7.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[:, :2]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))


                for name, params in self.D.named_parameters():
                    if name == 'net.4.weight':
                        # size : 1000,1000
                        #print("name {}, params grad {}".format(name, params.grad[:5, :5]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                """

                syn_loss.backward(retain_graph=True)  #backprop the gradients
                """

                print()
                print("after Syn backward step {}".format(step))
                #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :]))
                #print("encoder.10.weight {}".format(self.optim_VAE.param_groups[0]['params'][10][:, :2]))
                #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5]))

                for name, params in self.VAE.named_parameters():
                    if name == 'encoder.2.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                    if name == 'encoder.10.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[:, :2]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                    if name == 'decoder.0.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[:, :2]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                    if name == 'decoder.7.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[:, :2]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))



                for name, params in self.D.named_parameters():
                    if name == 'net.4.weight':
                        # size : 1000,1000
                        #print("name {}, params grad {}".format(name, params.grad[:5, :5]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                """

                self.optim_VAE.step(
                )  #Does the update in VAE network parameters

                #print()
                #print("after Syn update step {}".format(step))
                #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :]))
                #print("encoder.10.weight {}".format(self.optim_VAE.param_groups[0]['params'][10][:, :2]))
                #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5]))

                ###################
                """
                # check if the VAE is optimizing the encoder and decoder
                for name, params in self.VAE.named_parameters():
                    if name == 'encoder.2.weight':
                        # size : 32,32,4,4
                        print("After Syn optim step {}".format(step))
                        # print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))


                    if name == 'encoder.10.weight':
                        # size :
                        #print("After Syn optim  encoder step {}".format(step))
                        #print("name {}, params grad {}".format(name, params.grad[:, :]))
                        #print("name {}, params grad {}".format(name, params.grad[:, :2]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))

                            dim_changes = []
                            for dim in range(20):
                                if np.array_equal(dict_VAE[name][dim, :2], params.grad.numpy()[dim, :2]) == False:
                                    dim_changes.append(dim)
                            print("Changes in dimensions: {}".format(dim_changes))

                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))


                    if name == 'decoder.0.weight':
                        # 1024, 128
                        #print("After Syn optim  decoder linear step {}".format(step))
                        #print("name {}, params grad {}".format(name, params.grad[:5, :2]))
                        #print("name {}, params grad {}".format(name, params.grad[:3, :]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))


                    if name == 'decoder.7.weight':
                        #print("After Syn optim  decoder step {}".format(step))
                        # print("name {}, params grad {}".format(name, params.grad[1, 1, :, :]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))


                for name, params in self.D.named_parameters():

                    if name == 'net.4.weight':
                        #print("After Syn optim  discriminator step {}".format(step))
                        #print("name {}, params grad {}".format(name, params.grad[:5,:5]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))
                        print()
                """
                # Discriminator
                x_true2 = x_true2.unsqueeze(1).to(self.device)
                z_prime = self.VAE(x_true2, decode=False)[3]
                z_perm = permute_dims(z_prime).detach(
                )  ## detaches the output from the graph. no gradient will be backproped along this variable.
                D_z_perm = self.D(z_perm)

                # Discriminator loss
                d_loss = 0.5 * (F.cross_entropy(D_z, zeros) +
                                F.cross_entropy(D_z_perm, ones))
                #print("d_loss {}".format(d_loss))

                #print("dict VAE {}".format(dict_VAE['encoder.2.weight'][0, 0, :, :]))

                #print("before Disc, step {}".format(step))
                #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :]))
                #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5]))

                # Optimise Discriminator
                self.optim_D.zero_grad()
                """

                print("after zero grad Disc step {}".format(step))
                #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :]))
                #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5]))

                for name, params in self.VAE.named_parameters():
                    if name == 'encoder.2.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                    if name == 'encoder.10.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                    if name == 'decoder.0.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                    if name == 'decoder.7.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                for name, params in self.D.named_parameters():
                    if name == 'net.4.weight':
                        # size : 1000,1000
                        #print("name {}, params grad {}".format(name, params.grad[:5, :5]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                """

                d_loss.backward()
                """
                print()
                print("after backward Disc step {}".format(step))
                #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :]))
                #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5]))
                # check if the VAE is optimizing the encoder and decoder
                for name, params in self.VAE.named_parameters():
                    if name == 'encoder.2.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                    if name == 'encoder.10.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                    if name == 'decoder.0.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                    if name == 'decoder.7.weight':
                        # size : 32,32,4,4
                        #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))

                for name, params in self.D.named_parameters():
                    if name == 'net.4.weight':
                        # size : 1000,1000
                        #print("name {}, params grad {}".format(name, params.grad[:5, :5]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))
                """

                self.optim_D.step()

                #print("dict VAE {}".format(dict_VAE['encoder.2.weight'][0, 0, :, :]))
                """
                print()
                print("after update disc step {}".format(step))
                #print("encoder.2.weight size {}".format(self.optim_VAE.param_groups[0]['params'][2].size()))
                #print("encoder.2.weight {}".format(self.optim_VAE.param_groups[0]['params'][2][0, 0, :, :]))
                #print("net.4.weight {}".format(self.optim_D.param_groups[0]['params'][4][:5, :5]))

                for name, params in self.VAE.named_parameters():

                    if name == 'encoder.2.weight':
                        #size : 32,32,4,4
                        print("After Discriminator optim  encoder step {}".format(step))
                        #print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))
                        #print("dict VAE {}".format(dict_VAE[name][0, 0, :, :]))
                        #if np.isclose(dict_VAE[name], params.grad.numpy(), rtol=1e-05, atol=1e-08, equal_nan=False): "Works"
                        #if np.all(abs(dict_VAE[step][name] - params.grad.numpy())) < 1e-7 == False:
                        #if dict_VAE[name] != tuple(params.grad.numpy()):
                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))
                        #print("dict VAE {}".format(dict_VAE[name][0, 0, :, :]))

                    if name == 'encoder.10.weight':
                        # size :
                        #print("After Syn optim  encoder step {}".format(step))
                        # print("name {}, params grad {}".format(name, params.grad[0, 0, :, :]))
                        #print("name {}, params grad {}".format(name, params.grad[:, :2]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                            # if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()


                        else:
                            print("No change in gradients {}".format(name))


                    if name == 'decoder.0.weight':
                        #1024, 128
                        #print("After Discriminator optim  decoder linear step {}".format(step))
                        #print("name {}, params grad {}".format(name, params.grad[:5, :2]))
                        #print("name {}, params grad {}".format(name, params.grad[:3, :]))


                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                        #if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))


                    if name == 'decoder.7.weight':
                        #print("After Discriminator optim  decoder step {}".format(step))
                        #print("name {}, params grad {}".format(name, params.grad[1, 1, :, :]))
                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                        #if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))


                for name, params in self.D.named_parameters():

                    if name == 'net.4.weight':
                        #print("After Discriminator optim  decoder step {}".format(step))
                        #print("name {}, params grad {}".format(name, params.grad[:5, :5]))

                        if np.array_equal(dict_VAE[name], params.grad.numpy()) == False:
                        #if dict_VAE[name] != tuple(params.grad.numpy()):
                            print("Change in gradients {}".format(name))
                            dict_VAE[name] = params.grad.numpy().copy()
                        else:
                            print("No change in gradients {}".format(name))
                        print()"""

                # Logging
                if step % self.args.log_interval == 0:

                    print("Step {}".format(step))
                    print("Recons. Loss = " + "{:.4f}".format(vae_recon_loss))
                    print("KL Loss = " + "{:.4f}".format(vae_kl))
                    print("TC Loss = " + "{:.4f}".format(tc_loss))
                    print("Factor VAE Loss = " + "{:.4f}".format(vae_loss))
                    print("D loss = " + "{:.4f}".format(d_loss))
                    print("best_ai {}".format(best_ai))
                    print("I_max {}".format(I_max))
                    print("Syn loss {:.4f}".format(syn_loss))

                # Saving
                if not step % self.args.save_interval:
                    filename = 'alpha_' + str(
                        self.alpha) + '_traversal_' + str(step) + '.png'
                    filepath = os.path.join(self.args.output_dir, filename)
                    traverse(self.net_mode, self.VAE, self.test_imgs, filepath)
Пример #6
0
    def train(self):

        self.net_mode(train=True)

        ones = torch.ones(self.batch_size,
                          dtype=torch.long,
                          device=self.device)
        zeros = torch.zeros(self.batch_size,
                            dtype=torch.long,
                            device=self.device)

        epochs = int(np.ceil(self.steps) / len(self.dataloader))
        print("number of epochs {}".format(epochs))

        step = 0

        for e in range(epochs):
            #for e in range(1):

            for x_true1, x_true2 in self.dataloader:

                #if step == 50: break

                step += 1

                # VAE
                x_true1 = x_true1.unsqueeze(1).to(self.device)
                #print("x_true1 size {}".format(x_true1.size()))

                x_recon, mu, logvar, z = self.VAE(x_true1)

                #print("x_recon size {}".format(x_recon.size()))
                #print("mu size {}".format(mu.size()))
                #print("logvar size {}".format(logvar.size()))
                #print("z size {}".format(z.size()))

                # Reconstruction and KL
                vae_recon_loss = recon_loss(x_true1, x_recon)
                #print("vae recon loss {}".format(vae_recon_loss))
                vae_kl = kl_div(mu, logvar)
                #print("vae kl loss {}".format(vae_kl))

                # Total Correlation
                D_z = self.D(z)
                #print("D_z size {}".format(D_z.size()))
                tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()
                #print("tc loss {}".format(tc_loss))

                # VAE loss
                vae_loss = vae_recon_loss + vae_kl + self.gamma * tc_loss
                #print("Total VAE loss {}".format(vae_loss))

                # Optimise VAE
                self.optim_VAE.zero_grad()  #zero gradients the buffer
                vae_loss.backward(retain_graph=True)
                self.optim_VAE.step()  #Does the step

                # Discriminator
                x_true2 = x_true2.unsqueeze(1).to(self.device)
                z_prime = self.VAE(x_true2, decode=False)[3]
                z_perm = permute_dims(z_prime).detach(
                )  ## detaches the output from the graph. no gradient will be backproped along this variable.
                D_z_perm = self.D(z_perm)

                # Discriminator loss
                d_loss = 0.5 * (F.cross_entropy(D_z, zeros) +
                                F.cross_entropy(D_z_perm, ones))
                #print("d_loss {}".format(d_loss))

                # Optimise Discriminator
                self.optim_D.zero_grad()
                d_loss.backward()
                self.optim_D.step()

                # Logging
                if step % self.args.log_interval == 0:

                    print("Step {}".format(step))
                    print("Recons. Loss = " + "{:.4f}".format(vae_recon_loss))
                    print("KL Loss = " + "{:.4f}".format(vae_kl))
                    print("TC Loss = " + "{:.4f}".format(tc_loss))
                    print("Factor VAE Loss = " + "{:.4f}".format(vae_loss))
                    print("D loss = " + "{:.4f}".format(d_loss))

                # Saving
                if not step % self.args.save_interval:
                    filename = 'traversal_' + str(step) + '.png'
                    filepath = os.path.join(self.args.output_dir, filename)
                    traverse(self.net_mode, self.VAE, self.test_imgs, filepath)

                # Saving plot gt vs predicted
                if not step % self.args.gt_interval:
                    filename = 'gt_' + str(step) + '.png'
                    filepath = os.path.join(self.args.output_dir, filename)
                    plot_gt_shapes(self.net_mode, self.VAE, self.dataloader_gt,
                                   filepath)
Пример #7
0
    def train(self):

        gcam = GradCamDissen(self.VAE,
                             self.D,
                             target_layer='encode.1',
                             cuda=True)
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size,
                          dtype=torch.long,
                          device=self.device)
        zeros = torch.zeros(self.batch_size,
                            dtype=torch.long,
                            device=self.device)
        mu_avg, logvar_avg = 0, 1
        metrics = []
        out = False
        while not out:
            for batch_idx, (x1, x2) in enumerate(self.data_loader):
                self.global_iter += 1
                self.pbar.update(1)

                self.optim_VAE.step()
                self.optim_D.step()

                x1 = x1.to(self.device)
                x1_rec, mu, logvar, z = gcam.forward(x1)
                # For Standard FactorVAE loss
                vae_recon_loss = recon_loss(x1, x1_rec)
                vae_kld = kl_divergence(mu, logvar)
                D_z = self.D(z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

                factorVae_loss = vae_recon_loss + vae_kld + self.gamma * vae_tc_loss
                # For attention disentanglement loss
                gcam.backward(mu, logvar, mu_avg, logvar_avg)
                att_loss = 0
                with torch.no_grad():
                    gcam_maps = gcam.generate()
                    selected = self.select_attention_maps(gcam_maps)
                    for (sel1, sel2) in selected:
                        att_loss += attention_disentanglement(sel1, sel2)
                att_loss /= len(
                    selected)  # Averaging the loss accross all pairs of maps

                vae_loss = factorVae_loss + self.lambdaa * att_loss
                self.optim_VAE.zero_grad()
                vae_loss.backward(retain_graph=True)
                #self.optim_VAE.step()

                x2 = x2.to(self.device)
                z_prime = self.VAE(x2, no_dec=True)
                z_pperm = permute_dims(z_prime).detach()
                D_z_pperm = self.D(z_pperm)
                D_tc_loss = 0.5 * (F.cross_entropy(D_z, zeros) +
                                   F.cross_entropy(D_z_pperm, ones))

                self.optim_D.zero_grad()
                D_tc_loss.backward()
                #self.optim_D.step()

                # Saving the training metrics
                if self.global_iter % 100 == 0:
                    metrics.append({
                        'its':
                        self.global_iter,
                        'vae_loss':
                        vae_loss.detach().to(torch.device("cpu")).item(),
                        'D_loss':
                        D_tc_loss.detach().to(torch.device("cpu")).item(),
                        'recon_loss':
                        vae_recon_loss.detach().to(torch.device("cpu")).item(),
                        'tc_loss':
                        vae_tc_loss.detach().to(torch.device("cpu")).item()
                    })

                # Saving the disentanglement metrics results
                if self.global_iter % 1500 == 0:
                    score = self.disentanglement_metric()
                    metrics.append({
                        'its': self.global_iter,
                        'metric_score': score
                    })
                    self.net_mode(train=True)  # To continue the training again

                if self.global_iter % self.print_iter == 0:
                    self.pbar.write(
                        '[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'
                        .format(self.global_iter, vae_recon_loss.item(),
                                vae_kld.item(), vae_tc_loss.item(),
                                D_tc_loss.item()))

                if self.global_iter % self.ckpt_save_iter == 0:
                    self.save_checkpoint(str(self.global_iter) + ".pth")
                    self.save_metrics(metrics)
                    metrics = []

                if self.global_iter >= self.max_iter:
                    out = True
                    break

        self.pbar.write("[Training Finished]")
        self.pbar.close()
Пример #8
0
def main(args):

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    pbar = tqdm(total=args.epochs)
    image_gather = DataGather('true', 'recon')

    dataset = get_celeba_selected_dataset()
    data_loader = DataLoader(dataset=dataset,
                             batch_size=args.batch_size,
                             shuffle=True)

    lr_vae = args.lr_vae
    lr_D = args.lr_D
    vae = CelebaFactorVAE(args.z_dim, args.num_labels).to(device)
    optim_vae = torch.optim.Adam(vae.parameters(), lr=args.lr_vae)

    D = Discriminator(args.z_dim, args.num_labels).to(device)
    optim_D = torch.optim.Adam(D.parameters(), lr=args.lr_D, betas=(0.5, 0.9))

    # Checkpoint
    ckpt_dir = os.path.join(args.ckpt_dir, args.name)
    mkdirs(ckpt_dir)
    start_epoch = 0
    if args.ckpt_load:
        load_checkpoint(pbar, ckpt_dir, D, vae, optim_D, optim_vae, lr_vae,
                        lr_D)
        #optim_D.param_groups[0]['lr'] = 0.00001#lr_D
        #optim_vae.param_groups[0]['lr'] = 0.00001#lr_vae
        print("confirming lr after loading checkpoint: ",
              optim_vae.param_groups[0]['lr'])

    # Output
    output_dir = os.path.join(args.output_dir, args.name)
    mkdirs(output_dir)

    ones = torch.ones(args.batch_size, dtype=torch.long, device=device)
    zeros = torch.zeros(args.batch_size, dtype=torch.long, device=device)

    for epoch in range(start_epoch, args.epochs):
        pbar.update(1)

        for iteration, (x, y, x2, y2) in enumerate(data_loader):

            x, y, x2, y2 = x.to(device), y.to(device), x2.to(device), y2.to(
                device)

            recon_x, mean, log_var, z = vae(x, y)

            if z.shape[0] != args.batch_size:
                print("passed a batch in epoch {}, iteration {}!".format(
                    epoch, iteration))
                continue

            D_z = D(z)

            vae_recon_loss = recon_loss(x, recon_x) * args.recon_weight
            vae_kld = kl_divergence(mean, log_var)
            vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() * args.gamma
            vae_loss = vae_recon_loss + vae_tc_loss  #+ vae_kld

            optim_vae.zero_grad()
            vae_loss.backward(retain_graph=True)

            z_prime = vae(x2, y2, no_dec=True)
            z_pperm = permute_dims(z_prime).detach()
            D_z_pperm = D(z_pperm)
            D_tc_loss = 0.5 * (F.cross_entropy(D_z, zeros) +
                               F.cross_entropy(D_z_pperm, ones))

            optim_D.zero_grad()
            D_tc_loss.backward()
            optim_vae.step()
            optim_D.step()

            if iteration % args.print_iter == 0:
                pbar.write(
                    '[epoch {}/{}, iter {}/{}] vae_recon_loss:{:.4f} vae_kld:{:.4f} vae_tc_loss:{:.4f} D_tc_loss:{:.4f}'
                    .format(epoch, args.epochs, iteration,
                            len(data_loader) - 1, vae_recon_loss.item(),
                            vae_kld.item(), vae_tc_loss.item(),
                            D_tc_loss.item()))

            if iteration % args.output_iter == 0 and iteration != 0:
                output_dir = os.path.join(
                    args.output_dir,
                    args.name)  #, "{}.{}".format(epoch, iteration))
                mkdirs(output_dir)

                #reconstruction
                #image_gather.insert(true=x.data.cpu(), recon=torch.sigmoid(recon_x).data.cpu())
                #data = image_gather.data
                #true_image = data['true'][0]
                #recon_image = data['recon'][0]
                #true_image = make_grid(true_image)
                #recon_image = make_grid(recon_image)
                #sample = torch.stack([true_image, recon_image], dim=0)
                #save_image(tensor=sample.cpu(), fp=os.path.join(output_dir, "recon.jpg"))
                #image_gather.flush()

                #inference given num_labels = 10
                c = torch.randint(low=0, high=2,
                                  size=(1, 10))  #populated with 0s and 1s
                for i in range(9):
                    c = torch.cat(
                        (c, torch.randint(low=0, high=2, size=(1, 10))), 0)
                c = c.to(device)
                z_inf = torch.rand([c.size(0), args.z_dim]).to(device)
                #print("shapes: ",z_inf.shape, c.shape)
                #c = c.reshape(-1,args.num_labels,1,1)
                z_inf = torch.cat((z_inf, c), dim=1)
                z_inf = z_inf.reshape(-1, args.num_labels + args.z_dim, 1, 1)
                x = vae.decode(z_inf)

                plt.figure()
                plt.figure(figsize=(10, 20))
                for p in range(args.num_labels):
                    plt.subplot(5, 2, p + 1)  #row, col, index starting from 1
                    plt.text(0,
                             0,
                             "c={}".format(c[p]),
                             color='black',
                             backgroundcolor='white',
                             fontsize=10)

                    p = x[p].view(3, 218, 178)
                    image = torch.transpose(p, 0, 2)
                    image = torch.transpose(image, 0, 1)
                    plt.imshow(
                        (image.cpu().data.numpy() * 255).astype(np.uint8))
                    plt.axis('off')

                plt.savefig(os.path.join(
                    output_dir, "E{:d}||{:d}.png".format(epoch, iteration)),
                            dpi=300)
                plt.clf()
                plt.close('all')

        if epoch % 8 == 0:
            optim_vae.param_groups[0]['lr'] /= 10
            optim_D.param_groups[0]['lr'] /= 10
            print("\nnew learning rate at epoch {} is {}!".format(
                epoch, optim_vae.param_groups[0]['lr']))

        if epoch % args.ckpt_iter_epoch == 0:
            save_checkpoint(pbar, epoch, D, vae, optim_D, optim_vae, ckpt_dir,
                            epoch)

    pbar.write("[Training Finished]")
    pbar.close()
Пример #9
0
    def train(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size,
                          dtype=torch.long,
                          device=self.device)
        zeros = torch.zeros(self.batch_size,
                            dtype=torch.long,
                            device=self.device)

        out = False
        while not out:
            for x_true1, x_true2 in self.data_loader:
                self.global_iter += 1
                self.pbar.update(1)

                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                vae_recon_loss = recon_loss(x_true1, x_recon)
                vae_ad_loss = self.get_ad_loss(z)
                vae_kld = kl_divergence(mu, logvar)

                D_z = self.D(z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

                vae_loss = vae_recon_loss + vae_kld + self.gamma * vae_tc_loss + self.lamb * vae_ad_loss

                x_true2 = x_true2.to(self.device)
                z_prime = self.VAE(x_true2, no_dec=True)
                z_pperm = permute_dims(z_prime).detach()
                D_z_pperm = self.D(z_pperm)
                D_tc_loss = 0.5 * (F.cross_entropy(D_z, zeros) +
                                   F.cross_entropy(D_z_pperm, ones))

                self.optim_VAE.zero_grad()
                vae_loss.backward(retain_graph=True)

                self.optim_D.zero_grad()
                D_tc_loss.backward()

                self.optim_VAE.step()
                self.optim_D.step()

                if self.global_iter % self.print_iter == 0:
                    if self.dis_score:
                        dis_score = disentanglement_score(
                            self.VAE.eval(), self.device, self.dataset,
                            self.z_dim, self.L, self.vote_count,
                            self.dis_batch_size)
                        self.VAE.train()
                    else:
                        dis_score = torch.tensor(0)

                    self.pbar.write(
                        '[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} ad_loss:{:.3f} D_tc_loss:{:.3f} dis_score:{:.3f}'
                        .format(self.global_iter, vae_recon_loss.item(),
                                vae_kld.item(), vae_tc_loss.item(),
                                vae_ad_loss.item(), D_tc_loss.item(),
                                dis_score.item()))

                    if self.results_save:
                        self.outputs['vae_recon_loss'].append(
                            vae_recon_loss.item())
                        self.outputs['vae_kld'].append(vae_kld.item())
                        self.outputs['vae_tc_loss'].append(vae_tc_loss.item())
                        self.outputs['D_tc_loss'].append(D_tc_loss.item())
                        self.outputs['ad_loss'].append(vae_ad_loss.item())
                        self.outputs['dis_score'].append(dis_score.item())
                        self.outputs['iteration'].append(self.global_iter)

                if self.global_iter % self.ckpt_save_iter == 0:
                    self.save_checkpoint(self.global_iter)

                if self.global_iter >= self.max_iter:
                    out = True
                    break

        self.pbar.write("[Training Finished]")
        self.pbar.close()

        if self.results_save:
            save_args_outputs(self.results_dir, self.args, self.outputs)
Пример #10
0
    def train(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device)
        zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device)

        out = False
        while not out:
            for x_true1, x_true2 in self.data_loader:#ここで読み込んでいる?
                self.global_iter += 1
                self.pbar.update(1)
                if self.dataset == 'mnist':
                     x_true1 =  x_true1.view(x_true1.shape[0], -1)
                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                x = x_true1.view(x_true1.shape[0], -1) #custom

                #vae_recon_loss = self.custom_loss(x) / self.batch_size #custom
                vae_recon_loss = recon_loss(x, x_recon) #復元誤差, 交差エントロピー誤差
                vae_kld = kl_divergence(mu, logvar)
                D_z = self.D(z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() #恐らく, discriminatorのloss

                vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss
                #vae_loss = vae_recon_loss + self.gamma*vae_tc_loss 
                self.optim_VAE.zero_grad()
                vae_loss.backward(retain_graph=True)
                self.optim_VAE.step()
                x_true2 = x_true2.to(self.device)
                #x_true2 = x_true2.view(x_true2.shape[0], -1)
                z_prime = self.VAE(x_true2, no_dec=True) #trueにすることで潜在空間に写像した状態のデータを獲得?
                z_pperm = permute_dims(z_prime).detach()
                D_z_pperm = self.D(z_pperm)
                D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones)) #GANのdiscriminatorっぽい?偽物と本物
                #そのため誤差の部分が0と1になっているはず!zerosとonesの部分

                self.optim_D.zero_grad()
                D_tc_loss.backward()
                self.optim_D.step()

                #if self.global_iter%self.print_iter == 0:
                #    self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format(
                #        self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item()))
                if self.test_count % 547 == 0:
                    #self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format(
                        #self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item()))
                    self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format(
                        self.global_iter, vae_recon_loss.item(), vae_tc_loss.item(), D_tc_loss.item()))  
                    self.test_count = 0
                
                if self.global_iter%self.ckpt_save_iter == 0:
                    self.save_checkpoint(self.global_iter)

                if self.viz_on and (self.global_iter%self.viz_ll_iter == 0):
                    soft_D_z = F.softmax(D_z, 1)[:, :1].detach()
                    soft_D_z_pperm = F.softmax(D_z_pperm, 1)[:, :1].detach()
                    D_acc = ((soft_D_z >= 0.5).sum() + (soft_D_z_pperm < 0.5).sum()).float()
                    D_acc /= 2*self.batch_size
                    self.line_gather.insert(iter=self.global_iter,
                                            soft_D_z=soft_D_z.mean().item(),
                                            soft_D_z_pperm=soft_D_z_pperm.mean().item(),
                                            recon=vae_recon_loss.item(),
                                            #kld=vae_kld.item(),
                                            acc=D_acc.item())

                if self.viz_on and (self.global_iter%self.viz_la_iter == 0):
                    self.visualize_line()
                    self.line_gather.flush()

                if self.viz_on and (self.global_iter%self.viz_ra_iter == 0):
                    self.image_gather.insert(true=x_true1.data.cpu(),
                                             recon=F.sigmoid(x_recon).data.cpu())
                    self.visualize_recon()
                    self.image_gather.flush()

                if self.viz_on and (self.global_iter%self.viz_ta_iter == 0):
                    if self.dataset.lower() == '3dchairs':
                        self.visualize_traverse(limit=2, inter=0.5)
                    else:
                        #self.visualize_traverse(limit=3, inter=2/3)
                        print("ignore")

                if self.global_iter >= self.max_iter:
                    out = True
                    break
                self.test_count += 1

        self.pbar.write("[Training Finished]")
        torch.save(self.VAE.state_dict(), "model1/0531_128_2_gamma2.pth")
        self.pbar.close()