Exemple #1
0
    backup_every = config['training']['backup_every']
    sample_nlabels = config['training']['sample_nlabels']
    dim_z = config['z_dist']['dim']

    out_dir = config['training']['out_dir']
    checkpoint_dir = path.join(out_dir, 'chkpts')

    # Create missing directories
    if not path.exists(out_dir):
        os.makedirs(out_dir)
    if not path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    shutil.copyfile(sys.argv[0], out_dir + '/training_script.py')

    # Logger
    checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset
    train_dataset, nlabels = get_dataset(
        name=config['data']['type'],
        data_dir=config['data']['train_dir'],
        size=config['data']['img_size'],
        lsun_categories=config['data']['lsun_categories_train'])
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=config['training']['nworkers'],
        shuffle=True,
        pin_memory=True,
    # Shorthands
    nlabels = config['data']['nlabels']
    batch_size = config['test']['batch_size']
    sample_size = config['test']['sample_size']
    sample_nrow = config['test']['sample_nrow']

    out_dir = get_out_dir(config)
    checkpoint_dir = path.join(out_dir, 'chkpts')
    out_dir = path.join(out_dir, 'test')

    # Creat missing directories
    os.makedirs(out_dir, exist_ok=True)

    # Logger
    checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)

    device = torch.device("cuda:0" if is_cuda else "cpu")

    # Disable parameter transforms for testing
    config['param_transforms'] = 'none'

    # Create models
    generator, discriminator = build_models(config)
    print(generator)
    print(discriminator)

    # Put models on gpu if needed
    generator = generator.to(device)
    discriminator = discriminator.to(device)
Exemple #3
0
save_every = config['training']['save_every']
backup_every = config['training']['backup_every']
sample_nlabels = config['training']['sample_nlabels']

out_dir = config['training']['out_dir']
checkpoint_dir = path.join(out_dir, 'chkpts')

# Create missing directories
if not path.exists(out_dir):
    os.makedirs(out_dir)
if not path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# Logger
checkpoint_io = CheckpointIO(
    checkpoint_dir=checkpoint_dir
)

device = torch.device("cuda:0" if is_cuda else "cpu")


