예제 #1
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    train_transform = utils.get_train_transform(args.height,
                                                args.width,
                                                args.train_resizing,
                                                random_horizontal_flip=True,
                                                random_color_jitter=False,
                                                random_gray_scale=False)
    val_transform = utils.get_val_transform(args.height, args.width)
    print("train_transform: ", train_transform)
    print("val_transform: ", val_transform)

    working_dir = osp.dirname(osp.abspath(__file__))
    root = osp.join(working_dir, args.root)

    # source dataset
    source_dataset = datasets.__dict__[args.source](
        root=osp.join(root, args.source.lower()))
    sampler = RandomDomainMultiInstanceSampler(
        source_dataset.train,
        batch_size=args.batch_size,
        n_domains_per_batch=2,
        num_instances=args.num_instances)
    train_loader = DataLoader(convert_to_pytorch_dataset(
        source_dataset.train,
        root=source_dataset.images_dir,
        transform=train_transform),
                              batch_size=args.batch_size,
                              num_workers=args.workers,
                              sampler=sampler,
                              pin_memory=True,
                              drop_last=True)
    train_iter = ForeverDataIterator(train_loader)
    val_loader = DataLoader(convert_to_pytorch_dataset(
        list(set(source_dataset.query) | set(source_dataset.gallery)),
        root=source_dataset.images_dir,
        transform=val_transform),
                            batch_size=args.batch_size,
                            num_workers=args.workers,
                            shuffle=False,
                            pin_memory=True)

    # target dataset
    target_dataset = datasets.__dict__[args.target](
        root=osp.join(root, args.target.lower()))
    test_loader = DataLoader(convert_to_pytorch_dataset(
        list(set(target_dataset.query) | set(target_dataset.gallery)),
        root=target_dataset.images_dir,
        transform=val_transform),
                             batch_size=args.batch_size,
                             num_workers=args.workers,
                             shuffle=False,
                             pin_memory=True)

    # create model
    num_classes = source_dataset.num_train_pids
    backbone = models.__dict__[args.arch](mix_layers=args.mix_layers,
                                          mix_p=args.mix_p,
                                          mix_alpha=args.mix_alpha,
                                          resnet_class=ReidResNet,
                                          pretrained=True)
    model = ReIdentifier(backbone, num_classes,
                         finetune=args.finetune).to(device)
    model = DataParallel(model)

    # define optimizer and learning rate scheduler
    optimizer = Adam(model.module.get_parameters(base_lr=args.lr,
                                                 rate=args.rate),
                     args.lr,
                     weight_decay=args.weight_decay)
    lr_scheduler = WarmupMultiStepLR(optimizer,
                                     args.milestones,
                                     gamma=0.1,
                                     warmup_factor=0.1,
                                     warmup_steps=args.warmup_steps)

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'),
                                map_location='cpu')
        model.load_state_dict(checkpoint)

    # analysis the model
    if args.phase == 'analysis':
        # plot t-SNE
        utils.visualize_tsne(source_loader=val_loader,
                             target_loader=test_loader,
                             model=model,
                             filename=osp.join(logger.visualize_directory,
                                               'analysis', 'TSNE.pdf'),
                             device=device)
        # visualize ranked results
        visualize_ranked_results(test_loader,
                                 model,
                                 target_dataset.query,
                                 target_dataset.gallery,
                                 device,
                                 visualize_dir=logger.visualize_directory,
                                 width=args.width,
                                 height=args.height,
                                 rerank=args.rerank)
        return

    if args.phase == 'test':
        print("Test on source domain:")
        validate(val_loader,
                 model,
                 source_dataset.query,
                 source_dataset.gallery,
                 device,
                 cmc_flag=True,
                 rerank=args.rerank)
        print("Test on target domain:")
        validate(test_loader,
                 model,
                 target_dataset.query,
                 target_dataset.gallery,
                 device,
                 cmc_flag=True,
                 rerank=args.rerank)
        return

    # define loss function
    criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)
    criterion_triplet = SoftTripletLoss(margin=args.margin).to(device)

    # start training
    best_val_mAP = 0.
    best_test_mAP = 0.
    for epoch in range(args.epochs):
        # print learning rate
        print(lr_scheduler.get_lr())

        # train for one epoch
        train(train_iter, model, criterion_ce, criterion_triplet, optimizer,
              epoch, args)

        # update learning rate
        lr_scheduler.step()

        if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):

            # evaluate on validation set
            print("Validation on source domain...")
            _, val_mAP = validate(val_loader,
                                  model,
                                  source_dataset.query,
                                  source_dataset.gallery,
                                  device,
                                  cmc_flag=True)

            # remember best mAP and save checkpoint
            torch.save(model.state_dict(),
                       logger.get_checkpoint_path('latest'))
            if val_mAP > best_val_mAP:
                shutil.copy(logger.get_checkpoint_path('latest'),
                            logger.get_checkpoint_path('best'))
            best_val_mAP = max(val_mAP, best_val_mAP)

            # evaluate on test set
            print("Test on target domain...")
            _, test_mAP = validate(test_loader,
                                   model,
                                   target_dataset.query,
                                   target_dataset.gallery,
                                   device,
                                   cmc_flag=True,
                                   rerank=args.rerank)
            best_test_mAP = max(test_mAP, best_test_mAP)

    # evaluate on test set
    model.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
    print("Test on target domain:")
    _, test_mAP = validate(test_loader,
                           model,
                           target_dataset.query,
                           target_dataset.gallery,
                           device,
                           cmc_flag=True,
                           rerank=args.rerank)
    print("test mAP on target = {}".format(test_mAP))
    print("oracle mAP on target = {}".format(best_test_mAP))
    logger.close()
