예제 #1
0
def test(epoch, netG, test_data):
    f = open(opt.FID, 'a')
    if not os.path.exists('%s/%d' % (opt.output, epoch)):
        os.mkdir('%s/%d' % (opt.output, epoch))

    print('test_data len = %d' % len(test_data))
    for i, image in enumerate(test_data):
        imgA = image['A']
        A_path = image['A_path']
        real_A = imgA.cuda()
        fake_B = netG(real_A)
        name = A_path[0].split('/')
        vutils.save_image(fake_B[:, :, 3:253, 28:228],
                          '%s/%d/%s' % (opt.output, epoch, name[-1]),
                          normalize=True,
                          scale_each=True)

    fid_value = calculate_fid_given_paths(
        '%s/%d' % (opt.output, epoch),
        '/home/lixiang/dataset/photosketch/sketch/test', 8, opt.gpuid != '',
        2048)
    print(str(epoch) + " saved")
    str1 = str(epoch) + ':' + str(fid_value) + '\n'
    f.write(str1)
    f.close()
    print('calculate FID:%.4f' % fid_value)
예제 #2
0
def evaluate_model_samples(args, simulator, x_gen):
    """ Evaluate model samples and save results """

    logger.info("Calculating likelihood of generated samples")
    try:
        if simulator.parameter_dim() is None:
            log_likelihood_gen = simulator.log_density(x_gen)
        else:
            params = simulator.default_parameters(true_param_id=args.trueparam)
            params = np.asarray([params for _ in range(args.generate)])
            log_likelihood_gen = simulator.log_density(x_gen,
                                                       parameters=params)
        log_likelihood_gen[np.isnan(log_likelihood_gen)] = -1.0e-12
        np.save(create_filename("results", "samples_likelihood", args),
                log_likelihood_gen)
    except IntractableLikelihoodError:
        logger.info("True simulator likelihood is intractable for dataset %s",
                    args.dataset)

    logger.info("Calculating distance from manifold of generated samples")
    try:
        distances_gen = simulator.distance_from_manifold(x_gen)
        np.save(create_filename("results", "samples_manifold_distance", args),
                distances_gen)
    except NotImplementedError:
        logger.info("Cannot calculate distance from manifold for dataset %s",
                    args.dataset)

    if simulator.is_image():
        if calculate_fid_given_paths is None:
            logger.warning(
                "Cannot compute FID score, did not find FID implementation")
            return

        logger.info("Calculating FID score of generated samples")
        # The FID script needs an image folder
        with tempfile.TemporaryDirectory() as gen_dir:
            logger.debug(
                f"Storing generated images in temporary folder {gen_dir}")
            array_to_image_folder(x_gen, gen_dir)

            true_dir = create_filename("dataset", None, args) + "/test"
            os.makedirs(os.path.dirname(true_dir), exist_ok=True)
            if not os.path.exists(f"{true_dir}/0.jpg"):
                array_to_image_folder(
                    simulator.load_dataset(train=False,
                                           numpy=True,
                                           dataset_dir=create_filename(
                                               "dataset", None, args),
                                           true_param_id=args.trueparam)[0],
                    true_dir)

            logger.debug("Beginning FID calculation with batchsize 50")
            fid = calculate_fid_given_paths([gen_dir, true_dir], 50, "", 2048)
            logger.info(f"FID = {fid}")

            np.save(create_filename("results", "samples_fid", args), [fid])
예제 #3
0
    def mytest(self, step):
        num = 200
        z = tensor2var(torch.randn(num, self.z_dim))  # 32*128

        fake_images, gf1, gf2 = self.G(z)  # 1*3*64*64
        fake_images = fake_images.data
        # inception_score(fake_images)
        # fake_images = fake_images.resize_((1, 3, 218, 178))
        for n in range(num):
            save_image(fake_images[n],
                       '/home/xinzi/dataset_k40/test_celeba/%d.jpg' % n)
            str0 = '/home/xinzi/dataset_k40/test_celeba/' + str(n) + '.jpg'
            im = Image.open(str0)
            im = im.resize((178, 218))
            im.save(str0)
        '''
        path =  '/home/jingjie/xinzi/dataset/test'
        
        a = os.listdir(path)
        #a.sort()
        for file in a:
            
            file_path = os.path.join(path, file)
            if os.path.splitext(file_path)[1] == '.png':
                im = Image.open(file_path)
        '''

        #args = parser.parse_args()
        #os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

        # fid_value = calculate_fid_given_paths(self.path1, self.path2, 1, self.gpu != '', self.dims)
        # print('FID: ', fid_value)
        # a = 0
        str1 = '/home/xinzi/dataset_k40/celebAtemp'
        times = 5
        sum = 0
        for i in range(times):
            shutil.rmtree(str1)
            os.mkdir(str1)

            random_copyfile(self.path1, str1, 40000)

            # args.path0 = str1
            """
                dims = 64
            """
            fid_value = calculate_fid_given_paths(str1, self.path2, 100,
                                                  self.gpu != '', self.dims)
            sum = sum + fid_value
            print('FID: ', fid_value)
        print(float(sum / times))
        f = open('FID.txt', 'a')
        f.write('\n')
        f.write(str(float(sum / times)))
        f.close()
예제 #4
0
def validate(test_dataloader, netG_A2B, netG_B2A, dataset):
    now = datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    temp_buffer = "fid_buffer_{}".format(timestamp)
    if not os.path.exists(temp_buffer):
        os.mkdir(temp_buffer)
        os.mkdir(os.path.join(temp_buffer, 'A'))
        os.mkdir(os.path.join(temp_buffer, 'B'))
        
    for i, batch in enumerate(test_dataloader):
        # Set model input
        real_A = batch['A'].cuda()
        real_B = batch['B'].cuda()

        # Generate output
        fake_B = 0.5*(netG_A2B(real_A).data + 1.0)
        fake_A = 0.5*(netG_B2A(real_B).data + 1.0)

        # Save image files
        
        save_image( fake_A, os.path.join(temp_buffer, 'A', '%04d.png' % (i+1)))
        save_image( fake_A, os.path.join(temp_buffer, 'B', '%04d.png' % (i+1)))
        
    fid_value_A = calculate_fid_given_paths(['datasets/{}/train/A'.format(dataset), os.path.join(temp_buffer, 'A')] ,
                                          1,
                                          True,
                                          2048)
        
    fid_value_B = calculate_fid_given_paths(['datasets/{}/train/B'.format(dataset), os.path.join(temp_buffer, 'B')] ,
                                          1,
                                          True,
                                          2048)
        
    os.system('rm -rf {}'.format(temp_buffer))
    print("FID A: {}, FID B: {}".format(fid_value_A, fid_value_B))
    return fid_value_A, fid_value_B