# Dataset
train_dataset, nlabels = get_dataset(
    name=config['data']['type'],
    data_dir=config['data']['train_dir'],
    size=config['data']['img_size'],
    lsun_categories=config['data']['lsun_categories_train']
)
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
Exemple #4
0
def main():
    pp = pprint.PrettyPrinter(indent=1)
    pp.pprint({
        'data': config['data'],
        'generator': config['generator'],
        'discriminator': config['discriminator'],
        'clusterer': config['clusterer'],
        'training': config['training']
    })
    is_cuda = torch.cuda.is_available()

    # Short hands
    batch_size = config['training']['batch_size']
    log_every = config['training']['log_every']
    inception_every = config['training']['inception_every']
    backup_every = config['training']['backup_every']
    sample_nlabels = config['training']['sample_nlabels']
    nlabels = config['data']['nlabels']
    sample_nlabels = min(nlabels, sample_nlabels)

    checkpoint_dir = path.join(out_dir, 'chkpts')
    nepochs = args.nepochs

    # Create missing directories
    if not path.exists(out_dir):
        os.makedirs(out_dir)
    if not path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    # Logger
    checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)

    device = torch.device("cuda:0" if is_cuda else "cpu")

    train_dataset, _ = get_dataset(
        name=config['data']['type'],
        data_dir=config['data']['train_dir'],
        size=config['data']['img_size'],
        deterministic=config['data']['deterministic'])

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=config['training']['nworkers'],
        shuffle=True,
        pin_memory=True,
        sampler=None,
        drop_last=True)

    # Create models
    generator, discriminator = build_models(config)

    # Put models on gpu if needed
    generator = generator.to(device)
    discriminator = discriminator.to(device)

    for name, module in discriminator.named_modules():
        if isinstance(module, nn.Sigmoid):
            print('Found sigmoid layer in discriminator; not compatible with BCE with logits')
            exit()

    g_optimizer, d_optimizer = build_optimizers(generator, discriminator, config)

    devices = [int(x) for x in args.devices]
    generator = nn.DataParallel(generator, device_ids=devices)
    discriminator = nn.DataParallel(discriminator, device_ids=devices)

    # Register modules to checkpoint
    checkpoint_io.register_modules(generator=generator,
                                   discriminator=discriminator,
                                   g_optimizer=g_optimizer,
                                   d_optimizer=d_optimizer)

    # Logger
    logger = Logger(log_dir=path.join(out_dir, 'logs'),
                    img_dir=path.join(out_dir, 'imgs'),
                    monitoring=config['training']['monitoring'],
                    monitoring_dir=path.join(out_dir, 'monitoring'))

    # Distributions
    ydist = get_ydist(nlabels, device=device)
    zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'], device=device)

    ntest = config['training']['ntest']
    x_test, y_test = utils.get_nsamples(train_loader, ntest)
    x_cluster, y_cluster = utils.get_nsamples(train_loader, config['clusterer']['nimgs'])
    x_test, y_test = x_test.to(device), y_test.to(device)
    z_test = zdist.sample((ntest, ))
    utils.save_images(x_test, path.join(out_dir, 'real.png'))
    logger.add_imgs(x_test, 'gt', 0)

    # Test generator
    if config['training']['take_model_average']:
        print('Taking model average')
        bad_modules = [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]
        for model in [generator, discriminator]:
            for name, module in model.named_modules():
                for bad_module in bad_modules:
                    if isinstance(module, bad_module):
                        print('Batch norm in discriminator not compatible with exponential moving average')
                        exit()
        generator_test = copy.deepcopy(generator)
        checkpoint_io.register_modules(generator_test=generator_test)
    else:
        generator_test = generator

    clusterer = get_clusterer(config)(discriminator=discriminator,
                                      x_cluster=x_cluster,
                                      x_labels=y_cluster,
                                      gt_nlabels=config['data']['nlabels'],
                                      **config['clusterer']['kwargs'])

    # Load checkpoint if it exists
    it = utils.get_most_recent(checkpoint_dir, 'model') if args.model_it == -1 else args.model_it
    it, epoch_idx, loaded_clusterer = checkpoint_io.load_models(it=it, load_samples='supervised' != config['clusterer']['name'])

    if loaded_clusterer is None:
        print('Initializing new clusterer. The first clustering can be quite slow.')
        clusterer.recluster(discriminator=discriminator)
        checkpoint_io.save_clusterer(clusterer, it=0)
        np.savez(os.path.join(checkpoint_dir, 'cluster_samples.npz'), x=x_cluster)
    else:
        print('Using loaded clusterer')
        clusterer = loaded_clusterer

    # Evaluator
    evaluator = Evaluator(
        generator_test,
        zdist,
        ydist,
        train_loader=train_loader,
        clusterer=clusterer,
        batch_size=batch_size,
        device=device,
        inception_nsamples=config['training']['inception_nsamples'])

    # Trainer
    trainer = Trainer(generator,
                      discriminator,
                      g_optimizer,
                      d_optimizer,
                      gan_type=config['training']['gan_type'],
                      reg_type=config['training']['reg_type'],
                      reg_param=config['training']['reg_param'])

    # Training loop
    print('Start training...')
    while it < args.nepochs * len(train_loader):
        epoch_idx += 1

        for x_real, y in train_loader:
            it += 1

            x_real, y = x_real.to(device), y.to(device)
            z = zdist.sample((batch_size, ))
            y = clusterer.get_labels(x_real, y).to(device)

            # Discriminator updates
            dloss, reg = trainer.discriminator_trainstep(x_real, y, z)
            logger.add('losses', 'discriminator', dloss, it=it)
            logger.add('losses', 'regularizer', reg, it=it)

            # Generators updates
            gloss = trainer.generator_trainstep(y, z)
            logger.add('losses', 'generator', gloss, it=it)

            if config['training']['take_model_average']:
                update_average(generator_test, generator, beta=config['training']['model_average_beta'])

            # Print stats
            if it % log_every == 0:
                g_loss_last = logger.get_last('losses', 'generator')
                d_loss_last = logger.get_last('losses', 'discriminator')
                d_reg_last = logger.get_last('losses', 'regularizer')
                print('[epoch %0d, it %4d] g_loss = %.4f, d_loss = %.4f, reg=%.4f'
                      % (epoch_idx, it, g_loss_last, d_loss_last, d_reg_last))

            if it % config['training']['recluster_every'] == 0 and it > config['training']['burnin_time']:
                # print cluster distribution for online methods
                if it % 100 == 0 and config['training']['recluster_every'] <= 100:
                    print(f'[epoch {epoch_idx}, it {it}], distribution: {clusterer.get_label_distribution(x_real)}')
                clusterer.recluster(discriminator=discriminator, x_batch=x_real)

            # (i) Sample if necessary
            if it % config['training']['sample_every'] == 0:
                print('Creating samples...')
                x = evaluator.create_samples(z_test, y_test)
                x = evaluator.create_samples(z_test, clusterer.get_labels(x_test, y_test).to(device))
                logger.add_imgs(x, 'all', it)

                for y_inst in range(sample_nlabels):
                    x = evaluator.create_samples(z_test, y_inst)
                    logger.add_imgs(x, '%04d' % y_inst, it)

            # (ii) Compute inception if necessary
            if it % inception_every == 0 and it > 0:
                print('PyTorch Inception score...')
                inception_mean, inception_std = evaluator.compute_inception_score()
                logger.add('metrics', 'pt_inception_mean', inception_mean, it=it)
                logger.add('metrics', 'pt_inception_stddev', inception_std, it=it)
                print(f'[epoch {epoch_idx}, it {it}] pt_inception_mean: {inception_mean}, pt_inception_stddev: {inception_std}')

            # (iii) Backup if necessary
            if it % backup_every == 0:
                print('Saving backup...')
                checkpoint_io.save('model_%08d.pt' % it, it=it)
                checkpoint_io.save_clusterer(clusterer, int(it))
                logger.save_stats('stats_%08d.p' % it)

                if it > 0: checkpoint_io.save('model.pt', it=it)