예제 #2
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    train_transform = utils.get_train_transform(args.height,
                                                args.width,
                                                args.train_resizing,
                                                random_horizontal_flip=True,
                                                random_color_jitter=False,
                                                random_gray_scale=False,
                                                random_erasing=True)
    val_transform = utils.get_val_transform(args.height, args.width)
    print("train_transform: ", train_transform)
    print("val_transform: ", val_transform)

    working_dir = osp.dirname(osp.abspath(__file__))
    source_root = osp.join(working_dir, args.source_root)
    target_root = osp.join(working_dir, args.target_root)

    # source dataset
    source_dataset = datasets.__dict__[args.source](
        root=osp.join(source_root, args.source.lower()))
    val_loader = DataLoader(convert_to_pytorch_dataset(
        list(set(source_dataset.query) | set(source_dataset.gallery)),
        root=source_dataset.images_dir,
        transform=val_transform),
                            batch_size=args.batch_size,
                            num_workers=args.workers,
                            shuffle=False,
                            pin_memory=True)

    # target dataset
    target_dataset = datasets.__dict__[args.target](
        root=osp.join(target_root, args.target.lower()))
    cluster_loader = DataLoader(convert_to_pytorch_dataset(
        target_dataset.train,
        root=target_dataset.images_dir,
        transform=val_transform),
                                batch_size=args.batch_size,
                                num_workers=args.workers,
                                shuffle=False,
                                pin_memory=True)
    test_loader = DataLoader(convert_to_pytorch_dataset(
        list(set(target_dataset.query) | set(target_dataset.gallery)),
        root=target_dataset.images_dir,
        transform=val_transform),
                             batch_size=args.batch_size,
                             num_workers=args.workers,
                             shuffle=False,
                             pin_memory=True)

    # create model
    num_classes = args.num_clusters
    backbone = utils.get_model(args.arch)
    pool_layer = nn.Identity() if args.no_pool else None
    model = ReIdentifier(backbone,
                         num_classes,
                         finetune=args.finetune,
                         pool_layer=pool_layer).to(device)
    model = DataParallel(model)

    # load pretrained weights
    pretrained_model = torch.load(args.pretrained_model_path)
    utils.copy_state_dict(model, pretrained_model)

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'),
                                map_location='cpu')
        utils.copy_state_dict(model, checkpoint['model'])

    # analysis the model
    if args.phase == 'analysis':
        # plot t-SNE
        utils.visualize_tsne(source_loader=val_loader,
                             target_loader=test_loader,
                             model=model,
                             filename=osp.join(logger.visualize_directory,
                                               'analysis', 'TSNE.pdf'),
                             device=device)
        # visualize ranked results
        visualize_ranked_results(test_loader,
                                 model,
                                 target_dataset.query,
                                 target_dataset.gallery,
                                 device,
                                 visualize_dir=logger.visualize_directory,
                                 width=args.width,
                                 height=args.height,
                                 rerank=args.rerank)
        return

    if args.phase == 'test':
        print("Test on Source domain:")
        validate(val_loader,
                 model,
                 source_dataset.query,
                 source_dataset.gallery,
                 device,
                 cmc_flag=True,
                 rerank=args.rerank)
        print("Test on target domain:")
        validate(test_loader,
                 model,
                 target_dataset.query,
                 target_dataset.gallery,
                 device,
                 cmc_flag=True,
                 rerank=args.rerank)
        return

    # define loss function
    criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)
    criterion_triplet = SoftTripletLoss(margin=args.margin).to(device)

    # optionally resume from a checkpoint
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        utils.copy_state_dict(model, checkpoint['model'])
        args.start_epoch = checkpoint['epoch'] + 1

    # start training
    best_test_mAP = 0.
    for epoch in range(args.start_epoch, args.epochs):
        # run clustering algorithm and generate pseudo labels
        if args.clustering_algorithm == 'kmeans':
            train_target_iter = run_kmeans(cluster_loader, model,
                                           target_dataset, train_transform,
                                           args)
        elif args.clustering_algorithm == 'dbscan':
            train_target_iter, num_classes = run_dbscan(
                cluster_loader, model, target_dataset, train_transform, args)

        # define cross entropy loss with current number of classes
        criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)

        # define optimizer
        optimizer = Adam(model.module.get_parameters(base_lr=args.lr,
                                                     rate=args.rate),
                         args.lr,
                         weight_decay=args.weight_decay)

        # train for one epoch
        train(train_target_iter, model, optimizer, criterion_ce,
              criterion_triplet, epoch, args)

        if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):
            # remember best mAP and save checkpoint
            torch.save({
                'model': model.state_dict(),
                'epoch': epoch
            }, logger.get_checkpoint_path(epoch))
            print("Test on target domain...")
            _, test_mAP = validate(test_loader,
                                   model,
                                   target_dataset.query,
                                   target_dataset.gallery,
                                   device,
                                   cmc_flag=True,
                                   rerank=args.rerank)
            if test_mAP > best_test_mAP:
                shutil.copy(logger.get_checkpoint_path(epoch),
                            logger.get_checkpoint_path('best'))
            best_test_mAP = max(test_mAP, best_test_mAP)

    print("best mAP on target = {}".format(best_test_mAP))
    logger.close()
