def main():
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    model = Classifier().cuda()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    def filter_background(x):
        x[:, (x < 0.3).any(dim=0)] = 0.0
        return x

    transform = transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        # filter_background,
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    train_dset = ImagePairs(root=args.train_root,
                            transform=transform,
                            n_frames_apart=args.n_frames_apart,
                            include_neg=True)
    test_dset = ImagePairs(root=args.test_root,
                           transform=transform,
                           n_frames_apart=args.n_frames_apart,
                           include_neg=True)
    train_loader = data.DataLoader(train_dset,
                                   batch_size=args.bs,
                                   shuffle=True,
                                   num_workers=2)
    test_loader = data.DataLoader(test_dset,
                                  batch_size=args.bs,
                                  shuffle=False,
                                  num_workers=2)

    from torchvision.utils import save_image
    imgs = next(iter(train_loader))[0][0] * 0.5 + 0.5
    no_bg = imgs.clone()
    no_bg[(no_bg < 0.3).any(dim=1, keepdim=True).repeat(1, 3, 1, 1)] = 0.0
    save_image(imgs, 'train_img.png')
    save_image(no_bg, 'train_img_nobg.png')

    for epoch in range(args.epochs):
        train(model, optimizer, train_loader, epoch)
        test(model, test_loader, epoch)

        torch.save(model.state_dict(), 'classifier.pt')
def calculate_gram_distances(gram_fn, imagenet_dir, target_dir, output_dir):
    create_folder(output_dir)
    image_transforms = Compose([ToTensor(), ImagenetNorm()])

    # compare model pre-trained on ImageNet vs fine-tuned on the dog breed dataset
    for layer in LAYERS:
        img_dir1 = os.path.join(imagenet_dir, 'features', layer)
        img_dir2 = os.path.join(target_dir, 'features', layer)
        image_pairs = ImagePairs(img_dir1, img_dir2)

        output_path = os.path.join(output_dir, f'imagenet-vs-dogs-{layer}.csv')
        compare_images_with_gram(image_pairs, gram_fn, output_path, DEVICE, image_transforms, layer)
def calculate_cosine_similarities(domain_dir, output_dir):
    """Calculate the cosine similarities for each of the target layers in the

    Parameters
    ----------
    domain_dir: str
        The path to the target domain dir (e.g. DOG_DIR for the dog classifier directory)
    output_dir: str
        The path to the output directory (will be created if it does not exist)
    """
    cos_sim_fn = CosineSimResnet50().to(DEVICE)
    create_folder(output_dir)
    image_transforms = Compose([ToTensor(), ImagenetNorm()])
    # compare model pre-trained on ImageNet vs fine-tuned on the target dataset
    for layer in LAYERS:
        img_dir1 = os.path.join(IMAGENET_DIR, 'features', layer)
        img_dir2 = os.path.join(domain_dir, 'features', layer)
        image_pairs = ImagePairs(img_dir1, img_dir2)
        output_path = os.path.join(output_dir, f'resnet50-cosine-sim-{layer}.csv')
        compare_images_with_cosine_sim(image_pairs, cos_sim_fn, output_path, DEVICE, image_transforms, layer)