Exemple #5
0
  
  out_dir = get_out_dir(config)
  checkpoint_dir = path.join(out_dir, 'chkpts')
  
  # Create missing directories
  if not path.exists(out_dir):
    os.makedirs(out_dir)
  if not path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
  
  # Save config for this run
  save_config(os.path.join(out_dir, 'config.yaml'), config)
  
  # Logger
  checkpoint_io = CheckpointIO(
    checkpoint_dir=checkpoint_dir
  )
  
  device = torch.device("cuda:0" if is_cuda else "cpu")
  
  # Dataset
  train_loader = get_dataloader_stamps(config, split='train')
  
  # Create models
  generator, discriminator = build_models(config)
  print(generator)
  print(discriminator)

  # Put models on gpu if needed
  generator = generator.to(device)
  discriminator = discriminator.to(device)
def main():
    checkpoint_dir = os.path.join(out_dir, 'chkpts')
    batch_size = config['training']['batch_size']

    if 'cifar' in config['data']['train_dir'].lower():
        name = 'cifar10'
    elif 'stacked_mnist' == config['data']['type']:
        name = 'stacked_mnist'
    else:
        name = 'image'

    if os.path.exists(os.path.join(out_dir, 'cluster_preds.npz')):
        # if we've already computed assignments, load them and move on
        with np.load(os.path.join(out_dir, 'cluster_preds.npz')) as f:
            y_reals = f['y_reals']
            y_preds = f['y_preds']
    else:
        train_dataset, _ = get_dataset(name=name,
                                       data_dir=config['data']['train_dir'],
                                       size=config['data']['img_size'])

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=batch_size,
            num_workers=config['training']['nworkers'],
            shuffle=True,
            pin_memory=True,
            sampler=None,
            drop_last=True)

        checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)

        print('Loading clusterer:')
        most_recent = utils.get_most_recent(
            checkpoint_dir,
            'model') if args.model_it is None else args.model_it
        clusterer = checkpoint_io.load_clusterer(
            most_recent, load_samples=False, pretrained=config['pretrained'])

        if isinstance(clusterer.discriminator, nn.DataParallel):
            clusterer.discriminator = clusterer.discriminator.module

        y_preds = []
        y_reals = []

        for batch_num, (x_real, y_real) in enumerate(
                tqdm(train_loader, total=len(train_loader))):
            y_pred = clusterer.get_labels(x_real.cuda(), None)
            y_preds.append(y_pred.detach().cpu())
            y_reals.append(y_real)

        y_reals = torch.cat(y_reals).numpy()
        y_preds = torch.cat(y_preds).numpy()

        np.savez(os.path.join(out_dir, 'cluster_preds.npz'),
                 y_reals=y_reals,
                 y_preds=y_preds)

    if args.random:
        y_preds = np.random.randint(0, 100, size=y_reals.shape)

    nmi_score = nmi(y_preds, y_reals)
    purity = purity_score(y_preds, y_reals)
    print('nmi', nmi_score, 'purity', purity)
