コード例 #1
0
    def train(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device)
        zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device)

        out = False
        while not out:
            for x_true1, x_true2 in self.data_loader:
                self.global_iter += 1
                self.pbar.update(1)

                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                vae_recon_loss = recon_loss(x_true1, x_recon)
                vae_kld = kl_divergence(mu, logvar)

                D_z = self.D(z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

                vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss

                self.optim_VAE.zero_grad()
                vae_loss.backward(retain_graph=True)
                self.optim_VAE.step()

                x_true2 = x_true2.to(self.device)
                z_prime = self.VAE(x_true2, no_dec=True)
                z_pperm = permute_dims(z_prime).detach()
                D_z_pperm = self.D(z_pperm)
                D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones))

                self.optim_D.zero_grad()
                D_tc_loss.backward()
                self.optim_D.step()

                if self.global_iter%self.print_iter == 0:
                    self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format(
                        self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item()))

                if self.global_iter%self.ckpt_save_iter == 0:
                    self.save_checkpoint(self.global_iter)

                if self.global_iter >= self.max_iter:
                    out = True
                    break

        self.pbar.write("[Training Finished]")
        self.pbar.close()
コード例 #2
0
ファイル: solver.py プロジェクト: ThomasMrY/RF-VAE
    def train(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device)
        zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device)

        out = False
        while not out:
            for x_true1, x_true2 in self.data_loader:
                self.global_iter += 1
                self.pbar.update(1)

                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                vae_recon_loss = recon_loss(x_true1, x_recon)
                vae_kld = kl_divergence(mu, logvar,self.r)
                H_r = entropy(self.r)

                D_z = self.D(self.r*z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

                vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss + self.etaS*self.r.abs().sum() + self.etaH*H_r

                self.optim_VAE.zero_grad()
                vae_loss.backward(retain_graph=True)
                self.optim_VAE.step()

                self.optim_r.zero_grad()
                vae_loss.backward(retain_graph=True)
                self.optim_r.step()

                x_true2 = x_true2.to(self.device)
                z_prime = self.VAE(x_true2, no_dec=True)
                z_pperm = permute_dims(z_prime).detach()
                D_z_pperm = self.D(self.r*z_pperm)
                D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones))

                self.optim_D.zero_grad()
                D_tc_loss.backward()
                self.optim_D.step()


                if self.global_iter%self.print_iter == 0:
                    self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format(
                        self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item()))

                if self.global_iter%self.ckpt_save_iter == 0:
                    self.save_checkpoint(self.global_iter)

                if self.viz_on and (self.global_iter%self.viz_ll_iter == 0):
                    soft_D_z = F.softmax(D_z, 1)[:, :1].detach()
                    soft_D_z_pperm = F.softmax(D_z_pperm, 1)[:, :1].detach()
                    D_acc = ((soft_D_z >= 0.5).sum() + (soft_D_z_pperm < 0.5).sum()).float()
                    D_acc /= 2*self.batch_size
                    self.line_gather.insert(iter=self.global_iter,
                                            soft_D_z=soft_D_z.mean().item(),
                                            soft_D_z_pperm=soft_D_z_pperm.mean().item(),
                                            recon=vae_recon_loss.item(),
                                            kld=vae_kld.item(),
                                            acc=D_acc.item(),
                                            r_distribute=self.r.data.cpu())

                if self.viz_on and (self.global_iter%self.viz_la_iter == 0):
                    self.visualize_line()
                    self.line_gather.flush()

                if self.viz_on and (self.global_iter%self.viz_ra_iter == 0):
                    self.image_gather.insert(true=x_true1.data.cpu(),
                                             recon=F.sigmoid(x_recon).data.cpu())
                    self.visualize_recon()
                    self.image_gather.flush()

                if self.viz_on and (self.global_iter%self.viz_ta_iter == 0):
                    if self.dataset.lower() == '3dchairs':
                        self.visualize_traverse(limit=2, inter=0.5)
                    else:
                        self.visualize_traverse(limit=3, inter=2/3)

                if self.global_iter >= self.max_iter:
                    out = True
                    break

        self.pbar.write("[Training Finished]")
        self.pbar.close()