Exemple #4
0
def get_dataloaders():
    def filter_background(x):
        x[:, (x < 0.3).any(dim=0)] = 0.0
        return x

    def dilate(x):
        x = x.squeeze(0).numpy()
        x = grey_dilation(x, size=3)
        x = x[None, :, :]
        return torch.from_numpy(x)

    if args.thanard_dset:
        transform = transforms.Compose([
            transforms.Resize(64),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
        ])
    else:
        transform = transforms.Compose([
            transforms.Resize(64),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            filter_background,
            lambda x: x.mean(dim=0)[None, :, :],
            dilate,
            transforms.Normalize((0.5, ), (0.5, )),
        ])

    train_dset = ImagePairs(root=join(args.root, 'train_data'),
                            include_actions=args.include_actions,
                            thanard_dset=args.thanard_dset,
                            transform=transform,
                            n_frames_apart=args.k)
    train_loader = data.DataLoader(train_dset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=2,
                                   drop_last=True)

    test_dset = ImagePairs(root=join(args.root, 'test_data'),
                           include_actions=args.include_actions,
                           thanard_dset=args.thanard_dset,
                           transform=transform,
                           n_frames_apart=args.k)
    test_loader = data.DataLoader(test_dset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=2,
                                  drop_last=True)

    neg_train_dset = ImageFolder(join(args.root, 'train_data'),
                                 transform=transform)
    neg_train_loader = data.DataLoader(neg_train_dset,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       pin_memory=True,
                                       num_workers=2)  # for training decoder
    neg_train_inf = infinite_loader(
        data.DataLoader(neg_train_dset,
                        batch_size=args.n,
                        shuffle=True,
                        pin_memory=True,
                        num_workers=2,
                        drop_last=True))  # to get negative samples

    neg_test_dset = ImageFolder(join(args.root, 'test_data'),
                                transform=transform)
    neg_test_loader = data.DataLoader(neg_test_dset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      pin_memory=True,
                                      num_workers=2)
    neg_test_inf = infinite_loader(
        data.DataLoader(neg_test_dset,
                        batch_size=args.n,
                        shuffle=True,
                        pin_memory=True,
                        num_workers=2,
                        drop_last=True))

    start_dset = ImageFolder(join(args.root, 'seq_data', 'start'),
                             transform=transform)
    goal_dset = ImageFolder(join(args.root, 'seq_data', 'goal'),
                            transform=transform)

    start_images = torch.stack(
        [start_dset[i][0] for i in range(len(start_dset))], dim=0)
    goal_images = torch.stack([goal_dset[i][0] for i in range(len(goal_dset))],
                              dim=0)

    n = min(start_images.shape[0], goal_images.shape[0])
    start_images, goal_images = start_images[:n], goal_images[:n]

    return train_loader, test_loader, neg_train_loader, neg_test_loader, neg_train_inf, neg_test_inf, start_images, goal_images
