backup_every = config['training']['backup_every'] sample_nlabels = config['training']['sample_nlabels'] dim_z = config['z_dist']['dim'] out_dir = config['training']['out_dir'] checkpoint_dir = path.join(out_dir, 'chkpts') # Create missing directories if not path.exists(out_dir): os.makedirs(out_dir) if not path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) shutil.copyfile(sys.argv[0], out_dir + '/training_script.py') # Logger checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Dataset train_dataset, nlabels = get_dataset( name=config['data']['type'], data_dir=config['data']['train_dir'], size=config['data']['img_size'], lsun_categories=config['data']['lsun_categories_train']) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, num_workers=config['training']['nworkers'], shuffle=True, pin_memory=True,
# Shorthands nlabels = config['data']['nlabels'] batch_size = config['test']['batch_size'] sample_size = config['test']['sample_size'] sample_nrow = config['test']['sample_nrow'] out_dir = get_out_dir(config) checkpoint_dir = path.join(out_dir, 'chkpts') out_dir = path.join(out_dir, 'test') # Creat missing directories os.makedirs(out_dir, exist_ok=True) # Logger checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir) device = torch.device("cuda:0" if is_cuda else "cpu") # Disable parameter transforms for testing config['param_transforms'] = 'none' # Create models generator, discriminator = build_models(config) print(generator) print(discriminator) # Put models on gpu if needed generator = generator.to(device) discriminator = discriminator.to(device)
save_every = config['training']['save_every'] backup_every = config['training']['backup_every'] sample_nlabels = config['training']['sample_nlabels'] out_dir = config['training']['out_dir'] checkpoint_dir = path.join(out_dir, 'chkpts') # Create missing directories if not path.exists(out_dir): os.makedirs(out_dir) if not path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) # Logger checkpoint_io = CheckpointIO( checkpoint_dir=checkpoint_dir ) device = torch.device("cuda:0" if is_cuda else "cpu") # Dataset train_dataset, nlabels = get_dataset( name=config['data']['type'], data_dir=config['data']['train_dir'], size=config['data']['img_size'], lsun_categories=config['data']['lsun_categories_train'] ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size,
def main(): pp = pprint.PrettyPrinter(indent=1) pp.pprint({ 'data': config['data'], 'generator': config['generator'], 'discriminator': config['discriminator'], 'clusterer': config['clusterer'], 'training': config['training'] }) is_cuda = torch.cuda.is_available() # Short hands batch_size = config['training']['batch_size'] log_every = config['training']['log_every'] inception_every = config['training']['inception_every'] backup_every = config['training']['backup_every'] sample_nlabels = config['training']['sample_nlabels'] nlabels = config['data']['nlabels'] sample_nlabels = min(nlabels, sample_nlabels) checkpoint_dir = path.join(out_dir, 'chkpts') nepochs = args.nepochs # Create missing directories if not path.exists(out_dir): os.makedirs(out_dir) if not path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) # Logger checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir) device = torch.device("cuda:0" if is_cuda else "cpu") train_dataset, _ = get_dataset( name=config['data']['type'], data_dir=config['data']['train_dir'], size=config['data']['img_size'], deterministic=config['data']['deterministic']) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, num_workers=config['training']['nworkers'], shuffle=True, pin_memory=True, sampler=None, drop_last=True) # Create models generator, discriminator = build_models(config) # Put models on gpu if needed generator = generator.to(device) discriminator = discriminator.to(device) for name, module in discriminator.named_modules(): if isinstance(module, nn.Sigmoid): print('Found sigmoid layer in discriminator; not compatible with BCE with logits') exit() g_optimizer, d_optimizer = build_optimizers(generator, discriminator, config) devices = [int(x) for x in args.devices] generator = nn.DataParallel(generator, device_ids=devices) discriminator = nn.DataParallel(discriminator, device_ids=devices) # Register modules to checkpoint checkpoint_io.register_modules(generator=generator, discriminator=discriminator, g_optimizer=g_optimizer, d_optimizer=d_optimizer) # Logger logger = Logger(log_dir=path.join(out_dir, 'logs'), img_dir=path.join(out_dir, 'imgs'), monitoring=config['training']['monitoring'], monitoring_dir=path.join(out_dir, 'monitoring')) # Distributions ydist = get_ydist(nlabels, device=device) zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'], device=device) ntest = config['training']['ntest'] x_test, y_test = utils.get_nsamples(train_loader, ntest) x_cluster, y_cluster = utils.get_nsamples(train_loader, config['clusterer']['nimgs']) x_test, y_test = x_test.to(device), y_test.to(device) z_test = zdist.sample((ntest, )) utils.save_images(x_test, path.join(out_dir, 'real.png')) logger.add_imgs(x_test, 'gt', 0) # Test generator if config['training']['take_model_average']: print('Taking model average') bad_modules = [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d] for model in [generator, discriminator]: for name, module in model.named_modules(): for bad_module in bad_modules: if isinstance(module, bad_module): print('Batch norm in discriminator not compatible with exponential moving average') exit() generator_test = copy.deepcopy(generator) checkpoint_io.register_modules(generator_test=generator_test) else: generator_test = generator clusterer = get_clusterer(config)(discriminator=discriminator, x_cluster=x_cluster, x_labels=y_cluster, gt_nlabels=config['data']['nlabels'], **config['clusterer']['kwargs']) # Load checkpoint if it exists it = utils.get_most_recent(checkpoint_dir, 'model') if args.model_it == -1 else args.model_it it, epoch_idx, loaded_clusterer = checkpoint_io.load_models(it=it, load_samples='supervised' != config['clusterer']['name']) if loaded_clusterer is None: print('Initializing new clusterer. The first clustering can be quite slow.') clusterer.recluster(discriminator=discriminator) checkpoint_io.save_clusterer(clusterer, it=0) np.savez(os.path.join(checkpoint_dir, 'cluster_samples.npz'), x=x_cluster) else: print('Using loaded clusterer') clusterer = loaded_clusterer # Evaluator evaluator = Evaluator( generator_test, zdist, ydist, train_loader=train_loader, clusterer=clusterer, batch_size=batch_size, device=device, inception_nsamples=config['training']['inception_nsamples']) # Trainer trainer = Trainer(generator, discriminator, g_optimizer, d_optimizer, gan_type=config['training']['gan_type'], reg_type=config['training']['reg_type'], reg_param=config['training']['reg_param']) # Training loop print('Start training...') while it < args.nepochs * len(train_loader): epoch_idx += 1 for x_real, y in train_loader: it += 1 x_real, y = x_real.to(device), y.to(device) z = zdist.sample((batch_size, )) y = clusterer.get_labels(x_real, y).to(device) # Discriminator updates dloss, reg = trainer.discriminator_trainstep(x_real, y, z) logger.add('losses', 'discriminator', dloss, it=it) logger.add('losses', 'regularizer', reg, it=it) # Generators updates gloss = trainer.generator_trainstep(y, z) logger.add('losses', 'generator', gloss, it=it) if config['training']['take_model_average']: update_average(generator_test, generator, beta=config['training']['model_average_beta']) # Print stats if it % log_every == 0: g_loss_last = logger.get_last('losses', 'generator') d_loss_last = logger.get_last('losses', 'discriminator') d_reg_last = logger.get_last('losses', 'regularizer') print('[epoch %0d, it %4d] g_loss = %.4f, d_loss = %.4f, reg=%.4f' % (epoch_idx, it, g_loss_last, d_loss_last, d_reg_last)) if it % config['training']['recluster_every'] == 0 and it > config['training']['burnin_time']: # print cluster distribution for online methods if it % 100 == 0 and config['training']['recluster_every'] <= 100: print(f'[epoch {epoch_idx}, it {it}], distribution: {clusterer.get_label_distribution(x_real)}') clusterer.recluster(discriminator=discriminator, x_batch=x_real) # (i) Sample if necessary if it % config['training']['sample_every'] == 0: print('Creating samples...') x = evaluator.create_samples(z_test, y_test) x = evaluator.create_samples(z_test, clusterer.get_labels(x_test, y_test).to(device)) logger.add_imgs(x, 'all', it) for y_inst in range(sample_nlabels): x = evaluator.create_samples(z_test, y_inst) logger.add_imgs(x, '%04d' % y_inst, it) # (ii) Compute inception if necessary if it % inception_every == 0 and it > 0: print('PyTorch Inception score...') inception_mean, inception_std = evaluator.compute_inception_score() logger.add('metrics', 'pt_inception_mean', inception_mean, it=it) logger.add('metrics', 'pt_inception_stddev', inception_std, it=it) print(f'[epoch {epoch_idx}, it {it}] pt_inception_mean: {inception_mean}, pt_inception_stddev: {inception_std}') # (iii) Backup if necessary if it % backup_every == 0: print('Saving backup...') checkpoint_io.save('model_%08d.pt' % it, it=it) checkpoint_io.save_clusterer(clusterer, int(it)) logger.save_stats('stats_%08d.p' % it) if it > 0: checkpoint_io.save('model.pt', it=it)
out_dir = get_out_dir(config) checkpoint_dir = path.join(out_dir, 'chkpts') # Create missing directories if not path.exists(out_dir): os.makedirs(out_dir) if not path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) # Save config for this run save_config(os.path.join(out_dir, 'config.yaml'), config) # Logger checkpoint_io = CheckpointIO( checkpoint_dir=checkpoint_dir ) device = torch.device("cuda:0" if is_cuda else "cpu") # Dataset train_loader = get_dataloader_stamps(config, split='train') # Create models generator, discriminator = build_models(config) print(generator) print(discriminator) # Put models on gpu if needed generator = generator.to(device) discriminator = discriminator.to(device)
def main(): checkpoint_dir = os.path.join(out_dir, 'chkpts') batch_size = config['training']['batch_size'] if 'cifar' in config['data']['train_dir'].lower(): name = 'cifar10' elif 'stacked_mnist' == config['data']['type']: name = 'stacked_mnist' else: name = 'image' if os.path.exists(os.path.join(out_dir, 'cluster_preds.npz')): # if we've already computed assignments, load them and move on with np.load(os.path.join(out_dir, 'cluster_preds.npz')) as f: y_reals = f['y_reals'] y_preds = f['y_preds'] else: train_dataset, _ = get_dataset(name=name, data_dir=config['data']['train_dir'], size=config['data']['img_size']) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, num_workers=config['training']['nworkers'], shuffle=True, pin_memory=True, sampler=None, drop_last=True) checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir) print('Loading clusterer:') most_recent = utils.get_most_recent( checkpoint_dir, 'model') if args.model_it is None else args.model_it clusterer = checkpoint_io.load_clusterer( most_recent, load_samples=False, pretrained=config['pretrained']) if isinstance(clusterer.discriminator, nn.DataParallel): clusterer.discriminator = clusterer.discriminator.module y_preds = [] y_reals = [] for batch_num, (x_real, y_real) in enumerate( tqdm(train_loader, total=len(train_loader))): y_pred = clusterer.get_labels(x_real.cuda(), None) y_preds.append(y_pred.detach().cpu()) y_reals.append(y_real) y_reals = torch.cat(y_reals).numpy() y_preds = torch.cat(y_preds).numpy() np.savez(os.path.join(out_dir, 'cluster_preds.npz'), y_reals=y_reals, y_preds=y_preds) if args.random: y_preds = np.random.randint(0, 100, size=y_reals.shape) nmi_score = nmi(y_preds, y_reals) purity = purity_score(y_preds, y_reals) print('nmi', nmi_score, 'purity', purity)
# Shorthands nlabels = config['data']['nlabels'] out_dir = config['training']['out_dir'] batch_size = config['test']['batch_size'] sample_size = config['test']['sample_size'] sample_nrow = config['test']['sample_nrow'] fid_sample_size = config['training']['fid_sample_size'] checkpoint_dir = path.join(out_dir, 'pretrained') img_dir = path.join(out_dir, 'img') # Creat missing directories if not path.exists(img_dir): os.makedirs(img_dir) # Logger checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir) device = torch.device("cuda:0" if is_cuda else "cpu") generator, discriminator = build_models(config) # Put models on gpu if needed generator = generator.to(device) discriminator = discriminator.to(device) # Use multiple GPUs if possible generator = nn.DataParallel(generator) discriminator = nn.DataParallel(discriminator) # Register modules to checkpoint checkpoint_io.register_modules(
for epoch_id in range(80): for model in all_models: model_name = "/home/kunxu/Workspace/GAN_PID/{}".format(model) key_name = model_name if key_name not in all_results: all_results[key_name] = [] config = load_config(os.path.join(model_name, "config.yaml"), 'configs/default.yaml') generator, discriminator = build_models(config) generator = torch.nn.DataParallel(generator) zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'], device=device) ydist = get_ydist(1, device=device) checkpoint_io = CheckpointIO(checkpoint_dir="./tmp") checkpoint_io.register_modules(generator_test=generator) evaluator = Evaluator(generator, zdist, ydist, batch_size=100, device=device) ckptpath = os.path.join( model_name, "chkpts", "model_{:08d}.pt".format(epoch_id * 10000 + 9999)) print(ckptpath) load_dict = checkpoint_io.load(ckptpath) img_list = [] for i in range(500): ztest = zdist.sample((100, ))
out_dir = config['training']['out_dir'] batch_size = config['test']['batch_size'] sample_size = config['test']['sample_size'] sample_nrow = config['test']['sample_nrow'] checkpoint_dir = path.join(out_dir, 'chkpts') img_dir = path.join(out_dir, 'test', 'img') img_all_dir = path.join(out_dir, 'test', 'img_all') # Creat missing directories if not path.exists(img_dir): os.makedirs(img_dir) if not path.exists(img_all_dir): os.makedirs(img_all_dir) # Logger checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir) # Get model file model_file = config['test']['model_file'] # Models device = torch.device("cuda:0" if is_cuda else "cpu") generator, discriminator = build_models(config) print(generator) print(discriminator) # Put models on gpu if needed generator = generator.to(device) discriminator = discriminator.to(device)
def main(): checkpoint_dir = os.path.join(out_dir, 'chkpts') most_recent = utils.get_most_recent( checkpoint_dir, 'model') if args.model_it is None else args.model_it cluster_path = os.path.join(out_dir, 'clusters') print('Saving clusters/samples to', cluster_path) os.makedirs(cluster_path, exist_ok=True) shutil.copyfile('seeing/lightbox.html', os.path.join(cluster_path, '+lightbox.html')) checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir) most_recent = utils.get_most_recent( checkpoint_dir, 'model') if args.model_it is None else args.model_it clusterer = checkpoint_io.load_clusterer(most_recent, pretrained=config['pretrained'], load_samples=False) if isinstance(clusterer.discriminator, nn.DataParallel): clusterer.discriminator = clusterer.discriminator.module model_path = os.path.join(checkpoint_dir, 'model_%08d.pt' % most_recent) sampler = SeededSampler(args.config, model_path=model_path, clusterer_path=os.path.join( checkpoint_dir, f'clusterer{most_recent}.pkl'), pretrained=config['pretrained']) if args.show_clusters: clusters = [[] for _ in range(config['generator']['nlabels'])] train_dataset, _ = get_dataset( name='webp' if 'cifar' not in config['data']['train_dir'].lower() else 'cifar10', data_dir=config['data']['train_dir'], size=config['data']['img_size']) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config['training']['batch_size'], num_workers=config['training']['nworkers'], shuffle=True, pin_memory=True, sampler=None, drop_last=True) print('Generating clusters') for batch_num, (x_real, y_gt) in enumerate(train_loader): x_real = x_real.cuda() y_pred = clusterer.get_labels(x_real, y_gt) for i, yi in enumerate(y_pred): clusters[yi].append(x_real[i].cpu()) # don't generate too many, we're only visualizing 20 per cluster if batch_num * config['training']['batch_size'] >= 10000: break else: clusters = [None] * config['generator']['nlabels'] nimgs = 20 nrows = 4 for i in range(len(clusters)): if clusters[i] is None: pass elif len(clusters[i]) >= nimgs: cluster = torch.stack(clusters[i])[:nimgs] torchvision.utils.save_image(cluster * 0.5 + 0.5, os.path.join(cluster_path, f'{i}_real.png'), nrow=nrows) generated = [] for seed in range(nimgs): img = sampler.conditional_sample(i, seed=seed) generated.append(img.detach().cpu()) generated = torch.cat(generated) torchvision.utils.save_image(generated * 0.5 + 0.5, os.path.join(cluster_path, f'{i}_gen.png'), nrow=nrows) print('Clusters/samples can be visualized under', cluster_path)
def get_generator(self): ''' loads a generator according to self.model_path ''' exp_out_dir = os.path.join(self.rootdir, self.config['training']['out_dir']) # infer checkpoint if neeeded checkpoint_dir = os.path.join( exp_out_dir, 'chkpts' ) if self.model_path == "" or 'model' in self.pretrained else "./" model_name = get_most_recent(os.listdir( checkpoint_dir)) if self.model_path == "" else self.model_path checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir) self.checkpoint_io = checkpoint_io generator, _ = build_models(self.config) generator = generator.to(self.device) generator = nn.DataParallel(generator) if self.config['training']['take_model_average']: generator_test = copy.deepcopy(generator) checkpoint_io.register_modules(generator_test=generator_test) else: generator_test = generator checkpoint_io.register_modules(generator=generator) try: it = checkpoint_io.load(model_name, pretrained=self.pretrained) assert (it != -1) except Exception as e: # try again without data parallel print(e) checkpoint_io.register_modules(generator=generator.module) checkpoint_io.register_modules( generator_test=generator_test.module) it = checkpoint_io.load(model_name, pretrained=self.pretrained) assert (it != -1) print('Loaded iteration:', it['it']) return generator_test
def perform_evaluation(run_name, image_type): out_dir = os.path.join(os.getcwd(), '..', 'output', run_name) checkpoint_dir = os.path.join(out_dir, 'chkpts') checkpoints = sorted(glob.glob(os.path.join(checkpoint_dir, '*'))) evaluation_dict = {} for point in checkpoints: if not int( point.split('/')[-1].split('_')[1].split('.')[0]) % 10000 == 0: continue iter_num = int(point.split('/')[-1].split('_')[1].split('.')[0]) model_file = point.split('/')[-1] config = load_config('../configs/fr_default.yaml', None) is_cuda = (torch.cuda.is_available()) checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir) device = torch.device("cuda:0" if is_cuda else "cpu") generator, discriminator = build_models(config) # Put models on gpu if needed generator = generator.to(device) discriminator = discriminator.to(device) # Use multiple GPUs if possible generator = nn.DataParallel(generator) discriminator = nn.DataParallel(discriminator) generator_test_9 = copy.deepcopy(generator) generator_test_99 = copy.deepcopy(generator) generator_test_999 = copy.deepcopy(generator) generator_test_9999 = copy.deepcopy(generator) # Register modules to checkpoint checkpoint_io.register_modules( generator=generator, generator_test_9=generator_test_9, generator_test_99=generator_test_99, generator_test_999=generator_test_999, generator_test_9999=generator_test_9999, discriminator=discriminator, ) # Load checkpoint load_dict = checkpoint_io.load(model_file) # Distributions ydist = get_ydist(config['data']['nlabels'], device=device) zdist = get_zdist(config['z_dist']['type'], config['z_dist']['dim'], device=device) z_sample = torch.Tensor(np.load('z_data.npy')).to(device) #for name, model in zip(['0_', '09_', '099_', '0999_', '09999_'], [generator, generator_test_9, generator_test_99, generator_test_999, generator_test_9999]): for name, model in zip( ['099_', '0999_', '09999_'], [generator_test_99, generator_test_999, generator_test_9999]): # Evaluator evaluator = Evaluator(model, zdist, ydist, device=device) x_sample = [] for i in range(10): x = evaluator.create_samples(z_sample[i * 1000:(i + 1) * 1000]) x_sample.append(x) x_sample = torch.cat(x_sample) x_sample = x_sample / 2 + 0.5 if not os.path.exists('fake_data'): os.makedirs('fake_data') for i in range(10000): torchvision.utils.save_image(x_sample[i, :, :, :], 'fake_data/{}.png'.format(i)) fid_score = calculate_fid_given_paths( ['fake_data', image_type + '_real'], 50, True, 2048) print(iter_num, name, fid_score) os.system("rm -rf " + "fake_data") evaluation_dict[(iter_num, name[:-1])] = {'FID': fid_score} if not os.path.exists('evaluation_data/' + run_name): os.makedirs('evaluation_data/' + run_name) pickle.dump( evaluation_dict, open('evaluation_data/' + run_name + '/eval_fid.p', 'wb'))