예제 #5
0
def run(config):
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    if config['resume']:
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'
    utils.seed_rng(config['seed'])
    utils.prepare_root(config)
    torch.backends.cudnn.benchmark = True
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config))
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    G3 = model.Generator(**config).to(device)
    D3 = model.Discriminator(**config).to(device)
    if config['ema']:
        G_ema = model.Generator(**{**config, 'skip_init': True, 'no_optim': True}).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None
    if config['G_fp16']:
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        D = D.half()
    GD = model.G_D(G, D, config['conditional'])
    GD3 = model.G_D(G3, D3, config['conditional'])
    state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config}
    if config['resume']:
        utils.load_weights(G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None)
    #utils.load_weights(G, D, state_dict, '../Task3_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None)
    #utils.load_weights(G, D, state_dict, '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None)
    #utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'last0', G_ema if config['ema'] else None)
    #utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None)
    utils.load_weights(G3, D3, state_dict, '../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'last0', G_ema if config['ema'] else None)
    utils.load_weights(G, D, state_dict, '../Task3_CIFAR_MNIST_KLWGAN_Simulation_Experiment/weights', 'C10Ukl5', 'best0', G_ema if config['ema'] else None)
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    test_log = utils.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume']))
    train_log = utils.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle'])
    utils.write_metadata(config['logs_root'], experiment_name, config, state_dict)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations'])
    # Use: config['abnormal_class']
    #print(config['abnormal_class'])
    abnormal_class = config['abnormal_class']
    select_dataset = config['select_dataset']
    #print(config['select_dataset'])
    #print(select_dataset)
    loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'], 'abnormal_class': abnormal_class, 'select_dataset': select_dataset})
    # Usage: --select_dataset cifar10 --abnormal_class 0 --shuffle --batch_size 64 --parallel --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 --dataset C10 --data_root ../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/data/ --G_ortho 0.0 --G_attn 0 --D_attn 0 --G_init N02 --D_init N02 --ema --use_ema --ema_start 1000 --start_eval 50 --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --loss_type kl_5 --seed 2 --which_best FID --model BigGAN --experiment_name C10Ukl5
    # Use: --select_dataset mnist --abnormal_class 1 --shuffle --batch_size 64 --parallel --num_G_accumulations 1 --num_D_accumulations 1 --num_epochs 500 --num_D_steps 4 --G_lr 2e-4 --D_lr 2e-4 --dataset C10 --data_root ../Task2_CIFAR_MNIST_KLWGAN_Simulation_Experiment/data/ --G_ortho 0.0 --G_attn 0 --D_attn 0 --G_init N02 --D_init N02 --ema --use_ema --ema_start 1000 --start_eval 50 --test_every 5000 --save_every 2000 --num_best_copies 5 --num_save_copies 2 --loss_type kl_5 --seed 2 --which_best FID --model BigGAN --experiment_name C10Ukl5
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16'])
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    if not config['conditional']:
        fixed_y.zero_()
        y_.zero_()
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G3, D3, GD3, G3, D3, GD3, G, D, GD, z_, y_, ema, state_dict, config)
    else:
        train = train_fns.dummy_training_function()
    sample = functools.partial(utils.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config)
    if config['dataset'] == 'C10U' or config['dataset'] == 'C10':
        data_moments = 'fid_stats_cifar10_train.npz'
        #'../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz'
        #data_moments = '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz'
    else:
        print("Cannot find the data set.")
        sys.exit()
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        if config['pbar'] == 'mine':
            pbar = utils.progress(loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            state_dict['itr'] += 1
            G.eval()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            print('')
            # Random seed
            #print(config['seed'])
            if epoch==0 and i==0:
                print(config['seed'])
            metrics = train(x, y)
            # We double the learning rate if we double the batch size.
            train_log.log(itr=int(state_dict['itr']), **metrics)
            if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']), **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})
            if config['pbar'] == 'mine':
                print(', '.join(['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ')
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name)
            experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config))
            if (not (state_dict['itr'] % config['test_every'])) and (epoch >= config['start_eval']):
                if config['G_eval_mode']:
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                utils.sample_inception(
                    G_ema if config['ema'] and config['use_ema'] else G, config, str(epoch))
                folder_number = str(epoch)
                sample_moments = '%s/%s/%s/samples.npz' % (config['samples_root'], experiment_name, folder_number)
                FID = fid_score.calculate_fid_given_paths([data_moments, sample_moments], batch_size=50, cuda=True, dims=2048)
                train_fns.update_FID(G, D, G_ema, state_dict, config, FID, experiment_name, test_log)
        state_dict['epoch'] += 1
    #utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'be01Bes01Best%d' % state_dict['save_best_num'], G_ema if config['ema'] else None)
    utils.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'last%d' % 0, G_ema if config['ema'] else None)