Exemple #5
0
    def train(self):
        # Set up training.
        real_o = Variable(torch.FloatTensor(self.batch_size, 3, 64, 64).cuda(),
                          requires_grad=False)
        real_o_next = Variable(torch.FloatTensor(self.batch_size, 3, 64,
                                                 64).cuda(),
                               requires_grad=False)
        label = Variable(torch.FloatTensor(self.batch_size).cuda(),
                         requires_grad=False)
        z = Variable(torch.FloatTensor(self.batch_size,
                                       self.rand_z_dim).cuda(),
                     requires_grad=False)

        criterionD = nn.BCELoss().cuda()

        optimD = optim.Adam([{
            'params': self.D.parameters()
        }],
                            lr=self.lr_d,
                            betas=(0.5, 0.999))
        optimG = optim.Adam([{
            'params': self.G.parameters()
        }, {
            'params': self.Q.parameters()
        }, {
            'params': self.T.parameters()
        }],
                            lr=self.lr_g,
                            betas=(0.5, 0.999))
        ############################################
        # Load rope dataset and apply transformations
        rope_path = os.path.realpath(self.data_dir)

        def filter_background(x):
            x[:, (x < 0.3).any(dim=0)] = 0.0
            return x

        def dilate(x):
            x = x.squeeze(0).numpy()
            x = grey_dilation(x, size=3)
            x = x[None, :, :]
            return torch.from_numpy(x)

        trans = [
            transforms.Resize(64),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            filter_background,
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

        if not self.fcn:
            # If fcn it will do the transformation to gray
            # and normalize in the loop.
            # trans.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
            if self.gray:
                # Apply grayscale transformation.
                trans.append(lambda x: x.mean(dim=0)[None, :, :])
                trans.append(dilate)
                trans.append(transforms.Normalize((0.5, ), (0.5, )))

        trans_comp = transforms.Compose(trans)
        # Image 1 and image 2 are k steps apart.
        dataset = ImagePairs(root=rope_path,
                             transform=trans_comp,
                             n_frames_apart=self.k)
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=self.batch_size,
                                                 shuffle=True,
                                                 num_workers=2,
                                                 drop_last=True)
        from torchvision.utils import save_image
        imgs = next(iter(dataloader))[0][0]
        save_image(imgs * 0.5 + 0.5, 'train_img.png')
        ############################################
        # Load eval plan dataset
        planning_data_dir = self.planning_data_dir
        dataset_start = dset.ImageFolder(root=os.path.join(
            planning_data_dir, 'start'),
                                         transform=trans_comp)
        dataset_goal = dset.ImageFolder(root=os.path.join(
            planning_data_dir, 'goal'),
                                        transform=trans_comp)
        data_start_loader = torch.utils.data.DataLoader(dataset_start,
                                                        batch_size=1,
                                                        shuffle=False,
                                                        num_workers=1,
                                                        drop_last=True)
        data_goal_loader = torch.utils.data.DataLoader(dataset_goal,
                                                       batch_size=1,
                                                       shuffle=False,
                                                       num_workers=1,
                                                       drop_last=True)
        ############################################
        for epoch in range(self.n_epochs + 1):
            self.G.train()
            self.D.train()
            self.Q.train()
            self.T.train()
            for num_iters, batch_data in enumerate(dataloader, 0):
                # Real data
                o = batch_data[0]
                o_next = batch_data[1]
                bs = o.size(0)

                real_o.data.resize_(o.size())
                real_o_next.data.resize_(o_next.size())
                label.data.resize_(bs)

                real_o.data.copy_(o)
                real_o_next.data.copy_(o_next)
                if self.fcn:
                    real_o = self.apply_fcn_mse(o)
                    real_o_next = self.apply_fcn_mse(o_next)
                    if real_o.abs().max() > 1:
                        import ipdb
                        ipdb.set_trace()
                    assert real_o.abs().max() <= 1

                if epoch == 0:
                    break
                ############################################
                # D Loss (Update D)
                optimD.zero_grad()
                # Real data
                probs_real = self.D(real_o, real_o_next)
                label.data.fill_(1)
                loss_real = criterionD(probs_real, label)
                loss_real.backward()

                # Fake data
                z, c, c_next = self._noise_sample(z, bs)
                fake_o, fake_o_next = self.G(z, c, c_next)
                probs_fake = self.D(fake_o.detach(), fake_o_next.detach())
                label.data.fill_(0)
                loss_fake = criterionD(probs_fake, label)
                loss_fake.backward()

                D_loss = loss_real + loss_fake

                optimD.step()
                ############################################
                # G loss (Update G)
                optimG.zero_grad()

                probs_fake_2 = self.D(fake_o, fake_o_next)
                label.data.fill_(1)
                G_loss = criterionD(probs_fake_2, label)

                # Q loss (Update G, T, Q)
                ent_loss = -self.P.log_prob(c).mean(0)
                crossent_loss = -self.Q.log_prob(fake_o, c).mean(0)
                crossent_loss_next = -self.Q.log_prob(fake_o_next,
                                                      c_next).mean(0)
                # trans_prob = self.T.get_prob(Variable(torch.eye(self.dis_c_dim).cuda()))
                ent_loss_next = -self.T.log_prob(c, None, c_next).mean(0)
                mi_loss = crossent_loss - ent_loss
                mi_loss_next = crossent_loss_next - ent_loss_next
                Q_loss = mi_loss + mi_loss_next

                # T loss (Update T)
                Q_c_given_x, Q_c_given_x_var = (
                    i.detach() for i in self.Q.forward(real_o))
                t_mu, t_variance = self.T.get_mu_and_var(c)
                t_diff = t_mu - c
                # Keep the variance small.
                # TODO: add loss on t_diff
                T_loss = (t_variance**2).sum(1).mean(0)

                (G_loss + self.infow * Q_loss +
                 self.transw * T_loss).backward()
                optimG.step()
                #############################################
                # Logging (iteration)
                if num_iters % 100 == 0:
                    self.log_dict['Dloss'] = D_loss.item()
                    self.log_dict['Gloss'] = G_loss.item()
                    self.log_dict['Qloss'] = Q_loss.item()
                    self.log_dict['Tloss'] = T_loss.item()
                    self.log_dict['mi_loss'] = mi_loss.item()
                    self.log_dict['mi_loss_next'] = mi_loss_next.item()
                    self.log_dict['ent_loss'] = ent_loss.item()
                    self.log_dict['ent_loss_next'] = ent_loss_next.item()
                    self.log_dict['crossent_loss'] = crossent_loss.item()
                    self.log_dict[
                        'crossent_loss_next'] = crossent_loss_next.item()
                    self.log_dict['D(real)'] = probs_real.data.mean()
                    self.log_dict['D(fake)_before'] = probs_fake.data.mean()
                    self.log_dict['D(fake)_after'] = probs_fake_2.data.mean()

                    write_stats_from_var(self.log_dict, Q_c_given_x,
                                         'Q_c_given_real_x_mu')
                    write_stats_from_var(self.log_dict,
                                         Q_c_given_x,
                                         'Q_c_given_real_x_mu',
                                         idx=0)
                    write_stats_from_var(self.log_dict, Q_c_given_x_var,
                                         'Q_c_given_real_x_variance')
                    write_stats_from_var(self.log_dict,
                                         Q_c_given_x_var,
                                         'Q_c_given_real_x_variance',
                                         idx=0)

                    write_stats_from_var(self.log_dict, t_mu, 't_mu')
                    write_stats_from_var(self.log_dict, t_mu, 't_mu', idx=0)
                    write_stats_from_var(self.log_dict, t_diff, 't_diff')
                    write_stats_from_var(self.log_dict,
                                         t_diff,
                                         't_diff',
                                         idx=0)
                    write_stats_from_var(self.log_dict, t_variance,
                                         't_variance')
                    write_stats_from_var(self.log_dict,
                                         t_variance,
                                         't_variance',
                                         idx=0)

                    print('\n#######################'
                          '\nEpoch/Iter:%d/%d; '
                          '\nDloss: %.3f; '
                          '\nGloss: %.3f; '
                          '\nQloss: %.3f, %.3f; '
                          '\nT_loss: %.3f; '
                          '\nEnt: %.3f, %.3f; '
                          '\nCross Ent: %.3f, %.3f; '
                          '\nD(x): %.3f; '
                          '\nD(G(z)): b %.3f, a %.3f;'
                          '\n0_Q_c_given_rand_x_mean: %.3f'
                          '\n0_Q_c_given_rand_x_std: %.3f'
                          '\n0_Q_c_given_fixed_x_std: %.3f'
                          '\nt_diff_abs_mean: %.3f'
                          '\nt_std_mean: %.3f' % (
                              epoch,
                              num_iters,
                              D_loss.item(),
                              G_loss.item(),
                              mi_loss.item(),
                              mi_loss_next.item(),
                              T_loss.item(),
                              ent_loss.item(),
                              ent_loss_next.item(),
                              crossent_loss.item(),
                              crossent_loss_next.item(),
                              probs_real.data.mean(),
                              probs_fake.data.mean(),
                              probs_fake_2.data.mean(),
                              Q_c_given_x[:, 0].cpu().numpy().mean(),
                              Q_c_given_x[:, 0].cpu().numpy().std(),
                              np.sqrt(Q_c_given_x_var[:,
                                                      0].cpu().numpy().mean()),
                              t_diff.data.abs().mean(),
                              t_variance.data.sqrt().mean(),
                          ))
            #############################################
            # Start evaluation from here.
            self.G.eval()
            self.D.eval()
            self.Q.eval()
            self.T.eval()
            #############################################
            # Save images
            # Plot fake data
            x_save, x_next_save = self.G(*self.eval_input,
                                         self.get_c_next(epoch))
            save_image(x_save.data,
                       os.path.join(self.out_dir, 'gen',
                                    'curr_samples_%03d.png' % epoch),
                       nrow=self.test_num_codes,
                       normalize=True)
            save_image(x_next_save.data,
                       os.path.join(self.out_dir, 'gen',
                                    'next_samples_%03d.png' % epoch),
                       nrow=self.test_num_codes,
                       normalize=True)
            save_image((x_save - x_next_save).data,
                       os.path.join(self.out_dir, 'gen',
                                    'diff_samples_%03d.png' % epoch),
                       nrow=self.test_num_codes,
                       normalize=True)
            # Plot real data.
            if epoch % 10 == 0:
                save_image(real_o.data,
                           os.path.join(self.out_dir, 'real',
                                        'real_samples_%d.png' % epoch),
                           nrow=self.test_num_codes,
                           normalize=True)
                save_image(real_o_next.data,
                           os.path.join(self.out_dir, 'real',
                                        'real_samples_next_%d.png' % epoch),
                           nrow=self.test_num_codes,
                           normalize=True)
            #############################################
            # Save parameters
            if epoch % 5 == 0:
                if not os.path.exists('%s/var' % self.out_dir):
                    os.makedirs('%s/var' % self.out_dir)
                for i in [self.G, self.D, self.Q, self.T]:
                    torch.save(
                        i.state_dict(),
                        os.path.join(self.out_dir, 'var', '%s_%d' % (
                            i.__class__.__name__,
                            epoch,
                        )))
            #############################################
            # Logging (epoch)
            for k, v in self.log_dict.items():
                log_value(k, v, epoch)

            if epoch > 0:
                # tf logger
                # log_value('avg|x_next - x|', (x_next_save.data - x_save.data).abs().mean(dim=0).sum(), epoch + 1)
                # self.logger.histo_summary("Q_c_given_x", Q_c_given_x.data.cpu().numpy().reshape(-1), step=epoch)
                # self.logger.histo_summary("Q_c0_given_x", Q_c_given_x[:, 0].data.cpu().numpy(), step=epoch)
                # self.logger.histo_summary("Q_c_given_x_var", Q_c_given_x_var.cpu().numpy().reshape(-1), step=epoch)
                # self.logger.histo_summary("Q_c0_given_x_var", Q_c_given_x_var[:, 0].data.cpu().numpy(), step=epoch)

                # csv log
                with open(os.path.join(self.out_dir, 'progress.csv'),
                          'a') as csv_file:
                    writer = csv.writer(csv_file)
                    if epoch == 1:
                        writer.writerow(["epoch"] + list(self.log_dict.keys()))
                    writer.writerow([
                        "%.3f" % _tmp
                        for _tmp in [epoch] + list(self.log_dict.values())
                    ])
            #############################################
            # Do planning?
            if self.plan_length <= 0 or epoch not in self.planning_epoch:
                continue
            print("\n#######################" "\nPlanning")
            #############################################
            # Showing plans on real images using best code.
            # Min l2 distance from start and goal real images.
            self.plan_hack(data_start_loader, data_goal_loader, epoch, 'L2')

            # Min classifier distance from start and goal real images.
            self.plan_hack(data_start_loader, data_goal_loader, epoch,
                           'classifier')
    def train(self):
        # Set up training.
        real_o =        Variable(torch.FloatTensor(self.batch_size, 3, 64, 64).cuda(),          requires_grad=False)
        real_o_next =   Variable(torch.FloatTensor(self.batch_size, 3, 64, 64).cuda(),          requires_grad=False)
        label =         Variable(torch.FloatTensor(self.batch_size).cuda(),                     requires_grad=False)
        z =             Variable(torch.FloatTensor(self.batch_size, self.rand_z_dim).cuda(),    requires_grad=False)

        criterionD = nn.BCELoss().cuda()

        optimD = optim.Adam([{'params': self.D.parameters()}], lr=self.lr_d,
                            betas=(0.5, 0.999))
        optimG = optim.Adam([{'params': self.G.parameters()},
                             {'params': self.Q.parameters()},
                             {'params': self.T.parameters()}], lr=self.lr_g,
                            betas=(0.5, 0.999))
        ############################################
        # Load rope dataset and apply transformations
        rope_path = os.path.realpath(self.data_dir)

        trans = [
            transforms.Resize(64),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

        if not self.fcn:
            # If fcn it will do the transformation to gray
            # and normalize in the loop.
            trans.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
            if self.gray:
                # Apply grayscale transformation.
                trans.append(lambda x: x.mean(dim=0)[None, :, :])

        trans_comp = transforms.Compose(trans)
        # Image 1 and image 2 are k steps apart.
        dataset = ImagePairs(root=rope_path,
                             transform=trans_comp,
                             n_frames_apart=self.k)
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=self.batch_size,
                                                 shuffle=True,
                                                 num_workers=2,
                                                 drop_last=True)
        ############################################
        # Load eval plan dataset
        planning_data_dir = self.planning_data_dir
        dataset_start = dset.ImageFolder(root=os.path.join(planning_data_dir, 'start'),
                                         transform=trans_comp)
        dataset_goal = dset.ImageFolder(root=os.path.join(planning_data_dir, 'goal'),
                                        transform=trans_comp)
        data_start_loader = torch.utils.data.DataLoader(dataset_start,
                                                        batch_size=1,
                                                        shuffle=False,
                                                        num_workers=1,
                                                        drop_last=True)
        data_goal_loader = torch.utils.data.DataLoader(dataset_goal,
                                                       batch_size=1,
                                                       shuffle=False,
                                                       num_workers=1,
                                                       drop_last=True)
        ############################################
        for epoch in range(self.n_epochs + 1):
            self.G.train()
            self.D.train()
            self.Q.train()
            self.T.train()
            for num_iters, batch_data in enumerate(dataloader, 0):
                #print('going to sleep')
                #time.sleep(2)
                #print('waking up')

                # Real data
                o, _ = batch_data[0]
                o_next, _ = batch_data[1]
                bs = o.size(0)

                real_o.data.resize_(o.size())
                real_o_next.data.resize_(o_next.size())
                label.data.resize_(bs)

                real_o.data.copy_(o)
                real_o_next.data.copy_(o_next)

                # Plot real data:
                if epoch % 10 == 0:
                    save_image(real_o.data,
                               os.path.join(self.out_dir, 'real', 'real_preFcn_samples_%d.png' % epoch),
                               nrow=self.test_num_codes,
                               normalize=True)
                    save_image(real_o_next.data,
                               os.path.join(self.out_dir, 'real', 'real_preFcn_samples_next_%d.png' % epoch),
                               nrow=self.test_num_codes,
                               normalize=True)

                if self.fcn:
                    real_o = self.apply_fcn_mse(o) # a grey scale img [-1,1]. each pixel has the probability of being a part of the object
                    real_o_next = self.apply_fcn_mse(o_next)
                    if real_o.abs().max() > 1:
                        import ipdb;
                        ipdb.set_trace()
                    assert real_o.abs().max() <= 1

                if epoch == 0:
                    break
                ############################################
                # D Loss (Update D)
                optimD.zero_grad()
                # Real data
                probs_real = self.D(real_o, real_o_next)
                label.data.fill_(1) # label of real data is 1
                loss_real = criterionD(probs_real, label)
                loss_real.backward() # weight gradCalc of descriminator only (realData -> Descriminator -> Loss)

                # Fake data
                z, c, c_next = self._noise_sample(z, bs) # z,c have normal distribution; c_next has a gaussian distribution with mean = c and some default variance value
                fake_o, fake_o_next = self.G(z, c, c_next)
                probs_fake = self.D(fake_o.detach(), fake_o_next.detach())
                label.data.fill_(0)  # label of fake data is 0
                loss_fake = criterionD(probs_fake, label)
                loss_fake.backward() # weight gradCalc of descriminator only (because of detach (randNoise -> generator -> detach -> fakeData -> Descriminator -> Loss)

                D_loss = loss_real + loss_fake

                optimD.step() # weight update of D
                ############################################
                # G loss (Update G)
                optimG.zero_grad()

                probs_fake_2 = self.D(fake_o, fake_o_next)
                label.data.fill_(1) # the generator should make the discriminator output an 1 (i.e real)
                G_loss = criterionD(probs_fake_2, label)

                # Q loss (Update G, T, Q)
                ent_loss = -self.P.log_prob(c).mean(0) # always equals log(2), only size(c) is used
                crossent_loss = -self.Q.log_prob(fake_o, c).mean(0)
                # fake_o is forward through an NN Q that outputs (mu,var). creates a probability function into which c is placed.
                # then we have the probability of each c given it's o_fake. we take a mean.
                crossent_loss_next = -self.Q.log_prob(fake_o_next, c_next).mean(0)
                # trans_prob = self.T.get_prob(Variable(torch.eye(self.dis_c_dim).cuda()))
                ent_loss_next = -self.T.log_prob(c, None, c_next).mean(0)
                mi_loss = crossent_loss - ent_loss
                mi_loss_next = crossent_loss_next - ent_loss_next
                Q_loss = mi_loss + mi_loss_next

                # T loss (Update T)
                Q_c_given_x, Q_c_given_x_var = (i.detach() for i in self.Q.forward(real_o))
                t_mu, t_variance = self.T.get_mu_and_var(c)
                t_diff = t_mu - c
                # Keep the variance small.
                # TODO: add loss on t_diff
                T_loss = (t_variance ** 2).sum(1).mean(0)

                (G_loss +
                 self.infow * Q_loss +
                 self.transw * T_loss).backward()
                optimG.step()
                #############################################
                # Logging (iteration)
                if num_iters % 100 == 0:
                    os.system('nvidia-settings -q gpucoretemp')
                    #print('going to sleep')
                    #time.sleep(20)
                    os.system('nvidia-settings -q gpucoretemp')
                    #print('waking up')

                    self.log_dict['Dloss'] = D_loss.data[0]
                    self.log_dict['Gloss'] = G_loss.data[0]
                    self.log_dict['Qloss'] = Q_loss.data[0]
                    self.log_dict['Tloss'] = T_loss.data[0]
                    self.log_dict['mi_loss'] = mi_loss.data[0]
                    self.log_dict['mi_loss_next'] = mi_loss_next.data[0]
                    self.log_dict['ent_loss'] = ent_loss.data[0]
                    self.log_dict['ent_loss_next'] = ent_loss_next.data[0]
                    self.log_dict['crossent_loss'] = crossent_loss.data[0]
                    self.log_dict['crossent_loss_next'] = crossent_loss_next.data[0]
                    self.log_dict['D(real)'] = probs_real.data.mean()
                    self.log_dict['D(fake)_before'] = probs_fake.data.mean()
                    self.log_dict['D(fake)_after'] = probs_fake_2.data.mean()

                    write_stats_from_var(self.log_dict, Q_c_given_x, 'Q_c_given_real_x_mu')
                    write_stats_from_var(self.log_dict, Q_c_given_x, 'Q_c_given_real_x_mu', idx=0)
                    write_stats_from_var(self.log_dict, Q_c_given_x_var, 'Q_c_given_real_x_variance')
                    write_stats_from_var(self.log_dict, Q_c_given_x_var, 'Q_c_given_real_x_variance', idx=0)

                    write_stats_from_var(self.log_dict, t_mu, 't_mu')
                    write_stats_from_var(self.log_dict, t_mu, 't_mu', idx=0)
                    write_stats_from_var(self.log_dict, t_diff, 't_diff')
                    write_stats_from_var(self.log_dict, t_diff, 't_diff', idx=0)
                    write_stats_from_var(self.log_dict, t_variance, 't_variance')
                    write_stats_from_var(self.log_dict, t_variance, 't_variance', idx=0)

                    print('\n#######################'
                          '\nEpoch/Iter:%d/%d; '
                          '\nDloss: %.3f; '
                          '\nGloss: %.3f; '
                          '\nQloss: %.3f, %.3f; '
                          '\nT_loss: %.3f; '
                          '\nEnt: %.3f, %.3f; '
                          '\nCross Ent: %.3f, %.3f; '
                          '\nD(x): %.3f; '
                          '\nD(G(z)): b %.3f, a %.3f;'
                          '\n0_Q_c_given_rand_x_mean: %.3f'
                          '\n0_Q_c_given_rand_x_std: %.3f'
                          '\n0_Q_c_given_fixed_x_std: %.3f'
                          '\nt_diff_abs_mean: %.3f'
                          '\nt_std_mean: %.3f'
                          % (epoch, num_iters,
                             D_loss.data[0],
                             G_loss.data[0],
                             mi_loss.data[0], mi_loss_next.data[0],
                             T_loss.data[0],
                             ent_loss.data[0], ent_loss_next.data[0],
                             crossent_loss.data[0], crossent_loss_next.data[0],
                             probs_real.data.mean(),
                             probs_fake.data.mean(), probs_fake_2.data.mean(),
                             Q_c_given_x[:, 0].data.mean(),
                             Q_c_given_x[:, 0].data.std(),
                             np.sqrt(Q_c_given_x_var[:, 0].data.mean()),
                             t_diff.data.abs().mean(),
                             t_variance.data.sqrt().mean(),
                             ))
            #############################################
            # Start evaluation from here.
            self.G.eval()
            self.D.eval()
            self.Q.eval()
            self.T.eval()
            #############################################
            # Save images
            # Plot fake data
            x_save, x_next_save = self.G(*self.eval_input, self.get_c_next(epoch))
            save_image(x_save.data,
                       os.path.join(self.out_dir, 'gen', 'curr_samples_%03d.png' % epoch),
                       nrow=self.test_num_codes,
                       normalize=True)
            save_image(x_next_save.data,
                       os.path.join(self.out_dir, 'gen', 'next_samples_%03d.png' % epoch),
                       nrow=self.test_num_codes,
                       normalize=True)
            save_image((x_save - x_next_save).data,
                       os.path.join(self.out_dir, 'gen', 'diff_samples_%03d.png' % epoch),
                       nrow=self.test_num_codes,
                       normalize=True)
            # Plot real data.
            if epoch % 10 == 0:
                save_image(real_o.data,
                           os.path.join(self.out_dir, 'real', 'real_samples_%d.png' % epoch),
                           nrow=self.test_num_codes,
                           normalize=True)
                save_image(real_o_next.data,
                           os.path.join(self.out_dir, 'real', 'real_samples_next_%d.png' % epoch),
                           nrow=self.test_num_codes,
                           normalize=True)
            #############################################
            # Save parameters
            if epoch % 5 == 0:
                if not os.path.exists('%s/var' % self.out_dir):
                    os.makedirs('%s/var' % self.out_dir)
                for i in [self.G, self.D, self.Q, self.T]:
                    torch.save(i.state_dict(),
                               os.path.join(self.out_dir,
                                            'var',
                                            '%s_%d' % (i.__class__.__name__, epoch,
                                                       )))
            #############################################
            # Logging (epoch)
            for k, v in self.log_dict.items():
                log_value(k, v, epoch)

            if epoch > 0:
                # tf logger
                # log_value('avg|x_next - x|', (x_next_save.data - x_save.data).abs().mean(dim=0).sum(), epoch + 1)
                # self.logger.histo_summary("Q_c_given_x", Q_c_given_x.data.cpu().numpy().reshape(-1), step=epoch)
                # self.logger.histo_summary("Q_c0_given_x", Q_c_given_x[:, 0].data.cpu().numpy(), step=epoch)
                # self.logger.histo_summary("Q_c_given_x_var", Q_c_given_x_var.cpu().numpy().reshape(-1), step=epoch)
                # self.logger.histo_summary("Q_c0_given_x_var", Q_c_given_x_var[:, 0].data.cpu().numpy(), step=epoch)

                # csv log
                with open(os.path.join(self.out_dir, 'progress.csv'), 'a') as csv_file:
                    writer = csv.writer(csv_file)
                    if epoch == 1:
                        writer.writerow(["epoch"] + list(self.log_dict.keys()))
                    writer.writerow(["%.3f" % _tmp for _tmp in [epoch] + list(self.log_dict.values())])
            #############################################
            # Do planning?
            if self.plan_length <= 0 or epoch not in self.planning_epoch:
                continue
            print("\n#######################"
                  "\nPlanning")
            #############################################
            # Showing plans on real images using best code.
            # Min l2 distance from start and goal real images.
            self.plan_hack(data_start_loader,
                           data_goal_loader,
                           epoch,
                           'L2')

            # Min classifier distance from start and goal real images.
            self.plan_hack(data_start_loader,
                           data_goal_loader,
                           epoch,
                           'classifier')