Exemple #7
0
# Shorthands
nlabels = config['data']['nlabels']
out_dir = config['training']['out_dir']
batch_size = config['test']['batch_size']
sample_size = config['test']['sample_size']
sample_nrow = config['test']['sample_nrow']
fid_sample_size = config['training']['fid_sample_size']
checkpoint_dir = path.join(out_dir, 'pretrained')
img_dir = path.join(out_dir, 'img')

# Creat missing directories
if not path.exists(img_dir):
    os.makedirs(img_dir)

# Logger
checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)

device = torch.device("cuda:0" if is_cuda else "cpu")

generator, discriminator = build_models(config)

# Put models on gpu if needed
generator = generator.to(device)
discriminator = discriminator.to(device)

# Use multiple GPUs if possible
generator = nn.DataParallel(generator)
discriminator = nn.DataParallel(discriminator)

# Register modules to checkpoint
checkpoint_io.register_modules(
Exemple #8
0
for epoch_id in range(80):
    for model in all_models:
        model_name = "/home/kunxu/Workspace/GAN_PID/{}".format(model)
        key_name = model_name
        if key_name not in all_results:
            all_results[key_name] = []

        config = load_config(os.path.join(model_name, "config.yaml"),
                             'configs/default.yaml')
        generator, discriminator = build_models(config)
        generator = torch.nn.DataParallel(generator)
        zdist = get_zdist(config['z_dist']['type'],
                          config['z_dist']['dim'],
                          device=device)
        ydist = get_ydist(1, device=device)
        checkpoint_io = CheckpointIO(checkpoint_dir="./tmp")
        checkpoint_io.register_modules(generator_test=generator)
        evaluator = Evaluator(generator,
                              zdist,
                              ydist,
                              batch_size=100,
                              device=device)

        ckptpath = os.path.join(
            model_name, "chkpts",
            "model_{:08d}.pt".format(epoch_id * 10000 + 9999))
        print(ckptpath)
        load_dict = checkpoint_io.load(ckptpath)
        img_list = []
        for i in range(500):
            ztest = zdist.sample((100, ))
Exemple #9
0
out_dir = config['training']['out_dir']
batch_size = config['test']['batch_size']
sample_size = config['test']['sample_size']
sample_nrow = config['test']['sample_nrow']
checkpoint_dir = path.join(out_dir, 'chkpts')
img_dir = path.join(out_dir, 'test', 'img')
img_all_dir = path.join(out_dir, 'test', 'img_all')

# Creat missing directories
if not path.exists(img_dir):
    os.makedirs(img_dir)
if not path.exists(img_all_dir):
    os.makedirs(img_all_dir)

# Logger
checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)