예제 #6
0
파일: train.py 프로젝트: dvlab-research/SMR
                with torch.no_grad():
                    Ae = netE(Xa)
                    Xer, Ae = diffRender.render(**Ae)

                    Ai = deep_copy(Ae)
                    Ai['azimuths'] = - torch.empty((Xa.shape[0]), dtype=torch.float32).uniform_(-opt.azi_scope/2, opt.azi_scope/2).cuda()
                    Xir, Ai = diffRender.render(**Ai)

                    for i in range(len(paths)):
                        path = paths[i]
                        image_name = os.path.basename(path)
                        rec_path = os.path.join(rec_dir, image_name)
                        output_Xer = to_pil_image(Xer[i, :3].detach().cpu())
                        output_Xer.save(rec_path, 'JPEG', quality=100)

                        inter_path = os.path.join(inter_dir, image_name)
                        output_Xir = to_pil_image(Xir[i, :3].detach().cpu())
                        output_Xir.save(inter_path, 'JPEG', quality=100)

                        ori_path = os.path.join(ori_dir, image_name)
                        output_Xa = to_pil_image(Xa[i, :3].detach().cpu())
                        output_Xa.save(ori_path, 'JPEG', quality=100)
            fid_recon = calculate_fid_given_paths([ori_dir, rec_dir], 32, True)
            print('Test recon fid: %0.2f' % fid_recon)
            summary_writer.add_scalar('Test/fid_recon', fid_recon, epoch)

            fid_inter = calculate_fid_given_paths([ori_dir, inter_dir], 32, True)
            print('Test rotation fid: %0.2f' % fid_inter)
            summary_writer.add_scalar('Test/fid_inter', fid_inter, epoch)
            netE.train()
예제 #7
0
plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
plt.show()

if not os.path.exists('./checkpoint'):
    os.mkdir('./checkpoint')

if not os.path.exists('./dataset'):
    os.mkdir('./dataset')

if not os.path.exists('./img'):
    os.mkdir('./img')

if not os.path.exists('./img/real'):
    os.mkdir('./img/real')

if not os.path.exists('./img/fake'):
    os.mkdir('./img/fake')

fake_dataset = netG(fixed_noise)
fake_image_path_list = save_image_list(fake_dataset, False)

real_dataset = real_batch[0].to(device)[:1000]
real_image_path_list = save_image_list(real_dataset, True)

fid_value = calculate_fid_given_paths(
    [real_image_path_list, fake_image_path_list], 50, True, 2048)
print(f'FID scoreeeeeeeeeeeee: {fid_value}')
예제 #8
0
m1, s1 = fid_score.getRealm1s1(paths, batch_size, cuda, dims, model)

print('finish prepare FID')

# %% calc FID per dir

# aug_type_vec   = ['translationX']
# lambda_vec     = np.arange(0.1, 1.1, 0.1).round(2)
matplotlib.use('Agg')
for aug_type in aug_type_vec:
    fid_vec = []
    for lam in lambda_vec:
        print(aug_type, lam)
        folder_name = aug_type + '_' + str(lam)  # .replace('.','p')
        paths = ['output/real_samples', 'weights/'+folder_name+'/fake_samples/']
        fid   = fid_score.calculate_fid_given_paths(paths, batch_size, cuda, dims, model, m1, s1)
        file = open("output/"+folder_name+"/result_FID.txt", "w")
        file.write(str(fid))
        file.close()
        fid_vec.append(fid)

    p = plt.plot(lambda_vec, np.array(fid_vec))
    plt.title(aug_type)
    plt.xlabel("lambda_aug")
    plt.ylabel("fid score")
    plt.savefig("output/"+aug_type+"_result_FID.png")
print("FINISH calc FID per dir")

# %% read FID per dir

# matplotlib.use('Agg')
예제 #9
0
                torch.save(netG.state_dict(),
                           f"./checkpoint/netG_{version}_epoch_{epoch+1}.pth")
                torch.save(netD.state_dict(),
                           f"./checkpoint/netD_{version}_epoch_{epoch+1}.pth")
                print(f"Save checkpoint at epoch {epoch+1}")
                # f"./checkpoint/netG_{version}_epoch_{epoch+1}.pth"

    # TEST
    print("Start Test")
    TB = 1000
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=TB,
                                             shuffle=True,
                                             num_workers=workers)

    for i, (data, _) in enumerate(dataloader):
        real_dataset = data
        break

    noise = torch.randn(TB, nz, 1, 1).cuda()
    fake_dataset = netG(noise)

    print("Generate real/fake images")
    real_image_path_list = save_image_list(real_dataset, True)
    fake_image_path_list = save_image_list(fake_dataset, False)

    print("Evaluate FID score")
    fid_value = calculate_fid_given_paths(
        [real_image_path_list, fake_image_path_list], 1000, False, 2048)

    print(f'FID score: {fid_value}')
 def fid_score(self):
     return calculate_fid_given_paths([self._data_path, self._output_path],
                                      32, self._cuda, 2048)
