def unsup_finetune(model, FLAGS_model): train_dataset = Cifar10(FLAGS, train=False) train_dataloader = DataLoader(train_dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False) classifier = ModelLinear(FLAGS_model).cuda() optimizer = optim.Adam(classifier.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss() running_accuracy = [] count = 0 for i in range(100): for data_corrupt, data, label_gt in tqdm(train_dataloader): data = data.permute(0, 3, 1, 2).float().cuda() target = label_gt.long().cuda() with torch.no_grad(): model_feat = model.compute_feat(data, None) model_feat = F.normalize(model_feat, dim=-1, p=2) output = classifier(model_feat) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() acc1, acc5 = accuracy(output, target, topk=(1, 5)) running_accuracy.append(acc1.item()) running_accuracy = running_accuracy[-10:] print(count, loss, np.mean(running_accuracy)) count += 1
def energyevalmix(model): # dataset = Cifar100(FLAGS, train=False) # dataset = Svhn(train=False) # dataset = Textures(train=True) # dataset = Cifar10(FLAGS, train=False) dataset = CelebaSmall() test_dataloader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False) train_dataset = Cifar10(FLAGS, train=False) train_dataloader = DataLoader(train_dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False) test_iter = iter(test_dataloader) probs = [] labels = [] negs = [] pos = [] for data_corrupt, data, label_gt in tqdm(train_dataloader): _, data_mix, _ = test_iter.next() # data_mix = data_mix[:data.shape[0]] # data_mix_permute = torch.cat([data_mix[1:data.shape[0]], data_mix[:1]], dim=0) # data_mix = (data_mix + data_mix_permute) / 2. data = data.permute(0, 3, 1, 2).float().cuda() data_mix = data_mix.permute(0, 3, 1, 2).float().cuda() pos_energy = model.forward(data, None).detach().cpu().numpy().mean(axis=-1) neg_energy = model.forward(data_mix, None).detach().cpu().numpy().mean(axis=-1) print("pos_energy", pos_energy.mean()) print("neg_energy", neg_energy.mean()) probs.extend(list(-1*pos_energy)) probs.extend(list(-1*neg_energy)) pos.extend(list(-1*pos_energy)) negs.extend(list(-1*neg_energy)) labels.extend([1]*pos_energy.shape[0]) labels.extend([0]*neg_energy.shape[0]) pos, negs = np.array(pos), np.array(negs) np.save("pos.npy", pos) np.save("neg.npy", negs) auroc = sk.roc_auc_score(labels, probs) print("Roc score of {}".format(auroc))
def dataset_iterator(args): if args.dataset == 'mnist': train_gen, dev_gen, test_gen = Mnist.load(args.batch_size, args.batch_size) if args.dataset == 'cifar10': data_dir = '../../../images/cifar-10-batches-py/' train_gen, dev_gen = Cifar10.load(args.batch_size, data_dir) test_gen = None if args.dataset == 'imagenet': data_dir = '../../../images/imagenet12/imagenet_val_png/' train_gen, dev_gen = Imagenet.load(args.batch_size, data_dir) test_gen = None if args.dataset == 'raise': data_dir = '../../../images/raise/' train_gen, dev_gen = Raise.load(args.batch_size, data_dir) test_gen = None else: raise ValueError return (train_gen, dev_gen, test_gen)
def main(): print("Local rank: ", hvd.local_rank(), hvd.size()) logdir = osp.join(FLAGS.logdir, FLAGS.exp) if hvd.rank() == 0: if not osp.exists(logdir): os.makedirs(logdir) logger = TensorBoardOutputFormat(logdir) else: logger = None LABEL = None print("Loading data...") if FLAGS.dataset == 'cifar10': dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale) test_dataset = Cifar10(train=False, rescale=FLAGS.rescale) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) if FLAGS.large_model: model = ResNet32Large(num_channels=channel_num, num_filters=128, train=True) elif FLAGS.larger_model: model = ResNet32Larger(num_channels=channel_num, num_filters=128) elif FLAGS.wider_model: model = ResNet32Wider(num_channels=channel_num, num_filters=192) else: model = ResNet32(num_channels=channel_num, num_filters=128) elif FLAGS.dataset == 'imagenet': dataset = Imagenet(train=True) test_dataset = Imagenet(train=False) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32) model = ResNet32Wider(num_channels=channel_num, num_filters=256) elif FLAGS.dataset == 'imagenetfull': channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32) model = ResNet128(num_channels=channel_num, num_filters=64) elif FLAGS.dataset == 'mnist': dataset = Mnist(rescale=FLAGS.rescale) test_dataset = dataset channel_num = 1 X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32) X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) model = MnistNet(num_channels=channel_num, num_filters=FLAGS.num_filters) elif FLAGS.dataset == 'dsprites': dataset = DSprites(cond_shape=FLAGS.cond_shape, cond_size=FLAGS.cond_size, cond_pos=FLAGS.cond_pos, cond_rot=FLAGS.cond_rot) test_dataset = dataset channel_num = 1 X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) if FLAGS.dpos_only: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.dsize_only: LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32) elif FLAGS.drot_only: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.cond_size: LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32) elif FLAGS.cond_shape: LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32) elif FLAGS.cond_pos: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.cond_rot: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) else: LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32) model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters, cond_size=FLAGS.cond_size, cond_shape=FLAGS.cond_shape, cond_pos=FLAGS.cond_pos, cond_rot=FLAGS.cond_rot) print("Done loading...") if FLAGS.dataset == "imagenetfull": # In the case of full imagenet, use custom_tensorflow dataloader data_loader = TFImagenetLoader('train', FLAGS.batch_size, hvd.rank(), hvd.size(), rescale=FLAGS.rescale) else: data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=True, shuffle=True) batch_size = FLAGS.batch_size weights = [model.construct_weights('context_0')] Y = tf.placeholder(shape=(None), dtype=tf.int32) # Varibles to run in training X_SPLIT = tf.split(X, FLAGS.num_gpus) X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus) LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus) LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus) LABEL_SPLIT_INIT = list(LABEL_SPLIT) tower_grads = [] tower_gen_grads = [] x_mod_list = [] optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999) optimizer = hvd.DistributedOptimizer(optimizer) for j in range(FLAGS.num_gpus): if FLAGS.model_cclass: ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus label_tensor = tf.Variable(tf.convert_to_tensor(np.reshape( np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)), (FLAGS.batch_size * 10, 10)), dtype=tf.float32), trainable=False, dtype=tf.float32) x_split = tf.tile( tf.reshape(X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)), (1, 10, 1, 1, 1)) x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3)) energy_pos = model.forward(x_split, weights[0], label=label_tensor, stop_at_grad=False) energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10)) energy_partition_est = tf.reduce_logsumexp(energy_pos_full, axis=1, keepdims=True) uniform = tf.random_uniform(tf.shape(energy_pos_full)) label_tensor = tf.argmax(-energy_pos_full - tf.log(-tf.log(uniform)) - energy_partition_est, axis=1) label = tf.one_hot(label_tensor, 10, dtype=tf.float32) label = tf.Print(label, [label_tensor, energy_pos_full]) LABEL_SPLIT[j] = label energy_pos = tf.concat(energy_pos, axis=0) else: energy_pos = [ model.forward(X_SPLIT[j], weights[0], label=LABEL_POS_SPLIT[j], stop_at_grad=False) ] energy_pos = tf.concat(energy_pos, axis=0) print("Building graph...") x_mod = x_orig = X_NOISE_SPLIT[j] x_grads = [] energy_negs = [] loss_energys = [] energy_negs.extend([ model.forward(tf.stop_gradient(x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) ]) eps_begin = tf.zeros(1) steps = tf.constant(0) c = lambda i, x: tf.less(i, FLAGS.num_steps) def langevin_step(counter, x_mod): x_mod = x_mod + tf.random_normal( tf.shape(x_mod), mean=0.0, stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale) energy_noise = energy_start = tf.concat([ model.forward(x_mod, weights[0], label=LABEL_SPLIT[j], reuse=True, stop_at_grad=False, stop_batch=True) ], axis=0) x_grad, label_grad = tf.gradients(FLAGS.temperature * energy_noise, [x_mod, LABEL_SPLIT[j]]) energy_noise_old = energy_noise lr = FLAGS.step_lr if FLAGS.proj_norm != 0.0: if FLAGS.proj_norm_type == 'l2': x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm) elif FLAGS.proj_norm_type == 'li': x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) else: print("Other types of projection are not supported!!!") assert False # Clip gradient norm for now if FLAGS.hmc: # Step size should be tuned to get around 65% acceptance def energy(x): return FLAGS.temperature * \ model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True) x_last = hmc(x_mod, 15., 10, energy) else: x_last = x_mod - (lr) * x_grad x_mod = x_last x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale) counter = counter + 1 return counter, x_mod steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod)) energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0] x_grads.append(x_grad) energy_negs.append( model.forward(tf.stop_gradient(x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)) test_x_mod = x_mod temp = FLAGS.temperature energy_neg = energy_negs[-1] x_off = tf.reduce_mean( tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j])) loss_energy = model.forward(x_mod, weights[0], reuse=True, label=LABEL, stop_grad=True) print("Finished processing loop construction ...") target_vars = {} if FLAGS.cclass or FLAGS.model_cclass: label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0) label_prob = label_sum / tf.reduce_sum(label_sum) label_ent = -tf.reduce_sum( label_prob * tf.math.log(label_prob + 1e-7)) else: label_ent = tf.zeros(1) target_vars['label_ent'] = label_ent if FLAGS.train: if FLAGS.objective == 'logsumexp': pos_term = temp * energy_pos energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg)) coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced)) norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4 pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'cd': pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = -tf.reduce_mean(temp * energy_neg) loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'softplus': loss_ml = FLAGS.ml_coeff * \ tf.nn.softplus(temp * (energy_pos - energy_neg)) loss_total = tf.reduce_mean(loss_ml) if not FLAGS.zero_kl: loss_total = loss_total + tf.reduce_mean(loss_energy) loss_total = loss_total + \ FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg)))) print("Started gradient computation...") gvs = optimizer.compute_gradients(loss_total) gvs = [(k, v) for (k, v) in gvs if k is not None] print("Applying gradients...") tower_grads.append(gvs) print("Finished applying gradients.") target_vars['loss_ml'] = loss_ml target_vars['total_loss'] = loss_total target_vars['loss_energy'] = loss_energy target_vars['weights'] = weights target_vars['gvs'] = gvs target_vars['X'] = X target_vars['Y'] = Y target_vars['LABEL'] = LABEL target_vars['LABEL_POS'] = LABEL_POS target_vars['X_NOISE'] = X_NOISE target_vars['energy_pos'] = energy_pos target_vars['energy_start'] = energy_negs[0] if len(x_grads) >= 1: target_vars['x_grad'] = x_grads[-1] target_vars['x_grad_first'] = x_grads[0] else: target_vars['x_grad'] = tf.zeros(1) target_vars['x_grad_first'] = tf.zeros(1) target_vars['x_mod'] = x_mod target_vars['x_off'] = x_off target_vars['temp'] = temp target_vars['energy_neg'] = energy_neg target_vars['test_x_mod'] = test_x_mod target_vars['eps_begin'] = eps_begin if FLAGS.train: grads = average_gradients(tower_grads) train_op = optimizer.apply_gradients(grads) target_vars['train_op'] = train_op config = tf.ConfigProto() if hvd.size() > 1: config.gpu_options.visible_device_list = str(hvd.local_rank()) sess = tf.Session(config=config) saver = loader = tf.train.Saver(max_to_keep=30, keep_checkpoint_every_n_hours=6) total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Model has a total of {} parameters".format(total_parameters)) sess.run(tf.global_variables_initializer()) resume_itr = 0 if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0: model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter # saver.restore(sess, model_file) optimistic_restore(sess, model_file) sess.run(hvd.broadcast_global_variables(0)) print("Initializing variables...") print("Start broadcast") print("End broadcast") if FLAGS.train: print("Training phase") train(target_vars, saver, sess, logger, data_loader, resume_itr, logdir) print("Testing phase") test(target_vars, saver, sess, logger, data_loader)
def compute_inception(sess, target_vars): X_START = target_vars['X_START'] Y_GT = target_vars['Y_GT'] X_finals = target_vars['X_finals'] NOISE_SCALE = target_vars['NOISE_SCALE'] energy_noise = target_vars['energy_noise'] size = FLAGS.im_number num_steps = size // 1000 images = [] test_ims = [] test_images = [] if FLAGS.dataset == "cifar10": test_dataset = Cifar10(full=True, noise=False) elif FLAGS.dataset == "celeba": dataset = CelebA() elif FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull": test_dataset = Imagenet(train=False) if FLAGS.dataset != "imagenetfull": test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=4, shuffle=True, drop_last=False) else: test_dataloader = TFImagenetLoader('test', FLAGS.batch_size, 0, 1) for data_corrupt, data, label_gt in tqdm(test_dataloader): data = data.numpy() test_ims.extend(list(rescale_im(data))) if FLAGS.dataset == "imagenetfull" and len(test_ims) > 60000: test_ims = test_ims[:60000] break # n = min(len(images), len(test_ims)) print(len(test_ims)) # fid = get_fid_score(test_ims[:30000], test_ims[-30000:]) # print("Base FID of score {}".format(fid)) if FLAGS.dataset == "cifar10": classes = 10 else: classes = 1000 if FLAGS.dataset == "imagenetfull": n = 128 else: n = 32 for j in range(num_steps): itr = int(1000 / 500 * FLAGS.repeat_scale) data_buffer = InceptionReplayBuffer(1000) curr_index = 0 identity = np.eye(classes) test_steps = range(300, itr, 20) for i in tqdm(range(itr)): model_index = curr_index % len(X_finals) x_final = X_finals[model_index] noise_scale = [1] if len(data_buffer) < 1000: x_init = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3)) label = np.random.randint(0, classes, (FLAGS.batch_size)) label = identity[label] x_new = sess.run([x_final], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})[0] data_buffer.add(x_new, label) else: (x_init, label), idx = data_buffer.sample(FLAGS.batch_size) keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99) label_keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9) label_corrupt = np.random.randint(0, classes, (FLAGS.batch_size)) label_corrupt = identity[label_corrupt] x_init_corrupt = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3)) if i < itr - FLAGS.nomix: x_init[keep_mask] = x_init_corrupt[keep_mask] label[label_keep_mask] = label_corrupt[label_keep_mask] # else: # noise_scale = [0.7] x_new, e_noise = sess.run([x_final, energy_noise], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale}) data_buffer.set_elms(idx, x_new, label) curr_index += 1 ims = np.array(data_buffer._storage[:1000]) ims = rescale_im(ims) test_images.extend(list(ims)) saveim = osp.join(FLAGS.logdir, FLAGS.exp, "test{}.png".format(FLAGS.resume_iter)) row = 15 col = 20 ims = ims[:row * col] if FLAGS.dataset != "imagenetfull": im_panel = ims.reshape((row, col, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((32*row, 32*col, 3)) else: im_panel = ims.reshape((row, col, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((128*row, 128*col, 3)) imsave(saveim, im_panel) splits = max(1, len(test_images) // 5000) score, std = get_inception_score(test_images, splits=splits) print("Inception score of {} with std of {}".format(score, std)) # FID score # n = min(len(images), len(test_ims)) fid = get_fid_score(test_images, test_ims) print("FID of score {}".format(fid))
def main(): # Initialize dataset if FLAGS.dataset == 'cifar10': dataset = Cifar10(train=False, rescale=FLAGS.rescale) channel_num = 3 dim_input = 32 * 32 * 3 elif FLAGS.dataset == 'imagenet': dataset = ImagenetClass() channel_num = 3 dim_input = 64 * 64 * 3 elif FLAGS.dataset == 'mnist': dataset = Mnist(train=False, rescale=FLAGS.rescale) channel_num = 1 dim_input = 28 * 28 * 1 elif FLAGS.dataset == 'dsprites': dataset = DSprites() channel_num = 1 dim_input = 64 * 64 * 1 elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss': dataset = Box2D() dim_output = 1 data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=False, shuffle=True) if FLAGS.dataset == 'mnist': model = MnistNet(num_channels=channel_num) elif FLAGS.dataset == 'cifar10': if FLAGS.large_model: model = ResNet32Large(num_filters=128) elif FLAGS.wider_model: model = ResNet32Wider(num_filters=192) else: model = ResNet32(num_channels=channel_num, num_filters=128) elif FLAGS.dataset == 'dsprites': model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters) weights = model.construct_weights('context_{}'.format(0)) config = tf.ConfigProto() sess = tf.Session(config=config) saver = loader = tf.train.Saver(max_to_keep=10) sess.run(tf.global_variables_initializer()) logdir = osp.join(FLAGS.logdir, FLAGS.exp) model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter if FLAGS.resume_iter != "-1": optimistic_restore(sess, model_file) else: print("WARNING, YOU ARE NOT LOADING A SAVE FILE") # saver.restore(sess, model_file) chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample( model, weights, FLAGS.batch_size, temp=FLAGS.temperature) print("Finished constructing ancestral sample ...................") if FLAGS.dataset != "gauss": comb_weights_cum = [] batch_size = tf.shape(x_init)[0] label_tiled = tf.tile(label_default, (batch_size, 1)) e_compute = -FLAGS.temperature * model.forward( x_init, weights, label=label_tiled) e_pos_list = [] for data_corrupt, data, label_gt in tqdm(data_loader): e_pos = sess.run([e_compute], {x_init: data})[0] e_pos_list.extend(list(e_pos)) print(len(e_pos_list)) print("Positive sample probability ", np.mean(e_pos_list), np.std(e_pos_list)) if FLAGS.dataset == "2d": alr = 0.0045 elif FLAGS.dataset == "gauss": alr = 0.0085 elif FLAGS.dataset == "mnist": alr = 0.0065 #90 alr = 0.0035 else: # alr = 0.0125 if FLAGS.rescale == 8: alr = 0.0085 else: alr = 0.0045 # for i in range(1): tot_weight = 0 for j in tqdm(range(1, FLAGS.pdist + 1)): if j == 1: if FLAGS.dataset == "cifar10": x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3)) elif FLAGS.dataset == "gauss": x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim)) elif FLAGS.dataset == "mnist": x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28)) else: x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 2)) alpha_prev = (j - 1) / FLAGS.pdist alpha_new = j / FLAGS.pdist cweight, x_curr = sess.run( [chain_weights, x], { a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5**(2.5 * -alpha_prev)) }) tot_weight = tot_weight + cweight print("Total values of lower value based off forward sampling", np.mean(tot_weight), np.std(tot_weight)) tot_weight = 0 for j in tqdm(range(FLAGS.pdist, 0, -1)): alpha_new = (j - 1) / FLAGS.pdist alpha_prev = j / FLAGS.pdist cweight, x_curr = sess.run( [chain_weights, x], { a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5**(2.5 * -alpha_prev)) }) tot_weight = tot_weight - cweight print("Total values of upper value based off backward sampling", np.mean(tot_weight), np.std(tot_weight))
return inception_score def test_unconditional(self, data_loader): print("Testing phase") inception_score = test(self.target_vars, self.saver, self.sess, self.logger, data_loader) return inception_score if __name__ == "__main__": print("Loading data...") path = "/home/abhi/Documents/courses/UofT/CSC2506/project/data/cifar10" train_dataset = Cifar10(train=True, augment=FLAGS.augment, rescale=FLAGS.rescale, path=path) train_dataset_1 = torch.utils.data.Subset(train_dataset, list(range(0, 1000, 2))) print("Length of train_dataset:%d"%len(train_dataset_1)) data_loader = DataLoader(train_dataset_1, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=True, shuffle=True) print("Done loading...") base_functions=[elu,gelu,linear,relu,selu,sigmoid,softplus,swish,tanh,atan,cos,erf,sin,sqrt] base_operations=[maximum,minimum,add,subtract] evo=EvolutionaryAlgorithm(base_functions,base_operations,min_depth=1,max_depth=3,pop_size=10) initial_gen = evo.create_generation() custom_act = compile(initial_gen[0], pset=evo.pset) ebm_prob = EBMProbML(tf.nn.leaky_relu) # ebm_prob = EBMProbML(custom_act) train_inc_score = ebm_prob.train_unconditional(data_loader)
def main(): if FLAGS.dataset == "cifar10": dataset = Cifar10(train=True, noise=False) test_dataset = Cifar10(train=False, noise=False) else: dataset = Imagenet(train=True) test_dataset = Imagenet(train=False) if FLAGS.svhn: dataset = Svhn(train=True) test_dataset = Svhn(train=False) if FLAGS.task == 'latent': dataset = DSprites() test_dataset = dataset dataloader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=True) test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=True) hidden_dim = 128 if FLAGS.large_model: model = ResNet32Large(num_filters=hidden_dim) elif FLAGS.larger_model: model = ResNet32Larger(num_filters=hidden_dim) elif FLAGS.wider_model: if FLAGS.dataset == 'imagenet': model = ResNet32Wider(num_filters=196, train=False) else: model = ResNet32Wider(num_filters=256, train=False) else: model = ResNet32(num_filters=hidden_dim) if FLAGS.task == 'latent': model = DspritesNet() weights = model.construct_weights('context_{}'.format(0)) total_parameters = 0 for variable in tf.compat.v1.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Model has a total of {} parameters".format(total_parameters)) config = tf.compat.v1.ConfigProto() sess = tf.compat.v1.InteractiveSession() if FLAGS.task == 'latent': X = tf.compat.v1.placeholder(shape=(None, 64, 64), dtype=tf.float32) else: X = tf.compat.v1.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) if FLAGS.dataset == "cifar10": Y = tf.compat.v1.placeholder(shape=(None, 10), dtype=tf.float32) Y_GT = tf.compat.v1.placeholder(shape=(None, 10), dtype=tf.float32) elif FLAGS.dataset == "imagenet": Y = tf.compat.v1.placeholder(shape=(None, 1000), dtype=tf.float32) Y_GT = tf.compat.v1.placeholder(shape=(None, 1000), dtype=tf.float32) target_vars = {'X': X, 'Y': Y, 'Y_GT': Y_GT} if FLAGS.task == 'label': construct_label(weights, X, Y, Y_GT, model, target_vars) elif FLAGS.task == 'labelfinetune': construct_finetune_label( weights, X, Y, Y_GT, model, target_vars, ) elif FLAGS.task == 'energyeval' or FLAGS.task == 'mixenergy': construct_energy(weights, X, Y, Y_GT, model, target_vars) elif FLAGS.task == 'anticorrupt' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'crossclass' or FLAGS.task == 'cycleclass' or FLAGS.task == 'democlass' or FLAGS.task == 'nearestneighbor': construct_steps(weights, X, Y_GT, model, target_vars) elif FLAGS.task == 'latent': construct_latent(weights, X, Y_GT, model, target_vars) sess.run(tf.compat.v1.global_variables_initializer()) saver = loader = tf.compat.v1.train.Saver(max_to_keep=10) savedir = osp.join('cachedir', FLAGS.exp) logdir = osp.join(FLAGS.logdir, FLAGS.exp) if not osp.exists(logdir): os.makedirs(logdir) initialize() if FLAGS.resume_iter != -1: model_file = osp.join(savedir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter if FLAGS.task == 'label' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'labelfinetune' or FLAGS.task == "energyeval" or FLAGS.task == "crossclass" or FLAGS.task == "mixenergy": optimistic_restore(sess, model_file) # saver.restore(sess, model_file) else: # optimistic_restore(sess, model_file) saver.restore(sess, model_file) if FLAGS.task == 'label': if FLAGS.labelgrid: vals = [] if FLAGS.lnorm == -1: for i in range(31): accuracies = label(dataloader, test_dataloader, target_vars, sess, l1val=i) vals.append(accuracies) elif FLAGS.lnorm == 2: for i in range(0, 100, 5): accuracies = label(dataloader, test_dataloader, target_vars, sess, l2val=i) vals.append(accuracies) np.save("result_{}_{}.npy".format(FLAGS.lnorm, FLAGS.exp), vals) else: label(dataloader, test_dataloader, target_vars, sess) elif FLAGS.task == 'labelfinetune': labelfinetune(dataloader, test_dataloader, target_vars, sess, savedir, saver, l1val=FLAGS.lival, l2val=FLAGS.l2val) elif FLAGS.task == 'energyeval': energyeval(dataloader, test_dataloader, target_vars, sess) elif FLAGS.task == 'mixenergy': energyevalmix(dataloader, test_dataloader, target_vars, sess) elif FLAGS.task == 'anticorrupt': anticorrupt(test_dataloader, weights, model, target_vars, logdir, sess) elif FLAGS.task == 'boxcorrupt': # boxcorrupt(test_dataloader, weights, model, target_vars, logdir, sess) boxcorrupt(test_dataloader, dataloader, weights, model, target_vars, logdir, sess) elif FLAGS.task == 'crossclass': crossclass(test_dataloader, weights, model, target_vars, logdir, sess) elif FLAGS.task == 'cycleclass': cycleclass(test_dataloader, weights, model, target_vars, logdir, sess) elif FLAGS.task == 'democlass': democlass(test_dataloader, weights, model, target_vars, logdir, sess) elif FLAGS.task == 'nearestneighbor': # print(dir(dataset)) # print(type(dataset)) nearest_neighbor(dataset.data.train_data / 255, sess, target_vars, logdir) elif FLAGS.task == 'latent': latent(test_dataloader, weights, model, target_vars, sess)
from datetime import datetime from torch import nn, optim from torch.utils import data from data import Cifar10, FashionMNIST from model import MLP, ConvNet, OneLayer from utils import train, test from config import Config as config print(f"{datetime.now().ctime()} - Start Loading Dataset...") if config.dataset == "cifar10": train_dataset = Cifar10(config.root, train=True) test_dataset = Cifar10(config.root, train=False) elif config.dataset == "fashionmnist": train_dataset = FashionMNIST(config.root, train=True) test_dataset = FashionMNIST(config.root, train=False) train_dataloader = data.DataLoader(train_dataset, config.batch_size, shuffle=True, num_workers=2) test_dataloader = data.DataLoader(test_dataset, config.batch_size, shuffle=False, num_workers=2) print(f"{datetime.now().ctime()} - Finish Loading Dataset") print( f"{datetime.now().ctime()} - Start Creating Net, Criterion, Optimizer and Scheduler..." ) if config.model == "mlp": net = MLP(config.cifar10_input_size, config.num_classes) elif config.model == "convnet":
from config import DefaultConfig from data import Cifar10 import models from Trainer import Trainer import csv import matplotlib.pyplot as plt import os from torchvision import datasets, transforms opt = DefaultConfig() if not os.path.exists(opt.train_data_root): from precifar import * Precifar() train_data = Cifar10(opt.train_data_root, train=True) val_data = Cifar10(opt.train_data_root, train=False) test_data = Cifar10(opt.test_data_root, test=True) Model = getattr(models, opt.model)() if opt.load_model_path: Model.load(opt.load_model_path) Cifar_Trainer = Trainer(Model, opt) Cifar_Trainer.train(train_data, val_data) Model.save(opt.save_model_path) results, confusion_matrix, accuracy = Cifar_Trainer.test(test_data) with open(opt.result_file, 'wt', newline='', encoding='utf-8') as csvfile: try:
def main(): print("Local rank: ", hvd.local_rank(), hvd.size()) FLAGS.exp = FLAGS.exp + '_' + FLAGS.divergence logdir = osp.join(FLAGS.logdir, FLAGS.exp) if hvd.rank() == 0: if not osp.exists(logdir): os.makedirs(logdir) logger = TensorBoardOutputFormat(logdir) else: logger = None print("Loading data...") dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale) test_dataset = Cifar10(train=False, rescale=FLAGS.rescale) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) if FLAGS.large_model: model = ResNet32Large( num_channels=channel_num, num_filters=128, train=True) model_dis = ResNet32Large( num_channels=channel_num, num_filters=128, train=True) elif FLAGS.larger_model: model = ResNet32Larger( num_channels=channel_num, num_filters=128) model_dis = ResNet32Larger( num_channels=channel_num, num_filters=128) elif FLAGS.wider_model: model = ResNet32Wider( num_channels=channel_num, num_filters=256) model_dis = ResNet32Wider( num_channels=channel_num, num_filters=256) else: model = ResNet32( num_channels=channel_num, num_filters=128) model_dis = ResNet32( num_channels=channel_num, num_filters=128) print("Done loading...") grad_exp, conjugate_grad_exp = get_divergence_funcs(FLAGS.divergence) data_loader = DataLoader( dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=True, shuffle=True) weights = [model.construct_weights('context_energy'), model_dis.construct_weights('context_dis')] Y = tf.placeholder(shape=(None), dtype=tf.int32) # Varibles to run in training X_SPLIT = tf.split(X, FLAGS.num_gpus) X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus) LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus) LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus) LABEL_SPLIT_INIT = list(LABEL_SPLIT) tower_grads = [] tower_grads_dis = [] tower_grads_l2 = [] tower_grads_dis_l2 = [] optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999) optimizer = hvd.DistributedOptimizer(optimizer) optimizer_dis = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999) optimizer_dis = hvd.DistributedOptimizer(optimizer_dis) for j in range(FLAGS.num_gpus): energy_pos = [ model.forward( X_SPLIT[j], weights[0], label=LABEL_POS_SPLIT[j], stop_at_grad=False)] energy_pos = tf.concat(energy_pos, axis=0) score_pos = [ model_dis.forward( X_SPLIT[j], weights[1], label=LABEL_POS_SPLIT[j], stop_at_grad=False)] score_pos = tf.concat(score_pos, axis=0) print("Building graph...") x_mod = x_orig = X_NOISE_SPLIT[j] x_grads = [] energy_negs = [] loss_energys = [] energy_negs.extend([model.forward(tf.stop_gradient( x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)]) eps_begin = tf.zeros(1) steps = tf.constant(0) c = lambda i, x: tf.less(i, FLAGS.num_steps) def langevin_step(counter, x_mod): x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale) energy_noise = energy_start = tf.concat( [model.forward( x_mod, weights[0], label=LABEL_SPLIT[j], reuse=True, stop_at_grad=False, stop_batch=True)], axis=0) x_grad, label_grad = tf.gradients(energy_noise, [x_mod, LABEL_SPLIT[j]]) energy_noise_old = energy_noise lr = FLAGS.step_lr if FLAGS.proj_norm != 0.0: if FLAGS.proj_norm_type == 'l2': x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm) elif FLAGS.proj_norm_type == 'li': x_grad = tf.clip_by_value( x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) else: print("Other types of projection are not supported!!!") assert False # Clip gradient norm for now if FLAGS.hmc: # Step size should be tuned to get around 65% acceptance def energy(x): return FLAGS.temperature * \ model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True) x_last = hmc(x_mod, 15., 10, energy) else: x_last = x_mod - (lr) * x_grad x_mod = x_last x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale) counter = counter + 1 return counter, x_mod steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod)) energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) x_grad = tf.gradients(energy_eval, [x_mod])[0] x_grads.append(x_grad) energy_negs.append( model.forward( tf.stop_gradient(x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)) score_neg = model_dis.forward( tf.stop_gradient(x_mod), weights[1], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) test_x_mod = x_mod temp = FLAGS.temperature energy_neg = energy_negs[-1] x_off = tf.reduce_mean( tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j])) loss_energy = model.forward( x_mod, weights[0], reuse=True, label=LABEL, stop_grad=True) print("Finished processing loop construction ...") target_vars = {} if FLAGS.cclass or FLAGS.model_cclass: label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0) label_prob = label_sum / tf.reduce_sum(label_sum) label_ent = -tf.reduce_sum(label_prob * tf.math.log(label_prob + 1e-7)) else: label_ent = tf.zeros(1) target_vars['label_ent'] = label_ent if FLAGS.train: loss_dis = - (tf.reduce_mean(grad_exp(score_pos + energy_pos)) - tf.reduce_mean(conjugate_grad_exp(score_neg + energy_neg))) loss_dis = loss_dis + FLAGS.l2_coeff * (tf.reduce_mean(tf.square(score_pos)) + tf.reduce_mean(tf.square(score_neg))) l2_dis = FLAGS.l2_coeff * (tf.reduce_mean(tf.square(score_pos)) + tf.reduce_mean(tf.square(score_neg))) loss_model = tf.reduce_mean(grad_exp(score_pos + energy_pos)) + \ tf.reduce_mean(energy_neg * tf.stop_gradient(conjugate_grad_exp(score_neg + energy_neg))) - \ tf.reduce_mean(energy_neg) * tf.stop_gradient(tf.reduce_mean(conjugate_grad_exp(score_neg + energy_neg))) - \ tf.reduce_mean(conjugate_grad_exp(score_neg + energy_neg)) loss_model = loss_model + FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg))) l2_model = FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg))) print("Started gradient computation...") model_vars = [var for var in tf.trainable_variables() if 'context_energy' in var.name] print("model var number", len(model_vars)) dis_vars = [var for var in tf.trainable_variables() if 'context_dis' in var.name] print("discriminator var number", len(dis_vars)) gvs = optimizer.compute_gradients(loss_model, model_vars) gvs = [(k, v) for (k, v) in gvs if k is not None] tower_grads.append(gvs) gvs = optimizer.compute_gradients(l2_model, model_vars) gvs = [(k, v) for (k, v) in gvs if k is not None] tower_grads_l2.append(gvs) gvs_dis = optimizer_dis.compute_gradients(loss_dis, dis_vars) gvs_dis = [(k, v) for (k, v) in gvs_dis if k is not None] tower_grads_dis.append(gvs_dis) gvs_dis = optimizer_dis.compute_gradients(l2_dis, dis_vars) gvs_dis = [(k, v) for (k, v) in gvs_dis if k is not None] tower_grads_dis_l2.append(gvs_dis) print("Finished applying gradients.") target_vars['total_loss'] = loss_model target_vars['loss_energy'] = loss_energy target_vars['weights'] = weights target_vars['gvs'] = gvs target_vars['X'] = X target_vars['Y'] = Y target_vars['LABEL'] = LABEL target_vars['LABEL_POS'] = LABEL_POS target_vars['X_NOISE'] = X_NOISE target_vars['energy_pos'] = energy_pos target_vars['energy_start'] = energy_negs[0] if len(x_grads) >= 1: target_vars['x_grad'] = x_grads[-1] target_vars['x_grad_first'] = x_grads[0] else: target_vars['x_grad'] = tf.zeros(1) target_vars['x_grad_first'] = tf.zeros(1) target_vars['x_mod'] = x_mod target_vars['x_off'] = x_off target_vars['temp'] = temp target_vars['energy_neg'] = energy_neg target_vars['test_x_mod'] = test_x_mod target_vars['eps_begin'] = eps_begin target_vars['score_neg'] = score_neg target_vars['score_pos'] = score_pos if FLAGS.train: grads_model = average_gradients(tower_grads) train_op_model = optimizer.apply_gradients(grads_model) target_vars['train_op_model'] = train_op_model grads_model_l2 = average_gradients(tower_grads_l2) train_op_model_l2 = optimizer.apply_gradients(grads_model_l2) target_vars['train_op_model_l2'] = train_op_model_l2 grads_model_dis = average_gradients(tower_grads_dis) train_op_dis = optimizer_dis.apply_gradients(grads_model_dis) target_vars['train_op_dis'] = train_op_dis grads_model_dis_l2 = average_gradients(tower_grads_dis_l2) train_op_dis_l2 = optimizer_dis.apply_gradients(grads_model_dis_l2) target_vars['train_op_dis_l2'] = train_op_dis_l2 config = tf.ConfigProto() if hvd.size() > 1: config.gpu_options.visible_device_list = str(hvd.local_rank()) sess = tf.Session(config=config) saver = loader = tf.train.Saver(max_to_keep=500) total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Model has a total of {} parameters".format(total_parameters)) sess.run(tf.global_variables_initializer()) resume_itr = 0 if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0: model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter saver.restore(sess, model_file) # optimistic_restore(sess, model_file) sess.run(hvd.broadcast_global_variables(0)) print("Initializing variables...") print("Start broadcast") print("End broadcast") if FLAGS.train: train(target_vars, saver, sess, logger, data_loader, resume_itr, logdir) test(target_vars, saver, sess, logger, data_loader)
def main(): # Initialize dataset dataset = Cifar10(FLAGS, train=False, rescale=FLAGS.rescale) data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=4, drop_last=False, shuffle=True) logdir = osp.join(FLAGS.logdir, FLAGS.exp) model_path = osp.join(logdir, "model_best.pth") checkpoint = torch.load(model_path) FLAGS_model = checkpoint['FLAGS'] model = ResNetModel(FLAGS_model).eval().cuda() if FLAGS.ema: model.load_state_dict(checkpoint['ema_model_state_dict_0']) else: model.load_state_dict(checkpoint['model_state_dict_0']) print("Finished constructing ancestral sample ...................") e_pos_list = [] for data_corrupt, data, label_gt in tqdm(data_loader): data = data.permute(0, 3, 1, 2).contiguous().cuda() energy = model.forward(data, None) energy = -FLAGS.temperature * energy.squeeze().detach().cpu().numpy() e_pos_list.extend(list(energy)) print(len(e_pos_list)) print("Positive sample probability ", np.mean(e_pos_list), np.std(e_pos_list)) # alr = 0.0065 alr = 10.0 # for i in range(1): tot_weight = 0 for j in tqdm(range(1, FLAGS.pdist + 1)): if j == 1: x_curr = torch.rand(FLAGS.batch_size, 3, 32, 32).cuda() alpha_prev = (j - 1) / FLAGS.pdist alpha_new = j / FLAGS.pdist cweight, x_curr = ancestral_sample(model, x_curr, alpha_prev, alpha_new, FLAGS, FLAGS.batch_size, FLAGS.pdist, temp=FLAGS.temperature, approx_lr=alr) tot_weight = tot_weight + cweight.detach() x_curr = x_curr.detach() tot_weight = tot_weight.detach().cpu().float().numpy() print("Total values of lower value based off forward sampling", np.mean(tot_weight), np.std(tot_weight)) tot_weight = 0 x_curr = x_curr.detach() for j in tqdm(range(FLAGS.pdist, 0, -1)): alpha_new = (j - 1) / FLAGS.pdist alpha_prev = j / FLAGS.pdist cweight, x_curr = ancestral_sample(model, x_curr, alpha_prev, alpha_new, FLAGS, FLAGS.batch_size, FLAGS.pdist, temp=FLAGS.temperature, approx_lr=alr) tot_weight = tot_weight - cweight.detach() x_curr = x_curr.detach() tot_weight = tot_weight.detach().cpu().float().numpy() print("Total values of upper value based off backward sampling", np.mean(tot_weight), np.std(tot_weight))
def main_single(gpu, FLAGS): if FLAGS.slurm: init_distributed_mode(FLAGS) os.environ['MASTER_ADDR'] = FLAGS.master_addr os.environ['MASTER_PORT'] = FLAGS.port rank_idx = FLAGS.node_rank * FLAGS.gpus + gpu world_size = FLAGS.nodes * FLAGS.gpus print("Values of args: ", FLAGS) if world_size > 1: if FLAGS.slurm: dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank_idx) else: dist.init_process_group(backend='nccl', init_method='tcp://localhost:1700', world_size=world_size, rank=rank_idx) if FLAGS.dataset == "cifar10": train_dataset = Cifar10(FLAGS) valid_dataset = Cifar10(FLAGS, train=False, augment=False) test_dataset = Cifar10(FLAGS, train=False, augment=False) elif FLAGS.dataset == "stl": train_dataset = STLDataset(FLAGS) valid_dataset = STLDataset(FLAGS, train=False) test_dataset = STLDataset(FLAGS, train=False) elif FLAGS.dataset == "object": train_dataset = ObjectDataset(FLAGS.cond_idx) valid_dataset = ObjectDataset(FLAGS.cond_idx) test_dataset = ObjectDataset(FLAGS.cond_idx) elif FLAGS.dataset == "imagenet": train_dataset = ImageNet() valid_dataset = ImageNet() test_dataset = ImageNet() elif FLAGS.dataset == "mnist": train_dataset = Mnist(train=True) valid_dataset = Mnist(train=False) test_dataset = Mnist(train=False) elif FLAGS.dataset == "celeba": train_dataset = CelebAHQ(cond_idx=FLAGS.cond_idx) valid_dataset = CelebAHQ(cond_idx=FLAGS.cond_idx) test_dataset = CelebAHQ(cond_idx=FLAGS.cond_idx) elif FLAGS.dataset == "lsun": train_dataset = LSUNBed(cond_idx=FLAGS.cond_idx) valid_dataset = LSUNBed(cond_idx=FLAGS.cond_idx) test_dataset = LSUNBed(cond_idx=FLAGS.cond_idx) else: assert False train_dataloader = DataLoader(train_dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size, shuffle=True, drop_last=True) valid_dataloader = DataLoader(valid_dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size, shuffle=True, drop_last=True) test_dataloader = DataLoader(test_dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size, shuffle=True, drop_last=True) FLAGS_OLD = FLAGS logdir = osp.join(FLAGS.logdir, FLAGS.exp) best_inception = 0.0 if FLAGS.resume_iter != 0: model_path = osp.join(logdir, "model_{}.pth".format(FLAGS.resume_iter)) checkpoint = torch.load(model_path) best_inception = checkpoint['best_inception'] FLAGS = checkpoint['FLAGS'] FLAGS.resume_iter = FLAGS_OLD.resume_iter FLAGS.nodes = FLAGS_OLD.nodes FLAGS.gpus = FLAGS_OLD.gpus FLAGS.node_rank = FLAGS_OLD.node_rank FLAGS.master_addr = FLAGS_OLD.master_addr FLAGS.train = FLAGS_OLD.train FLAGS.num_steps = FLAGS_OLD.num_steps FLAGS.step_lr = FLAGS_OLD.step_lr FLAGS.batch_size = FLAGS_OLD.batch_size FLAGS.ensembles = FLAGS_OLD.ensembles FLAGS.kl_coeff = FLAGS_OLD.kl_coeff FLAGS.repel_im = FLAGS_OLD.repel_im FLAGS.save_interval = FLAGS_OLD.save_interval for key in dir(FLAGS): if "__" not in key: FLAGS_OLD[key] = getattr(FLAGS, key) FLAGS = FLAGS_OLD if FLAGS.dataset == "cifar10": model_fn = ResNetModel elif FLAGS.dataset == "stl": model_fn = ResNetModel elif FLAGS.dataset == "object": model_fn = CelebAModel elif FLAGS.dataset == "mnist": model_fn = MNISTModel elif FLAGS.dataset == "celeba": model_fn = CelebAModel elif FLAGS.dataset == "lsun": model_fn = CelebAModel elif FLAGS.dataset == "imagenet": model_fn = ImagenetModel else: assert False models = [model_fn(FLAGS).train() for i in range(FLAGS.ensembles)] models_ema = [model_fn(FLAGS).train() for i in range(FLAGS.ensembles)] torch.cuda.set_device(gpu) if FLAGS.cuda: models = [model.cuda(gpu) for model in models] model_ema = [model_ema.cuda(gpu) for model_ema in models_ema] if FLAGS.gpus > 1: sync_model(models) parameters = [] for model in models: parameters.extend(list(model.parameters())) optimizer = Adam(parameters, lr=FLAGS.lr, betas=(0.0, 0.9), eps=1e-8) ema_model(models, models_ema, mu=0.0) logger = TensorBoardOutputFormat(logdir) it = FLAGS.resume_iter if not osp.exists(logdir): os.makedirs(logdir) checkpoint = None if FLAGS.resume_iter != 0: model_path = osp.join(logdir, "model_{}.pth".format(FLAGS.resume_iter)) checkpoint = torch.load(model_path) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) for i, (model, model_ema) in enumerate(zip(models, models_ema)): model.load_state_dict(checkpoint['model_state_dict_{}'.format(i)]) model_ema.load_state_dict(checkpoint['ema_model_state_dict_{}'.format(i)]) print("New Values of args: ", FLAGS) pytorch_total_params = sum([p.numel() for p in model.parameters() if p.requires_grad]) print("Number of parameters for models", pytorch_total_params) train(models, models_ema, optimizer, logger, train_dataloader, FLAGS.resume_iter, logdir, FLAGS, gpu, best_inception)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--policy-lr", type=float, default=5e-4, help="learning rate") parser.add_argument("--num-gpu-core", type=int, default=1, help="Number of GPU cores to use") parser.add_argument("--num-cpu-core", type=int, default=4, help="Number of CPU cores to use") parser.add_argument("--inter-op-parallelism-threads", type=int, default=4, help="0 means the system picks an appropriate number.") parser.add_argument("--intra-op-parallelism-threads", type=int, default=4, help="0 means the system picks an appropriate number.") parser.add_argument("--use-fp16", default=False, action='store_true', help="whether use float16 as default") args = parser.parse_args() print("Available devices:") print(device_lib.list_local_devices()) config_dict = load_cfg_from_yaml('config/CIFAR10/R_50.yaml') config_dict = merge_cfg_from_args(config_dict, args) lr = config_dict["policy_lr"] intra_op_threads = config_dict['intra_op_parallelism_threads'] inter_op_threads = config_dict['inter_op_parallelism_threads'] use_fp16 = config_dict['use_fp16'] cpu_num = config_dict["num_cpu_core"] gpu_num = config_dict["num_gpu_core"] MAX_EPOCH = config_dict['TRAIN']['MAX_EPOCH'] device_id = -1 config = tf.ConfigProto( # device_count limits the number of CPUs being used, not the number of cores or threads. device_count={'CPU': cpu_num, 'GPU': gpu_num}, inter_op_parallelism_threads=cpu_num, # parallel without each operation, i.e., reduce_sum intra_op_parallelism_threads=cpu_num, # parallel between multiple operations log_device_placement=True, # log the GPU or CPU device that is assigned to an operation allow_soft_placement=True, # use soft constraints for the device placement ) config.gpu_options.allow_growth = True config.gpu_options.per_process_gpu_memory_fraction = 0.9 with tf.Session(config=config) as sess: # input for resnet tf_x = tf.placeholder(dtype=data_type(use_fp16), shape=[None, 32, 32, 3], name='tf_x') tf_y = tf.placeholder(dtype=data_type(use_fp16), shape=[None, 10], name='tf_y') # output of resnet with slim.arg_scope(resnet_v2.resnet_arg_scope()): with tf.device(next_device(device_id, config_dict, use_cpu=False)): resnet_out, end_points = resnet_v2.resnet_v2_50(tf_x, num_classes=10, is_training=False) # loss cross_entropy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=resnet_out, labels=tf_y)) # tf_loss_summary = tf.summary.scalar('loss', cross_entropy_loss) # optimizer tf_lr = tf.placeholder(dtype=data_type(use_fp16), shape=None, name='learning_rate') # more flexible learning rate optimizer = tf.train.AdamOptimizer(tf_lr) grads_and_vars = optimizer.compute_gradients(cross_entropy_loss) train_step = optimizer.minimize(cross_entropy_loss) correct_prediction = tf.equal(tf.argmax(resnet_out, 1), tf.argmax(tf_y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) # tf_accuracy_summary = tf.summary.scalar('accuracy', accuracy) # initialize variables sess.run(tf.global_variables_initializer()) # Name scope allows you to group various summaries together # Summaries having the same name_scope will be displayed on the same row with tf.name_scope('performance'): # Summaries need to be displayed # Whenever you need to record the loss, feed the mean loss to this placeholder tf_loss_ph = tf.placeholder(dtype=data_type(use_fp16), shape=None, name='loss_summary') # Create a scalar summary object for the loss so it can be displayed tf_loss_summary = tf.summary.scalar('loss', tf_loss_ph) # Whenever you need to record the loss, feed the mean test accuracy to this placeholder tf_accuracy_ph = tf.placeholder(dtype=data_type(use_fp16), shape=None, name='accuracy_summary') # Create a scalar summary object for the accuracy so it can be displayed tf_accuracy_summary = tf.summary.scalar('accuracy', tf_accuracy_ph) # Gradient norm summary # tf_gradnorm_summary: this calculates the l2 norm of the gradients of the last layer # of your neural network. Gradient norm is a good indicator of whether the weights of # the neural network are being properly updated. A too small gradient norm can indicate # vanishing gradient or a too large gradient can imply exploding gradient phenomenon. t_layer = len(grads_and_vars) - 2 # index of the last layer for i_layer, (g, v) in enumerate(grads_and_vars): if i_layer == t_layer: with tf.name_scope('gradients_norm'): tf_last_grad_norm = tf.sqrt(tf.reduce_mean(g ** 2)) tf_gradnorm_summary = tf.summary.scalar('grad_norm', tf_last_grad_norm) break # A summary for each weight in each layer of ResNet50 all_summaries = [] for weight_name in end_points: try: name_scope = weight_name.split("bottleneck_v2/")[0] + weight_name.split("bottleneck_v2/")[1] except: name_scope = weight_name.split("bottleneck_v2/")[0] with tf.name_scope(name_scope): weight = end_points[weight_name] # Create a scalar summary object for the loss so it can be displayed tf_w_hist = tf.summary.histogram('weights_hist', tf.reshape(weight, [-1])) all_summaries.append([tf_w_hist]) # Merge all parameter histogram summaries together tf_weight_summaries = tf.summary.merge(all_summaries) # Merge all summaries together # the following two statements are equal # [1] merge_op = tf.summary.merge_all() # [2] merge_op = tf.summary.merge([tf_loss_summary, tf_accuracy_summary, tf_gradnorm_summary, ...]) merge_op = tf.summary.merge([tf_loss_summary, tf_accuracy_summary]) # separate tensorboard, i.e., one log file, one folder. train_writer = tf.summary.FileWriter('logs/train', sess.graph) test_writer = tf.summary.FileWriter('logs/test', sess.graph) # get cifar10 dataset cifar10_data = Cifar10('dataset/CIFAR10/', config=config_dict) test_images, test_labels = cifar10_data.test_data() # training start_time = time.time() batch_counter = 0 while cifar10_data.epochs_completed < MAX_EPOCH: batch_xs, batch_ys = cifar10_data.next_train_batch() batch_counter += 1 sess.run(train_step, feed_dict={tf_x: batch_xs, tf_y: batch_ys, tf_lr: lr}) # calculate the train_accuracy for one batch if batch_counter % 100 == 0: train_accuracy, loss = sess.run( [accuracy, cross_entropy_loss], feed_dict={tf_x: batch_xs, tf_y: batch_ys, tf_lr: lr}) train_summary = sess.run( merge_op, feed_dict={tf_loss_ph: loss, tf_accuracy_ph: train_accuracy}) train_writer.add_summary(train_summary, batch_counter) print("----- epoch {} batch {} training accuracy {} loss {}". format(cifar10_data.epochs_completed, batch_counter, train_accuracy, loss)) end_time = time.time() print("time: {}".format(end_time - start_time)) start_time = end_time # calculate the gradient norm summary and weight histogram if batch_counter % 100 == 0: gn_summary, wb_summary = sess.run( [tf_gradnorm_summary, tf_weight_summaries], feed_dict={tf_x: batch_xs, tf_y: batch_ys, tf_lr: lr}) train_writer.add_summary(gn_summary, batch_counter) train_writer.add_summary(wb_summary, batch_counter) # calculate the test_accuracy if batch_counter % 1000 == 0: # Test_accuracy test_accuracy, test_loss = sess.run( [accuracy, cross_entropy_loss], feed_dict={tf_x: test_images, tf_y: test_labels, tf_lr: lr}) test_summary = sess.run( merge_op, feed_dict={tf_loss_ph: test_loss, tf_accuracy_ph: test_accuracy}) test_writer.add_summary(test_summary, batch_counter) print("----- test accuracy {} test loss {}".format(test_accuracy, test_loss)) # Overall test accuracy overall_accuracy = accuracy.eval(feed_dict={tf_x: test_images, tf_y: test_labels, tf_lr: lr}) print("\nOverall test accuracy {}".format(overall_accuracy)) save_cfg_to_yaml(config_dict, 'logs/current_config.yaml')
def compute_inception(model): size = FLAGS.im_number num_steps = size // 1000 images = [] test_ims = [] if FLAGS.dataset == "cifar10": test_dataset = Cifar10(FLAGS) elif FLAGS.dataset == "celeba": test_dataset = CelebAHQ() elif FLAGS.dataset == "mnist": test_dataset = Mnist(train=True) test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=4, shuffle=True, drop_last=False) if FLAGS.dataset == "cifar10": for data_corrupt, data, label_gt in tqdm(test_dataloader): data = data.numpy() test_ims.extend(list(rescale_im(data))) if len(test_ims) > 10000: break elif FLAGS.dataset == "mnist": for data_corrupt, data, label_gt in tqdm(test_dataloader): data = data.numpy() test_ims.extend(list(np.tile(rescale_im(data), (1, 1, 3)))) if len(test_ims) > 10000: break test_ims = test_ims[:10000] classes = 10 print(FLAGS.batch_size) data_buffer = None for j in range(num_steps): itr = int(1000 / 500 * FLAGS.repeat_scale) if data_buffer is None: data_buffer = InceptionReplayBuffer(1000) curr_index = 0 identity = np.eye(classes) if FLAGS.dataset == "celeba": n = 128 c = 3 elif FLAGS.dataset == "mnist": n = 28 c = 1 else: n = 32 c = 3 for i in tqdm(range(itr)): noise_scale = [1] if len(data_buffer) < 1000: x_init = np.random.uniform(0, 1, (FLAGS.batch_size, c, n, n)) label = np.random.randint(0, classes, (FLAGS.batch_size)) x_init = torch.Tensor(x_init).cuda() label = identity[label] label = torch.Tensor(label).cuda() x_new, _ = gen_image(label, FLAGS, model, x_init, FLAGS.num_steps) x_new = x_new.detach().cpu().numpy() label = label.detach().cpu().numpy() data_buffer.add(x_new, label) else: if i < itr - FLAGS.nomix: (x_init, label), idx = data_buffer.sample( FLAGS.batch_size, transform=FLAGS.transform) else: if FLAGS.dataset == "celeba": n = 20 else: n = 2 ix = i % n # for i in range(n): start_idx = (1000 // n) * ix end_idx = (1000 // n) * (ix + 1) (x_init, label) = data_buffer._encode_sample( list(range(start_idx, end_idx)), transform=False) idx = list(range(start_idx, end_idx)) x_init = torch.Tensor(x_init).cuda() label = torch.Tensor(label).cuda() x_new, energy = gen_image(label, FLAGS, model, x_init, FLAGS.num_steps) energy = energy.cpu().detach().numpy() x_new = x_new.cpu().detach().numpy() label = label.cpu().detach().numpy() data_buffer.set_elms(idx, x_new, label) if FLAGS.im_number != 50000: print(np.mean(energy), np.std(energy)) curr_index += 1 ims = np.array(data_buffer._storage[:1000]) ims = rescale_im(ims).transpose((0, 2, 3, 1)) if FLAGS.dataset == "mnist": ims = np.tile(ims, (1, 1, 1, 3)) images.extend(list(ims)) random.shuffle(images) saveim = osp.join('sandbox_cachedir', FLAGS.exp, "test{}.png".format(FLAGS.idx)) if FLAGS.dataset == "cifar10": rix = np.random.permutation(1000)[:100] ims = ims[rix] im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose( (0, 2, 1, 3, 4)).reshape((320, 320, 3)) imsave(saveim, im_panel) print("Saved image!!!!") splits = max(1, len(images) // 5000) score, std = get_inception_score(images, splits=splits) print("Inception score of {} with std of {}".format(score, std)) # FID score n = min(len(images), len(test_ims)) fid = get_fid_score(images, test_ims) print("FID of score {}".format(fid)) elif FLAGS.dataset == "mnist": # ims = ims[:100] # im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((320, 320, 3)) # imsave(saveim, im_panel) ims = ims[:100] im_panel = ims.reshape((10, 10, 28, 28, 3)).transpose( (0, 2, 1, 3, 4)).reshape((280, 280, 3)) imsave(saveim, im_panel) print("Saved image!!!!") splits = max(1, len(images) // 5000) # score, std = get_inception_score(images, splits=splits) # print("Inception score of {} with std of {}".format(score, std)) # FID score n = min(len(images), len(test_ims)) fid = get_fid_score(images, test_ims) print("FID of score {}".format(fid)) elif FLAGS.dataset == "celeba": ims = ims[:25] im_panel = ims.reshape((5, 5, 128, 128, 3)).transpose( (0, 2, 1, 3, 4)).reshape((5 * 128, 5 * 128, 3)) imsave(saveim, im_panel)