コード例 #3
0
ファイル: vanilla.py プロジェクト: alfjesus3/FACT-2021
    def train(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device)
        zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device)
        metrics = []
        out = False
        while not out:
            for x_true1, x_true2 in self.data_loader:
                self.global_iter += 1
                self.pbar.update(1)

                self.optim_VAE.step()
                self.optim_D.step()

                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                vae_recon_loss = recon_loss(x_true1, x_recon)
                vae_kld = kl_divergence(mu, logvar)

                D_z = self.D(z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

                vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss

                self.optim_VAE.zero_grad()
                vae_loss.backward(retain_graph=True)
                #self.optim_VAE.step()

                x_true2 = x_true2.to(self.device)
                z_prime = self.VAE(x_true2, no_dec=True)
                z_pperm = permute_dims(z_prime).detach()
                D_z_pperm = self.D(z_pperm)
                D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones))

                self.optim_D.zero_grad()
                D_tc_loss.backward()
                #self.optim_D.step()


                # Saving the training metrics
                if self.global_iter % 100 == 0:
                    metrics.append({'its':self.global_iter,
                        'vae_loss': vae_loss.detach().to(torch.device("cpu")).item(),
                        'D_loss': D_tc_loss.detach().to(torch.device("cpu")).item(),
                        'recon_loss':vae_recon_loss.detach().to(torch.device("cpu")).item(),
                        'tc_loss': vae_tc_loss.detach().to(torch.device("cpu")).item()})

                # Saving the disentanglement metrics results
                if self.global_iter % 1500 == 0:
                    score = self.disentanglement_metric() 
                    metrics.append({'its':self.global_iter, 'metric_score': score})
                    self.net_mode(train=True) #To continue the training again

                if self.global_iter%self.print_iter == 0:
                    self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format(
                        self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item()))

                if self.global_iter%self.ckpt_save_iter == 0:
                    self.save_checkpoint(str(self.global_iter)+".pth")
                    self.save_metrics(metrics)
                    metrics = []

                if self.global_iter >= self.max_iter:
                    out = True
                    break

        self.pbar.write("[Training Finished]")
        self.pbar.close()