예제 #11
0
    def get_fid(self):
        """Calculate the FID score"""
        # Load the trained generator.
        self.restore_model(self.test_iters)
        data_loader = self.labelled_loader

        if self.dataset == 'CelebA':
            attr_list = self.selected_attrs
        elif self.dataset == 'RaFD':
            attr_list = self.selected_emots

        real_dir = os.path.join(self.result_dir, 'real')
        if not os.path.isdir(real_dir):
            os.mkdir(real_dir)
        fake_dir = os.path.join(self.result_dir, 'fake_all')
        if not os.path.isdir(fake_dir):
            os.mkdir(fake_dir)

            with torch.no_grad():
                r_count = 0
                f_count = 0
                for i, (x_real, c_org) in enumerate(data_loader):

                    # Prepare input images and target domain labels.
                    x_real = x_real.to(self.device)
                    c_trg_list = self.create_labels(c_org, self.c_dim,
                                                    self.dataset,
                                                    self.selected_attrs)
                    if self.dataset == 'RaFD':
                        c_org = self.label2onehot(c_org, self.c_dim)

                    # Save real images.
                    for r in range(x_real.size(0)):
                        r_count += 1
                        c_o = ''.join(
                            c_org[r].numpy().astype('int8').astype('str'))
                        real_path = os.path.join(
                            real_dir, 'r_{}_{}.jpg'.format(r_count, c_o))
                        save_image(self.denorm(x_real[r].data.cpu()),
                                   real_path)

                    # Translate real images into fake ones and save them.
                    f_count_old = f_count
                    for c_trg, attr_name in zip(c_trg_list, attr_list):
                        x_fake = self.G(x_real, c_trg)
                        if f_count > f_count_old:
                            f_count = f_count_old
                        for f in range(x_fake.size(0)):
                            f_count += 1
                            c_t = ''.join(c_trg[f].cpu().numpy().astype(
                                'int8').astype('str'))
                            fake_path = os.path.join(
                                fake_dir,
                                attr_name + '_{}_{}.jpg'.format(f_count, c_t))
                            save_image(self.denorm(x_fake[f].data.cpu()),
                                       fake_path)

        fids = []
        fid_path = os.path.join(self.result_dir, 'fid_scores.txt')
        fake_paths_list = []
        for attr_name in attr_list:
            fake_paths_list.append(
                list(pathlib.Path(fake_dir).glob('{}*'.format(attr_name))))
        fake_paths_list.append(fake_dir)
        fids = calculate_fid_given_paths([real_dir, fake_paths_list],
                                         128,
                                         True,
                                         2048,
                                         device=self.device_num)
        with open(fid_path, 'w') as f:
            for attr_name, fid in zip(attr_list, fids[:-1]):
                f.write('{}: {:.4f}\n'.format(attr_name, fid))
            f.write('Average FID: {:.4f}\n'.format(np.mean(fids[:-1])))
            f.write('Overall FID: {:.4f}\n'.format(fids[-1]))
        print('Overall FID: {:.4f}\n'.format(fids[-1]))
    
    p_small_G = "MNIST_DCGAN_results/RGB/small_G_" + str(small_size)
    if not os.path.isdir(p_small_G):
        os.mkdir(p_small_G)
    p_mnist = "MNIST_DCGAN_results/RGB/mnist"
    if os.path.exists(p_mnist+".npz"):
        p_mnist = p_mnist+".npz"
    #path = [p_mnist, p_small_G]
    for i in range(0,1000):
        src=skimage.transform.resize(images_small_numpy[i][0], (64, 64))
        imageio.imwrite("temp.jpg",skimage.img_as_ubyte(src))
        src = cv2.imread("temp.jpg", 0)
        src_RGB = cv2.cvtColor(src, cv2.COLOR_GRAY2RGB)
        cv2.imwrite(p_small_G + '/'+str(i)+".jpg", src_RGB)
    
    fid = fid_score.calculate_fid_given_paths([p_mnist, p_small_G], 50, True,  2048)
    if fid < best_FID:
        best_FID = fid
        print("best_FID:" + str(fid))
        state = {'small_G': small_G.state_dict(), 'G_optimizer': G_optimizer.state_dict(),'train_hist' : train_hist, 'epoch': epoch+1, 'total_ptime' : total_ptime, 'best_FID' : best_FID}
        torch.save(state, "MNIST_DCGAN_results/FID/state_small_"+ str(small_size)+ ".pkl")
    if fid < best_FID:
        best_FID = fid
        print("best_FID:" + str(fid))
        state = {'small_G': small_G.state_dict(), 'G_optimizer': G_optimizer.state_dict(),'train_hist' : train_hist, 'epoch': epoch+1, 'total_ptime' : total_ptime, 'best_FID' : best_FID}
        torch.save(state, "MNIST_DCGAN_results/FID/best_state_small_"+ str(small_size)+ ".pkl")
    state = {'small_G': small_G.state_dict(), 'G_optimizer': G_optimizer.state_dict(),'train_hist' : train_hist, 'epoch': epoch+1, 'total_ptime' : total_ptime, 'best_FID' : best_FID}
    torch.save(state, "MNIST_DCGAN_results/FID/state_small_"+ str(small_size)+ ".pkl")
    
# end_time = time.time()
# total_ptime = end_time - start_time
예제 #13
0
import torch

from fid_score import calculate_fid_given_paths  # The code is downloaded from github
fid_value = calculate_fid_given_paths(['./img/real', './img/fake'], 100, True,
                                      2048)

print(f'FID score: {fid_value}')
예제 #14
0
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    for i, image in enumerate(test_loader):
        imgA = image[0]
        # imgB = image['A']

        real_A = Variable(imgA.cuda())
        # real_B = Variable(imgB.cuda())

        fake_B = net_G(real_A)

        # fakeB[i, :, :, :] = fake_B.data
        # realB[i, :, :, :] = real_B.data

        print("%d.jpg generate completed" % (i+88))


        # vutils.save_image(fake_B[:, :, 3:253, 28:228],
        #                   '%s/fakeB/%d.png' % (opt.output, i),
        #                   normalize=True,
        #                   scale_each=True)
        vutils.save_image(fake_B[:, :, 3:253, 28:228],'%s/%d.png' % (save_dir, i),normalize=True,scale_each=True)

    fid_value = calculate_fid_given_paths(save_dir, opt.dataroot + '/CUHK/Testing Images/sketches', 8, opt.gpuid != '',2048)
    print(fid_value)
    # vutils.save_image(realB,
    #                   '%s/realB_8.png' % (opt.outf),
    #                   normalize=True,
    #                   scale_each=True)
def compress(big_size, small_size):
    # training parameters
    batch_size = 128
    lr = 0.0002
    train_epoch = 20
    
    # data_loader
    img_size = 64
    transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
    
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=True, download=True, transform=transform),
        batch_size=batch_size, shuffle=True)

    G = generator(big_size)
    small_G = generator(small_size)
    # D = discriminator(128)
    G.weight_init(mean=0.0, std=0.02)
    small_G.weight_init(mean=0.0, std=0.02)
    # D.weight_init(mean=0.0, std=0.02)
    G.cuda()
    small_G.cuda()
    # D.cuda()
    
    # Binary Cross Entropy loss
    # BCE_loss = nn.BCELoss()
    L1_loss = nn.L1Loss()
    
    # Adam optimizer
    G_optimizer = optim.Adam(small_G.parameters(), lr=lr, betas=(0.5, 0.999))
    # D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    
# results save folder
    if not os.path.isdir('MNIST_DCGAN_results'):
        os.mkdir('MNIST_DCGAN_results')