# Get model file
model_file = config['test']['model_file']

# Models
device = torch.device("cuda:0" if is_cuda else "cpu")

generator, discriminator = build_models(config)
print(generator)
print(discriminator)

# Put models on gpu if needed
generator = generator.to(device)
discriminator = discriminator.to(device)
Exemple #10
0
def main():
    checkpoint_dir = os.path.join(out_dir, 'chkpts')

    most_recent = utils.get_most_recent(
        checkpoint_dir, 'model') if args.model_it is None else args.model_it

    cluster_path = os.path.join(out_dir, 'clusters')
    print('Saving clusters/samples to', cluster_path)

    os.makedirs(cluster_path, exist_ok=True)

    shutil.copyfile('seeing/lightbox.html',
                    os.path.join(cluster_path, '+lightbox.html'))

    checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)

    most_recent = utils.get_most_recent(
        checkpoint_dir, 'model') if args.model_it is None else args.model_it
    clusterer = checkpoint_io.load_clusterer(most_recent,
                                             pretrained=config['pretrained'],
                                             load_samples=False)

    if isinstance(clusterer.discriminator, nn.DataParallel):
        clusterer.discriminator = clusterer.discriminator.module

    model_path = os.path.join(checkpoint_dir, 'model_%08d.pt' % most_recent)
    sampler = SeededSampler(args.config,
                            model_path=model_path,
                            clusterer_path=os.path.join(
                                checkpoint_dir, f'clusterer{most_recent}.pkl'),
                            pretrained=config['pretrained'])

    if args.show_clusters:
        clusters = [[] for _ in range(config['generator']['nlabels'])]
        train_dataset, _ = get_dataset(
            name='webp' if 'cifar' not in config['data']['train_dir'].lower()
            else 'cifar10',
            data_dir=config['data']['train_dir'],
            size=config['data']['img_size'])

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=config['training']['batch_size'],
            num_workers=config['training']['nworkers'],
            shuffle=True,
            pin_memory=True,
            sampler=None,
            drop_last=True)

        print('Generating clusters')
        for batch_num, (x_real, y_gt) in enumerate(train_loader):
            x_real = x_real.cuda()
            y_pred = clusterer.get_labels(x_real, y_gt)

            for i, yi in enumerate(y_pred):
                clusters[yi].append(x_real[i].cpu())

            # don't generate too many, we're only visualizing 20 per cluster
            if batch_num * config['training']['batch_size'] >= 10000:
                break
    else:
        clusters = [None] * config['generator']['nlabels']

    nimgs = 20
    nrows = 4

    for i in range(len(clusters)):
        if clusters[i] is None:
            pass
        elif len(clusters[i]) >= nimgs:
            cluster = torch.stack(clusters[i])[:nimgs]

            torchvision.utils.save_image(cluster * 0.5 + 0.5,
                                         os.path.join(cluster_path,
                                                      f'{i}_real.png'),
                                         nrow=nrows)
        generated = []
        for seed in range(nimgs):
            img = sampler.conditional_sample(i, seed=seed)
            generated.append(img.detach().cpu())
        generated = torch.cat(generated)

        torchvision.utils.save_image(generated * 0.5 + 0.5,
                                     os.path.join(cluster_path,
                                                  f'{i}_gen.png'),
                                     nrow=nrows)

    print('Clusters/samples can be visualized under', cluster_path)
    def get_generator(self):
        ''' loads a generator according to self.model_path '''

        exp_out_dir = os.path.join(self.rootdir,
                                   self.config['training']['out_dir'])
        # infer checkpoint if neeeded
        checkpoint_dir = os.path.join(
            exp_out_dir, 'chkpts'
        ) if self.model_path == "" or 'model' in self.pretrained else "./"
        model_name = get_most_recent(os.listdir(
            checkpoint_dir)) if self.model_path == "" else self.model_path

        checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)
        self.checkpoint_io = checkpoint_io

        generator, _ = build_models(self.config)
        generator = generator.to(self.device)
        generator = nn.DataParallel(generator)

        if self.config['training']['take_model_average']:
            generator_test = copy.deepcopy(generator)
            checkpoint_io.register_modules(generator_test=generator_test)
        else:
            generator_test = generator

        checkpoint_io.register_modules(generator=generator)

        try:
            it = checkpoint_io.load(model_name, pretrained=self.pretrained)
            assert (it != -1)
        except Exception as e:
            # try again without data parallel
            print(e)
            checkpoint_io.register_modules(generator=generator.module)
            checkpoint_io.register_modules(
                generator_test=generator_test.module)
            it = checkpoint_io.load(model_name, pretrained=self.pretrained)
            assert (it != -1)

        print('Loaded iteration:', it['it'])
        return generator_test