예제 #3
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    train_transform = utils.get_train_transform(args.height,
                                                args.width,
                                                args.train_resizing,
                                                random_horizontal_flip=True,
                                                random_color_jitter=False,
                                                random_gray_scale=False,
                                                random_erasing=True)
    val_transform = utils.get_val_transform(args.height, args.width)
    print("train_transform: ", train_transform)
    print("val_transform: ", val_transform)

    working_dir = osp.dirname(osp.abspath(__file__))
    source_root = osp.join(working_dir, args.source_root)
    target_root = osp.join(working_dir, args.target_root)

    # source dataset
    source_dataset = datasets.__dict__[args.source](
        root=osp.join(source_root, args.source.lower()))
    sampler = RandomMultipleGallerySampler(source_dataset.train,
                                           args.num_instances)
    train_source_loader = DataLoader(convert_to_pytorch_dataset(
        source_dataset.train,
        root=source_dataset.images_dir,
        transform=train_transform),
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     sampler=sampler,
                                     pin_memory=True,
                                     drop_last=True)
    train_source_iter = ForeverDataIterator(train_source_loader)
    cluster_source_loader = DataLoader(convert_to_pytorch_dataset(
        source_dataset.train,
        root=source_dataset.images_dir,
        transform=val_transform),
                                       batch_size=args.batch_size,
                                       num_workers=args.workers,
                                       shuffle=False,
                                       pin_memory=True)
    val_loader = DataLoader(convert_to_pytorch_dataset(
        list(set(source_dataset.query) | set(source_dataset.gallery)),
        root=source_dataset.images_dir,
        transform=val_transform),
                            batch_size=args.batch_size,
                            num_workers=args.workers,
                            shuffle=False,
                            pin_memory=True)

    # target dataset
    target_dataset = datasets.__dict__[args.target](
        root=osp.join(target_root, args.target.lower()))
    cluster_target_loader = DataLoader(convert_to_pytorch_dataset(
        target_dataset.train,
        root=target_dataset.images_dir,
        transform=val_transform),
                                       batch_size=args.batch_size,
                                       num_workers=args.workers,
                                       shuffle=False,
                                       pin_memory=True)
    test_loader = DataLoader(convert_to_pytorch_dataset(
        list(set(target_dataset.query) | set(target_dataset.gallery)),
        root=target_dataset.images_dir,
        transform=val_transform),
                             batch_size=args.batch_size,
                             num_workers=args.workers,
                             shuffle=False,
                             pin_memory=True)

    n_s_classes = source_dataset.num_train_pids
    args.n_classes = n_s_classes + len(target_dataset.train)
    args.n_s_classes = n_s_classes
    args.n_t_classes = len(target_dataset.train)

    # create model
    backbone = models.__dict__[args.arch](pretrained=True)
    pool_layer = nn.Identity() if args.no_pool else None
    model = ReIdentifier(backbone,
                         args.n_classes,
                         finetune=args.finetune,
                         pool_layer=pool_layer)
    features_dim = model.features_dim

    idm_bn_names = filter_layers(args.stage)
    convert_dsbn_idm(model, idm_bn_names, idm=False)

    model = model.to(device)
    model = DataParallel(model)

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'),
                                map_location='cpu')
        utils.copy_state_dict(model, checkpoint['model'])

    # analysis the model
    if args.phase == 'analysis':
        # plot t-SNE
        utils.visualize_tsne(source_loader=val_loader,
                             target_loader=test_loader,
                             model=model,
                             filename=osp.join(logger.visualize_directory,
                                               'analysis', 'TSNE.pdf'),
                             device=device)
        # visualize ranked results
        visualize_ranked_results(test_loader,
                                 model,
                                 target_dataset.query,
                                 target_dataset.gallery,
                                 device,
                                 visualize_dir=logger.visualize_directory,
                                 width=args.width,
                                 height=args.height,
                                 rerank=args.rerank)
        return

    if args.phase == 'test':
        print("Test on target domain:")
        validate(test_loader,
                 model,
                 target_dataset.query,
                 target_dataset.gallery,
                 device,
                 cmc_flag=True,
                 rerank=args.rerank)
        return

    # create XBM
    dataset_size = len(source_dataset.train) + len(target_dataset.train)
    memory_size = int(args.ratio * dataset_size)
    xbm = XBM(memory_size, features_dim)

    # initialize source-domain class centroids
    source_feature_dict = extract_reid_feature(cluster_source_loader,
                                               model,
                                               device,
                                               normalize=True)
    source_features_per_id = {}
    for f, pid, _ in source_dataset.train:
        if pid not in source_features_per_id:
            source_features_per_id[pid] = []
        source_features_per_id[pid].append(source_feature_dict[f].unsqueeze(0))
    source_centers = [
        torch.cat(source_features_per_id[pid], 0).mean(0)
        for pid in sorted(source_features_per_id.keys())
    ]
    source_centers = torch.stack(source_centers, 0)
    source_centers = F.normalize(source_centers, dim=1)
    model.module.head.weight.data[0:n_s_classes].copy_(
        source_centers.to(device))

    # save memory
    del source_centers, cluster_source_loader, source_features_per_id

    # define optimizer and lr scheduler
    optimizer = Adam(model.module.get_parameters(base_lr=args.lr,
                                                 rate=args.rate),
                     args.lr,
                     weight_decay=args.weight_decay)
    lr_scheduler = StepLR(optimizer, step_size=args.step_size, gamma=0.1)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        utils.copy_state_dict(model, checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    # start training
    best_test_mAP = 0.
    for epoch in range(args.start_epoch, args.epochs):
        # run clustering algorithm and generate pseudo labels
        train_target_iter = run_dbscan(cluster_target_loader, model,
                                       target_dataset, train_transform, args)

        # train for one epoch
        print(lr_scheduler.get_lr())
        train(train_source_iter, train_target_iter, model, optimizer, xbm,
              epoch, args)

        if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):
            # remember best mAP and save checkpoint
            torch.save(
                {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch
                }, logger.get_checkpoint_path(epoch))
            print("Test on target domain...")
            _, test_mAP = validate(test_loader,
                                   model,
                                   target_dataset.query,
                                   target_dataset.gallery,
                                   device,
                                   cmc_flag=True,
                                   rerank=args.rerank)
            if test_mAP > best_test_mAP:
                shutil.copy(logger.get_checkpoint_path(epoch),
                            logger.get_checkpoint_path('best'))
            best_test_mAP = max(test_mAP, best_test_mAP)

        # update lr
        lr_scheduler.step()

    print("best mAP on target = {}".format(best_test_mAP))
    logger.close()