# if not os.path.isdir('MNIST_DCGAN_results/Random_results'):
#     os.mkdir('MNIST_DCGAN_results/Random_results')
# if not os.path.isdir('MNIST_DCGAN_results/Fixed_results'):
#     os.mkdir('MNIST_DCGAN_results/Fixed_results')
    if not os.path.isdir('MNIST_DCGAN_results/Compress/'):
        os.mkdir('MNIST_DCGAN_results/Compress/')
    if not os.path.isdir('MNIST_DCGAN_results/Compress/'+ str(big_size) + '_' + str(small_size)):
        os.mkdir('MNIST_DCGAN_results/Compress/'+ str(big_size) + '_' + str(small_size))
    if not os.path.isdir('MNIST_DCGAN_results/Compress/'+str(big_size) + '_' + str(small_size)+ '/Training'):
        os.mkdir('MNIST_DCGAN_results/Compress/'+str(big_size) + '_' + str(small_size)+ '/Training')
    if not os.path.isdir('MNIST_DCGAN_results/Compress/'+ str(big_size) + '_' + str(small_size)+ '/Validation'):
        os.mkdir('MNIST_DCGAN_results/Compress/'+ str(big_size) + '_' + str(small_size)+ '/Validation')
    if not os.path.isdir('MNIST_DCGAN_results/FID/'):
        os.mkdir('MNIST_DCGAN_results/FID/')
    
    if os.path.exists("MNIST_DCGAN_results/iter/best_state_"+ str(big_size)+ ".pkl"):
        print("load")
        checkpoint = torch.load("MNIST_DCGAN_results/iter/best_state_"+ str(big_size)+ ".pkl")
        G.load_state_dict(checkpoint['G'])
    if os.path.exists("MNIST_DCGAN_results/iter/state_"+ str(small_size)+ ".pkl"):
        checkpoint = torch.load("MNIST_DCGAN_results/iter/state_" + str(small_size)+ ".pkl")
        small_G.load_state_dict(checkpoint['G'])
        
    #     D.load_state_dict(checkpoint['D'])
        G_optimizer.load_state_dict(checkpoint['G_optimizer'])
    #     D_optimizer.load_state_dict(checkpoint['D_optimizer'])
        train_hist = checkpoint['train_hist']
        total_ptime =  checkpoint['total_ptime']
        start_epoch = checkpoint['epoch']
        best_FID = checkpoint['best_FID']
        print("start from epoch" + str(start_epoch))
        #num_iter = start_epoch

    else:
        start_epoch = 0
        total_ptime = 0
        train_hist = {}
        best_FID = sys.maxsize
        #train_hist['D_losses'] = []
        train_hist['G_losses'] = []
        train_hist['per_epoch_ptimes'] = []
        train_hist['total_ptime'] = []
        #num_iter = start_epoch 
    
    print('training start!')
    # start_time = time.time()
    for epoch in range(start_epoch, train_epoch):
    #     D_losses = []
        G_losses = []
        epoch_start_time = time.time()
        num_iter = 0
        for x_, _ in train_loader:
            mini_batch = x_.size()[0]
            G.zero_grad()
            z_ = torch.randn((mini_batch, 100)).view(-1, 100, 1, 1)
            z_ = Variable(z_.cuda())
            G_result = G(z_)
            small_G_result = small_G(z_)
            G_train_loss = L1_loss( small_G_result, G_result.detach())
            G_train_loss.backward()
            G_optimizer.step()

            G_losses.append(G_train_loss.item())

            num_iter += 1
        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time
        print(time.strftime("%Y-%m-%d %H:%M:%S",time.localtime()))

        print('[%d/%d] - ptime: %.2f, loss_g: %.3f' % ((epoch + 1), train_epoch, per_epoch_ptime, 
                                                                  torch.mean(torch.FloatTensor(G_losses))))
        #p = 'MNIST_DCGAN_results/Compress/Random_results/MNIST_DCGAN_' + str(epoch + 1) 
        
        fixed_p = 'MNIST_DCGAN_results/Compress/'+ str(big_size) + '_' + str(small_size)+ '/Validation/MNIST_DCGAN_' + str(epoch + 1) 
        #show_result((epoch+1), save=True, path=p, isFix=False)
        show_result(G, small_G, (epoch+1), save=True, path=fixed_p, isFix=True)
        #train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))
        train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))
        train_hist['per_epoch_ptimes'].append(per_epoch_ptime)
        total_ptime += per_epoch_ptime
        torch.cuda.empty_cache()
        
        
        for h in range(1):

            #export jpg
            z_ = torch.randn((10000, 100)).view(-1, 100, 1, 1)
            with torch.no_grad():
                z_ = Variable(z_.cuda())
            test_images_small = small_G(z_)
            images_small_numpy = test_images_small.cpu().data.numpy()
            torch.cuda.empty_cache()
            p_small_G = "MNIST_DCGAN_results/RGB/" +str(big_size) + "_"+ str(small_size)
            if not os.path.isdir(p_small_G):
                os.mkdir(p_small_G)
            p_mnist = "MNIST_DCGAN_results/RGB/mnist"
            if os.path.exists(p_mnist+".npz"):
                p_mnist = p_mnist+".npz"
            #path = [p_mnist, p_small_G]
            for i in range(0,10000):
                
                src=skimage.transform.resize(images_small_numpy[i][0], (64, 64))
                imageio.imwrite("temp.jpg",skimage.img_as_ubyte(src))
                src = cv2.imread("temp.jpg", 0)
                src_RGB = cv2.cvtColor(src, cv2.COLOR_GRAY2RGB)
                cv2.imwrite(p_small_G + '/'+str(h*1 + i)+".jpg", src_RGB)
        
        fid = fid_score.calculate_fid_given_paths([p_mnist, p_small_G], 50, True,  2048)
        #shutil.rmtree(p_small_G)
        if fid < best_FID:
            best_FID = fid
            print("FID update:" + str(fid))
            state = {'G': small_G.state_dict(), 'G_optimizer': G_optimizer.state_dict(),'train_hist' : train_hist, 'epoch': epoch+1, 'total_ptime' : total_ptime, 'best_FID' : best_FID}
            torch.save(state, "MNIST_DCGAN_results/iter/best_state_"+ str(small_size)+ ".pkl")
        else:
            print("FID:" + str(fid))
        state = {'G': small_G.state_dict(), 'G_optimizer': G_optimizer.state_dict(),'train_hist' : train_hist, 'epoch': epoch+1, 'total_ptime' : total_ptime, 'best_FID' : best_FID}
        torch.save(state, "MNIST_DCGAN_results/iter/state_"+ str(small_size)+ ".pkl")
        torch.cuda.empty_cache()

    # end_time = time.time()
    # total_ptime = end_time - start_time

    print("G_" + str(big_size) + '_' + str(small_size)+"_best_FID:" + str(best_FID))    
    train_hist['total_ptime'].append(total_ptime)

    print("Avg per epoch ptime: %.2f, total %d epochs ptime: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), train_epoch, total_ptime))
    print("Training finish!... save training results")
예제 #16
0
                   '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
        torch.save(netD.state_dict(),
                   '%s/netD_epoch_%d.pth' % (opt.outf, epoch))

if True:
    batch_size = opt.batchSize
    netG.load_state_dict(
        torch.load('%s/netG_epoch_%d.pth' % (opt.outf, opt.niter - 1)))

    # test netG, and calculate FID score
    netG.eval()
    for i in range(fake_images_number):
        noise = torch.randn(batch_size, nz, 1, 1).cuda()
        fake = netG(noise)
        for j in range(fake.shape[0]):
            vutils.save_image(fake.detach()[j, ...],
                              fake_folder +
                              '/{}.png'.format(j + i * batch_size),
                              normalize=True)
    netG.train()

    # calculate FID score
    fid_value = calculate_fid_given_paths([real_folder, fake_folder],
                                          opt.batchSize // 2, True)
    FIDs.append(fid_value)

    print('FID: {}'.format(fid_value))

    df = pd.DataFrame(FIDs)
    df.to_csv('FID_{}.csv'.format(opt.outf))
예제 #17
0
def perform_evaluation(run_name, image_type):

    out_dir = os.path.join(os.getcwd(), '..', 'output', run_name)
    checkpoint_dir = os.path.join(out_dir, 'chkpts')
    checkpoints = sorted(glob.glob(os.path.join(checkpoint_dir, '*')))
    evaluation_dict = {}

    for point in checkpoints:
        if not int(
                point.split('/')[-1].split('_')[1].split('.')[0]) % 10000 == 0:
            continue

        iter_num = int(point.split('/')[-1].split('_')[1].split('.')[0])
        model_file = point.split('/')[-1]

        config = load_config('../configs/fr_default.yaml', None)
        is_cuda = (torch.cuda.is_available())
        checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)
        device = torch.device("cuda:0" if is_cuda else "cpu")

        generator, discriminator = build_models(config)

        # Put models on gpu if needed
        generator = generator.to(device)
        discriminator = discriminator.to(device)

        # Use multiple GPUs if possible
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)

        generator_test_9 = copy.deepcopy(generator)
        generator_test_99 = copy.deepcopy(generator)
        generator_test_999 = copy.deepcopy(generator)
        generator_test_9999 = copy.deepcopy(generator)

        # Register modules to checkpoint
        checkpoint_io.register_modules(
            generator=generator,
            generator_test_9=generator_test_9,
            generator_test_99=generator_test_99,
            generator_test_999=generator_test_999,
            generator_test_9999=generator_test_9999,
            discriminator=discriminator,
        )

        # Load checkpoint
        load_dict = checkpoint_io.load(model_file)

        # Distributions
        ydist = get_ydist(config['data']['nlabels'], device=device)
        zdist = get_zdist(config['z_dist']['type'],
                          config['z_dist']['dim'],
                          device=device)
        z_sample = torch.Tensor(np.load('z_data.npy')).to(device)

        #for name, model in zip(['0_', '09_', '099_', '0999_', '09999_'], [generator, generator_test_9, generator_test_99, generator_test_999, generator_test_9999]):
        for name, model in zip(
            ['099_', '0999_', '09999_'],
            [generator_test_99, generator_test_999, generator_test_9999]):

            # Evaluator
            evaluator = Evaluator(model, zdist, ydist, device=device)

            x_sample = []

            for i in range(10):
                x = evaluator.create_samples(z_sample[i * 1000:(i + 1) * 1000])
                x_sample.append(x)

            x_sample = torch.cat(x_sample)
            x_sample = x_sample / 2 + 0.5

            if not os.path.exists('fake_data'):
                os.makedirs('fake_data')

            for i in range(10000):
                torchvision.utils.save_image(x_sample[i, :, :, :],
                                             'fake_data/{}.png'.format(i))

            fid_score = calculate_fid_given_paths(
                ['fake_data', image_type + '_real'], 50, True, 2048)
            print(iter_num, name, fid_score)

            os.system("rm -rf " + "fake_data")

            evaluation_dict[(iter_num, name[:-1])] = {'FID': fid_score}

            if not os.path.exists('evaluation_data/' + run_name):
                os.makedirs('evaluation_data/' + run_name)

            pickle.dump(
                evaluation_dict,
                open('evaluation_data/' + run_name + '/eval_fid.p', 'wb'))
