def compute_inception_score(n): all_samples = [] for i in range(int(n / 100)): all_samples.append(session.run(samples_100)) all_samples = np.concatenate(all_samples, axis=0) all_samples = all_samples.reshape((-1, 3, 32, 32)) all_samples = scale_value(all_samples, [-1.0, 1.0]) print(all_samples.shape) return get_inception_score(all_samples)
def compute_metric(generator_model): global best_icp sample_size = 20000 noise = np.random.normal(size=(sample_size, 100)) art_images = generator_model.predict(noise) art_images = scale_value(art_images, [-1.0, 1.0]) art_images = np.transpose(art_images, (0, 3, 1, 2)) (icp_mean, icp_std) = get_inception_score(art_images) if icp_mean > best_icp: best_icp = icp_mean print('Inception score: ', icp_mean)
def test(target_vars, saver, sess, logger, dataloader): X_NOISE = target_vars['X_NOISE'] X = target_vars['X'] Y = target_vars['Y'] LABEL = target_vars['LABEL'] energy_start = target_vars['energy_start'] x_mod = target_vars['x_mod'] x_mod = target_vars['test_x_mod'] energy_neg = target_vars['energy_neg'] np.random.seed(1) random.seed(1) output = [x_mod, energy_start, energy_neg] dataloader_iterator = iter(dataloader) data_corrupt, data, label = next(dataloader_iterator) data_corrupt, data, label = data_corrupt.numpy(), data.numpy( ), label.numpy() orig_im = try_im = data_corrupt if FLAGS.cclass: try_im, energy_orig, energy = sess.run(output, { X_NOISE: orig_im, Y: label[0:1], LABEL: label }) else: try_im, energy_orig, energy = sess.run(output, { X_NOISE: orig_im, Y: label[0:1] }) orig_im = rescale_im(orig_im) try_im = rescale_im(try_im) actual_im = rescale_im(data) for i, (im, energy_i, t_im, energy, label_i, actual_im_i) in enumerate( zip(orig_im, energy_orig, try_im, energy, label, actual_im)): label_i = np.array(label_i) shape = im.shape[1:] new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:])) size = shape[1] new_im[:, :size] = im new_im[:, size:2 * size] = t_im if FLAGS.cclass: label_i = np.where(label_i == 1)[0][0] if FLAGS.dataset == 'cifar10': log_image(new_im, logger, '{}_{:.4f}_now_{:.4f}_{}'.format( i, energy_i[0], energy[0], cifar10_map[label_i]), step=i) else: log_image(new_im, logger, '{}_{:.4f}_now_{:.4f}_{}'.format( i, energy_i[0], energy[0], label_i), step=i) else: log_image(new_im, logger, '{}_{:.4f}_now_{:.4f}'.format(i, energy_i[0], energy[0]), step=i) test_ims = list(try_im) real_ims = list(actual_im) for i in tqdm(range(50000 // FLAGS.batch_size + 1)): try: data_corrupt, data, label = dataloader_iterator.next() except BaseException: dataloader_iterator = iter(dataloader) data_corrupt, data, label = dataloader_iterator.next() data_corrupt, data, label = data_corrupt.numpy(), data.numpy( ), label.numpy() if FLAGS.cclass: try_im, energy_orig, energy = sess.run(output, { X_NOISE: data_corrupt, Y: label[0:1], LABEL: label }) else: try_im, energy_orig, energy = sess.run(output, { X_NOISE: data_corrupt, Y: label[0:1] }) try_im = rescale_im(try_im) real_im = rescale_im(data) test_ims.extend(list(try_im)) real_ims.extend(list(real_im)) score, std = get_inception_score(test_ims) print("!!!Inception score of {} with std of {}".format(score, std))
def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir): X = target_vars['X'] Y = target_vars['Y'] X_NOISE = target_vars['X_NOISE'] train_op = target_vars['train_op'] energy_pos = target_vars['energy_pos'] energy_neg = target_vars['energy_neg'] loss_energy = target_vars['loss_energy'] loss_ml = target_vars['loss_ml'] loss_total = target_vars['total_loss'] gvs = target_vars['gvs'] x_grad = target_vars['x_grad'] x_grad_first = target_vars['x_grad_first'] x_off = target_vars['x_off'] temp = target_vars['temp'] x_mod = target_vars['x_mod'] LABEL = target_vars['LABEL'] LABEL_POS = target_vars['LABEL_POS'] weights = target_vars['weights'] test_x_mod = target_vars['test_x_mod'] eps = target_vars['eps_begin'] label_ent = target_vars['label_ent'] if FLAGS.use_attention: gamma = weights[0]['atten']['gamma'] else: gamma = tf.zeros(1) val_output = [test_x_mod] gvs_dict = dict(gvs) log_output = [ train_op, energy_pos, energy_neg, eps, loss_energy, loss_ml, loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent, *gvs_dict.keys() ] output = [train_op, x_mod] replay_buffer = ReplayBuffer(10000) itr = resume_iter x_mod = None gd_steps = 1 dataloader_iterator = iter(dataloader) best_inception = 0.0 for epoch in range(FLAGS.epoch_num): print("Training epoch:%d" % epoch) for data_corrupt, data, label in dataloader: data_corrupt = data_corrupt_init = data_corrupt.numpy() data_corrupt_init = data_corrupt.copy() data = data.numpy() label = label.numpy() label_init = label.copy() if FLAGS.mixup: idx = np.random.permutation(data.shape[0]) lam = np.random.beta(1, 1, size=(data.shape[0], 1, 1, 1)) data = data * lam + data[idx] * (1 - lam) if FLAGS.replay_batch and (x_mod is not None): replay_buffer.add(compress_x_mod(x_mod)) if len(replay_buffer) > FLAGS.batch_size: replay_batch = replay_buffer.sample(FLAGS.batch_size) replay_batch = decompress_x_mod(replay_batch) replay_mask = (np.random.uniform(0, FLAGS.rescale, FLAGS.batch_size) > 0.05) data_corrupt[replay_mask] = replay_batch[replay_mask] if FLAGS.pcd: if x_mod is not None: data_corrupt = x_mod feed_dict = {X_NOISE: data_corrupt, X: data, Y: label} if FLAGS.cclass: feed_dict[LABEL] = label feed_dict[LABEL_POS] = label_init if itr % FLAGS.log_interval == 0: _, e_pos, e_neg, eps, loss_e, loss_ml, loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent, * \ grads = sess.run(log_output, feed_dict) kvs = {} kvs['e_pos'] = e_pos.mean() kvs['e_pos_std'] = e_pos.std() kvs['e_neg'] = e_neg.mean() kvs['e_diff'] = kvs['e_pos'] - kvs['e_neg'] kvs['e_neg_std'] = e_neg.std() kvs['temp'] = temp kvs['loss_e'] = loss_e.mean() kvs['eps'] = eps.mean() kvs['label_ent'] = label_ent kvs['loss_ml'] = loss_ml.mean() kvs['loss_total'] = loss_total.mean() kvs['x_grad'] = np.abs(x_grad).mean() kvs['x_grad_first'] = np.abs(x_grad_first).mean() kvs['x_off'] = x_off.mean() kvs['iter'] = itr kvs['gamma'] = gamma for v, k in zip(grads, [v.name for v in gvs_dict.values()]): kvs[k] = np.abs(v).max() string = "Obtained a total of " for key, value in kvs.items(): string += "{}: {}, ".format(key, value) if hvd.rank() == 0: print(string) logger.writekvs(kvs) else: _, x_mod = sess.run(output, feed_dict) if itr % FLAGS.save_interval == 0 and hvd.rank() == 0: saver.save( sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr))) if itr % FLAGS.test_interval == 0 and hvd.rank( ) == 0 and FLAGS.dataset != '2d': try_im = x_mod orig_im = data_corrupt.squeeze() actual_im = rescale_im(data) orig_im = rescale_im(orig_im) try_im = rescale_im(try_im).squeeze() for i, (im, t_im, actual_im_i) in enumerate( zip(orig_im[:20], try_im[:20], actual_im)): shape = orig_im.shape[1:] new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:])) size = shape[1] new_im[:, :size] = im new_im[:, size:2 * size] = t_im new_im[:, 2 * size:] = actual_im_i log_image(new_im, logger, 'train_gen_{}'.format(itr), step=i) test_im = x_mod try: data_corrupt, data, label = next(dataloader_iterator) except BaseException: dataloader_iterator = iter(dataloader) data_corrupt, data, label = next(dataloader_iterator) data_corrupt = data_corrupt.numpy() if FLAGS.replay_batch and ( x_mod is not None) and len(replay_buffer) > 0: replay_batch = replay_buffer.sample(FLAGS.batch_size) replay_batch = decompress_x_mod(replay_batch) replay_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.05) data_corrupt[replay_mask] = replay_batch[replay_mask] if FLAGS.dataset == 'cifar10' or FLAGS.dataset == 'imagenet' or FLAGS.dataset == 'imagenetfull': n = 128 if FLAGS.dataset == "imagenetfull": n = 32 if len(replay_buffer) > n: data_corrupt = decompress_x_mod( replay_buffer.sample(n)) elif FLAGS.dataset == 'imagenetfull': data_corrupt = np.random.uniform( 0, FLAGS.rescale, (n, 128, 128, 3)) else: data_corrupt = np.random.uniform( 0, FLAGS.rescale, (n, 32, 32, 3)) if FLAGS.dataset == 'cifar10': label = np.eye(10)[np.random.randint(0, 10, (n))] else: label = np.eye(1000)[np.random.randint(0, 1000, (n))] feed_dict[X_NOISE] = data_corrupt feed_dict[X] = data if FLAGS.cclass: feed_dict[LABEL] = label test_x_mod = sess.run(val_output, feed_dict) try_im = test_x_mod orig_im = data_corrupt.squeeze() actual_im = rescale_im(data.numpy()) orig_im = rescale_im(orig_im) try_im = rescale_im(try_im).squeeze() for i, (im, t_im, actual_im_i) in enumerate( zip(orig_im[:20], try_im[:20], actual_im)): shape = orig_im.shape[1:] new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:])) size = shape[1] new_im[:, :size] = im new_im[:, size:2 * size] = t_im new_im[:, 2 * size:] = actual_im_i log_image(new_im, logger, 'val_gen_{}'.format(itr), step=i) score, std = get_inception_score(list(try_im), splits=1) print("///Inception score of {} with std of {}".format( score, std)) kvs = {} kvs['inception_score'] = score kvs['inception_score_std'] = std logger.writekvs(kvs) if score > best_inception: best_inception = score saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_best')) if itr > 60000 and FLAGS.dataset == "mnist": assert False itr += 1 print("Training iteration:%d" % itr) saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr)))
def compute_inception(sess, target_vars): X_START = target_vars['X_START'] Y_GT = target_vars['Y_GT'] X_finals = target_vars['X_finals'] NOISE_SCALE = target_vars['NOISE_SCALE'] energy_noise = target_vars['energy_noise'] size = FLAGS.im_number num_steps = size // 1000 images = [] test_ims = [] test_images = [] if FLAGS.dataset == "cifar10": test_dataset = Cifar10(full=True, noise=False) elif FLAGS.dataset == "celeba": dataset = CelebA() elif FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull": test_dataset = Imagenet(train=False) if FLAGS.dataset != "imagenetfull": test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=4, shuffle=True, drop_last=False) else: test_dataloader = TFImagenetLoader('test', FLAGS.batch_size, 0, 1) for data_corrupt, data, label_gt in tqdm(test_dataloader): data = data.numpy() test_ims.extend(list(rescale_im(data))) if FLAGS.dataset == "imagenetfull" and len(test_ims) > 60000: test_ims = test_ims[:60000] break # n = min(len(images), len(test_ims)) print(len(test_ims)) # fid = get_fid_score(test_ims[:30000], test_ims[-30000:]) # print("Base FID of score {}".format(fid)) if FLAGS.dataset == "cifar10": classes = 10 else: classes = 1000 if FLAGS.dataset == "imagenetfull": n = 128 else: n = 32 for j in range(num_steps): itr = int(1000 / 500 * FLAGS.repeat_scale) data_buffer = InceptionReplayBuffer(1000) curr_index = 0 identity = np.eye(classes) test_steps = range(300, itr, 20) for i in tqdm(range(itr)): model_index = curr_index % len(X_finals) x_final = X_finals[model_index] noise_scale = [1] if len(data_buffer) < 1000: x_init = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3)) label = np.random.randint(0, classes, (FLAGS.batch_size)) label = identity[label] x_new = sess.run([x_final], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})[0] data_buffer.add(x_new, label) else: (x_init, label), idx = data_buffer.sample(FLAGS.batch_size) keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99) label_keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9) label_corrupt = np.random.randint(0, classes, (FLAGS.batch_size)) label_corrupt = identity[label_corrupt] x_init_corrupt = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3)) if i < itr - FLAGS.nomix: x_init[keep_mask] = x_init_corrupt[keep_mask] label[label_keep_mask] = label_corrupt[label_keep_mask] # else: # noise_scale = [0.7] x_new, e_noise = sess.run([x_final, energy_noise], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale}) data_buffer.set_elms(idx, x_new, label) curr_index += 1 ims = np.array(data_buffer._storage[:1000]) ims = rescale_im(ims) test_images.extend(list(ims)) saveim = osp.join(FLAGS.logdir, FLAGS.exp, "test{}.png".format(FLAGS.resume_iter)) row = 15 col = 20 ims = ims[:row * col] if FLAGS.dataset != "imagenetfull": im_panel = ims.reshape((row, col, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((32*row, 32*col, 3)) else: im_panel = ims.reshape((row, col, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((128*row, 128*col, 3)) imsave(saveim, im_panel) splits = max(1, len(test_images) // 5000) score, std = get_inception_score(test_images, splits=splits) print("Inception score of {} with std of {}".format(score, std)) # FID score # n = min(len(images), len(test_ims)) fid = get_fid_score(test_images, test_ims) print("FID of score {}".format(fid))
def run(config): # Prepare state dict, which holds things like epoch # and itr # state_dict = { 'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0, 'best_IS': 0, 'best_FID': 999999, 'config': config } # Optionally, get the configuration from the state dict. This allows for # recovery of the config provided only a state dict and experiment name, # and can be convenient for writing less verbose sample shell scripts. if config['config_from_name']: utils.load_weights(None, None, state_dict, config['weights_root'], config['experiment_name'], config['load_weights'], None, strict=False, load_optim=False) # Ignore items which we might want to overwrite from the command line for item in state_dict['config']: if item not in [ 'z_var', 'base_root', 'batch_size', 'G_batch_size', 'use_ema', 'G_eval_mode' ]: config[item] = state_dict['config'][item] # update config (see train.py for explanation) config['resolution'] = utils.imsize_dict[config['dataset']] config['n_classes'] = utils.nclass_dict[config['dataset']] config['n_channels'] = utils.nchannels_dict[config['dataset']] config['G_activation'] = utils.activation_dict[config['G_nl']] config['D_activation'] = utils.activation_dict[config['D_nl']] config = utils.update_config_roots(config) config['skip_init'] = True config['no_optim'] = True device = 'cuda' # Seed RNG # utils.seed_rng(config['seed']) # Setup cudnn.benchmark for free speed torch.backends.cudnn.benchmark = True # Import the model--this line allows us to dynamically select different files. model = __import__(config['model']) experiment_name = (config['experiment_name'] if config['experiment_name'] else utils.name_from_config(config)) print('Experiment name is %s' % experiment_name) G = model.Generator(**config).cuda() utils.count_parameters(G) # In some cases we need to load D if True or config['get_test_error'] or config['get_train_error'] or config[ 'get_self_error'] or config['get_generator_error']: disc_config = config.copy() if config['mh_csc_loss'] or config['mh_loss']: disc_config['output_dim'] = disc_config['n_classes'] + 1 D = model.Discriminator(**disc_config).to(device) def get_n_correct_from_D(x, y): """Gets the "classifications" from D. y: the correct labels In the case of projection discrimination we have to pass in all the labels as conditionings to get the class specific affinity. """ x = x.to(device) if config['model'] == 'BigGAN': # projection discrimination case if not config['get_self_error']: y = y.to(device) yhat = D(x, y) for i in range(1, config['n_classes']): yhat_ = D(x, ((y + i) % config['n_classes'])) yhat = torch.cat([yhat, yhat_], 1) preds_ = yhat.data.max(1)[1].cpu() return preds_.eq(0).cpu().sum() else: # the mh gan case if not config['get_self_error']: y = y.to(device) yhat = D(x) preds_ = yhat[:, :config['n_classes']].data.max(1)[1] return preds_.eq(y.data).cpu().sum() # Load weights print('Loading weights...') # Here is where we deal with the ema--load ema weights or load normal weights utils.load_weights(G if not (config['use_ema']) else None, D, state_dict, config['weights_root'], experiment_name, config['load_weights'], G if config['ema'] and config['use_ema'] else None, strict=False, load_optim=False) # Update batch size setting used for G G_batch_size = max(config['G_batch_size'], config['batch_size']) z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'], device=device, fp16=config['G_fp16'], z_var=config['z_var']) if config['G_eval_mode']: print('Putting G in eval mode..') G.eval() else: print('G is in %s mode...' % ('training' if G.training else 'eval')) sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) brief_expt_name = config['experiment_name'][-30:] # load results dict always HIST_FNAME = 'scoring_hist.npy' def load_or_make_hist(d): """make/load history files in each """ if not os.path.isdir(d): raise Exception('%s is not a valid directory' % d) f = os.path.join(d, HIST_FNAME) if os.path.isfile(f): return np.load(f, allow_pickle=True).item() else: return defaultdict(dict) hist_dir = os.path.join(config['weights_root'], config['experiment_name']) hist = load_or_make_hist(hist_dir) if config['get_test_error'] or config['get_train_error']: loaders = utils.get_data_loaders( **{ **config, 'batch_size': config['batch_size'], 'start_itr': state_dict['itr'], 'use_test_set': config['get_test_error'] }) acc_type = 'Test' if config['get_test_error'] else 'Train' pbar = tqdm(loaders[0]) loader_total = len(loaders[0]) * config['batch_size'] sample_todo = min(config['sample_num_error'], loader_total) print('Getting %s error accross %i examples' % (acc_type, sample_todo)) correct = 0 total = 0 with torch.no_grad(): for i, (x, y) in enumerate(pbar): correct += get_n_correct_from_D(x, y) total += config['batch_size'] if loader_total > total and total >= config['sample_num_error']: print('Quitting early...') break accuracy = float(correct) / float(total) hist = load_or_make_hist(hist_dir) hist[state_dict['itr']][acc_type] = accuracy np.save(os.path.join(hist_dir, HIST_FNAME), hist) print('[%s][%06d] %s accuracy: %f.' % (brief_expt_name, state_dict['itr'], acc_type, accuracy * 100)) if config['get_self_error']: n_used_imgs = config['sample_num_error'] correct = 0 imageSize = config['resolution'] x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8) for l in tqdm(range(n_used_imgs // G_batch_size), desc='Generating [%s][%06d]' % (brief_expt_name, state_dict['itr'])): with torch.no_grad(): images, y = sample() correct += get_n_correct_from_D(images, y) accuracy = float(correct) / float(n_used_imgs) print('[%s][%06d] %s accuracy: %f.' % (brief_expt_name, state_dict['itr'], 'Self', accuracy * 100)) hist = load_or_make_hist(hist_dir) hist[state_dict['itr']]['Self'] = accuracy np.save(os.path.join(hist_dir, HIST_FNAME), hist) if config['get_generator_error']: if config['dataset'] == 'C10': from classification.models.densenet import DenseNet121 from torchvision import transforms compnet = DenseNet121() compnet = torch.nn.DataParallel(compnet) #checkpoint = torch.load(os.path.join('/scratch0/ilya/locDoc/classifiers/densenet121','ckpt_47.t7')) checkpoint = torch.load( os.path.join( '/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/cifar/densenet121', 'ckpt_47.t7')) compnet.load_state_dict(checkpoint['net']) compnet = compnet.to(device) compnet.eval() minimal_trans = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) elif config['dataset'] == 'C100': from classification.models.densenet import DenseNet121 from torchvision import transforms compnet = DenseNet121(num_classes=100) compnet = torch.nn.DataParallel(compnet) checkpoint = torch.load( os.path.join( '/scratch0/ilya/locDoc/classifiers/cifar100/densenet121', 'ckpt.copy.t7')) #checkpoint = torch.load(os.path.join('/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/cifar100/densenet121','ckpt.copy.t7')) compnet.load_state_dict(checkpoint['net']) compnet = compnet.to(device) compnet.eval() minimal_trans = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)), ]) elif config['dataset'] == 'STL48': from classification.models.wideresnet import WideResNet48 from torchvision import transforms checkpoint = torch.load( os.path.join( '/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/stl/mixmatch_48', 'model_best.pth.tar')) compnet = WideResNet48(num_classes=10) compnet = compnet.to(device) for param in compnet.parameters(): param.detach_() compnet.load_state_dict(checkpoint['ema_state_dict']) compnet.eval() minimal_trans = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) else: raise ValueError('Dataset %s has no comparison network.' % config['dataset']) n_used_imgs = 10000 correct = 0 mean_label = np.zeros(config['n_classes']) imageSize = config['resolution'] x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8) for l in tqdm(range(n_used_imgs // G_batch_size), desc='Generating [%s][%06d]' % (brief_expt_name, state_dict['itr'])): with torch.no_grad(): images, y = sample() fake = images.data.cpu().numpy() fake = np.floor((fake + 1) * 255 / 2.0).astype(np.uint8) fake_input = np.zeros(fake.shape) for bi in range(fake.shape[0]): fake_input[bi] = minimal_trans(np.moveaxis( fake[bi], 0, -1)) images.data.copy_(torch.from_numpy(fake_input)) lab = compnet(images).max(1)[1] mean_label += np.bincount(lab.data.cpu(), minlength=config['n_classes']) correct += int((lab == y).sum().cpu()) accuracy = float(correct) / float(n_used_imgs) mean_label_normalized = mean_label / float(n_used_imgs) print( '[%s][%06d] %s accuracy: %f.' % (brief_expt_name, state_dict['itr'], 'Generator', accuracy * 100)) hist = load_or_make_hist(hist_dir) hist[state_dict['itr']]['Generator'] = accuracy hist[state_dict['itr']]['Mean_Label'] = mean_label_normalized np.save(os.path.join(hist_dir, HIST_FNAME), hist) if config['accumulate_stats']: print('Accumulating standing stats across %d accumulations...' % config['num_standing_accumulations']) utils.accumulate_standing_stats(G, z_, y_, config['n_classes'], config['num_standing_accumulations']) # Sample a number of images and save them to an NPZ, for use with TF-Inception if config['sample_npz']: # Lists to hold images and labels for images x, y = [], [] print('Sampling %d images and saving them to npz...' % config['sample_num_npz']) for i in trange( int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))): with torch.no_grad(): images, labels = sample() x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)] y += [labels.cpu().numpy()] x = np.concatenate(x, 0)[:config['sample_num_npz']] y = np.concatenate(y, 0)[:config['sample_num_npz']] print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape)) npz_filename = '%s/%s/samples.npz' % (config['samples_root'], experiment_name) print('Saving npz to %s...' % npz_filename) np.savez(npz_filename, **{'x': x, 'y': y}) if config['official_FID']: f = np.load(config['dataset_is_fid']) # this is for using the downloaded one from # https://github.com/bioinf-jku/TTUR #mdata, sdata = f['mu'][:], f['sigma'][:] # this one is for my format files mdata, sdata = f['mfid'], f['sfid'] # Sample a number of images and stick them in memory, for use with TF-Inception official_IS and official_FID data_gen_necessary = False if config['sample_np_mem']: is_saved = int('IS' in hist[state_dict['itr']]) is_todo = int(config['official_IS']) fid_saved = int('FID' in hist[state_dict['itr']]) fid_todo = int(config['official_FID']) data_gen_necessary = config['overwrite'] or (is_todo > is_saved) or ( fid_todo > fid_saved) if config['sample_np_mem'] and data_gen_necessary: n_used_imgs = 50000 imageSize = config['resolution'] x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8) for l in tqdm(range(n_used_imgs // G_batch_size), desc='Generating [%s][%06d]' % (brief_expt_name, state_dict['itr'])): start = l * G_batch_size end = start + G_batch_size with torch.no_grad(): images, labels = sample() fake = np.uint8(255 * (images.cpu().numpy() + 1) / 2.) x[start:end] = np.moveaxis(fake, 1, -1) #y += [labels.cpu().numpy()] if config['official_IS']: if (not ('IS' in hist[state_dict['itr']])) or config['overwrite']: mis, sis = iscore.get_inception_score(x) print('[%s][%06d] IS mu: %f. IS sigma: %f.' % (brief_expt_name, state_dict['itr'], mis, sis)) hist = load_or_make_hist(hist_dir) hist[state_dict['itr']]['IS'] = [mis, sis] np.save(os.path.join(hist_dir, HIST_FNAME), hist) else: mis, sis = hist[state_dict['itr']]['IS'] print( '[%s][%06d] Already done (skipping...): IS mu: %f. IS sigma: %f.' % (brief_expt_name, state_dict['itr'], mis, sis)) if config['official_FID']: import tensorflow as tf def fid_ms_for_imgs(images, mem_fraction=0.5): gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=mem_fraction) inception_path = fid.check_or_download_inception(None) fid.create_inception_graph( inception_path) # load the graph into the current TF graph with tf.Session(config=tf.ConfigProto( gpu_options=gpu_options)) as sess: sess.run(tf.global_variables_initializer()) mu_gen, sigma_gen = fid.calculate_activation_statistics( images, sess, batch_size=100) return mu_gen, sigma_gen if (not ('FID' in hist[state_dict['itr']])) or config['overwrite']: m1, s1 = fid_ms_for_imgs(x) fid_value = fid.calculate_frechet_distance(m1, s1, mdata, sdata) print('[%s][%06d] FID: %f' % (brief_expt_name, state_dict['itr'], fid_value)) hist = load_or_make_hist(hist_dir) hist[state_dict['itr']]['FID'] = fid_value np.save(os.path.join(hist_dir, HIST_FNAME), hist) else: fid_value = hist[state_dict['itr']]['FID'] print('[%s][%06d] Already done (skipping...): FID: %f' % (brief_expt_name, state_dict['itr'], fid_value)) # Prepare sample sheets if config['sample_sheets']: print('Preparing conditional sample sheets...') folder_number = config['sample_sheet_folder_num'] if folder_number == -1: folder_number = config['load_weights'] utils.sample_sheet( G, classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']], num_classes=config['n_classes'], samples_per_class=10, parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=folder_number, z_=z_, ) # Sample interp sheets if config['sample_interps']: print('Preparing interp sheets...') folder_number = config['sample_sheet_folder_num'] if folder_number == -1: folder_number = config['load_weights'] for fix_z, fix_y in zip([False, False, True], [False, True, False]): utils.interp_sheet(G, num_per_sheet=16, num_midpoints=8, num_classes=config['n_classes'], parallel=config['parallel'], samples_root=config['samples_root'], experiment_name=experiment_name, folder_number=int(folder_number), sheet_number=0, fix_z=fix_z, fix_y=fix_y, device='cuda') # Sample random sheet if config['sample_random']: print('Preparing random sample sheet...') images, labels = sample() torchvision.utils.save_image( images.float(), '%s/%s/%s.jpg' % (config['samples_root'], experiment_name, config['load_weights']), nrow=int(G_batch_size**0.5), normalize=True) # Prepare a simple function get metrics that we use for trunc curves def get_metrics(): # Get Inception Score and FID get_inception_metrics = inception_utils.prepare_inception_metrics( config['dataset'], config['parallel'], config['no_fid']) sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config) IS_mean, IS_std, FID = get_inception_metrics( sample, config['num_inception_images'], num_splits=10, prints=False) # Prepare output string outstring = 'Using %s weights ' % ('ema' if config['use_ema'] else 'non-ema') outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else 'training') outstring += 'with noise variance %3.3f, ' % z_.var outstring += 'over %d images, ' % config['num_inception_images'] if config['accumulate_stats'] or not config['G_eval_mode']: outstring += 'with batch size %d, ' % G_batch_size if config['accumulate_stats']: outstring += 'using %d standing stat accumulations, ' % config[ 'num_standing_accumulations'] outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % ( state_dict['itr'], IS_mean, IS_std, FID) print(outstring) if config['sample_inception_metrics']: print('Calculating Inception metrics...') get_metrics() # Sample truncation curve stuff. This is basically the same as the inception metrics code if config['sample_trunc_curves']: start, step, end = [ float(item) for item in config['sample_trunc_curves'].split('_') ] print( 'Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...' % (start, step, end)) for var in np.arange(start, end + step, step): z_.var = var # Optionally comment this out if you want to run with standing stats # accumulated at one z variance setting if config['accumulate_stats']: utils.accumulate_standing_stats( G, z_, y_, config['n_classes'], config['num_standing_accumulations']) get_metrics()
def train(target_vars, saver, sess, logger, dataloaders, test_dataloaders, resume_iter, logdir): X = target_vars['X'] Y = target_vars['Y'] X_NOISE = target_vars['X_NOISE'] train_op = target_vars['train_op'] energy_pos = target_vars['energy_pos'] energy_neg = target_vars['energy_neg'] loss_energy = target_vars['loss_energy'] loss_ml = target_vars['loss_ml'] loss_total = target_vars['total_loss'] gvs = target_vars['gvs'] x_grad = target_vars['x_grad'] x_grad_first = target_vars['x_grad_first'] x_off = target_vars['x_off'] temp = target_vars['temp'] x_mod = target_vars['x_mod'] LABEL = target_vars['LABEL'] LABEL_POS = target_vars['LABEL_POS'] weights = target_vars['weights'] test_x_mod = target_vars['test_x_mod'] eps = target_vars['eps_begin'] label_ent = target_vars['label_ent'] set_seed(0) np.random.seed(0) random.seed(0) if FLAGS.use_attention: gamma = weights[0]['atten']['gamma'] else: gamma = tf.zeros(1) val_output = [test_x_mod] gvs_dict = dict(gvs) log_output = [ train_op, energy_pos, energy_neg, eps, loss_energy, loss_ml, loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent, *gvs_dict.keys() ] output = [train_op, x_mod] replay_buffer = ReplayBuffer(10000) itr = resume_iter x_mod = None gd_steps = 1 err_message = 'Total number of epochs should be divisible by the number of CL tasks.' assert FLAGS.epoch_num % FLAGS.num_tasks == 0, err_message epochs_per_task = FLAGS.epoch_num // FLAGS.num_tasks // FLAGS.num_cycles for task_index, dataloader in enumerate(dataloaders): dataloader_iterator = iter(dataloader) best_inception = 0.0 for epoch in range(1, epochs_per_task + 1): for data_corrupt, data, label in dataloader: print('Iter: {}; Epoch: {}/{}; Task: {}/{}'.format( itr, epoch + (task_index * epochs_per_task), FLAGS.epoch_num, task_index + 1, FLAGS.num_tasks)) data_corrupt = data_corrupt_init = data_corrupt.numpy() data_corrupt_init = data_corrupt.copy() data = data.numpy() label = label.numpy() label_init = label.copy() if FLAGS.mixup: idx = np.random.permutation(data.shape[0]) lam = np.random.beta(1, 1, size=(data.shape[0], 1, 1, 1)) data = data * lam + data[idx] * (1 - lam) if FLAGS.replay_batch and (x_mod is not None): replay_buffer.add(compress_x_mod(x_mod)) if len(replay_buffer) > FLAGS.batch_size: replay_batch = replay_buffer.sample(FLAGS.batch_size) replay_batch = decompress_x_mod(replay_batch) replay_mask = (np.random.uniform( 0, FLAGS.rescale, FLAGS.batch_size) > 0.05) data_corrupt[replay_mask] = replay_batch[replay_mask] if FLAGS.pcd: if x_mod is not None: data_corrupt = x_mod feed_dict = {X_NOISE: data_corrupt, X: data, Y: label} if FLAGS.cclass: feed_dict[LABEL] = label feed_dict[LABEL_POS] = label_init if itr % FLAGS.log_interval == 0: _, e_pos, e_neg, eps, loss_e, loss_ml, loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent, * \ grads = sess.run(log_output, feed_dict) kvs = {} kvs['e_pos'] = e_pos.mean() kvs['e_pos_std'] = e_pos.std() kvs['e_neg'] = e_neg.mean() kvs['e_diff'] = kvs['e_pos'] - kvs['e_neg'] kvs['e_neg_std'] = e_neg.std() kvs['temp'] = temp kvs['loss_e'] = loss_e.mean() kvs['eps'] = eps.mean() kvs['label_ent'] = label_ent kvs['loss_ml'] = loss_ml.mean() kvs['loss_total'] = loss_total.mean() kvs['x_grad'] = np.abs(x_grad).mean() kvs['x_grad_first'] = np.abs(x_grad_first).mean() kvs['x_off'] = x_off.mean() kvs['iter'] = itr kvs['gamma'] = gamma for v, k in zip(grads, [v.name for v in gvs_dict.values()]): kvs[k] = np.abs(v).max() string = "Obtained a total of " for key, value in kvs.items(): string += "{}: {}, ".format(key, value) if hvd.rank() == 0: print(string) logger.writekvs(kvs) for key, value in kvs.items(): neptune.log_metric(key, x=itr, y=value) else: _, x_mod = sess.run(output, feed_dict) if itr % FLAGS.save_interval == 0 and hvd.rank() == 0: saver.save( sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr))) if itr % FLAGS.test_interval == 0 and hvd.rank( ) == 0 and FLAGS.dataset != '2d': if FLAGS.dataset == 'cifar10': cifar10_map = { 0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck' } imgs = data labels = np.argmax(label, axis=1) for idx, img in enumerate(imgs[:20, :, :, :]): neptune.log_image( 'input_images', rescale_im(imgs[idx]), description=str(int(labels[idx])) + ': ' + cifar10_map[int(labels[idx])]) if FLAGS.evaluate: print('Test.') train_acc = test_accuracy(target_vars, saver, sess, logger, test_dataloaders[0]) test_acc = test_accuracy(target_vars, saver, sess, logger, test_dataloaders[1]) neptune.log_metric('train_accuracy', x=itr, y=train_acc) neptune.log_metric('test_accuracy', x=itr, y=test_acc) try_im = x_mod orig_im = data_corrupt.squeeze() actual_im = rescale_im(data) orig_im = rescale_im(orig_im) try_im = rescale_im(try_im).squeeze() for i, (im, t_im, actual_im_i) in enumerate( zip(orig_im[:20], try_im[:20], actual_im)): shape = orig_im.shape[1:] new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:])) size = shape[1] new_im[:, :size] = im new_im[:, size:2 * size] = t_im new_im[:, 2 * size:] = actual_im_i log_image(new_im, logger, 'train_gen_{}'.format(itr), step=i) neptune.log_image( 'train_gen', x=new_im, description='train_gen_iter:{}_idx:{}'.format( itr, i)) test_im = x_mod try: data_corrupt, data, label = next(dataloader_iterator) except BaseException: dataloader_iterator = iter(dataloader) data_corrupt, data, label = next(dataloader_iterator) data_corrupt = data_corrupt.numpy() if FLAGS.replay_batch and ( x_mod is not None) and len(replay_buffer) > 0: replay_batch = replay_buffer.sample(FLAGS.batch_size) replay_batch = decompress_x_mod(replay_batch) replay_mask = (np.random.uniform( 0, 1, (FLAGS.batch_size)) > 0.05) data_corrupt[replay_mask] = replay_batch[replay_mask] if FLAGS.dataset == 'cifar10' or FLAGS.dataset == 'imagenet' or FLAGS.dataset == 'imagenetfull': n = 128 if FLAGS.dataset == "imagenetfull": n = 32 if len(replay_buffer) > n: data_corrupt = decompress_x_mod( replay_buffer.sample(n)) elif FLAGS.dataset == 'imagenetfull': data_corrupt = np.random.uniform( 0, FLAGS.rescale, (n, 128, 128, 3)) else: data_corrupt = np.random.uniform( 0, FLAGS.rescale, (n, 32, 32, 3)) if FLAGS.dataset == 'cifar10': label = np.eye(10)[np.random.randint(0, 10, (n))] else: label = np.eye(1000)[np.random.randint( 0, 1000, (n))] feed_dict[X_NOISE] = data_corrupt feed_dict[X] = data if FLAGS.cclass: feed_dict[LABEL] = label test_x_mod = sess.run(val_output, feed_dict) try_im = test_x_mod orig_im = data_corrupt.squeeze() actual_im = rescale_im(data.numpy()) orig_im = rescale_im(orig_im) try_im = rescale_im(try_im).squeeze() for i, (im, t_im, actual_im_i) in enumerate( zip(orig_im[:20], try_im[:20], actual_im)): shape = orig_im.shape[1:] new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:])) size = shape[1] new_im[:, :size] = im new_im[:, size:2 * size] = t_im new_im[:, 2 * size:] = actual_im_i log_image(new_im, logger, 'val_gen_{}'.format(itr), step=i) neptune.log_image( 'val_gen', new_im, description='val_gen_iter:{}_idx:{}'.format( itr, i)) score, std = get_inception_score(list(try_im), splits=1) print("Inception score of {} with std of {}".format( score, std)) kvs = {} kvs['inception_score'] = score kvs['inception_score_std'] = std logger.writekvs(kvs) for key, value in kvs.items(): neptune.log_metric(key, x=itr, y=value) if score > best_inception: best_inception = score saver.save( sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_best')) if itr > 600000 and FLAGS.dataset == "mnist": assert False itr += 1 saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr)))
def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir): X = target_vars['X'] Y = target_vars['Y'] X_NOISE = target_vars['X_NOISE'] train_op_model = target_vars['train_op_model'] train_op_dis = target_vars['train_op_dis'] energy_pos = target_vars['energy_pos'] energy_neg = target_vars['energy_neg'] score_pos = target_vars['score_pos'] score_neg = target_vars['score_neg'] loss_energy = target_vars['loss_energy'] loss_total = target_vars['total_loss'] gvs = target_vars['gvs'] x_grad = target_vars['x_grad'] x_grad_first = target_vars['x_grad_first'] x_off = target_vars['x_off'] temp = target_vars['temp'] x_mod = target_vars['x_mod'] LABEL = target_vars['LABEL'] LABEL_POS = target_vars['LABEL_POS'] weights = target_vars['weights'] test_x_mod = target_vars['test_x_mod'] eps = target_vars['eps_begin'] label_ent = target_vars['label_ent'] train_op_model_l2 = target_vars['train_op_model_l2'] train_op_dis_l2 = target_vars['train_op_dis_l2'] output = [train_op_model, x_mod] if FLAGS.use_attention: gamma = weights[0]['atten']['gamma'] else: gamma = tf.zeros(1) val_output = [test_x_mod] gvs_dict = dict(gvs) # log_output = [ # train_op, # energy_pos, # energy_neg, # eps, # loss_energy, # loss_total, # x_grad, # x_off, # x_mod, # gamma, # x_grad_first, # label_ent, # *gvs_dict.keys()] replay_buffer = ReplayBuffer(10000) itr = resume_iter x_mod = None gd_steps = 1 dataloader_iterator = iter(dataloader) best_inception = 0.0 save_interval = FLAGS.save_interval for epoch in range(FLAGS.epoch_num): for data_corrupt, data, label in dataloader: data_corrupt = data_corrupt_init = data_corrupt.numpy() data_corrupt_init = data_corrupt.copy() data = data.numpy() label = label.numpy() label_init = label.copy() if FLAGS.mixup: idx = np.random.permutation(data.shape[0]) lam = np.random.beta(1, 1, size=(data.shape[0], 1, 1, 1)) data = data * lam + data[idx] * (1 - lam) if FLAGS.replay_batch and (x_mod is not None): replay_buffer.add(compress_x_mod(x_mod)) if len(replay_buffer) > FLAGS.batch_size: replay_batch = replay_buffer.sample(FLAGS.batch_size) replay_batch = decompress_x_mod(replay_batch) replay_mask = ( np.random.uniform( 0, FLAGS.rescale, FLAGS.batch_size) > 0.05) data_corrupt[replay_mask] = replay_batch[replay_mask] if FLAGS.pcd: if x_mod is not None: data_corrupt = x_mod feed_dict = {X_NOISE: data_corrupt, X: data, Y: label} if FLAGS.cclass: feed_dict[LABEL] = label feed_dict[LABEL_POS] = label_init if itr > 10: # Train discriminator _ = sess.run(train_op_dis, feed_dict) # Train model _, x_mod = sess.run(output, feed_dict) else: _, _ = sess.run([train_op_dis_l2, train_op_model_l2], feed_dict) energy_neg_, energy_pos_, score_neg_, score_pos_ = sess.run([energy_neg, energy_pos, score_neg, score_pos], feed_dict) print(np.mean(energy_neg_), np.mean(energy_pos_), np.mean(score_neg_), np.mean(score_pos_)) if itr > 30000: save_interval = 100 # if itr % save_interval == 0 and hvd.rank() == 0: # saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr))) if itr and itr % FLAGS.test_interval == 0 and hvd.rank() == 0 and FLAGS.dataset != '2d': try_im = x_mod orig_im = data_corrupt.squeeze() actual_im = rescale_im(data) orig_im = rescale_im(orig_im) try_im = rescale_im(try_im).squeeze() for i, (im, t_im, actual_im_i) in enumerate( zip(orig_im[:20], try_im[:20], actual_im)): shape = orig_im.shape[1:] new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:])) size = shape[1] new_im[:, :size] = im new_im[:, size:2 * size] = t_im new_im[:, 2 * size:] = actual_im_i log_image( new_im, logger, 'train_gen_{}'.format(itr), step=i) test_im = x_mod try: data_corrupt, data, label = next(dataloader_iterator) except BaseException: dataloader_iterator = iter(dataloader) data_corrupt, data, label = next(dataloader_iterator) data_corrupt = data_corrupt.numpy() if FLAGS.replay_batch and ( x_mod is not None) and len(replay_buffer) > 0: replay_batch = replay_buffer.sample(FLAGS.batch_size) replay_batch = decompress_x_mod(replay_batch) replay_mask = ( np.random.uniform( 0, 1, (FLAGS.batch_size)) > 0.05) data_corrupt[replay_mask] = replay_batch[replay_mask] if FLAGS.dataset == 'cifar10' or FLAGS.dataset == 'imagenet' or FLAGS.dataset == 'imagenetfull' or FLAGS.dataset == 'celeba': n = 128 if FLAGS.dataset == "imagenetfull": n = 32 if len(replay_buffer) > n: data_corrupt = decompress_x_mod(replay_buffer.sample(n)) elif FLAGS.dataset == 'imagenetfull': data_corrupt = np.random.uniform( 0, FLAGS.rescale, (n, 128, 128, 3)) else: data_corrupt = np.random.uniform( 0, FLAGS.rescale, (n, 32, 32, 3)) if FLAGS.dataset == 'cifar10': label = np.eye(10)[np.random.randint(0, 10, (n))] elif FLAGS.dataset == 'celeba': label = np.array([1] * n). reshape((n, 1)) else: label = np.eye(1000)[ np.random.randint( 0, 1000, (n))] feed_dict[X_NOISE] = data_corrupt feed_dict[X] = data if FLAGS.cclass: feed_dict[LABEL] = label test_x_mod = sess.run(val_output, feed_dict) try_im = test_x_mod orig_im = data_corrupt.squeeze() actual_im = rescale_im(data.numpy()) orig_im = rescale_im(orig_im) try_im = rescale_im(try_im).squeeze() for i, (im, t_im, actual_im_i) in enumerate( zip(orig_im[:20], try_im[:20], actual_im)): shape = orig_im.shape[1:] new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:])) size = shape[1] new_im[:, :size] = im new_im[:, size:2 * size] = t_im new_im[:, 2 * size:] = actual_im_i log_image( new_im, logger, 'val_gen_{}'.format(itr), step=i) score, std = get_inception_score(list(try_im), splits=1) print("Iteration {}: Inception score of {} with std of {}".format(itr, score, std)) kvs = {} kvs['inception_score'] = score kvs['inception_score_std'] = std logger.writekvs(kvs) if score > best_inception: best_inception = score saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_best')) saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr))) if itr > 60000 and FLAGS.dataset == "mnist": assert False itr += 1 saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr)))
def train(models, models_ema, optimizer, logger, dataloader, resume_iter, logdir, FLAGS, rank_idx, best_inception): torch.cuda.set_device(rank_idx) if FLAGS.replay_batch: if FLAGS.reservoir: replay_buffer = ReservoirBuffer(FLAGS.buffer_size, FLAGS.transform, FLAGS.dataset) else: replay_buffer = ReplayBuffer(FLAGS.buffer_size, FLAGS.transform, FLAGS.dataset) if rank_idx == 0: from inception import get_inception_score itr = resume_iter im_neg = None gd_steps = 1 optimizer.zero_grad() num_steps = FLAGS.num_steps if FLAGS.cuda: dev = torch.device("cuda:{}".format(rank_idx)) else: dev = torch.device("cpu") for epoch in range(FLAGS.epoch_num): tock = time.time() for data_corrupt, data, label in dataloader: label = label.float().cuda(rank_idx) data = data.permute(0, 3, 1, 2).float().contiguous() # Generate samples to evaluate inception score if itr % FLAGS.save_interval == 0: if FLAGS.dataset == "cifar10": data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (128, 32, 32, 3))) repeat = 128 // FLAGS.batch_size + 1 label = torch.cat([label] * repeat, axis=0) label = label[:128] elif FLAGS.dataset == "celeba": data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (data.shape[0], 128, 128, 3))) label = label[:data.shape[0]] data_corrupt = data_corrupt[:label.shape[0]] elif FLAGS.dataset == "stl": data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 48, 48, 3))) label = label[:32] data_corrupt = data_corrupt[:label.shape[0]] elif FLAGS.dataset == "lsun": data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 128, 128, 3))) label = label[:32] data_corrupt = data_corrupt[:label.shape[0]] elif FLAGS.dataset == "imagenet": data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 128, 128, 3))) label = label[:32] data_corrupt = data_corrupt[:label.shape[0]] elif FLAGS.dataset == "object": data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 128, 128, 3))) label = label[:32] data_corrupt = data_corrupt[:label.shape[0]] elif FLAGS.dataset == "mnist": data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 28, 28, 1))) label = label[:32] data_corrupt = data_corrupt[:label.shape[0]] else: assert False data_corrupt = torch.Tensor(data_corrupt.float()).permute(0, 3, 1, 2).float().contiguous() data = data.cuda(rank_idx) data_corrupt = data_corrupt.cuda(rank_idx) if FLAGS.replay_batch and len(replay_buffer) >= FLAGS.batch_size: replay_batch, idxs = replay_buffer.sample(data_corrupt.size(0)) replay_batch = decompress_x_mod(replay_batch) replay_mask = ( np.random.uniform( 0, 1, data_corrupt.size(0)) > 0.05) data_corrupt[replay_mask] = torch.Tensor(replay_batch[replay_mask]).cuda(rank_idx) else: idxs = None ix = random.randint(0, len(models) - 1) model = models[ix] if FLAGS.hmc: if itr % FLAGS.save_interval == 0: im_neg, im_samples, x_grad, v = gen_hmc_image(label, FLAGS, model, data_corrupt, num_steps, sample=True) else: im_neg, x_grad, v = gen_hmc_image(label, FLAGS, model, data_corrupt, num_steps) else: if itr % FLAGS.save_interval == 0: im_neg, im_neg_kl, im_samples, x_grad = gen_image(label, FLAGS, model, data_corrupt, num_steps, sample=True) else: im_neg, im_neg_kl, x_grad = gen_image(label, FLAGS, model, data_corrupt, num_steps) energy_pos = model.forward(data, label[:data.size(0)]) energy_neg = model.forward(im_neg.clone(), label) if FLAGS.replay_batch and (im_neg is not None): replay_buffer.add(compress_x_mod(im_neg.detach().cpu().numpy())) loss = energy_pos.mean() - energy_neg.mean() # loss = loss + (torch.pow(energy_pos, 2).mean() + torch.pow(energy_neg, 2).mean()) if FLAGS.kl: model.requires_grad_(False) loss_kl = model.forward(im_neg_kl, label) model.requires_grad_(True) loss = loss + FLAGS.kl_coeff * loss_kl.mean() if FLAGS.repel_im: start = timeit.timeit() bs = im_neg_kl.size(0) if FLAGS.dataset in ["celeba", "imagenet", "object", "lsun", "stl"]: im_neg_kl = im_neg_kl[:, :, :, :].contiguous() im_flat = torch.clamp(im_neg_kl.view(bs, -1), 0, 1) if FLAGS.dataset == "cifar10": if len(replay_buffer) > 1000: compare_batch, idxs = replay_buffer.sample(100, no_transform=False) compare_batch = decompress_x_mod(compare_batch) compare_batch = torch.Tensor(compare_batch).cuda(rank_idx) compare_flat = compare_batch.view(100, -1) dist_matrix = torch.norm(im_flat[:, None, :] - compare_flat[None, :, :], p=2, dim=-1) loss_repel = torch.log(dist_matrix.min(dim=1)[0]).mean() loss = loss - 0.3 * loss_repel else: loss_repel = torch.zeros(1) else: if len(replay_buffer) > 1000: compare_batch, idxs = replay_buffer.sample(100, no_transform=False, downsample=True) compare_batch = decompress_x_mod(compare_batch) compare_batch = torch.Tensor(compare_batch).cuda(rank_idx) compare_flat = compare_batch.view(100, -1) dist_matrix = torch.norm(im_flat[:, None, :] - compare_flat[None, :, :], p=2, dim=-1) loss_repel = torch.log(dist_matrix.min(dim=1)[0]).mean() else: loss_repel = torch.zeros(1).cuda(rank_idx) loss = loss - 0.3 * loss_repel end = timeit.timeit() else: loss_repel = torch.zeros(1) else: loss_kl = torch.zeros(1) loss_repel = torch.zeros(1) if FLAGS.hmc: v_flat = v.view(v.size(0), -1) im_grad_flat = x_grad.view(x_grad.size(0), -1) dot_product = F.normalize(v_flat, dim=1) * F.normalize(im_grad_flat, dim=1) hmc_loss = torch.abs(dot_product.sum(dim=1)).mean() loss = loss + 0.01 * hmc_loss else: hmc_loss = torch.zeros(1) if FLAGS.log_grad and len(replay_buffer) > 1000: loss_kl = loss_kl - 0.1 * loss_repel loss_kl = loss_kl.mean() loss_ml = energy_pos.mean() - energy_neg.mean() loss_ml.backward(retain_graph=True) ele = [] for param in model.parameters(): if param.grad is not None: ele.append(torch.norm(param.grad.data)) ele = torch.stack(ele, dim=0) ml_grad = torch.mean(ele) model.zero_grad() loss_kl.backward(retain_graph=True) ele = [] for param in model.parameters(): if param.grad is not None: ele.append(torch.norm(param.grad.data)) ele = torch.stack(ele, dim=0) kl_grad = torch.mean(ele) model.zero_grad() else: ml_grad = None kl_grad = None loss.backward() if FLAGS.gpus > 1: average_gradients(models) [clip_grad_norm(model.parameters(), 0.5) for model in models] optimizer.step() optimizer.zero_grad() ema_model(models, models_ema) if torch.isnan(energy_pos.mean()): assert False if torch.abs(energy_pos.mean()) > 10.0: assert False if itr % FLAGS.log_interval == 0 and rank_idx==0: tick = time.time() kvs = {} kvs['e_pos'] = energy_pos.mean().item() kvs['e_pos_std'] = energy_pos.std().item() kvs['e_neg'] = energy_neg.mean().item() kvs['kl_mean'] = loss_kl.mean().item() kvs['loss_repel'] = loss_repel.mean().item() kvs['e_neg_std'] = energy_neg.std().item() kvs['e_diff'] = kvs['e_pos'] - kvs['e_neg'] kvs['x_grad'] = np.abs(x_grad.detach().cpu().numpy()).mean() kvs['iter'] = itr kvs['hmc_loss'] = hmc_loss.item() kvs['num_steps'] = num_steps kvs['t_diff'] = tick - tock if FLAGS.replay_batch: kvs['length'] = len(replay_buffer) if (ml_grad is not None): kvs['kl_grad'] = kl_grad kvs['ml_grad'] = ml_grad string = "Obtained a total of " for key, value in kvs.items(): string += "{}: {}, ".format(key, value) print(string) logger.writekvs(kvs) tock = tick if itr % FLAGS.save_interval == 0 and rank_idx == 0 and (FLAGS.save_interval != 0): model_path = osp.join(logdir, "model_{}.pth".format(itr)) ckpt = {'optimizer_state_dict': optimizer.state_dict(), 'FLAGS': FLAGS, 'best_inception': best_inception} for i in range(FLAGS.ensembles): ckpt['model_state_dict_{}'.format(i)] = models[i].state_dict() ckpt['ema_model_state_dict_{}'.format(i)] = models_ema[i].state_dict() torch.save(ckpt, model_path) if itr % FLAGS.save_interval == 0 and rank_idx == 0: im_samples = im_samples[::10] im_samples_total = torch.stack(im_samples, dim=1).detach().cpu().permute(0, 1, 3, 4, 2).numpy() try_im = im_neg orig_im = data_corrupt actual_im = rescale_im(data.detach().permute(0, 2, 3, 1).cpu().numpy()) orig_im = rescale_im(orig_im.detach().permute(0, 2, 3, 1).cpu().numpy()) try_im = rescale_im(try_im.detach().permute(0, 2, 3, 1).cpu().numpy()).squeeze() im_samples_total = rescale_im(im_samples_total) for i, (im, sample_im, actual_im_i) in enumerate( zip(orig_im[:20], im_samples_total[:20], actual_im)): shape = orig_im.shape[1:] new_im = np.zeros((shape[0], shape[1] * (2 + sample_im.shape[0]), *shape[2:])) size = shape[1] new_im[:, :size] = im for i, sample_i in enumerate(sample_im): new_im[:, (i+1) * size:(i+2) * size] = sample_i new_im[:, -size:] = actual_im_i log_image( new_im, logger, 'train_gen_{}'.format(itr), step=i) if rank_idx == 0: score, std = get_inception_score(list(try_im), splits=1) print("Inception score of {} with std of {}".format( score, std)) kvs = {} kvs['inception_score'] = score kvs['inception_score_std'] = std logger.writekvs(kvs) if score > best_inception: model_path = osp.join(logdir, "model_best.pth") torch.save(ckpt, model_path) best_inception = score itr += 1
def conceptcombineeval(model_list, select_idx): dataset = ImageNet() dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4) n = 64 labels = [] for six in select_idx: six = np.random.permutation(1000)[:n] print(six) label_batch = np.eye(1000)[six] # label_ix = np.eye(2)[six] # label_batch = np.tile(label_ix[None, :], (n, 1)) label = torch.Tensor(label_batch).cuda() labels.append(label) def get_color_distortion(s=1.0): # s is the strength of color distortion. color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.4 * s) rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) rnd_gray = transforms.RandomGrayscale(p=0.2) color_distort = transforms.Compose([rnd_color_jitter, rnd_gray]) return color_distort color_transform = get_color_distortion(0.5) im_size = 128 transform = transforms.Compose([ transforms.RandomResizedCrop(im_size, scale=(0.3, 1.0)), transforms.RandomHorizontalFlip(), color_transform, transforms.ToTensor() ]) gt_ims = [] fake_ims = [] label_embed = torch.eye(1000).cuda() im = None for _, data, label in tqdm(dataloader): print(label) gt_ims.extend(list((data.numpy() * 255).astype(np.uint8))) if im is None: im = torch.rand(n, 3, 128, 128).cuda() im_noise = torch.randn_like(im).detach() # First get good initializations for sampling for i in range(5): for i in range(60): label = torch.randperm(1000).to(im.device)[:n] label = label_embed[label] im_noise.normal_() im = im + 0.001 * im_noise # im.requires_grad = True im.requires_grad_(requires_grad=True) energy = 0 for model, label in zip(model_list, labels): energy = model.forward(im, label) + energy # print("step: ", i, energy.mean()) im_grad = torch.autograd.grad([energy.sum()], [im])[0] im = im - FLAGS.step_lr * im_grad im = im.detach() im = torch.clamp(im, 0, 1) im = im.detach().cpu().numpy().transpose((0, 2, 3, 1)) im = (im * 255).astype(np.uint8) ims = [] for i in range(im.shape[0]): im_i = np.array(transform(Image.fromarray(np.array(im[i])))) ims.append(im_i) im = torch.Tensor(np.array(ims)).cuda() # Then refine the images for i in range(FLAGS.num_steps): im_noise.normal_() im = im + 0.001 * im_noise # im.requires_grad = True im.requires_grad_(requires_grad=True) energy = 0 label = torch.randperm(1000).to(im.device)[:n] label = label_embed[label] for model, label in zip(model_list, labels): energy = model.forward(im, label) + energy print("step: ", i, energy.mean()) im_grad = torch.autograd.grad([energy.sum()], [im])[0] im = im - FLAGS.step_lr * im_grad im = im.detach() im = torch.clamp(im, 0, 1) im_cpu = im.detach().cpu() ims = list((im_cpu.numpy().transpose( (0, 2, 3, 1)) * 255).astype(np.uint8)) fake_ims.extend(ims) if len(gt_ims) > 50000: break splits = max(1, len(fake_ims) // 5000) score, std = get_inception_score(fake_ims, splits=splits) print("inception score {}, with std {} ".format(score, std)) get_fid_score(gt_ims, fake_ims) import pdb pdb.set_trace() print("here")
def ComputeInception(images): images = ((images + 1) / 2.0) * 255.0 images = images.astype(np.uint8) IS = inception.get_inception_score(images) return IS
def compute_inception(model): size = FLAGS.im_number num_steps = size // 1000 images = [] test_ims = [] if FLAGS.dataset == "cifar10": test_dataset = Cifar10(FLAGS) elif FLAGS.dataset == "celeba": test_dataset = CelebAHQ() elif FLAGS.dataset == "mnist": test_dataset = Mnist(train=True) test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=4, shuffle=True, drop_last=False) if FLAGS.dataset == "cifar10": for data_corrupt, data, label_gt in tqdm(test_dataloader): data = data.numpy() test_ims.extend(list(rescale_im(data))) if len(test_ims) > 10000: break elif FLAGS.dataset == "mnist": for data_corrupt, data, label_gt in tqdm(test_dataloader): data = data.numpy() test_ims.extend(list(np.tile(rescale_im(data), (1, 1, 3)))) if len(test_ims) > 10000: break test_ims = test_ims[:10000] classes = 10 print(FLAGS.batch_size) data_buffer = None for j in range(num_steps): itr = int(1000 / 500 * FLAGS.repeat_scale) if data_buffer is None: data_buffer = InceptionReplayBuffer(1000) curr_index = 0 identity = np.eye(classes) if FLAGS.dataset == "celeba": n = 128 c = 3 elif FLAGS.dataset == "mnist": n = 28 c = 1 else: n = 32 c = 3 for i in tqdm(range(itr)): noise_scale = [1] if len(data_buffer) < 1000: x_init = np.random.uniform(0, 1, (FLAGS.batch_size, c, n, n)) label = np.random.randint(0, classes, (FLAGS.batch_size)) x_init = torch.Tensor(x_init).cuda() label = identity[label] label = torch.Tensor(label).cuda() x_new, _ = gen_image(label, FLAGS, model, x_init, FLAGS.num_steps) x_new = x_new.detach().cpu().numpy() label = label.detach().cpu().numpy() data_buffer.add(x_new, label) else: if i < itr - FLAGS.nomix: (x_init, label), idx = data_buffer.sample( FLAGS.batch_size, transform=FLAGS.transform) else: if FLAGS.dataset == "celeba": n = 20 else: n = 2 ix = i % n # for i in range(n): start_idx = (1000 // n) * ix end_idx = (1000 // n) * (ix + 1) (x_init, label) = data_buffer._encode_sample( list(range(start_idx, end_idx)), transform=False) idx = list(range(start_idx, end_idx)) x_init = torch.Tensor(x_init).cuda() label = torch.Tensor(label).cuda() x_new, energy = gen_image(label, FLAGS, model, x_init, FLAGS.num_steps) energy = energy.cpu().detach().numpy() x_new = x_new.cpu().detach().numpy() label = label.cpu().detach().numpy() data_buffer.set_elms(idx, x_new, label) if FLAGS.im_number != 50000: print(np.mean(energy), np.std(energy)) curr_index += 1 ims = np.array(data_buffer._storage[:1000]) ims = rescale_im(ims).transpose((0, 2, 3, 1)) if FLAGS.dataset == "mnist": ims = np.tile(ims, (1, 1, 1, 3)) images.extend(list(ims)) random.shuffle(images) saveim = osp.join('sandbox_cachedir', FLAGS.exp, "test{}.png".format(FLAGS.idx)) if FLAGS.dataset == "cifar10": rix = np.random.permutation(1000)[:100] ims = ims[rix] im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose( (0, 2, 1, 3, 4)).reshape((320, 320, 3)) imsave(saveim, im_panel) print("Saved image!!!!") splits = max(1, len(images) // 5000) score, std = get_inception_score(images, splits=splits) print("Inception score of {} with std of {}".format(score, std)) # FID score n = min(len(images), len(test_ims)) fid = get_fid_score(images, test_ims) print("FID of score {}".format(fid)) elif FLAGS.dataset == "mnist": # ims = ims[:100] # im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((320, 320, 3)) # imsave(saveim, im_panel) ims = ims[:100] im_panel = ims.reshape((10, 10, 28, 28, 3)).transpose( (0, 2, 1, 3, 4)).reshape((280, 280, 3)) imsave(saveim, im_panel) print("Saved image!!!!") splits = max(1, len(images) // 5000) # score, std = get_inception_score(images, splits=splits) # print("Inception score of {} with std of {}".format(score, std)) # FID score n = min(len(images), len(test_ims)) fid = get_fid_score(images, test_ims) print("FID of score {}".format(fid)) elif FLAGS.dataset == "celeba": ims = ims[:25] im_panel = ims.reshape((5, 5, 128, 128, 3)).transpose( (0, 2, 1, 3, 4)).reshape((5 * 128, 5 * 128, 3)) imsave(saveim, im_panel)