예제 #4
0
def run_exp(argsdict):
    # Example of usage of the code provided and recommended hyper parameters for training GANs.
    data_root = './'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    n_iter = 50  # N training iterations
    n_critic_updates = 1  # N critic updates per generator update
    train_batch_size = argsdict['batch_size']
    lr = 1e-4
    beta1 = 0.5
    beta2 = 0.9
    z_dim = 25
    hidden_dim = (400, 400)

    if argsdict['dataset'] in ['svhn']:
        image_shape = (3, 32, 32)
        encoding = 'tanh'
    elif argsdict['dataset'] in ['CIFAR']:
        image_shape = (3, 32, 32)
        encoding = 'sigmoid'
    elif argsdict['dataset'] in ['MNIST']:
        image_shape = (1, 28, 28)
        encoding = 'sigmoid'
    elif argsdict['dataset'] in ['Gaussian']:
        image_shape = (1, 28, 28)
        encoding = 'tanh'
        # Finding random mu
        mus = []
        for gaus in range(argsdict['number_gaussians']):
            mus.append([random.randint(0, 27) for _ in range(argsdict['Gauss_size'])])
        mus = torch.tensor(mus)
        argsdict['mus'] = mus

    # Use the GPU if you have one
    if torch.cuda.is_available():
        print("Using the GPU")
        device = torch.device("cuda")
    else:
        print("WARNING: You are about to run on cpu, and this will likely run out \
          of memory. \n You can try setting batch_size=1 to reduce memory usage")
        device = torch.device("cpu")

    train_loader, valid_loader, test_loader, num_samples = get_data(argsdict)
    print(device)
    print(num_samples)
    generator = Generator(image_shape, hidden_dim[0], hidden_dim[1], z_dim, encoding).to(device)
    # generator = Generatorsvhn(z_dim, hidden_dim).to(device)
    critic = Critic(image_shape, 400, 400).to(device)
    # critic = Criticsvhn(argsdict['hidden_discri_size']).to(device)

    # TODO Adding beta seems to make total variation go to 0, why.
    # TODO In rapport talk about how finicky the whole system is
    optim_critic = optim.Adam(critic.parameters(), lr=lr)  # , betas=(beta1, beta2))
    optim_generator = optim.Adam(generator.parameters(), lr=lr)  # , betas=(beta1, beta2))

    losses = Divergence(argsdict['divergence'])
    if argsdict['use_cuda']:
        Fix_Noise = Variable(torch.normal(torch.zeros(25, z_dim), torch.ones(25, z_dim))).cuda()
    else:
        Fix_Noise = Variable(torch.normal(torch.zeros(25, z_dim), torch.ones(25, z_dim)))

    losses_Generator = []
    losses_Discriminator = []
    real_statistics = []
    fake_statistics = []

    # COMPLETE TRAINING PROCEDURE
    for epoch in range(n_iter):
        G_losses, D_losses = [], []
        real_stat, fake_stat = [], []
        if argsdict['visualize']:
            real_imgs = torch.zeros([num_samples, image_shape[1], image_shape[2]])
        for i_batch, sample_batch in enumerate(train_loader):
            optim_critic.zero_grad()
            if argsdict['use_cuda']:
                real_img, label_batch = sample_batch[0].cuda(), sample_batch[1]
            else:
                real_img, label_batch = sample_batch[0], sample_batch[1]
            if argsdict['visualize']:
                real_imgs[i_batch * train_batch_size:i_batch * train_batch_size + train_batch_size] = real_img.squeeze(
                    1)
            # fake img
            if argsdict['use_cuda']:
                noise = Variable(
                    torch.normal(torch.zeros(train_batch_size, z_dim), torch.ones(train_batch_size, z_dim))).cuda()
            else:
                noise = Variable(
                    torch.normal(torch.zeros(train_batch_size, z_dim), torch.ones(train_batch_size, z_dim)))
            fake_img = generator(noise)
            # Attempting loss
            DX_score = critic(real_img)
            DG_score = critic(fake_img)
            loss_D = losses.D_loss(DX_score, DG_score)
            fake, real = losses.RealFake(DG_score, DX_score)
            real_stat.append(real)
            fake_stat.append(fake)
            loss_D.backward()
            # D_grad=critic.x[0].weight.grad.detach()
            optim_critic.step()

            # Clip weights of discriminator
            for p in critic.parameters():
                p.data.clamp_(-0.1, 0.1)

            # train the generator ever n_critic iterations
            D_losses.append(loss_D.item())
            if i_batch % n_critic_updates == 0:
                optim_generator.zero_grad()

                gen_img = generator(noise)
                if argsdict['modified_loss']:
                    DG_score = critic(gen_img)
                    # We maximize instead of minimizing
                    loss_G = losses.G_loss_modified_sec_32(DG_score)
                else:
                    DG_score = critic(gen_img)
                    loss_G = losses.G_loss(DG_score)
                loss_G.backward()
                optim_generator.step()

            G_losses.append(loss_G.item())
        # print(G_losses)
        # print(D_losses)
        # print(D_grad)
        print("Epoch[%d/%d], G Loss: %.4f, D Loss: %.4f"
              % (epoch, n_iter, np.mean(G_losses), np.mean(D_losses)))
        print(
            f"Classified on average {round(np.mean(real_stat), 2)} real examples correctly and {round(np.mean(fake_stat), 2)} fake examples correctly")
        losses_Generator.append(np.mean(G_losses))
        losses_Discriminator.append(np.mean(D_losses))
        real_statistics.append(np.mean(real_stat))
        fake_statistics.append(np.mean(fake_stat))
        if argsdict['dataset'] == 'Gaussian':
            # A bit hacky but reset iterators
            train_loader, valid_loader, test_loader = get_data(argsdict)
        if argsdict['visualize']:
            if argsdict['use_cuda']:
                noise = Variable(torch.normal(torch.zeros(500, z_dim), torch.ones(500, z_dim))).cuda()
            else:
                noise = Variable(torch.normal(torch.zeros(500, z_dim), torch.ones(500, z_dim)))
            fake_imgs = generator(noise)
            visualize_tsne(fake_imgs, real_imgs[:500], argsdict, epoch)
        with torch.no_grad():
            img = generator(Fix_Noise)
        # Saving Images
        if argsdict['modified_loss']:
            save_image(img.view(-1, image_shape[0], image_shape[1], image_shape[2]),
                       f"{argsdict['dataset']}_IMGS/{argsdict['divergence']}/GRID_trick32%d.png" % epoch, nrow=5,
                       normalize=True)
        else:
            save_image(img.view(-1, image_shape[0], image_shape[1], image_shape[2]),
                       f"{argsdict['dataset']}_IMGS/{argsdict['divergence']}/GRID%d.png" % epoch, nrow=5,
                       normalize=True)
        with open(f"{argsdict['dataset']}_IMGS/{argsdict['divergence']}/Losses.txt", "w") as f:
            json.dump({'Gen_Loss': losses_Generator, 'Discri_Loss': losses_Discriminator, 'real_stat': real_statistics,
                       'fake_stat': fake_statistics}, f)

        # Update the losses plot every 5 epochs
        if epoch % 5 == 0 and epoch != 0:
            plot_losses(argsdict, epoch + 1, show_plot=0)

    plot_losses(argsdict, n_iter)