예제 #18
0
def run(config):
    config['resolution'] = imsize_dict[config['dataset']]
    config['n_classes'] = nclass_dict[config['dataset']]
    config['G_activation'] = activation_dict[config['G_nl']]
    config['D_activation'] = activation_dict[config['D_nl']]
    if config['resume']:
        config['skip_init'] = True
    config = update_config_roots(config)
    device = 'cuda'
    utils_Task1_KLWGAN_Simulation_Experiment.seed_rng(config['seed'])
    utils_Task1_KLWGAN_Simulation_Experiment.prepare_root(config)
    torch.backends.cudnn.benchmark = True
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name'] else utils_Task1_KLWGAN_Simulation_Experiment.name_from_config(config))
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)
    if config['ema']:
        G_ema = model.Generator(**{**config, 'skip_init': True, 'no_optim': True}).to(device)
        ema = utils_Task1_KLWGAN_Simulation_Experiment.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None
    if config['G_fp16']:
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        D = D.half()
    GD = model.G_D(G, D, config['conditional'])
    state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config}
    if config['resume']:
        utils_Task1_KLWGAN_Simulation_Experiment.load_weights(G, D, state_dict, config['weights_root'], experiment_name, config['load_weights'] if config['load_weights'] else None, G_ema if config['ema'] else None)
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'], experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    test_log = utils_Task1_KLWGAN_Simulation_Experiment.MetricsLogger(test_metrics_fname, reinitialize=(not config['resume']))
    train_log = utils_Task1_KLWGAN_Simulation_Experiment.MyLogger(train_metrics_fname, reinitialize=(not config['resume']), logstyle=config['logstyle'])
    utils_Task1_KLWGAN_Simulation_Experiment.write_metadata(config['logs_root'], experiment_name, config, state_dict)
    D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations'])
    # Use: config['abnormal_class']
    #print(config['abnormal_class'])
    abnormal_class = config['abnormal_class']
    select_dataset = config['select_dataset']
    #print(config['select_dataset'])
    #print(select_dataset)
    loaders = utils_Task1_KLWGAN_Simulation_Experiment.get_data_loaders(**{**config, 'batch_size': D_batch_size, 'start_itr': state_dict['itr'], 'abnormal_class': abnormal_class, 'select_dataset': select_dataset})
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils_Task1_KLWGAN_Simulation_Experiment.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16'])
    fixed_z, fixed_y = utils_Task1_KLWGAN_Simulation_Experiment.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()
    if not config['conditional']:
        fixed_y.zero_()
        y_.zero_()
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, GD, z_, y_, ema, state_dict, config)
    else:
        train = train_fns.dummy_training_function()
    sample = functools.partial(utils_Task1_KLWGAN_Simulation_Experiment.sample, G=(G_ema if config['ema'] and config['use_ema'] else G), z_=z_, y_=y_, config=config)
    if config['dataset'] == 'C10U' or config['dataset'] == 'C10':
        data_moments = 'fid_stats_cifar10_train.npz'
        #'../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz'
        #data_moments = '../Task1_CIFAR_MNIST_KLWGAN_Simulation_Experiment/fid_stats_cifar10_train.npz'
    else:
        print("Cannot find the dataset.")
        sys.exit()
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        if config['pbar'] == 'mine':
            pbar = utils_Task1_KLWGAN_Simulation_Experiment.progress(loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            state_dict['itr'] += 1
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            #metrics = train(x, y)
            print('')
            # Random seed
            #print(config['seed'])
            if epoch==0 and i==0:
                print(config['seed'])
            metrics = train(x, y)
            # We double the learning rate if we double the batch size.
            train_log.log(itr=int(state_dict['itr']), **metrics)
            if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']), **{**utils_Task1_KLWGAN_Simulation_Experiment.get_SVs(G, 'G'), **utils_Task1_KLWGAN_Simulation_Experiment.get_SVs(D, 'D')})
            if config['pbar'] == 'mine':
                print(', '.join(['itr: %d' % state_dict['itr']] + ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ')
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config, experiment_name)
            experiment_name = (config['experiment_name'] if config['experiment_name'] else utils_Task1_KLWGAN_Simulation_Experiment.name_from_config(config))
            if (not (state_dict['itr'] % config['test_every'])) and (epoch >= config['start_eval']):
                if config['G_eval_mode']:
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                utils_Task1_KLWGAN_Simulation_Experiment.sample_inception(G_ema if config['ema'] and config['use_ema'] else G, config, str(epoch))
                folder_number = str(epoch)
                sample_moments = '%s/%s/%s/samples.npz' % (config['samples_root'], experiment_name, folder_number)
                #FID = fid_score.calculate_fid_given_paths([data_moments, sample_moments], batch_size=50, cuda=True, dims=2048)
                #train_fns.update_FID(G, D, G_ema, state_dict, config, FID, experiment_name, test_log)
                # Use the files train_fns.py and utils_Task1_KLWGAN_Simulation_Experiment.py
                # Use the functions update_FID() and save_weights()
                # Save the lowest FID score
                FID = fid_score.calculate_fid_given_paths([data_moments, sample_moments], batch_size=50, cuda=True, dims=2048)
                train_fns.update_FID(G, D, G_ema, state_dict, config, FID, experiment_name, test_log)
                # FID also from: https://github.com/DarthSid95/RumiGANs/blob/main/gan_metrics.py
                # Implicit generative models and GANs generate sharp, low-FID, realistic, and high-quality images.
                # We use implicit generative models and GANs for the challenging task of anomaly detection in high-dimensional spaces.
        state_dict['epoch'] += 1
    # Save the last model
    utils_Task1_KLWGAN_Simulation_Experiment.save_weights(G, D, state_dict, config['weights_root'], experiment_name, 'last%d' % 0, G_ema if config['ema'] else None)
예제 #19
0
파일: train.py 프로젝트: sailfish009/f-wgan
def run(config):

    # Update the config dict as necessary
    # This is for convenience, to add settings derived from the user-specified
    # configuration into the config-dict (e.g. inferring the number of classes
    # and size of the images from the dataset, passing in a pytorch object
    # for the activation specified as a string)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    # By default, skip init if resuming training.
    if config['resume']:
        print('Skipping initialization for training resumption...')
        config['skip_init'] = True
    config = utils.update_config_roots(config)
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Prepare root folders if necessary
    utils.prepare_root(config)

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    # Next, build the model
    G = model.Generator(**config).to(device)
    D = model.Discriminator(**config).to(device)

    # If using EMA, prepare it
    if config['ema']:
        print('Preparing EMA for G with decay of {}'.format(
            config['ema_decay']))
        G_ema = model.Generator(**{**config, 'skip_init': True,
                                   'no_optim': True}).to(device)
        ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
    else:
        G_ema, ema = None, None

    # FP16?
    if config['G_fp16']:
        print('Casting G to float16...')
        G = G.half()
        if config['ema']:
            G_ema = G_ema.half()
    if config['D_fp16']:
        print('Casting D to fp16...')
        D = D.half()
        # Consider automatically reducing SN_eps?
    GD = model.G_D(G, D, config['conditional'])  # setting conditional to false
    print(G)
    print(D)
    print('Number of params in G: {} D: {}'.format(
        *[sum([p.data.nelement() for p in net.parameters()]) for net in [G, D]]))
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                  'best_IS': 0, 'best_FID': 999999, 'config': config}

    # If loading from a pre-trained model, load weights
    if config['resume']:
        print('Loading weights...')
        utils.load_weights(G, D, state_dict,
                           config['weights_root'], experiment_name,
                           config['load_weights'] if config['load_weights'] else None,
                           G_ema if config['ema'] else None)

    # If parallel, parallelize the GD module
    if config['parallel']:
        GD = nn.DataParallel(GD)
        if config['cross_replica']:
            patch_replication_callback(GD)

    # Prepare loggers for stats; metrics holds test metrics,
    # lmetrics holds any desired training metrics.
    test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
                                              experiment_name)
    train_metrics_fname = '%s/%s' % (config['logs_root'], experiment_name)
    print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
    test_log = utils.MetricsLogger(test_metrics_fname,
                                   reinitialize=(not config['resume']))
    print('Training Metrics will be saved to {}'.format(train_metrics_fname))
    train_log = utils.MyLogger(train_metrics_fname,
                               reinitialize=(not config['resume']),
                               logstyle=config['logstyle'])
    # Write metadata
    utils.write_metadata(config['logs_root'],
                         experiment_name, config, state_dict)
    # Prepare data; the Discriminator's batch size is all that needs to be passed
    # to the dataloader, as G doesn't require dataloading.
    # Note that at every loader iteration we pass in enough data to complete
    # a full D iteration (regardless of number of D steps and accumulations)
    D_batch_size = (config['batch_size'] * config['num_D_steps']
                    * config['num_D_accumulations'])
    loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
                                        'start_itr': state_dict['itr']})

    # Prepare noise and randomly sampled label arrays
    # Allow for different batch sizes in G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],
                               device=device, fp16=config['G_fp16'])
    # Prepare a fixed z & y to see individual sample evolution throghout training
    fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z,
                                         config['n_classes'], device=device,
                                         fp16=config['G_fp16'])
    fixed_z.sample_()
    fixed_y.sample_()

    if not config['conditional']:
        fixed_y.zero_()
        y_.zero_()
    # Loaders are loaded, prepare the training function
    if config['which_train_fn'] == 'GAN':
        train = train_fns.GAN_training_function(G, D, GD, z_, y_,
                                                ema, state_dict, config)
    # Else, assume debugging and use the dummy train fn
    else:
        train = train_fns.dummy_training_function()
    # Prepare Sample function for use with inception metrics
    sample = functools.partial(utils.sample,
                               G=(G_ema if config['ema'] and config['use_ema']
                                   else G),
                               z_=z_, y_=y_, config=config)

    print('Beginning training at epoch %d...' % state_dict['epoch'])
    print('Total training epochs ',  config['num_epochs'])
    print("the dataset is ", config['dataset'], )
    if config['dataset'] == 'C10U' or config['dataset'] == 'C10':
        data_moments = 'fid_stats_cifar10_train.npz'
    else:
        print("cannot find the dataset")
        sys.exit()

    print("the data moments is ", data_moments)
    # Train for specified number of epochs, although we mostly track G iterations.
    for epoch in range(state_dict['epoch'], config['num_epochs']):
        # Which progressbar to use? TQDM or my own?
        if config['pbar'] == 'mine':
            pbar = utils.progress(
                loaders[0], displaytype='s1k' if config['use_multiepoch_sampler'] else 'eta')
        else:
            pbar = tqdm(loaders[0])
        for i, (x, y) in enumerate(pbar):
            # Increment the iteration counter
            state_dict['itr'] += 1
            # Make sure G and D are in training mode, just in case they got set to eval
            # For D, which typically doesn't have BN, this shouldn't matter much.
            G.train()
            D.train()
            if config['ema']:
                G_ema.train()
            if config['D_fp16']:
                x, y = x.to(device).half(), y.to(device)
            else:
                x, y = x.to(device), y.to(device)
            metrics = train(x, y)
            train_log.log(itr=int(state_dict['itr']), **metrics)

            # Every sv_log_interval, log singular values
            if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])):
                train_log.log(itr=int(state_dict['itr']),
                              **{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})

            # If using my progbar, print metrics.
            if config['pbar'] == 'mine':
                print(', '.join(['itr: %d' % state_dict['itr']]
                                + ['%s : %+4.3f' % (key, metrics[key])
                                    for key in metrics]), end=' ')

            # Save weights and copies as configured at specified interval
            if not (state_dict['itr'] % config['save_every']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y,
                                          state_dict, config, experiment_name)

            # Test every specified interval
            # First load celeba moments

            experiment_name = (config['experiment_name'] if config['experiment_name']
                               else utils.name_from_config(config))
            if (not (state_dict['itr'] % config['test_every'])) and (epoch >= config['start_eval']):
                if config['G_eval_mode']:
                    print('Switchin G to eval mode...')
                    G.eval()
                    if config['ema']:
                        G_ema.eval()
                # sampling images and saving to samples/experiments/epoch
                utils.sample_inception(
                    G_ema if config['ema'] and config['use_ema'] else G, config, str(epoch))
                # Get saved sample path
                folder_number = str(epoch)
                sample_moments = '%s/%s/%s/samples.npz' % (
                    config['samples_root'], experiment_name, folder_number)
                # Calculate FID
                FID = fid_score.calculate_fid_given_paths([data_moments, sample_moments],
                                                          batch_size=50, cuda=True, dims=2048)
                print("FID calculated")
                train_fns.update_FID(
                    G, D, G_ema, state_dict, config, FID, experiment_name, test_log)
                # train_fns.test(G, D, G_ema, z_, y_, state_dict, config, sample,
                #                get_inception_metrics, experiment_name, test_log)
        # Increment epoch counter at end of epoch
        state_dict['epoch'] += 1