def get_generator(self):
        ''' loads a generator according to self.model_path '''

        exp_out_dir = os.path.join(self.rootdir,
                                   self.config['training']['out_dir'])
        # infer checkpoint if neeeded
        checkpoint_dir = os.path.join(
            exp_out_dir, 'chkpts'
        ) if self.model_path == "" or 'model' in self.pretrained else "./"
        model_name = get_most_recent(os.listdir(
            checkpoint_dir)) if self.model_path == "" else self.model_path

        checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)
        self.checkpoint_io = checkpoint_io

        generator, _ = build_models(self.config)
        generator = generator.to(self.device)
        generator = nn.DataParallel(generator)

        if self.config['training']['take_model_average']:
            generator_test = copy.deepcopy(generator)
            checkpoint_io.register_modules(generator_test=generator_test)
        else:
            generator_test = generator

        checkpoint_io.register_modules(generator=generator)

        try:
            it = checkpoint_io.load(model_name, pretrained=self.pretrained)
            assert (it != -1)
        except Exception as e:
            # try again without data parallel
            print(e)
            checkpoint_io.register_modules(generator=generator.module)
            checkpoint_io.register_modules(
                generator_test=generator_test.module)
            it = checkpoint_io.load(model_name, pretrained=self.pretrained)
            assert (it != -1)

        print('Loaded iteration:', it['it'])
        return generator_test
Пример #2
0
if config['training']['take_model_average']:
    generator_test = copy.deepcopy(generator)
    checkpoint_io.register_modules(generator_test=generator_test)
else:
    generator_test = generator

# Evaluator
evaluator = Evaluator(generator_test, zdist, ydist,
                      batch_size=batch_size, device=device)

# Train
tstart = t0 = time.time()
it = epoch_idx = -1

# Load checkpoint if existant
it = checkpoint_io.load('model.pt')
if it != -1:
    logger.load_stats('stats.p')

# Reinitialize model average if needed
if (config['training']['take_model_average']
        and config['training']['model_average_reinit']):
    update_average(generator_test, generator, 0.)

# Learning rate anneling
g_scheduler = build_lr_scheduler(g_optimizer, config, last_epoch=it)
d_scheduler = build_lr_scheduler(d_optimizer, config, last_epoch=it)

# Trainer
trainer = Trainer(
    generator, discriminator, g_optimizer, d_optimizer,
                                 batch_size=config['test']['batch_size'],
                                 device=device)

    # initialize fid evaluators
    if 'fid' in eval_attr:
        cache_file = os.path.join(out_dir, 'cache_test.npz')
        test_loader = get_dataloader(config, split='test')
        evaluator.inception_eval.initialize_target(test_loader,
                                                   cache_file=cache_file)
        cache_file = os.path.join(out_dir, 'cache_test_single.npz')
        test_loader_single = get_dataloader(config, split='test', single=True)
        evaluator_single.inception_eval.initialize_target(
            test_loader_single, cache_file=cache_file)

    # Load checkpoint if existant
    load_dict = checkpoint_io.load(model_file)
    it = load_dict.get('it', -1)
    epoch_idx = load_dict.get('epoch_idx', -1)

    # Pick a random but fixed seed
    seed = torch.randint(0, 10000, (1, ))[0]

    if 'fid' in eval_attr:
        print('Computing FID score...')

        def sample(transform):
            while True:
                z = zdist.sample((evaluator.batch_size, ))
                x = evaluator.create_samples(z, param_transform=transform)
                rgb = x['img']
                del x, z
Пример #4
0
        shuffle=True,
        pin_memory=True,
        sampler=None,
        drop_last=True)
    fid_real_samples, _ = utils.get_nsamples(train_loader, fid_sample_size)

evaluator = Evaluator(generator_test,
                      zdist,
                      ydist,
                      batch_size=batch_size,
                      device=device,
                      fid_real_samples=fid_real_samples,
                      fid_sample_size=fid_sample_size)

# Load checkpoint if existant
it = checkpoint_io.load('CelebA_HQ_vgan_model.pt')

# Inception score
if config['test']['compute_inception']:
    print('Computing inception score...')
    inception_mean, inception_std, fid = evaluator.compute_inception_score()
    print('Inception score: %.4f +- %.4f, FID: %.2f' %
          (inception_mean, inception_std, fid))

# Samples
print('Creating samples...')
for i in range(sample_size):
    ztest = zdist.sample((1, ))
    x = evaluator.create_samples(ztest)
    utils.save_images(x, path.join(img_dir, '%05d.png' % i))