Exemple #12
0
def perform_evaluation(run_name, image_type):

    out_dir = os.path.join(os.getcwd(), '..', 'output', run_name)
    checkpoint_dir = os.path.join(out_dir, 'chkpts')
    checkpoints = sorted(glob.glob(os.path.join(checkpoint_dir, '*')))
    evaluation_dict = {}

    for point in checkpoints:
        if not int(
                point.split('/')[-1].split('_')[1].split('.')[0]) % 10000 == 0:
            continue

        iter_num = int(point.split('/')[-1].split('_')[1].split('.')[0])
        model_file = point.split('/')[-1]

        config = load_config('../configs/fr_default.yaml', None)
        is_cuda = (torch.cuda.is_available())
        checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)
        device = torch.device("cuda:0" if is_cuda else "cpu")

        generator, discriminator = build_models(config)

        # Put models on gpu if needed
        generator = generator.to(device)
        discriminator = discriminator.to(device)

        # Use multiple GPUs if possible
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)

        generator_test_9 = copy.deepcopy(generator)
        generator_test_99 = copy.deepcopy(generator)
        generator_test_999 = copy.deepcopy(generator)
        generator_test_9999 = copy.deepcopy(generator)

        # Register modules to checkpoint
        checkpoint_io.register_modules(
            generator=generator,
            generator_test_9=generator_test_9,
            generator_test_99=generator_test_99,
            generator_test_999=generator_test_999,
            generator_test_9999=generator_test_9999,
            discriminator=discriminator,
        )

        # Load checkpoint
        load_dict = checkpoint_io.load(model_file)

        # Distributions
        ydist = get_ydist(config['data']['nlabels'], device=device)
        zdist = get_zdist(config['z_dist']['type'],
                          config['z_dist']['dim'],
                          device=device)
        z_sample = torch.Tensor(np.load('z_data.npy')).to(device)

        #for name, model in zip(['0_', '09_', '099_', '0999_', '09999_'], [generator, generator_test_9, generator_test_99, generator_test_999, generator_test_9999]):
        for name, model in zip(
            ['099_', '0999_', '09999_'],
            [generator_test_99, generator_test_999, generator_test_9999]):

            # Evaluator
            evaluator = Evaluator(model, zdist, ydist, device=device)

            x_sample = []

            for i in range(10):
                x = evaluator.create_samples(z_sample[i * 1000:(i + 1) * 1000])
                x_sample.append(x)

            x_sample = torch.cat(x_sample)
            x_sample = x_sample / 2 + 0.5

            if not os.path.exists('fake_data'):
                os.makedirs('fake_data')

            for i in range(10000):
                torchvision.utils.save_image(x_sample[i, :, :, :],
                                             'fake_data/{}.png'.format(i))

            fid_score = calculate_fid_given_paths(
                ['fake_data', image_type + '_real'], 50, True, 2048)
            print(iter_num, name, fid_score)

            os.system("rm -rf " + "fake_data")

            evaluation_dict[(iter_num, name[:-1])] = {'FID': fid_score}

            if not os.path.exists('evaluation_data/' + run_name):
                os.makedirs('evaluation_data/' + run_name)

            pickle.dump(
                evaluation_dict,
                open('evaluation_data/' + run_name + '/eval_fid.p', 'wb'))