def main(): print("Local rank: ", hvd.local_rank(), hvd.size()) logdir = osp.join(FLAGS.logdir, FLAGS.exp) if hvd.rank() == 0: if not osp.exists(logdir): os.makedirs(logdir) logger = TensorBoardOutputFormat(logdir) else: logger = None LABEL = None print("Loading data...") if FLAGS.dataset == 'cifar10': dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale) test_dataset = Cifar10(train=False, rescale=FLAGS.rescale) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) if FLAGS.large_model: model = ResNet32Large(num_channels=channel_num, num_filters=128, train=True) elif FLAGS.larger_model: model = ResNet32Larger(num_channels=channel_num, num_filters=128) elif FLAGS.wider_model: model = ResNet32Wider(num_channels=channel_num, num_filters=192) else: model = ResNet32(num_channels=channel_num, num_filters=128) elif FLAGS.dataset == 'imagenet': dataset = Imagenet(train=True) test_dataset = Imagenet(train=False) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32) model = ResNet32Wider(num_channels=channel_num, num_filters=256) elif FLAGS.dataset == 'imagenetfull': channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32) model = ResNet128(num_channels=channel_num, num_filters=64) elif FLAGS.dataset == 'mnist': dataset = Mnist(rescale=FLAGS.rescale) test_dataset = dataset channel_num = 1 X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32) X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) model = MnistNet(num_channels=channel_num, num_filters=FLAGS.num_filters) elif FLAGS.dataset == 'dsprites': dataset = DSprites(cond_shape=FLAGS.cond_shape, cond_size=FLAGS.cond_size, cond_pos=FLAGS.cond_pos, cond_rot=FLAGS.cond_rot) test_dataset = dataset channel_num = 1 X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) if FLAGS.dpos_only: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.dsize_only: LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32) elif FLAGS.drot_only: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.cond_size: LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32) elif FLAGS.cond_shape: LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32) elif FLAGS.cond_pos: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.cond_rot: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) else: LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32) model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters, cond_size=FLAGS.cond_size, cond_shape=FLAGS.cond_shape, cond_pos=FLAGS.cond_pos, cond_rot=FLAGS.cond_rot) print("Done loading...") if FLAGS.dataset == "imagenetfull": # In the case of full imagenet, use custom_tensorflow dataloader data_loader = TFImagenetLoader('train', FLAGS.batch_size, hvd.rank(), hvd.size(), rescale=FLAGS.rescale) else: data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=True, shuffle=True) batch_size = FLAGS.batch_size weights = [model.construct_weights('context_0')] Y = tf.placeholder(shape=(None), dtype=tf.int32) # Varibles to run in training X_SPLIT = tf.split(X, FLAGS.num_gpus) X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus) LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus) LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus) LABEL_SPLIT_INIT = list(LABEL_SPLIT) tower_grads = [] tower_gen_grads = [] x_mod_list = [] optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999) optimizer = hvd.DistributedOptimizer(optimizer) for j in range(FLAGS.num_gpus): if FLAGS.model_cclass: ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus label_tensor = tf.Variable(tf.convert_to_tensor(np.reshape( np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)), (FLAGS.batch_size * 10, 10)), dtype=tf.float32), trainable=False, dtype=tf.float32) x_split = tf.tile( tf.reshape(X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)), (1, 10, 1, 1, 1)) x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3)) energy_pos = model.forward(x_split, weights[0], label=label_tensor, stop_at_grad=False) energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10)) energy_partition_est = tf.reduce_logsumexp(energy_pos_full, axis=1, keepdims=True) uniform = tf.random_uniform(tf.shape(energy_pos_full)) label_tensor = tf.argmax(-energy_pos_full - tf.log(-tf.log(uniform)) - energy_partition_est, axis=1) label = tf.one_hot(label_tensor, 10, dtype=tf.float32) label = tf.Print(label, [label_tensor, energy_pos_full]) LABEL_SPLIT[j] = label energy_pos = tf.concat(energy_pos, axis=0) else: energy_pos = [ model.forward(X_SPLIT[j], weights[0], label=LABEL_POS_SPLIT[j], stop_at_grad=False) ] energy_pos = tf.concat(energy_pos, axis=0) print("Building graph...") x_mod = x_orig = X_NOISE_SPLIT[j] x_grads = [] energy_negs = [] loss_energys = [] energy_negs.extend([ model.forward(tf.stop_gradient(x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) ]) eps_begin = tf.zeros(1) steps = tf.constant(0) c = lambda i, x: tf.less(i, FLAGS.num_steps) def langevin_step(counter, x_mod): x_mod = x_mod + tf.random_normal( tf.shape(x_mod), mean=0.0, stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale) energy_noise = energy_start = tf.concat([ model.forward(x_mod, weights[0], label=LABEL_SPLIT[j], reuse=True, stop_at_grad=False, stop_batch=True) ], axis=0) x_grad, label_grad = tf.gradients(FLAGS.temperature * energy_noise, [x_mod, LABEL_SPLIT[j]]) energy_noise_old = energy_noise lr = FLAGS.step_lr if FLAGS.proj_norm != 0.0: if FLAGS.proj_norm_type == 'l2': x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm) elif FLAGS.proj_norm_type == 'li': x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) else: print("Other types of projection are not supported!!!") assert False # Clip gradient norm for now if FLAGS.hmc: # Step size should be tuned to get around 65% acceptance def energy(x): return FLAGS.temperature * \ model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True) x_last = hmc(x_mod, 15., 10, energy) else: x_last = x_mod - (lr) * x_grad x_mod = x_last x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale) counter = counter + 1 return counter, x_mod steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod)) energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0] x_grads.append(x_grad) energy_negs.append( model.forward(tf.stop_gradient(x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)) test_x_mod = x_mod temp = FLAGS.temperature energy_neg = energy_negs[-1] x_off = tf.reduce_mean( tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j])) loss_energy = model.forward(x_mod, weights[0], reuse=True, label=LABEL, stop_grad=True) print("Finished processing loop construction ...") target_vars = {} if FLAGS.cclass or FLAGS.model_cclass: label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0) label_prob = label_sum / tf.reduce_sum(label_sum) label_ent = -tf.reduce_sum( label_prob * tf.math.log(label_prob + 1e-7)) else: label_ent = tf.zeros(1) target_vars['label_ent'] = label_ent if FLAGS.train: if FLAGS.objective == 'logsumexp': pos_term = temp * energy_pos energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg)) coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced)) norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4 pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'cd': pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = -tf.reduce_mean(temp * energy_neg) loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'softplus': loss_ml = FLAGS.ml_coeff * \ tf.nn.softplus(temp * (energy_pos - energy_neg)) loss_total = tf.reduce_mean(loss_ml) if not FLAGS.zero_kl: loss_total = loss_total + tf.reduce_mean(loss_energy) loss_total = loss_total + \ FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg)))) print("Started gradient computation...") gvs = optimizer.compute_gradients(loss_total) gvs = [(k, v) for (k, v) in gvs if k is not None] print("Applying gradients...") tower_grads.append(gvs) print("Finished applying gradients.") target_vars['loss_ml'] = loss_ml target_vars['total_loss'] = loss_total target_vars['loss_energy'] = loss_energy target_vars['weights'] = weights target_vars['gvs'] = gvs target_vars['X'] = X target_vars['Y'] = Y target_vars['LABEL'] = LABEL target_vars['LABEL_POS'] = LABEL_POS target_vars['X_NOISE'] = X_NOISE target_vars['energy_pos'] = energy_pos target_vars['energy_start'] = energy_negs[0] if len(x_grads) >= 1: target_vars['x_grad'] = x_grads[-1] target_vars['x_grad_first'] = x_grads[0] else: target_vars['x_grad'] = tf.zeros(1) target_vars['x_grad_first'] = tf.zeros(1) target_vars['x_mod'] = x_mod target_vars['x_off'] = x_off target_vars['temp'] = temp target_vars['energy_neg'] = energy_neg target_vars['test_x_mod'] = test_x_mod target_vars['eps_begin'] = eps_begin if FLAGS.train: grads = average_gradients(tower_grads) train_op = optimizer.apply_gradients(grads) target_vars['train_op'] = train_op config = tf.ConfigProto() if hvd.size() > 1: config.gpu_options.visible_device_list = str(hvd.local_rank()) sess = tf.Session(config=config) saver = loader = tf.train.Saver(max_to_keep=30, keep_checkpoint_every_n_hours=6) total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Model has a total of {} parameters".format(total_parameters)) sess.run(tf.global_variables_initializer()) resume_itr = 0 if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0: model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter # saver.restore(sess, model_file) optimistic_restore(sess, model_file) sess.run(hvd.broadcast_global_variables(0)) print("Initializing variables...") print("Start broadcast") print("End broadcast") if FLAGS.train: print("Training phase") train(target_vars, saver, sess, logger, data_loader, resume_itr, logdir) print("Testing phase") test(target_vars, saver, sess, logger, data_loader)
def main(model_list): if FLAGS.dataset == "imagenetfull": model = ResNet128(num_filters=64) elif FLAGS.dataset == "imagenet": model = ResNet32Wider( num_channels=3, num_filters=256) elif FLAGS.large_model: model = ResNet32Large(num_filters=128) elif FLAGS.larger_model: model = ResNet32Larger(num_filters=hidden_dim) elif FLAGS.wider_model: model = ResNet32Wider(num_filters=256, train=False) else: model = ResNet32(num_filters=128) # config = tf.ConfigProto() sess = tf.InteractiveSession() logdir = osp.join(FLAGS.logdir, FLAGS.exp) weights = [] for i, model_num in enumerate(model_list): weight = model.construct_weights('context_{}'.format(i)) initialize() save_file = osp.join(logdir, 'model_{}'.format(model_num)) v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(i)) v_map = {(v.name.replace('context_{}'.format(i), 'context_energy')[:-2]): v for v in v_list} saver = tf.train.Saver(v_map) try: saver.restore(sess, save_file) except: optimistic_remap_restore(sess, save_file, i) weights.append(weight) if FLAGS.dataset == "imagenetfull": X_START = tf.placeholder(shape=(None, 128, 128, 3), dtype = tf.float32) else: X_START = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32) if FLAGS.dataset == "cifar10": Y_GT = tf.placeholder(shape=(None, 10), dtype = tf.float32) else: Y_GT = tf.placeholder(shape=(None, 1000), dtype = tf.float32) NOISE_SCALE = tf.placeholder(shape=1, dtype=tf.float32) X_finals = [] # Seperate loops for weight in weights: X = X_START steps = tf.constant(0) c = lambda i, x: tf.less(i, FLAGS.num_steps) def langevin_step(counter, X): scale_rate = 1 X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=scale_rate * FLAGS.noise_scale * NOISE_SCALE) energy_noise = model.forward(X, weight, label=Y_GT, reuse=True) x_grad = tf.gradients(energy_noise, [X])[0] if FLAGS.proj_norm != 0.0: x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) X = X - FLAGS.step_lr * x_grad * scale_rate X = tf.clip_by_value(X, 0, 1) counter = counter + 1 return counter, X steps, X = tf.while_loop(c, langevin_step, (steps, X)) energy_noise = model.forward(X, weight, label=Y_GT, reuse=True) X_final = X X_finals.append(X_final) target_vars = {} target_vars['X_START'] = X_START target_vars['Y_GT'] = Y_GT target_vars['X_finals'] = X_finals target_vars['NOISE_SCALE'] = NOISE_SCALE target_vars['energy_noise'] = energy_noise compute_inception(sess, target_vars)
def main(): if FLAGS.dataset == "cifar10": dataset = Cifar10(train=True, noise=False) test_dataset = Cifar10(train=False, noise=False) else: dataset = Imagenet(train=True) test_dataset = Imagenet(train=False) if FLAGS.svhn: dataset = Svhn(train=True) test_dataset = Svhn(train=False) if FLAGS.task == 'latent': dataset = DSprites() test_dataset = dataset dataloader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=True) test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=True) hidden_dim = 128 if FLAGS.large_model: model = ResNet32Large(num_filters=hidden_dim) elif FLAGS.larger_model: model = ResNet32Larger(num_filters=hidden_dim) elif FLAGS.wider_model: if FLAGS.dataset == 'imagenet': model = ResNet32Wider(num_filters=196, train=False) else: model = ResNet32Wider(num_filters=256, train=False) else: model = ResNet32(num_filters=hidden_dim) if FLAGS.task == 'latent': model = DspritesNet() weights = model.construct_weights('context_{}'.format(0)) total_parameters = 0 for variable in tf.compat.v1.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Model has a total of {} parameters".format(total_parameters)) config = tf.compat.v1.ConfigProto() sess = tf.compat.v1.InteractiveSession() if FLAGS.task == 'latent': X = tf.compat.v1.placeholder(shape=(None, 64, 64), dtype=tf.float32) else: X = tf.compat.v1.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) if FLAGS.dataset == "cifar10": Y = tf.compat.v1.placeholder(shape=(None, 10), dtype=tf.float32) Y_GT = tf.compat.v1.placeholder(shape=(None, 10), dtype=tf.float32) elif FLAGS.dataset == "imagenet": Y = tf.compat.v1.placeholder(shape=(None, 1000), dtype=tf.float32) Y_GT = tf.compat.v1.placeholder(shape=(None, 1000), dtype=tf.float32) target_vars = {'X': X, 'Y': Y, 'Y_GT': Y_GT} if FLAGS.task == 'label': construct_label(weights, X, Y, Y_GT, model, target_vars) elif FLAGS.task == 'labelfinetune': construct_finetune_label( weights, X, Y, Y_GT, model, target_vars, ) elif FLAGS.task == 'energyeval' or FLAGS.task == 'mixenergy': construct_energy(weights, X, Y, Y_GT, model, target_vars) elif FLAGS.task == 'anticorrupt' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'crossclass' or FLAGS.task == 'cycleclass' or FLAGS.task == 'democlass' or FLAGS.task == 'nearestneighbor': construct_steps(weights, X, Y_GT, model, target_vars) elif FLAGS.task == 'latent': construct_latent(weights, X, Y_GT, model, target_vars) sess.run(tf.compat.v1.global_variables_initializer()) saver = loader = tf.compat.v1.train.Saver(max_to_keep=10) savedir = osp.join('cachedir', FLAGS.exp) logdir = osp.join(FLAGS.logdir, FLAGS.exp) if not osp.exists(logdir): os.makedirs(logdir) initialize() if FLAGS.resume_iter != -1: model_file = osp.join(savedir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter if FLAGS.task == 'label' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'labelfinetune' or FLAGS.task == "energyeval" or FLAGS.task == "crossclass" or FLAGS.task == "mixenergy": optimistic_restore(sess, model_file) # saver.restore(sess, model_file) else: # optimistic_restore(sess, model_file) saver.restore(sess, model_file) if FLAGS.task == 'label': if FLAGS.labelgrid: vals = [] if FLAGS.lnorm == -1: for i in range(31): accuracies = label(dataloader, test_dataloader, target_vars, sess, l1val=i) vals.append(accuracies) elif FLAGS.lnorm == 2: for i in range(0, 100, 5): accuracies = label(dataloader, test_dataloader, target_vars, sess, l2val=i) vals.append(accuracies) np.save("result_{}_{}.npy".format(FLAGS.lnorm, FLAGS.exp), vals) else: label(dataloader, test_dataloader, target_vars, sess) elif FLAGS.task == 'labelfinetune': labelfinetune(dataloader, test_dataloader, target_vars, sess, savedir, saver, l1val=FLAGS.lival, l2val=FLAGS.l2val) elif FLAGS.task == 'energyeval': energyeval(dataloader, test_dataloader, target_vars, sess) elif FLAGS.task == 'mixenergy': energyevalmix(dataloader, test_dataloader, target_vars, sess) elif FLAGS.task == 'anticorrupt': anticorrupt(test_dataloader, weights, model, target_vars, logdir, sess) elif FLAGS.task == 'boxcorrupt': # boxcorrupt(test_dataloader, weights, model, target_vars, logdir, sess) boxcorrupt(test_dataloader, dataloader, weights, model, target_vars, logdir, sess) elif FLAGS.task == 'crossclass': crossclass(test_dataloader, weights, model, target_vars, logdir, sess) elif FLAGS.task == 'cycleclass': cycleclass(test_dataloader, weights, model, target_vars, logdir, sess) elif FLAGS.task == 'democlass': democlass(test_dataloader, weights, model, target_vars, logdir, sess) elif FLAGS.task == 'nearestneighbor': # print(dir(dataset)) # print(type(dataset)) nearest_neighbor(dataset.data.train_data / 255, sess, target_vars, logdir) elif FLAGS.task == 'latent': latent(test_dataloader, weights, model, target_vars, sess)
def main(): # Initialize dataset if FLAGS.dataset == 'cifar10': dataset = Cifar10(train=False, rescale=FLAGS.rescale) channel_num = 3 dim_input = 32 * 32 * 3 elif FLAGS.dataset == 'imagenet': dataset = ImagenetClass() channel_num = 3 dim_input = 64 * 64 * 3 elif FLAGS.dataset == 'mnist': dataset = Mnist(train=False, rescale=FLAGS.rescale) channel_num = 1 dim_input = 28 * 28 * 1 elif FLAGS.dataset == 'dsprites': dataset = DSprites() channel_num = 1 dim_input = 64 * 64 * 1 elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss': dataset = Box2D() dim_output = 1 data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=False, shuffle=True) if FLAGS.dataset == 'mnist': model = MnistNet(num_channels=channel_num) elif FLAGS.dataset == 'cifar10': if FLAGS.large_model: model = ResNet32Large(num_filters=128) elif FLAGS.wider_model: model = ResNet32Wider(num_filters=192) else: model = ResNet32(num_channels=channel_num, num_filters=128) elif FLAGS.dataset == 'dsprites': model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters) weights = model.construct_weights('context_{}'.format(0)) config = tf.ConfigProto() sess = tf.Session(config=config) saver = loader = tf.train.Saver(max_to_keep=10) sess.run(tf.global_variables_initializer()) logdir = osp.join(FLAGS.logdir, FLAGS.exp) model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter if FLAGS.resume_iter != "-1": optimistic_restore(sess, model_file) else: print("WARNING, YOU ARE NOT LOADING A SAVE FILE") # saver.restore(sess, model_file) chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample( model, weights, FLAGS.batch_size, temp=FLAGS.temperature) print("Finished constructing ancestral sample ...................") if FLAGS.dataset != "gauss": comb_weights_cum = [] batch_size = tf.shape(x_init)[0] label_tiled = tf.tile(label_default, (batch_size, 1)) e_compute = -FLAGS.temperature * model.forward( x_init, weights, label=label_tiled) e_pos_list = [] for data_corrupt, data, label_gt in tqdm(data_loader): e_pos = sess.run([e_compute], {x_init: data})[0] e_pos_list.extend(list(e_pos)) print(len(e_pos_list)) print("Positive sample probability ", np.mean(e_pos_list), np.std(e_pos_list)) if FLAGS.dataset == "2d": alr = 0.0045 elif FLAGS.dataset == "gauss": alr = 0.0085 elif FLAGS.dataset == "mnist": alr = 0.0065 #90 alr = 0.0035 else: # alr = 0.0125 if FLAGS.rescale == 8: alr = 0.0085 else: alr = 0.0045 # for i in range(1): tot_weight = 0 for j in tqdm(range(1, FLAGS.pdist + 1)): if j == 1: if FLAGS.dataset == "cifar10": x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3)) elif FLAGS.dataset == "gauss": x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim)) elif FLAGS.dataset == "mnist": x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28)) else: x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 2)) alpha_prev = (j - 1) / FLAGS.pdist alpha_new = j / FLAGS.pdist cweight, x_curr = sess.run( [chain_weights, x], { a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5**(2.5 * -alpha_prev)) }) tot_weight = tot_weight + cweight print("Total values of lower value based off forward sampling", np.mean(tot_weight), np.std(tot_weight)) tot_weight = 0 for j in tqdm(range(FLAGS.pdist, 0, -1)): alpha_new = (j - 1) / FLAGS.pdist alpha_prev = j / FLAGS.pdist cweight, x_curr = sess.run( [chain_weights, x], { a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5**(2.5 * -alpha_prev)) }) tot_weight = tot_weight - cweight print("Total values of upper value based off backward sampling", np.mean(tot_weight), np.std(tot_weight))
def main(): print("Local rank: ", hvd.local_rank(), hvd.size()) FLAGS.exp = FLAGS.exp + '_' + FLAGS.divergence logdir = osp.join(FLAGS.logdir, FLAGS.exp) if hvd.rank() == 0: if not osp.exists(logdir): os.makedirs(logdir) logger = TensorBoardOutputFormat(logdir) else: logger = None print("Loading data...") dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale) test_dataset = Cifar10(train=False, rescale=FLAGS.rescale) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) if FLAGS.large_model: model = ResNet32Large( num_channels=channel_num, num_filters=128, train=True) model_dis = ResNet32Large( num_channels=channel_num, num_filters=128, train=True) elif FLAGS.larger_model: model = ResNet32Larger( num_channels=channel_num, num_filters=128) model_dis = ResNet32Larger( num_channels=channel_num, num_filters=128) elif FLAGS.wider_model: model = ResNet32Wider( num_channels=channel_num, num_filters=256) model_dis = ResNet32Wider( num_channels=channel_num, num_filters=256) else: model = ResNet32( num_channels=channel_num, num_filters=128) model_dis = ResNet32( num_channels=channel_num, num_filters=128) print("Done loading...") grad_exp, conjugate_grad_exp = get_divergence_funcs(FLAGS.divergence) data_loader = DataLoader( dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=True, shuffle=True) weights = [model.construct_weights('context_energy'), model_dis.construct_weights('context_dis')] Y = tf.placeholder(shape=(None), dtype=tf.int32) # Varibles to run in training X_SPLIT = tf.split(X, FLAGS.num_gpus) X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus) LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus) LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus) LABEL_SPLIT_INIT = list(LABEL_SPLIT) tower_grads = [] tower_grads_dis = [] tower_grads_l2 = [] tower_grads_dis_l2 = [] optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999) optimizer = hvd.DistributedOptimizer(optimizer) optimizer_dis = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999) optimizer_dis = hvd.DistributedOptimizer(optimizer_dis) for j in range(FLAGS.num_gpus): energy_pos = [ model.forward( X_SPLIT[j], weights[0], label=LABEL_POS_SPLIT[j], stop_at_grad=False)] energy_pos = tf.concat(energy_pos, axis=0) score_pos = [ model_dis.forward( X_SPLIT[j], weights[1], label=LABEL_POS_SPLIT[j], stop_at_grad=False)] score_pos = tf.concat(score_pos, axis=0) print("Building graph...") x_mod = x_orig = X_NOISE_SPLIT[j] x_grads = [] energy_negs = [] loss_energys = [] energy_negs.extend([model.forward(tf.stop_gradient( x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)]) eps_begin = tf.zeros(1) steps = tf.constant(0) c = lambda i, x: tf.less(i, FLAGS.num_steps) def langevin_step(counter, x_mod): x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale) energy_noise = energy_start = tf.concat( [model.forward( x_mod, weights[0], label=LABEL_SPLIT[j], reuse=True, stop_at_grad=False, stop_batch=True)], axis=0) x_grad, label_grad = tf.gradients(energy_noise, [x_mod, LABEL_SPLIT[j]]) energy_noise_old = energy_noise lr = FLAGS.step_lr if FLAGS.proj_norm != 0.0: if FLAGS.proj_norm_type == 'l2': x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm) elif FLAGS.proj_norm_type == 'li': x_grad = tf.clip_by_value( x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) else: print("Other types of projection are not supported!!!") assert False # Clip gradient norm for now if FLAGS.hmc: # Step size should be tuned to get around 65% acceptance def energy(x): return FLAGS.temperature * \ model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True) x_last = hmc(x_mod, 15., 10, energy) else: x_last = x_mod - (lr) * x_grad x_mod = x_last x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale) counter = counter + 1 return counter, x_mod steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod)) energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) x_grad = tf.gradients(energy_eval, [x_mod])[0] x_grads.append(x_grad) energy_negs.append( model.forward( tf.stop_gradient(x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)) score_neg = model_dis.forward( tf.stop_gradient(x_mod), weights[1], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) test_x_mod = x_mod temp = FLAGS.temperature energy_neg = energy_negs[-1] x_off = tf.reduce_mean( tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j])) loss_energy = model.forward( x_mod, weights[0], reuse=True, label=LABEL, stop_grad=True) print("Finished processing loop construction ...") target_vars = {} if FLAGS.cclass or FLAGS.model_cclass: label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0) label_prob = label_sum / tf.reduce_sum(label_sum) label_ent = -tf.reduce_sum(label_prob * tf.math.log(label_prob + 1e-7)) else: label_ent = tf.zeros(1) target_vars['label_ent'] = label_ent if FLAGS.train: loss_dis = - (tf.reduce_mean(grad_exp(score_pos + energy_pos)) - tf.reduce_mean(conjugate_grad_exp(score_neg + energy_neg))) loss_dis = loss_dis + FLAGS.l2_coeff * (tf.reduce_mean(tf.square(score_pos)) + tf.reduce_mean(tf.square(score_neg))) l2_dis = FLAGS.l2_coeff * (tf.reduce_mean(tf.square(score_pos)) + tf.reduce_mean(tf.square(score_neg))) loss_model = tf.reduce_mean(grad_exp(score_pos + energy_pos)) + \ tf.reduce_mean(energy_neg * tf.stop_gradient(conjugate_grad_exp(score_neg + energy_neg))) - \ tf.reduce_mean(energy_neg) * tf.stop_gradient(tf.reduce_mean(conjugate_grad_exp(score_neg + energy_neg))) - \ tf.reduce_mean(conjugate_grad_exp(score_neg + energy_neg)) loss_model = loss_model + FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg))) l2_model = FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg))) print("Started gradient computation...") model_vars = [var for var in tf.trainable_variables() if 'context_energy' in var.name] print("model var number", len(model_vars)) dis_vars = [var for var in tf.trainable_variables() if 'context_dis' in var.name] print("discriminator var number", len(dis_vars)) gvs = optimizer.compute_gradients(loss_model, model_vars) gvs = [(k, v) for (k, v) in gvs if k is not None] tower_grads.append(gvs) gvs = optimizer.compute_gradients(l2_model, model_vars) gvs = [(k, v) for (k, v) in gvs if k is not None] tower_grads_l2.append(gvs) gvs_dis = optimizer_dis.compute_gradients(loss_dis, dis_vars) gvs_dis = [(k, v) for (k, v) in gvs_dis if k is not None] tower_grads_dis.append(gvs_dis) gvs_dis = optimizer_dis.compute_gradients(l2_dis, dis_vars) gvs_dis = [(k, v) for (k, v) in gvs_dis if k is not None] tower_grads_dis_l2.append(gvs_dis) print("Finished applying gradients.") target_vars['total_loss'] = loss_model target_vars['loss_energy'] = loss_energy target_vars['weights'] = weights target_vars['gvs'] = gvs target_vars['X'] = X target_vars['Y'] = Y target_vars['LABEL'] = LABEL target_vars['LABEL_POS'] = LABEL_POS target_vars['X_NOISE'] = X_NOISE target_vars['energy_pos'] = energy_pos target_vars['energy_start'] = energy_negs[0] if len(x_grads) >= 1: target_vars['x_grad'] = x_grads[-1] target_vars['x_grad_first'] = x_grads[0] else: target_vars['x_grad'] = tf.zeros(1) target_vars['x_grad_first'] = tf.zeros(1) target_vars['x_mod'] = x_mod target_vars['x_off'] = x_off target_vars['temp'] = temp target_vars['energy_neg'] = energy_neg target_vars['test_x_mod'] = test_x_mod target_vars['eps_begin'] = eps_begin target_vars['score_neg'] = score_neg target_vars['score_pos'] = score_pos if FLAGS.train: grads_model = average_gradients(tower_grads) train_op_model = optimizer.apply_gradients(grads_model) target_vars['train_op_model'] = train_op_model grads_model_l2 = average_gradients(tower_grads_l2) train_op_model_l2 = optimizer.apply_gradients(grads_model_l2) target_vars['train_op_model_l2'] = train_op_model_l2 grads_model_dis = average_gradients(tower_grads_dis) train_op_dis = optimizer_dis.apply_gradients(grads_model_dis) target_vars['train_op_dis'] = train_op_dis grads_model_dis_l2 = average_gradients(tower_grads_dis_l2) train_op_dis_l2 = optimizer_dis.apply_gradients(grads_model_dis_l2) target_vars['train_op_dis_l2'] = train_op_dis_l2 config = tf.ConfigProto() if hvd.size() > 1: config.gpu_options.visible_device_list = str(hvd.local_rank()) sess = tf.Session(config=config) saver = loader = tf.train.Saver(max_to_keep=500) total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Model has a total of {} parameters".format(total_parameters)) sess.run(tf.global_variables_initializer()) resume_itr = 0 if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0: model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter saver.restore(sess, model_file) # optimistic_restore(sess, model_file) sess.run(hvd.broadcast_global_variables(0)) print("Initializing variables...") print("Start broadcast") print("End broadcast") if FLAGS.train: train(target_vars, saver, sess, logger, data_loader, resume_itr, logdir) test(target_vars, saver, sess, logger, data_loader)