Пример #5
0
        generator, discriminator = build_models(config)
        generator = torch.nn.DataParallel(generator)
        zdist = get_zdist(config['z_dist']['type'],
                          config['z_dist']['dim'],
                          device=device)
        ydist = get_ydist(1, device=device)
        checkpoint_io = CheckpointIO(checkpoint_dir="./tmp")
        checkpoint_io.register_modules(generator_test=generator)
        evaluator = Evaluator(generator,
                              zdist,
                              ydist,
                              batch_size=100,
                              device=device)

        ckptpath = os.path.join(
            model_name, "chkpts",
            "model_{:08d}.pt".format(epoch_id * 10000 + 9999))
        print(ckptpath)
        load_dict = checkpoint_io.load(ckptpath)
        img_list = []
        for i in range(500):
            ztest = zdist.sample((100, ))
            x = evaluator.create_samples(ztest)
            img_list.append(x.cpu().numpy())
        img_list = np.concatenate(img_list, axis=0)
        m, s = evaluation(img_list)
        all_results[key_name].append([float(m), float(s)])

with open("./output/cifar_inception_plot.pkl", 'wb') as f:
    pickle.dump(all_results, f)
Пример #6
0
# Distributions
ydist = get_ydist(nlabels, device=device)
zdist = get_zdist(config['z_dist']['type'],
                  config['z_dist']['dim'],
                  device=device)

# Evaluator
evaluator = Evaluator(generator_test,
                      zdist,
                      ydist,
                      batch_size=batch_size,
                      device=device)

# Load checkpoint if existant
load_dict = checkpoint_io.load(args.oldmodel)
it = load_dict.get('it', -1)
epoch_idx = load_dict.get('epoch_idx', -1)

# Inception score
if config['test']['compute_inception']:
    print('Computing inception score...')
    inception_mean, inception_std = evaluator.compute_inception_score()
    print('Inception score: %.4f +- %.4f' % (inception_mean, inception_std))

# Samples
print('Creating samples...')
ztest = zdist.sample((sample_size, ))
x = evaluator.create_samples(ztest)
utils.save_images(x, path.join(img_all_dir, '%08d.png' % it), nrow=sample_nrow)
Пример #7
0
def perform_evaluation(run_name, image_type):

    out_dir = os.path.join(os.getcwd(), '..', 'output', run_name)
    checkpoint_dir = os.path.join(out_dir, 'chkpts')
    checkpoints = sorted(glob.glob(os.path.join(checkpoint_dir, '*')))
    evaluation_dict = {}

    for point in checkpoints:
        if not int(
                point.split('/')[-1].split('_')[1].split('.')[0]) % 10000 == 0:
            continue

        iter_num = int(point.split('/')[-1].split('_')[1].split('.')[0])
        model_file = point.split('/')[-1]

        config = load_config('../configs/fr_default.yaml', None)
        is_cuda = (torch.cuda.is_available())
        checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)
        device = torch.device("cuda:0" if is_cuda else "cpu")

        generator, discriminator = build_models(config)

        # Put models on gpu if needed
        generator = generator.to(device)
        discriminator = discriminator.to(device)

        # Use multiple GPUs if possible
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)

        generator_test_9 = copy.deepcopy(generator)
        generator_test_99 = copy.deepcopy(generator)
        generator_test_999 = copy.deepcopy(generator)
        generator_test_9999 = copy.deepcopy(generator)

        # Register modules to checkpoint
        checkpoint_io.register_modules(
            generator=generator,
            generator_test_9=generator_test_9,
            generator_test_99=generator_test_99,
            generator_test_999=generator_test_999,
            generator_test_9999=generator_test_9999,
            discriminator=discriminator,
        )

        # Load checkpoint
        load_dict = checkpoint_io.load(model_file)

        # Distributions
        ydist = get_ydist(config['data']['nlabels'], device=device)
        zdist = get_zdist(config['z_dist']['type'],
                          config['z_dist']['dim'],
                          device=device)
        z_sample = torch.Tensor(np.load('z_data.npy')).to(device)

        #for name, model in zip(['0_', '09_', '099_', '0999_', '09999_'], [generator, generator_test_9, generator_test_99, generator_test_999, generator_test_9999]):
        for name, model in zip(
            ['099_', '0999_', '09999_'],
            [generator_test_99, generator_test_999, generator_test_9999]):

            # Evaluator
            evaluator = Evaluator(model, zdist, ydist, device=device)

            x_sample = []

            for i in range(10):
                x = evaluator.create_samples(z_sample[i * 1000:(i + 1) * 1000])
                x_sample.append(x)

            x_sample = torch.cat(x_sample)
            x_sample = x_sample / 2 + 0.5

            if not os.path.exists('fake_data'):
                os.makedirs('fake_data')

            for i in range(10000):
                torchvision.utils.save_image(x_sample[i, :, :, :],
                                             'fake_data/{}.png'.format(i))

            fid_score = calculate_fid_given_paths(
                ['fake_data', image_type + '_real'], 50, True, 2048)
            print(iter_num, name, fid_score)

            os.system("rm -rf " + "fake_data")

            evaluation_dict[(iter_num, name[:-1])] = {'FID': fid_score}

            if not os.path.exists('evaluation_data/' + run_name):
                os.makedirs('evaluation_data/' + run_name)

            pickle.dump(
                evaluation_dict,
                open('evaluation_data/' + run_name + '/eval_fid.p', 'wb'))