コード例 #4
0
ファイル: ad_factorvae.py プロジェクト: alfjesus3/FACT-2021
    def train(self):

        gcam = GradCamDissen(self.VAE,
                             self.D,
                             target_layer='encode.1',
                             cuda=True)
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size,
                          dtype=torch.long,
                          device=self.device)
        zeros = torch.zeros(self.batch_size,
                            dtype=torch.long,
                            device=self.device)
        mu_avg, logvar_avg = 0, 1
        metrics = []
        out = False
        while not out:
            for batch_idx, (x1, x2) in enumerate(self.data_loader):
                self.global_iter += 1
                self.pbar.update(1)

                self.optim_VAE.step()
                self.optim_D.step()

                x1 = x1.to(self.device)
                x1_rec, mu, logvar, z = gcam.forward(x1)
                # For Standard FactorVAE loss
                vae_recon_loss = recon_loss(x1, x1_rec)
                vae_kld = kl_divergence(mu, logvar)
                D_z = self.D(z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

                factorVae_loss = vae_recon_loss + vae_kld + self.gamma * vae_tc_loss
                # For attention disentanglement loss
                gcam.backward(mu, logvar, mu_avg, logvar_avg)
                att_loss = 0
                with torch.no_grad():
                    gcam_maps = gcam.generate()
                    selected = self.select_attention_maps(gcam_maps)
                    for (sel1, sel2) in selected:
                        att_loss += attention_disentanglement(sel1, sel2)
                att_loss /= len(
                    selected)  # Averaging the loss accross all pairs of maps

                vae_loss = factorVae_loss + self.lambdaa * att_loss
                self.optim_VAE.zero_grad()
                vae_loss.backward(retain_graph=True)
                #self.optim_VAE.step()

                x2 = x2.to(self.device)
                z_prime = self.VAE(x2, no_dec=True)
                z_pperm = permute_dims(z_prime).detach()
                D_z_pperm = self.D(z_pperm)
                D_tc_loss = 0.5 * (F.cross_entropy(D_z, zeros) +
                                   F.cross_entropy(D_z_pperm, ones))

                self.optim_D.zero_grad()
                D_tc_loss.backward()
                #self.optim_D.step()

                # Saving the training metrics
                if self.global_iter % 100 == 0:
                    metrics.append({
                        'its':
                        self.global_iter,
                        'vae_loss':
                        vae_loss.detach().to(torch.device("cpu")).item(),
                        'D_loss':
                        D_tc_loss.detach().to(torch.device("cpu")).item(),
                        'recon_loss':
                        vae_recon_loss.detach().to(torch.device("cpu")).item(),
                        'tc_loss':
                        vae_tc_loss.detach().to(torch.device("cpu")).item()
                    })

                # Saving the disentanglement metrics results
                if self.global_iter % 1500 == 0:
                    score = self.disentanglement_metric()
                    metrics.append({
                        'its': self.global_iter,
                        'metric_score': score
                    })
                    self.net_mode(train=True)  # To continue the training again

                if self.global_iter % self.print_iter == 0:
                    self.pbar.write(
                        '[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'
                        .format(self.global_iter, vae_recon_loss.item(),
                                vae_kld.item(), vae_tc_loss.item(),
                                D_tc_loss.item()))

                if self.global_iter % self.ckpt_save_iter == 0:
                    self.save_checkpoint(str(self.global_iter) + ".pth")
                    self.save_metrics(metrics)
                    metrics = []

                if self.global_iter >= self.max_iter:
                    out = True
                    break

        self.pbar.write("[Training Finished]")
        self.pbar.close()
コード例 #5
0
def main(args):

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    pbar = tqdm(total=args.epochs)
    image_gather = DataGather('true', 'recon')

    dataset = get_celeba_selected_dataset()
    data_loader = DataLoader(dataset=dataset,
                             batch_size=args.batch_size,
                             shuffle=True)

    lr_vae = args.lr_vae
    lr_D = args.lr_D
    vae = CelebaFactorVAE(args.z_dim, args.num_labels).to(device)
    optim_vae = torch.optim.Adam(vae.parameters(), lr=args.lr_vae)

    D = Discriminator(args.z_dim, args.num_labels).to(device)
    optim_D = torch.optim.Adam(D.parameters(), lr=args.lr_D, betas=(0.5, 0.9))

    # Checkpoint
    ckpt_dir = os.path.join(args.ckpt_dir, args.name)
    mkdirs(ckpt_dir)
    start_epoch = 0
    if args.ckpt_load:
        load_checkpoint(pbar, ckpt_dir, D, vae, optim_D, optim_vae, lr_vae,
                        lr_D)
        #optim_D.param_groups[0]['lr'] = 0.00001#lr_D
        #optim_vae.param_groups[0]['lr'] = 0.00001#lr_vae
        print("confirming lr after loading checkpoint: ",
              optim_vae.param_groups[0]['lr'])

    # Output
    output_dir = os.path.join(args.output_dir, args.name)
    mkdirs(output_dir)

    ones = torch.ones(args.batch_size, dtype=torch.long, device=device)
    zeros = torch.zeros(args.batch_size, dtype=torch.long, device=device)

    for epoch in range(start_epoch, args.epochs):
        pbar.update(1)

        for iteration, (x, y, x2, y2) in enumerate(data_loader):

            x, y, x2, y2 = x.to(device), y.to(device), x2.to(device), y2.to(
                device)

            recon_x, mean, log_var, z = vae(x, y)

            if z.shape[0] != args.batch_size:
                print("passed a batch in epoch {}, iteration {}!".format(
                    epoch, iteration))
                continue

            D_z = D(z)

            vae_recon_loss = recon_loss(x, recon_x) * args.recon_weight
            vae_kld = kl_divergence(mean, log_var)
            vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() * args.gamma
            vae_loss = vae_recon_loss + vae_tc_loss  #+ vae_kld

            optim_vae.zero_grad()
            vae_loss.backward(retain_graph=True)

            z_prime = vae(x2, y2, no_dec=True)
            z_pperm = permute_dims(z_prime).detach()
            D_z_pperm = D(z_pperm)
            D_tc_loss = 0.5 * (F.cross_entropy(D_z, zeros) +
                               F.cross_entropy(D_z_pperm, ones))

            optim_D.zero_grad()
            D_tc_loss.backward()
            optim_vae.step()
            optim_D.step()

            if iteration % args.print_iter == 0:
                pbar.write(
                    '[epoch {}/{}, iter {}/{}] vae_recon_loss:{:.4f} vae_kld:{:.4f} vae_tc_loss:{:.4f} D_tc_loss:{:.4f}'
                    .format(epoch, args.epochs, iteration,
                            len(data_loader) - 1, vae_recon_loss.item(),
                            vae_kld.item(), vae_tc_loss.item(),
                            D_tc_loss.item()))

            if iteration % args.output_iter == 0 and iteration != 0:
                output_dir = os.path.join(
                    args.output_dir,
                    args.name)  #, "{}.{}".format(epoch, iteration))
                mkdirs(output_dir)

                #reconstruction
                #image_gather.insert(true=x.data.cpu(), recon=torch.sigmoid(recon_x).data.cpu())
                #data = image_gather.data
                #true_image = data['true'][0]
                #recon_image = data['recon'][0]
                #true_image = make_grid(true_image)
                #recon_image = make_grid(recon_image)
                #sample = torch.stack([true_image, recon_image], dim=0)
                #save_image(tensor=sample.cpu(), fp=os.path.join(output_dir, "recon.jpg"))
                #image_gather.flush()

                #inference given num_labels = 10
                c = torch.randint(low=0, high=2,
                                  size=(1, 10))  #populated with 0s and 1s
                for i in range(9):
                    c = torch.cat(
                        (c, torch.randint(low=0, high=2, size=(1, 10))), 0)
                c = c.to(device)
                z_inf = torch.rand([c.size(0), args.z_dim]).to(device)
                #print("shapes: ",z_inf.shape, c.shape)
                #c = c.reshape(-1,args.num_labels,1,1)
                z_inf = torch.cat((z_inf, c), dim=1)
                z_inf = z_inf.reshape(-1, args.num_labels + args.z_dim, 1, 1)
                x = vae.decode(z_inf)

                plt.figure()
                plt.figure(figsize=(10, 20))
                for p in range(args.num_labels):
                    plt.subplot(5, 2, p + 1)  #row, col, index starting from 1
                    plt.text(0,
                             0,
                             "c={}".format(c[p]),
                             color='black',
                             backgroundcolor='white',
                             fontsize=10)

                    p = x[p].view(3, 218, 178)
                    image = torch.transpose(p, 0, 2)
                    image = torch.transpose(image, 0, 1)
                    plt.imshow(
                        (image.cpu().data.numpy() * 255).astype(np.uint8))
                    plt.axis('off')

                plt.savefig(os.path.join(
                    output_dir, "E{:d}||{:d}.png".format(epoch, iteration)),
                            dpi=300)
                plt.clf()
                plt.close('all')

        if epoch % 8 == 0:
            optim_vae.param_groups[0]['lr'] /= 10
            optim_D.param_groups[0]['lr'] /= 10
            print("\nnew learning rate at epoch {} is {}!".format(
                epoch, optim_vae.param_groups[0]['lr']))

        if epoch % args.ckpt_iter_epoch == 0:
            save_checkpoint(pbar, epoch, D, vae, optim_D, optim_vae, ckpt_dir,
                            epoch)

    pbar.write("[Training Finished]")
    pbar.close()
コード例 #6
0
def disentanglement_score(model, device, dataset, z_dim, L=100, n_votes=800, batch_size=2048, verbose=False):
    factors = dataset.latents_classes
    max_factors = dataset.latents_classes[-1]
    n_factors = len(max_factors)

    n_votes_per_factor = int(n_votes / n_factors)

    test_loader = DataLoader(dataset,
                              batch_size=2048,
                              shuffle=False,
                              num_workers=2,
                              pin_memory=True,
                              drop_last=False)

    all_latents = []
    for input, _ in test_loader:
        latents = model.encode(input.to(device)).detach().cpu().flatten(start_dim=1)
        all_latents.append(latents)

    # Concatenate every encoding
    all_latents = torch.cat(all_latents)

    # Compute KL divergence per latent dimension
    emp_mean_kl = kl_divergence(all_latents[:, :z_dim], all_latents[:, z_dim:], dim_wise=True)

    # Remove the dimensions that collapsed to the prior
    kl_tol = 1e-5
    useful_dims = np.where(emp_mean_kl.numpy() > kl_tol)[0]
    u_dim = len(useful_dims)

    # Compute scales for useful dims
    scales = torch.std(all_latents[:, useful_dims], axis=0)

    picked_latents = []
    # Fix a factor k
    for k_fixed in range(n_factors):
        # Generate training examples for this factor
        for _ in range(n_votes_per_factor):
            # Fix a value for this factor
            fixed_factor = np.random.randint(0, max_factors[k_fixed]+1)
            sample_indices = np.random.choice(np.where(factors[:, k_fixed] == fixed_factor)[0], size=L)

            latents = all_latents[sample_indices]
            picked_latents.append(latents)

    picked_latents = torch.cat(picked_latents)
    picked_latents = picked_latents[:, useful_dims]/scales

    if verbose:
        print("Remaining dimensions")
        print(u_dim)
        print("Empirical mean for kl dimension-wise:")
        print(list(emp_mean_kl))
        print("Useful dimensions:", list(useful_dims), " - Total:", useful_dims.shape[0])
        print("Empirical Scales:", list(scales))

    r1 = 0
    v_matrix = torch.zeros((u_dim, n_factors))

    # Fix a factor k
    for k_fixed in range(n_factors):
        # Retrieve training examples for this factor
        for i in range(n_votes_per_factor):
            r2 = r1 + L
            norm_latents = picked_latents[r1:r2]
            # Take the empirical variance in each dimension of these normalised representations
            emp_var = torch.var(norm_latents, axis=0)
            # Then the index of the dimension with the lowest variance...
            d_j = torch.argmin(emp_var)
            # ...and the target index k provide one training input/output example for the classifier majority vote
            v_matrix[d_j, k_fixed] += 1

            r1 = r2

    if verbose:
        print("Votes:")
        print(v_matrix.numpy())

    # Since both inputs and outputs lie in a discrete space, the optimal classifier is the majority-vote classifier
    # and the metric is the error rate of the classifier (actually they show the accuracy in the paper)

    return torch.sum(torch.max(v_matrix, dim=1).values)/n_votes
コード例 #7
0
ファイル: solver.py プロジェクト: FrankBrongers/FactorVAE
    def train(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size,
                          dtype=torch.long,
                          device=self.device)
        zeros = torch.zeros(self.batch_size,
                            dtype=torch.long,
                            device=self.device)

        out = False
        while not out:
            for x_true1, x_true2 in self.data_loader:
                self.global_iter += 1
                self.pbar.update(1)

                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                vae_recon_loss = recon_loss(x_true1, x_recon)
                vae_ad_loss = self.get_ad_loss(z)
                vae_kld = kl_divergence(mu, logvar)

                D_z = self.D(z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean()

                vae_loss = vae_recon_loss + vae_kld + self.gamma * vae_tc_loss + self.lamb * vae_ad_loss

                x_true2 = x_true2.to(self.device)
                z_prime = self.VAE(x_true2, no_dec=True)
                z_pperm = permute_dims(z_prime).detach()
                D_z_pperm = self.D(z_pperm)
                D_tc_loss = 0.5 * (F.cross_entropy(D_z, zeros) +
                                   F.cross_entropy(D_z_pperm, ones))

                self.optim_VAE.zero_grad()
                vae_loss.backward(retain_graph=True)

                self.optim_D.zero_grad()
                D_tc_loss.backward()

                self.optim_VAE.step()
                self.optim_D.step()

                if self.global_iter % self.print_iter == 0:
                    if self.dis_score:
                        dis_score = disentanglement_score(
                            self.VAE.eval(), self.device, self.dataset,
                            self.z_dim, self.L, self.vote_count,
                            self.dis_batch_size)
                        self.VAE.train()
                    else:
                        dis_score = torch.tensor(0)

                    self.pbar.write(
                        '[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} ad_loss:{:.3f} D_tc_loss:{:.3f} dis_score:{:.3f}'
                        .format(self.global_iter, vae_recon_loss.item(),
                                vae_kld.item(), vae_tc_loss.item(),
                                vae_ad_loss.item(), D_tc_loss.item(),
                                dis_score.item()))

                    if self.results_save:
                        self.outputs['vae_recon_loss'].append(
                            vae_recon_loss.item())
                        self.outputs['vae_kld'].append(vae_kld.item())
                        self.outputs['vae_tc_loss'].append(vae_tc_loss.item())
                        self.outputs['D_tc_loss'].append(D_tc_loss.item())
                        self.outputs['ad_loss'].append(vae_ad_loss.item())
                        self.outputs['dis_score'].append(dis_score.item())
                        self.outputs['iteration'].append(self.global_iter)

                if self.global_iter % self.ckpt_save_iter == 0:
                    self.save_checkpoint(self.global_iter)

                if self.global_iter >= self.max_iter:
                    out = True
                    break

        self.pbar.write("[Training Finished]")
        self.pbar.close()

        if self.results_save:
            save_args_outputs(self.results_dir, self.args, self.outputs)
コード例 #8
0
ファイル: solver.py プロジェクト: OZA15015/custom_FVAE
    def train(self):
        self.net_mode(train=True)

        ones = torch.ones(self.batch_size, dtype=torch.long, device=self.device)
        zeros = torch.zeros(self.batch_size, dtype=torch.long, device=self.device)

        out = False
        while not out:
            for x_true1, x_true2 in self.data_loader:#ここで読み込んでいる?
                self.global_iter += 1
                self.pbar.update(1)
                if self.dataset == 'mnist':
                     x_true1 =  x_true1.view(x_true1.shape[0], -1)
                x_true1 = x_true1.to(self.device)
                x_recon, mu, logvar, z = self.VAE(x_true1)
                x = x_true1.view(x_true1.shape[0], -1) #custom

                #vae_recon_loss = self.custom_loss(x) / self.batch_size #custom
                vae_recon_loss = recon_loss(x, x_recon) #復元誤差, 交差エントロピー誤差
                vae_kld = kl_divergence(mu, logvar)
                D_z = self.D(z)
                vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() #恐らく, discriminatorのloss

                vae_loss = vae_recon_loss + vae_kld + self.gamma*vae_tc_loss
                #vae_loss = vae_recon_loss + self.gamma*vae_tc_loss 
                self.optim_VAE.zero_grad()
                vae_loss.backward(retain_graph=True)
                self.optim_VAE.step()
                x_true2 = x_true2.to(self.device)
                #x_true2 = x_true2.view(x_true2.shape[0], -1)
                z_prime = self.VAE(x_true2, no_dec=True) #trueにすることで潜在空間に写像した状態のデータを獲得?
                z_pperm = permute_dims(z_prime).detach()
                D_z_pperm = self.D(z_pperm)
                D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones)) #GANのdiscriminatorっぽい?偽物と本物
                #そのため誤差の部分が0と1になっているはず!zerosとonesの部分

                self.optim_D.zero_grad()
                D_tc_loss.backward()
                self.optim_D.step()

                #if self.global_iter%self.print_iter == 0:
                #    self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format(
                #        self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item()))
                if self.test_count % 547 == 0:
                    #self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_kld:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format(
                        #self.global_iter, vae_recon_loss.item(), vae_kld.item(), vae_tc_loss.item(), D_tc_loss.item()))
                    self.pbar.write('[{}] vae_recon_loss:{:.3f} vae_tc_loss:{:.3f} D_tc_loss:{:.3f}'.format(
                        self.global_iter, vae_recon_loss.item(), vae_tc_loss.item(), D_tc_loss.item()))  
                    self.test_count = 0
                
                if self.global_iter%self.ckpt_save_iter == 0:
                    self.save_checkpoint(self.global_iter)

                if self.viz_on and (self.global_iter%self.viz_ll_iter == 0):
                    soft_D_z = F.softmax(D_z, 1)[:, :1].detach()
                    soft_D_z_pperm = F.softmax(D_z_pperm, 1)[:, :1].detach()
                    D_acc = ((soft_D_z >= 0.5).sum() + (soft_D_z_pperm < 0.5).sum()).float()
                    D_acc /= 2*self.batch_size
                    self.line_gather.insert(iter=self.global_iter,
                                            soft_D_z=soft_D_z.mean().item(),
                                            soft_D_z_pperm=soft_D_z_pperm.mean().item(),
                                            recon=vae_recon_loss.item(),
                                            #kld=vae_kld.item(),
                                            acc=D_acc.item())

                if self.viz_on and (self.global_iter%self.viz_la_iter == 0):
                    self.visualize_line()
                    self.line_gather.flush()

                if self.viz_on and (self.global_iter%self.viz_ra_iter == 0):
                    self.image_gather.insert(true=x_true1.data.cpu(),
                                             recon=F.sigmoid(x_recon).data.cpu())
                    self.visualize_recon()
                    self.image_gather.flush()

                if self.viz_on and (self.global_iter%self.viz_ta_iter == 0):
                    if self.dataset.lower() == '3dchairs':
                        self.visualize_traverse(limit=2, inter=0.5)
                    else:
                        #self.visualize_traverse(limit=3, inter=2/3)
                        print("ignore")

                if self.global_iter >= self.max_iter:
                    out = True
                    break
                self.test_count += 1

        self.pbar.write("[Training Finished]")
        torch.save(self.VAE.state_dict(), "model1/0531_128_2_gamma2.pth")
        self.pbar.close()