def construct_ff_model(model, weights, X_NOISE, X, ACTION_LABEL, ACTION_NOISE_LABEL, optimizer): target_vars = {} x_mods = [] x_pred = model.forward(X[:, 0, 0], weights, action_label=ACTION_LABEL) loss_total = tf.reduce_mean(tf.square(x_pred - X[:, 1, 0])) dyn_model = TrajInverseDynamics(dim_input=FLAGS.latent_dim, dim_output=FLAGS.action_dim) weights = dyn_model.construct_weights(scope="inverse_dynamics", weights=weights) output_action = dyn_model.forward(X, weights) dyn_loss = tf.reduce_mean(tf.square(output_action - ACTION_LABEL)) dyn_dist = tf.reduce_mean(tf.abs(output_action - ACTION_LABEL)) target_vars['dyn_loss'] = dyn_loss target_vars['dyn_dist'] = dyn_dist dyn_optimizer = AdamOptimizer(1e-3) gvs = dyn_optimizer.compute_gradients(dyn_loss) dyn_train_op = dyn_optimizer.apply_gradients(gvs) gvs = optimizer.compute_gradients(loss_total) gvs = [(k, v) for (k, v) in gvs if k is not None] print("Applying gradients...") grads, vs = zip(*gvs) def filter_grad(g, v): return tf.clip_by_value(g, -1e5, 1e5) capped_gvs = [(filter_grad(grad, var), var) for grad, var in gvs] gvs = capped_gvs train_op = optimizer.apply_gradients(gvs) if not FLAGS.gt_inverse_dynamics: train_op = tf.group(train_op, dyn_train_op) target_vars['train_op'] = train_op target_vars['loss_ml'] = tf.zeros(1) target_vars['loss_total'] = loss_total target_vars['gvs'] = gvs target_vars['loss_energy'] = tf.zeros(1) target_vars['weights'] = weights target_vars['X'] = X target_vars['X_NOISE'] = X_NOISE target_vars['energy_pos'] = tf.zeros(1) target_vars['energy_neg'] = tf.zeros(1) target_vars['x_grad'] = tf.zeros(1) target_vars['action_grad'] = tf.zeros(1) target_vars['x_mod'] = tf.zeros(1) target_vars['x_off'] = tf.zeros(1) target_vars['temp'] = FLAGS.temperature target_vars['ACTION_LABEL'] = ACTION_LABEL target_vars['ACTION_NOISE_LABEL'] = ACTION_NOISE_LABEL return target_vars
def main(): logdir = osp.join(FLAGS.logdir, FLAGS.exp) logger = TensorBoardOutputFormat(logdir) config = tf.ConfigProto() sess = tf.Session(config=config) LABEL = None print("Loading data...") if FLAGS.dataset == 'cubes': dataset = Cubes(cond_idx=FLAGS.cond_idx) test_dataset = dataset if FLAGS.cond_idx == 0: label_size = 2 elif FLAGS.cond_idx == 1: label_size = 1 elif FLAGS.cond_idx == 2: label_size = 3 elif FLAGS.cond_idx == 3: label_size = 20 LABEL = tf.placeholder(shape=(None, label_size), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, label_size), dtype=tf.float32) elif FLAGS.dataset == 'color': dataset = CubesColor() test_dataset = dataset LABEL = tf.placeholder(shape=(None, 301), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 301), dtype=tf.float32) label_size = 301 elif FLAGS.dataset == 'pos': dataset = CubesPos() test_dataset = dataset LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) label_size = 2 elif FLAGS.dataset == "pairs": dataset = Pairs(cond_idx=0) test_dataset = dataset LABEL = tf.placeholder(shape=(None, 6), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 6), dtype=tf.float32) label_size = 6 elif FLAGS.dataset == "continual": dataset = CubesContinual() test_dataset = dataset if FLAGS.prelearn_model_shape: LABEL = tf.placeholder(shape=(None, 20), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 20), dtype=tf.float32) label_size = 20 else: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) label_size = 2 elif FLAGS.dataset == "cross": dataset = CubesCrossProduct(FLAGS.ratio, cond_size=FLAGS.cond_size, cond_pos=FLAGS.cond_pos, joint_baseline=FLAGS.joint_baseline) test_dataset = dataset if FLAGS.cond_size: LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32) label_size = 1 elif FLAGS.cond_pos: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) label_size = 2 if FLAGS.joint_baseline: LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32) label_size = 3 elif FLAGS.dataset == 'celeba': dataset = CelebA(cond_idx=FLAGS.celeba_cond_idx) test_dataset = dataset channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) model = ResNet128( num_channels=channel_num, num_filters=64, classes=2) if FLAGS.joint_baseline: # Other stuff for joint model optimizer = AdamOptimizer(FLAGS.lr, beta1=0.99, beta2=0.999) X = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32) X_NOISE = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32) ATTENTION_MASK = tf.placeholder(shape=(None, 64, 64, FLAGS.cond_func), dtype=tf.float32) NOISE = tf.placeholder(shape=(None, 128), dtype=tf.float32) HIER_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) channel_num = 3 model = CubesNetGen(num_channels=channel_num, label_size=label_size) weights = model.construct_weights('context_0') output = model.forward(NOISE, weights, reuse=False, label=LABEL) print(output.get_shape()) mse_loss = tf.reduce_mean(tf.square(output - X)) gvs = optimizer.compute_gradients(mse_loss) train_op = optimizer.apply_gradients(gvs) gvs = [(k, v) for (k, v) in gvs if k is not None] target_vars = {} target_vars['train_op'] = train_op target_vars['X'] = X target_vars['X_NOISE'] = X_NOISE target_vars['ATTENTION_MASK'] = ATTENTION_MASK target_vars['eps_begin'] = tf.zeros(1) target_vars['gvs'] = gvs target_vars['energy_pos'] = tf.zeros(1) target_vars['energy_neg'] = tf.zeros(1) target_vars['loss_energy'] = tf.zeros(1) target_vars['loss_ml'] = tf.zeros(1) target_vars['total_loss'] = mse_loss target_vars['attention_mask'] = tf.zeros(1) target_vars['attention_grad'] = tf.zeros(1) target_vars['x_off'] = tf.reduce_mean(tf.abs(output - X)) target_vars['x_mod'] = tf.zeros(1) target_vars['x_grad'] = tf.zeros(1) target_vars['NOISE'] = NOISE target_vars['LABEL'] = LABEL target_vars['LABEL_POS'] = LABEL_POS target_vars['HIER_LABEL'] = HIER_LABEL data_loader = DataLoader( dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=True, shuffle=True) else: print("label size here ", label_size) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32) HEIR_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) ATTENTION_MASK = tf.placeholder(shape=(None, 64, 64, FLAGS.cond_func), dtype=tf.float32) if FLAGS.dataset != "celeba": model = CubesNet(num_channels=channel_num, label_size=label_size) heir_model = HeirNet(num_channels=FLAGS.cond_func) models_pretrain = [] if FLAGS.prelearn_model: model_prelearn = CubesNet(num_channels=channel_num, label_size=FLAGS.prelearn_label) weights = model_prelearn.construct_weights('context_1') LABEL_PRELEARN = tf.placeholder(shape=(None, FLAGS.prelearn_label), dtype=tf.float32) models_pretrain.append((model_prelearn, weights, LABEL_PRELEARN)) cubes_logdir = osp.join(FLAGS.logdir, FLAGS.prelearn_exp) if (FLAGS.prelearn_iter != -1 or not FLAGS.train): model_file = osp.join(cubes_logdir, 'model_{}'.format(FLAGS.prelearn_iter)) resume_itr = FLAGS.resume_iter # saver.restore(sess, model_file) v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1)) v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list} saver = tf.train.Saver(v_map) saver.restore(sess, model_file) if FLAGS.prelearn_model_shape: model_prelearn = CubesNet(num_channels=channel_num, label_size=FLAGS.prelearn_label_shape) weights = model_prelearn.construct_weights('context_2') LABEL_PRELEARN = tf.placeholder(shape=(None, FLAGS.prelearn_label_shape), dtype=tf.float32) models_pretrain.append((model_prelearn, weights, LABEL_PRELEARN)) cubes_logdir = osp.join(FLAGS.logdir, FLAGS.prelearn_exp_shape) if (FLAGS.prelearn_iter_shape != -1 or not FLAGS.train): model_file = osp.join(cubes_logdir, 'model_{}'.format(FLAGS.prelearn_iter_shape)) resume_itr = FLAGS.resume_iter # saver.restore(sess, model_file) v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2)) v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list} saver = tf.train.Saver(v_map) saver.restore(sess, model_file) print("Done loading...") data_loader = DataLoader( dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=True, shuffle=True) batch_size = FLAGS.batch_size weights = model.construct_weights('context_0') if FLAGS.heir_mask: weights = heir_model.construct_weights('heir_0', weights=weights) Y = tf.placeholder(shape=(None), dtype=tf.int32) # Varibles to run in training X_SPLIT = tf.split(X, FLAGS.num_gpus) X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus) LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus) LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus) LABEL_SPLIT_INIT = list(LABEL_SPLIT) attention_mask = ATTENTION_MASK tower_grads = [] tower_gen_grads = [] x_mod_list = [] optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.99) for j in range(FLAGS.num_gpus): x_mod = X_SPLIT[j] if FLAGS.comb_mask: steps = tf.constant(0) c = lambda i, x: tf.less(i, FLAGS.num_steps) def langevin_attention_step(counter, attention_mask): attention_mask = attention_mask + tf.random_normal(tf.shape(attention_mask), mean=0.0, stddev=0.01) energy_noise = energy_start = model.forward( x_mod, weights, attention_mask, label=LABEL_SPLIT[j], reuse=True, stop_at_grad=False, stop_batch=True) if FLAGS.heir_mask: energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL) energy_noise = energy_noise + energy_heir attention_grad = tf.gradients( FLAGS.temperature * energy_noise, [attention_mask])[0] energy_noise_old = energy_noise # Clip gradient norm for now attention_mask = attention_mask - (FLAGS.attention_lr) * attention_grad attention_mask = tf.layers.average_pooling2d(attention_mask, (3, 3), 1, padding='SAME') attention_mask = tf.stop_gradient(attention_mask) counter = counter + 1 return counter, attention_mask steps, attention_mask = tf.while_loop(c, langevin_attention_step, (steps, attention_mask)) # attention_mask = tf.Print(attention_mask, [attention_mask]) energy_pos = model.forward( X_SPLIT[j], weights, tf.stop_gradient(attention_mask), label=LABEL_POS_SPLIT[j], stop_at_grad=False) if FLAGS.heir_mask: energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL) energy_pos = energy_heir + energy_pos else: energy_pos = model.forward( X_SPLIT[j], weights, attention_mask, label=LABEL_POS_SPLIT[j], stop_at_grad=False) if FLAGS.heir_mask: energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL) energy_pos = energy_heir + energy_pos print("Building graph...") x_mod = x_orig = X_NOISE_SPLIT[j] x_grads = [] loss_energys = [] eps_begin = tf.zeros(1) steps = tf.constant(0) c_cond = lambda i, x, y: tf.less(i, FLAGS.num_steps) def langevin_step(counter, x_mod, attention_mask): lr = FLAGS.step_lr x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.001 * FLAGS.rescale * FLAGS.noise_scale) attention_mask = attention_mask + tf.random_normal(tf.shape(attention_mask), mean=0.0, stddev=0.01) energy_noise = model.forward( x_mod, weights, attention_mask, label=LABEL_SPLIT[j], reuse=True, stop_at_grad=False, stop_batch=True) if FLAGS.prelearn_model: for m_i, w_i, l_i in models_pretrain: energy_noise = energy_noise + m_i.forward( x_mod, w_i, attention_mask, label=l_i, reuse=True, stop_at_grad=False, stop_batch=True) if FLAGS.heir_mask: energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL) energy_noise = energy_heir + energy_noise x_grad, attention_grad = tf.gradients( FLAGS.temperature * energy_noise, [x_mod, attention_mask]) if not FLAGS.comb_mask: attention_grad = tf.zeros(1) energy_noise_old = energy_noise if FLAGS.proj_norm != 0.0: if FLAGS.proj_norm_type == 'l2': x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm) elif FLAGS.proj_norm_type == 'li': x_grad = tf.clip_by_value( x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) else: print("Other types of projection are not supported!!!") assert False # Clip gradient norm for now x_last = x_mod - (lr) * x_grad if FLAGS.comb_mask: attention_mask = attention_mask - FLAGS.attention_lr * attention_grad attention_mask = tf.layers.average_pooling2d(attention_mask, (3, 3), 1, padding='SAME') attention_mask = tf.stop_gradient(attention_mask) x_mod = x_last x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale) counter = counter + 1 return counter, x_mod, attention_mask steps, x_mod, attention_mask = tf.while_loop(c_cond, langevin_step, (steps, x_mod, attention_mask)) attention_mask = tf.stop_gradient(attention_mask) # attention_mask = tf.Print(attention_mask, [attention_mask]) energy_eval = model.forward(x_mod, weights, attention_mask, label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) x_grad, attention_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod, attention_mask]) x_grads.append(x_grad) energy_neg = model.forward( tf.stop_gradient(x_mod), weights, tf.stop_gradient(attention_mask), label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) if FLAGS.heir_mask: energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL) energy_neg = energy_heir + energy_neg temp = FLAGS.temperature x_off = tf.reduce_mean( tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j])) loss_energy = model.forward( x_mod, weights, attention_mask, reuse=True, label=LABEL, stop_grad=True) print("Finished processing loop construction ...") target_vars = {} if FLAGS.antialias: antialias = tf.tile(stride_3, (1, 1, tf.shape(x_mod)[3], tf.shape(x_mod)[3])) inp = tf.nn.conv2d(x_mod, antialias, [1, 2, 2, 1], padding='SAME') test_x_mod = x_mod if FLAGS.cclass or FLAGS.model_cclass: label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0) label_prob = label_sum / tf.reduce_sum(label_sum) label_ent = -tf.reduce_sum(label_prob * tf.math.log(label_prob + 1e-7)) else: label_ent = tf.zeros(1) target_vars['label_ent'] = label_ent if FLAGS.train: if FLAGS.objective == 'logsumexp': pos_term = temp * energy_pos energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg)) coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced)) norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4 pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'cd': pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = -tf.reduce_mean(temp * energy_neg) loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'softplus': loss_ml = FLAGS.ml_coeff * \ tf.nn.softplus(temp * (energy_pos - energy_neg)) loss_total = tf.reduce_mean(loss_ml) if not FLAGS.zero_kl: loss_total = loss_total + tf.reduce_mean(loss_energy) loss_total = loss_total + \ FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg)))) print("Started gradient computation...") gvs = optimizer.compute_gradients(loss_total) gvs = [(k, v) for (k, v) in gvs if k is not None] print("Applying gradients...") tower_grads.append(gvs) print("Finished applying gradients.") target_vars['loss_ml'] = loss_ml target_vars['total_loss'] = loss_total target_vars['loss_energy'] = loss_energy target_vars['weights'] = weights target_vars['gvs'] = gvs target_vars['X'] = X target_vars['Y'] = Y target_vars['LABEL'] = LABEL target_vars['HIER_LABEL'] = HEIR_LABEL target_vars['LABEL_POS'] = LABEL_POS target_vars['X_NOISE'] = X_NOISE target_vars['energy_pos'] = energy_pos target_vars['attention_grad'] = attention_grad if len(x_grads) >= 1: target_vars['x_grad'] = x_grads[-1] target_vars['x_grad_first'] = x_grads[0] else: target_vars['x_grad'] = tf.zeros(1) target_vars['x_grad_first'] = tf.zeros(1) target_vars['x_mod'] = x_mod target_vars['x_off'] = x_off target_vars['temp'] = temp target_vars['energy_neg'] = energy_neg target_vars['test_x_mod'] = test_x_mod target_vars['eps_begin'] = eps_begin target_vars['ATTENTION_MASK'] = ATTENTION_MASK target_vars['models_pretrain'] = models_pretrain if FLAGS.comb_mask: target_vars['attention_mask'] = tf.nn.softmax(attention_mask) else: target_vars['attention_mask'] = tf.zeros(1) if FLAGS.train: grads = average_gradients(tower_grads) train_op = optimizer.apply_gradients(grads) target_vars['train_op'] = train_op # sess = tf.Session(config=config) saver = loader = tf.train.Saver( max_to_keep=30, keep_checkpoint_every_n_hours=6) total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Model has a total of {} parameters".format(total_parameters)) sess.run(tf.global_variables_initializer()) resume_itr = 0 if (FLAGS.resume_iter != -1 or not FLAGS.train): model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter # saver.restore(sess, model_file) optimistic_restore(sess, model_file) print("Initializing variables...") print("Start broadcast") print("End broadcast") if FLAGS.train: train(target_vars, saver, sess, logger, data_loader, resume_itr, logdir) test(target_vars, saver, sess, logger, data_loader)
def main(): print("Local rank: ", hvd.local_rank(), hvd.size()) logdir = osp.join(FLAGS.logdir, FLAGS.exp) if hvd.rank() == 0: if not osp.exists(logdir): os.makedirs(logdir) logger = TensorBoardOutputFormat(logdir) else: logger = None LABEL = None print("Loading data...") if FLAGS.dataset == 'cifar10': dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale) test_dataset = Cifar10(train=False, rescale=FLAGS.rescale) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) if FLAGS.large_model: model = ResNet32Large(num_channels=channel_num, num_filters=128, train=True) elif FLAGS.larger_model: model = ResNet32Larger(num_channels=channel_num, num_filters=128) elif FLAGS.wider_model: model = ResNet32Wider(num_channels=channel_num, num_filters=192) else: model = ResNet32(num_channels=channel_num, num_filters=128) elif FLAGS.dataset == 'imagenet': dataset = Imagenet(train=True) test_dataset = Imagenet(train=False) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32) model = ResNet32Wider(num_channels=channel_num, num_filters=256) elif FLAGS.dataset == 'imagenetfull': channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32) model = ResNet128(num_channels=channel_num, num_filters=64) elif FLAGS.dataset == 'mnist': dataset = Mnist(rescale=FLAGS.rescale) test_dataset = dataset channel_num = 1 X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32) X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) model = MnistNet(num_channels=channel_num, num_filters=FLAGS.num_filters) elif FLAGS.dataset == 'dsprites': dataset = DSprites(cond_shape=FLAGS.cond_shape, cond_size=FLAGS.cond_size, cond_pos=FLAGS.cond_pos, cond_rot=FLAGS.cond_rot) test_dataset = dataset channel_num = 1 X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) if FLAGS.dpos_only: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.dsize_only: LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32) elif FLAGS.drot_only: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.cond_size: LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32) elif FLAGS.cond_shape: LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32) elif FLAGS.cond_pos: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.cond_rot: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) else: LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32) model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters, cond_size=FLAGS.cond_size, cond_shape=FLAGS.cond_shape, cond_pos=FLAGS.cond_pos, cond_rot=FLAGS.cond_rot) print("Done loading...") if FLAGS.dataset == "imagenetfull": # In the case of full imagenet, use custom_tensorflow dataloader data_loader = TFImagenetLoader('train', FLAGS.batch_size, hvd.rank(), hvd.size(), rescale=FLAGS.rescale) else: data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=True, shuffle=True) batch_size = FLAGS.batch_size weights = [model.construct_weights('context_0')] Y = tf.placeholder(shape=(None), dtype=tf.int32) # Varibles to run in training X_SPLIT = tf.split(X, FLAGS.num_gpus) X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus) LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus) LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus) LABEL_SPLIT_INIT = list(LABEL_SPLIT) tower_grads = [] tower_gen_grads = [] x_mod_list = [] optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999) optimizer = hvd.DistributedOptimizer(optimizer) for j in range(FLAGS.num_gpus): if FLAGS.model_cclass: ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus label_tensor = tf.Variable(tf.convert_to_tensor(np.reshape( np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)), (FLAGS.batch_size * 10, 10)), dtype=tf.float32), trainable=False, dtype=tf.float32) x_split = tf.tile( tf.reshape(X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)), (1, 10, 1, 1, 1)) x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3)) energy_pos = model.forward(x_split, weights[0], label=label_tensor, stop_at_grad=False) energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10)) energy_partition_est = tf.reduce_logsumexp(energy_pos_full, axis=1, keepdims=True) uniform = tf.random_uniform(tf.shape(energy_pos_full)) label_tensor = tf.argmax(-energy_pos_full - tf.log(-tf.log(uniform)) - energy_partition_est, axis=1) label = tf.one_hot(label_tensor, 10, dtype=tf.float32) label = tf.Print(label, [label_tensor, energy_pos_full]) LABEL_SPLIT[j] = label energy_pos = tf.concat(energy_pos, axis=0) else: energy_pos = [ model.forward(X_SPLIT[j], weights[0], label=LABEL_POS_SPLIT[j], stop_at_grad=False) ] energy_pos = tf.concat(energy_pos, axis=0) print("Building graph...") x_mod = x_orig = X_NOISE_SPLIT[j] x_grads = [] energy_negs = [] loss_energys = [] energy_negs.extend([ model.forward(tf.stop_gradient(x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) ]) eps_begin = tf.zeros(1) steps = tf.constant(0) c = lambda i, x: tf.less(i, FLAGS.num_steps) def langevin_step(counter, x_mod): x_mod = x_mod + tf.random_normal( tf.shape(x_mod), mean=0.0, stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale) energy_noise = energy_start = tf.concat([ model.forward(x_mod, weights[0], label=LABEL_SPLIT[j], reuse=True, stop_at_grad=False, stop_batch=True) ], axis=0) x_grad, label_grad = tf.gradients(FLAGS.temperature * energy_noise, [x_mod, LABEL_SPLIT[j]]) energy_noise_old = energy_noise lr = FLAGS.step_lr if FLAGS.proj_norm != 0.0: if FLAGS.proj_norm_type == 'l2': x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm) elif FLAGS.proj_norm_type == 'li': x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) else: print("Other types of projection are not supported!!!") assert False # Clip gradient norm for now if FLAGS.hmc: # Step size should be tuned to get around 65% acceptance def energy(x): return FLAGS.temperature * \ model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True) x_last = hmc(x_mod, 15., 10, energy) else: x_last = x_mod - (lr) * x_grad x_mod = x_last x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale) counter = counter + 1 return counter, x_mod steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod)) energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0] x_grads.append(x_grad) energy_negs.append( model.forward(tf.stop_gradient(x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)) test_x_mod = x_mod temp = FLAGS.temperature energy_neg = energy_negs[-1] x_off = tf.reduce_mean( tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j])) loss_energy = model.forward(x_mod, weights[0], reuse=True, label=LABEL, stop_grad=True) print("Finished processing loop construction ...") target_vars = {} if FLAGS.cclass or FLAGS.model_cclass: label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0) label_prob = label_sum / tf.reduce_sum(label_sum) label_ent = -tf.reduce_sum( label_prob * tf.math.log(label_prob + 1e-7)) else: label_ent = tf.zeros(1) target_vars['label_ent'] = label_ent if FLAGS.train: if FLAGS.objective == 'logsumexp': pos_term = temp * energy_pos energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg)) coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced)) norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4 pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'cd': pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = -tf.reduce_mean(temp * energy_neg) loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'softplus': loss_ml = FLAGS.ml_coeff * \ tf.nn.softplus(temp * (energy_pos - energy_neg)) loss_total = tf.reduce_mean(loss_ml) if not FLAGS.zero_kl: loss_total = loss_total + tf.reduce_mean(loss_energy) loss_total = loss_total + \ FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg)))) print("Started gradient computation...") gvs = optimizer.compute_gradients(loss_total) gvs = [(k, v) for (k, v) in gvs if k is not None] print("Applying gradients...") tower_grads.append(gvs) print("Finished applying gradients.") target_vars['loss_ml'] = loss_ml target_vars['total_loss'] = loss_total target_vars['loss_energy'] = loss_energy target_vars['weights'] = weights target_vars['gvs'] = gvs target_vars['X'] = X target_vars['Y'] = Y target_vars['LABEL'] = LABEL target_vars['LABEL_POS'] = LABEL_POS target_vars['X_NOISE'] = X_NOISE target_vars['energy_pos'] = energy_pos target_vars['energy_start'] = energy_negs[0] if len(x_grads) >= 1: target_vars['x_grad'] = x_grads[-1] target_vars['x_grad_first'] = x_grads[0] else: target_vars['x_grad'] = tf.zeros(1) target_vars['x_grad_first'] = tf.zeros(1) target_vars['x_mod'] = x_mod target_vars['x_off'] = x_off target_vars['temp'] = temp target_vars['energy_neg'] = energy_neg target_vars['test_x_mod'] = test_x_mod target_vars['eps_begin'] = eps_begin if FLAGS.train: grads = average_gradients(tower_grads) train_op = optimizer.apply_gradients(grads) target_vars['train_op'] = train_op config = tf.ConfigProto() if hvd.size() > 1: config.gpu_options.visible_device_list = str(hvd.local_rank()) sess = tf.Session(config=config) saver = loader = tf.train.Saver(max_to_keep=30, keep_checkpoint_every_n_hours=6) total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Model has a total of {} parameters".format(total_parameters)) sess.run(tf.global_variables_initializer()) resume_itr = 0 if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0: model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter # saver.restore(sess, model_file) optimistic_restore(sess, model_file) sess.run(hvd.broadcast_global_variables(0)) print("Initializing variables...") print("Start broadcast") print("End broadcast") if FLAGS.train: print("Training phase") train(target_vars, saver, sess, logger, data_loader, resume_itr, logdir) print("Testing phase") test(target_vars, saver, sess, logger, data_loader)
def main(): logdir = osp.join(FLAGS.logdir, FLAGS.exp) if not osp.exists(logdir): os.makedirs(logdir) logger = TensorBoardOutputFormat(logdir) datasource = FLAGS.datasource def make_env(rank): def _thunk(): # Make the environments non stoppable for now if datasource == "maze": env = Maze(end=[0.7, -0.8], start=[-0.85, -0.85], random_starts=False) elif datasource == "point": env = Point(end=[0.5, 0.5], start=[0.0, 0.0], random_starts=True) elif datasource == "reacher": env = Reacher(end=[0.7, 0.5], eps=0.01) env.seed(rank) env = Monitor(env, os.path.join("/tmp", str(rank)), allow_early_resets=True) return env return _thunk env = SubprocVecEnv( [make_env(i + FLAGS.seed) for i in range(FLAGS.num_env)]) if FLAGS.datasource == 'point' or FLAGS.datasource == 'maze' or FLAGS.datasource == 'reacher': if FLAGS.ff_model: model = TrajFFDynamics(dim_input=FLAGS.latent_dim, dim_output=FLAGS.latent_dim) else: model = TrajNetLatentFC(dim_input=FLAGS.latent_dim) X_NOISE = tf.placeholder(shape=(None, FLAGS.total_frame, FLAGS.input_objects, FLAGS.latent_dim), dtype=tf.float32) X = tf.placeholder(shape=(None, FLAGS.total_frame, FLAGS.input_objects, FLAGS.latent_dim), dtype=tf.float32) if FLAGS.cond: ACTION_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) else: ACTION_LABEL = None ACTION_NOISE_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) ACTION_PLAN = tf.placeholder(shape=(None, FLAGS.plan_steps + 1, 2), dtype=tf.float32) X_START = tf.placeholder(shape=(None, 1, FLAGS.input_objects, FLAGS.latent_dim), dtype=tf.float32) X_PLAN = tf.placeholder(shape=(None, FLAGS.plan_steps, FLAGS.input_objects, FLAGS.latent_dim), dtype=tf.float32) if FLAGS.datasource == 'reacher': X_END = tf.placeholder(shape=(None, 1, FLAGS.input_objects, 2), dtype=tf.float32) else: X_END = tf.placeholder(shape=(None, 1, FLAGS.input_objects, FLAGS.latent_dim), dtype=tf.float32) else: raise AssertionError("Unsupported data source") weights = model.construct_weights(action_size=FLAGS.action_dim) optimizer = AdamOptimizer(1e-2, beta1=0.0, beta2=0.999) if FLAGS.ff_model: target_vars = construct_ff_model(model, weights, X_NOISE, X, ACTION_LABEL, ACTION_NOISE_LABEL, optimizer) target_vars = construct_ff_plan_model(model, weights, X_PLAN, X_START, X_END, ACTION_PLAN, target_vars=target_vars) else: target_vars = construct_model(model, weights, X_NOISE, X, ACTION_LABEL, ACTION_NOISE_LABEL, optimizer) target_vars = construct_plan_model(model, weights, X_PLAN, X_START, X_END, ACTION_PLAN, target_vars=target_vars) sess = tf.InteractiveSession() saver = loader = tf.train.Saver(max_to_keep=10, keep_checkpoint_every_n_hours=2) tf.global_variables_initializer().run() print("Initializing variables...") if FLAGS.resume_iter != -1 or not FLAGS.train: model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter saver.restore(sess, model_file) train(target_vars, saver, sess, logger, FLAGS.resume_iter, env)
def construct_model(model, weights, X_NOISE, X, ACTION_LABEL, ACTION_NOISE_LABEL, optimizer): target_vars = {} x_mods = [] energy_pos = model.forward(X, weights, action_label=ACTION_LABEL) energy_noise = energy_start = model.forward(X_NOISE, weights, reuse=True, stop_at_grad=True, action_label=ACTION_LABEL) x_mod = X_NOISE x_grads = [] x_ees = [] if not FLAGS.gt_inverse_dynamics: dyn_model = TrajInverseDynamics(dim_input=FLAGS.latent_dim, dim_output=FLAGS.action_dim) weights = dyn_model.construct_weights(scope="inverse_dynamics", weights=weights) steps = tf.constant(0) c = lambda i, x, y: tf.less(i, FLAGS.num_steps) def mcmc_step(counter, x_mod, action_label): x_mod = x_mod + tf.random_normal( tf.shape(x_mod), mean=0.0, stddev=0.01) action_label = action_label + tf.random_normal( tf.shape(action_label), mean=0.0, stddev=0.01) energy_noise = model.forward(x_mod, weights, action_label=action_label, reuse=True) lr = FLAGS.step_lr x_grad = tf.gradients(FLAGS.temperature * energy_noise, [x_mod])[0] x_mod = x_mod - lr * x_grad if FLAGS.cond: x_grad, action_grad = tf.gradients( FLAGS.temperature * energy_noise, [x_mod, action_label]) else: x_grad, action_grad = tf.gradients( FLAGS.temperature * energy_noise, [x_mod])[0], tf.zeros(1) action_label = action_label - FLAGS.step_lr * action_grad x_mod = tf.clip_by_value(x_mod, -1.2, 1.2) action_label = tf.clip_by_value(action_label, -1.2, 1.2) counter = counter + 1 return counter, x_mod, action_label steps, x_mod, action_label = tf.while_loop( c, mcmc_step, (steps, x_mod, ACTION_NOISE_LABEL)) target_vars['x_mod'] = x_mod temp = FLAGS.temperature loss_energy = temp * model.forward( x_mod, weights, reuse=True, action_label=action_label, stop_grad=True) x_mod = tf.stop_gradient(x_mod) action_label = tf.stop_gradient(action_label) energy_neg = model.forward(x_mod, weights, action_label=action_label, reuse=True) if FLAGS.cond: x_grad, action_grad = tf.gradients(FLAGS.temperature * energy_neg, [x_mod, action_label]) else: x_grad, action_grad = tf.gradients(FLAGS.temperature * energy_neg, [x_mod])[0], tf.zeros(1) x_off = tf.reduce_mean(tf.square(x_mod - X)) if FLAGS.train: pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = -tf.reduce_mean(temp * energy_neg) loss_ml = (pos_loss + tf.reduce_sum(neg_loss)) loss_total = tf.reduce_mean(loss_ml) loss_total = loss_total + \ (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg)))) if not FLAGS.gt_inverse_dynamics: output_action = dyn_model.forward(X, weights) dyn_loss = tf.reduce_mean(tf.square(output_action - ACTION_LABEL)) dyn_dist = tf.reduce_mean(tf.abs(output_action - ACTION_LABEL)) target_vars['dyn_loss'] = dyn_loss target_vars['dyn_dist'] = dyn_dist dyn_optimizer = AdamOptimizer(1e-3) gvs = dyn_optimizer.compute_gradients(dyn_loss) dyn_train_op = dyn_optimizer.apply_gradients(gvs) else: target_vars['dyn_loss'] = tf.zeros(1) target_vars['dyn_dist'] = tf.zeros(1) if FLAGS.train: print("Started gradient computation...") gvs = optimizer.compute_gradients(loss_total) gvs = [(k, v) for (k, v) in gvs if k is not None] print("Applying gradients...") grads, vs = zip(*gvs) def filter_grad(g, v): return tf.clip_by_value(g, -1e5, 1e5) capped_gvs = [(filter_grad(grad, var), var) for grad, var in gvs] gvs = capped_gvs train_op = optimizer.apply_gradients(gvs) if not FLAGS.gt_inverse_dynamics: train_op = tf.group(train_op, dyn_train_op) target_vars['train_op'] = train_op print("Finished applying gradients.") target_vars['loss_ml'] = loss_ml target_vars['loss_total'] = loss_total target_vars['gvs'] = gvs target_vars['loss_energy'] = loss_energy target_vars['weights'] = weights target_vars['X'] = X target_vars['X_NOISE'] = X_NOISE target_vars['energy_pos'] = energy_pos target_vars['energy_neg'] = energy_neg target_vars['x_grad'] = x_grad target_vars['action_grad'] = action_grad target_vars['x_mod'] = x_mod target_vars['x_off'] = x_off target_vars['temp'] = temp target_vars['ACTION_LABEL'] = ACTION_LABEL target_vars['ACTION_NOISE_LABEL'] = ACTION_NOISE_LABEL target_vars['idyn_model'] = dyn_model return target_vars
def gentest(sess, kvs, data, latents, save_exp_dir): X_NOISE = kvs['X_NOISE'] LABEL_SIZE = kvs['LABEL_SIZE'] LABEL_SHAPE = kvs['LABEL_SHAPE'] LABEL_POS = kvs['LABEL_POS'] LABEL_ROT = kvs['LABEL_ROT'] model_size = kvs['model_size'] model_shape = kvs['model_shape'] model_pos = kvs['model_pos'] model_rot = kvs['model_rot'] weight_size = kvs['weight_size'] weight_shape = kvs['weight_shape'] weight_pos = kvs['weight_pos'] weight_rot = kvs['weight_rot'] X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) datafull = data # Test combination of generalization where we use slices of both training x_final = X_NOISE x_mod_size = X_NOISE x_mod_pos = X_NOISE for i in range(FLAGS.num_steps): # use cond_pos energies = [] x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005) e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS) # energies.append(e_noise) x_grad = tf.gradients(e_noise, [x_final])[0] x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005) x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1) if FLAGS.joint_shape: # use cond_shape e_noise = model_shape.forward(x_mod_pos, weight_shape, label=LABEL_SHAPE) elif FLAGS.joint_rot: e_noise = model_rot.forward(x_mod_pos, weight_rot, label=LABEL_ROT) else: # use cond_size e_noise = model_size.forward(x_mod_pos, weight_size, label=LABEL_SIZE) # energies.append(e_noise) # energy_stack = tf.concat(energies, axis=1) # energy_stack = tf.reduce_logsumexp(-1*energy_stack, axis=1) # energy_stack = tf.reduce_sum(energy_stack, axis=1) x_grad = tf.gradients(e_noise, [x_mod_pos])[0] x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1) # for x_mod_size # use cond_size # e_noise = model_size.forward(x_mod_size, weight_size, label=LABEL_SIZE) # x_grad = tf.gradients(e_noise, [x_mod_size])[0] # x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005) # x_mod_size = x_mod_size - FLAGS.step_lr * x_grad # x_mod_size = tf.clip_by_value(x_mod_size, 0, 1) # # use cond_pos # e_noise = model_pos.forward(x_mod_size, weight_pos, label=LABEL_POS) # x_grad = tf.gradients(e_noise, [x_mod_size])[0] # x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005) # x_mod_size = x_mod_size - FLAGS.step_lr * tf.stop_gradient(x_grad) # x_mod_size = tf.clip_by_value(x_mod_size, 0, 1) x_mod = x_mod_pos x_final = x_mod if FLAGS.joint_shape: loss_kl = model_shape.forward(x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True) + \ model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True) energy_pos = model_shape.forward(X, weight_shape, reuse=True, label=LABEL_SHAPE) + \ model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) energy_neg = model_shape.forward(tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE) + \ model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS) elif FLAGS.joint_rot: loss_kl = model_rot.forward(x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True) + \ model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True) energy_pos = model_rot.forward(X, weight_rot, reuse=True, label=LABEL_ROT) + \ model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) energy_neg = model_rot.forward(tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT) + \ model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS) else: loss_kl = model_size.forward(x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True) + \ model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True) energy_pos = model_size.forward(X, weight_size, reuse=True, label=LABEL_SIZE) + \ model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) energy_neg = model_size.forward(tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE) + \ model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS) energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg)) coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced)) norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4 neg_loss = coeff * (-1*energy_neg) / norm_constant loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg) loss_total = loss_ml + tf.reduce_mean(loss_kl) + 1 * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg))) optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999) gvs = optimizer.compute_gradients(loss_total) gvs = [(k, v) for (k, v) in gvs if k is not None] train_op = optimizer.apply_gradients(gvs) vs = optimizer.variables() sess.run(tf.variables_initializer(vs)) dataloader = DataLoader(DSpritesGen(data, latents), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True) x_off = tf.reduce_mean(tf.square(x_mod - X)) itr = 0 saver = tf.train.Saver() x_mod = None if FLAGS.train: replay_buffer = ReplayBuffer(10000) for _ in range(1): for data_corrupt, data, label_size, label_pos in tqdm(dataloader): data_corrupt = data_corrupt.numpy()[:, :, :] data = data.numpy()[:, :, :] if x_mod is not None: replay_buffer.add(x_mod) replay_batch = replay_buffer.sample(FLAGS.batch_size) replay_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95) data_corrupt[replay_mask] = replay_batch[replay_mask] if FLAGS.joint_shape: feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SHAPE: label_size, LABEL_POS: label_pos} elif FLAGS.joint_rot: feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_ROT: label_size, LABEL_POS: label_pos} else: feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SIZE: label_size, LABEL_POS: label_pos} _, off_value, e_pos, e_neg, x_mod = sess.run([train_op, x_off, energy_pos, energy_neg, x_final], feed_dict=feed_dict) itr += 1 if itr % 10 == 0: print("x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(off_value, e_pos.mean(), e_neg.mean(), itr)) if itr == FLAGS.break_steps: break saver.save(sess, osp.join(save_exp_dir, 'model_gentest')) saver.restore(sess, osp.join(save_exp_dir, 'model_gentest')) l = latents if FLAGS.joint_shape: mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5) elif FLAGS.joint_rot: mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5) else: mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31)))) data_gen = datafull[mask_gen] latents_gen = latents[mask_gen] losses = [] for dat, latent in zip(np.array_split(data_gen, 120), np.array_split(latents_gen, 120)): x = 0.5 + np.random.randn(*dat.shape) if FLAGS.joint_shape: feed_dict = {LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat} elif FLAGS.joint_rot: feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1), LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat} else: feed_dict = {LABEL_SIZE: latent[:, 2:3], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat} for i in range(2): x = sess.run([x_final], feed_dict=feed_dict)[0] feed_dict[X_NOISE] = x loss = sess.run([x_off], feed_dict=feed_dict)[0] losses.append(loss) print("Mean MSE loss of {} ".format(np.mean(losses))) data_try = data_gen[:10] data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64) latent_scale = latents_gen[:10, 2:3] latent_pos = latents_gen[:10, 4:] if FLAGS.joint_shape: feed_dict = {X_NOISE: data_init, LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32)-1], LABEL_POS: latent_pos} elif FLAGS.joint_rot: feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1), LABEL_POS: latent[:10, 4:], X_NOISE: data_init} else: feed_dict = {X_NOISE: data_init, LABEL_SIZE: latent_scale, LABEL_POS: latent_pos} x_output = sess.run([x_final], feed_dict=feed_dict)[0] if FLAGS.joint_shape: im_name = "size_shape_combine_gentest.png" else: im_name = "size_scale_combine_gentest.png" x_output_wrap = np.ones((10, 66, 66)) data_try_wrap = np.ones((10, 66, 66)) x_output_wrap[:, 1:-1, 1:-1] = x_output data_try_wrap[:, 1:-1, 1:-1] = data_try im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2) impath = osp.join(save_exp_dir, im_name) imsave(impath, im_output) print("Successfully saved images at {}".format(impath))
def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0): # tf.reset_default_graph() if FLAGS.joint_shape: model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=5) LABEL = tf.placeholder(shape=(None, 5), dtype=tf.float32) else: model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3) LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) weights_baseline = model_baseline.construct_weights('context_baseline_{}'.format(frac)) X_feed = tf.placeholder(shape=(None, 2*FLAGS.num_filters), dtype=tf.float32) X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL) loss_sq = tf.reduce_mean(tf.square(X_out - X_label)) optimizer = AdamOptimizer(1e-3) gvs = optimizer.compute_gradients(loss_sq) gvs = [(k, v) for (k, v) in gvs if k is not None] train_op = optimizer.apply_gradients(gvs) dataloader = DataLoader(DSpritesGen(data, latents, frac=frac), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True) datafull = data itr = 0 saver = tf.train.Saver() vs = optimizer.variables() sess.run(tf.global_variables_initializer()) if FLAGS.train: for _ in range(5): for data_corrupt, data, label_size, label_pos in tqdm(dataloader): data_corrupt = data_corrupt.numpy() label_size, label_pos = label_size.numpy(), label_pos.numpy() data_corrupt = np.random.randn(data_corrupt.shape[0], 2*FLAGS.num_filters) label_comb = np.concatenate([label_size, label_pos], axis=1) feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb} output = [loss_sq, train_op] loss, _ = sess.run(output, feed_dict=feed_dict) itr += 1 saver.save(sess, osp.join(save_exp_dir, 'model_genbaseline')) saver.restore(sess, osp.join(save_exp_dir, 'model_genbaseline')) l = latents if FLAGS.joint_shape: mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5) else: mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31)))) data_gen = datafull[mask_gen] latents_gen = latents[mask_gen] losses = [] for dat, latent in zip(np.array_split(data_gen, 10), np.array_split(latents_gen, 10)): data_init = np.random.randn(dat.shape[0], 2*FLAGS.num_filters) if FLAGS.joint_shape: latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1] latent_pos = latent[:, 4:6] latent = np.concatenate([latent_size, latent_pos], axis=1) feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat} else: feed_dict = {X_feed: data_init, LABEL: latent[:, [2,4,5]], X_label: dat} loss = sess.run([loss_sq], feed_dict=feed_dict)[0] # print(loss) losses.append(loss) print("Overall MSE for generalization of {} for fraction of {}".format(np.mean(losses), frac)) data_try = data_gen[:10] data_init = np.random.randn(10, 2*FLAGS.num_filters) if FLAGS.joint_shape: latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1] latent_pos = latents_gen[:10, 4:] else: latent_scale = latents_gen[:10, 2:3] latent_pos = latents_gen[:10, 4:] latent_tot = np.concatenate([latent_scale, latent_pos], axis=1) feed_dict = {X_feed: data_init, LABEL: latent_tot} x_output = sess.run([X_out], feed_dict=feed_dict)[0] x_output = np.clip(x_output, 0, 1) im_name = "size_scale_combine_genbaseline.png" x_output_wrap = np.ones((10, 66, 66)) data_try_wrap = np.ones((10, 66, 66)) x_output_wrap[:, 1:-1, 1:-1] = x_output data_try_wrap[:, 1:-1, 1:-1] = data_try im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2) impath = osp.join(save_exp_dir, im_name) imsave(impath, im_output) print("Successfully saved images at {}".format(impath)) return np.mean(losses)
def main(): print("Local rank: ", hvd.local_rank(), hvd.size()) FLAGS.exp = FLAGS.exp + '_' + FLAGS.divergence logdir = osp.join(FLAGS.logdir, FLAGS.exp) if hvd.rank() == 0: if not osp.exists(logdir): os.makedirs(logdir) logger = TensorBoardOutputFormat(logdir) else: logger = None print("Loading data...") dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale) test_dataset = Cifar10(train=False, rescale=FLAGS.rescale) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) if FLAGS.large_model: model = ResNet32Large( num_channels=channel_num, num_filters=128, train=True) model_dis = ResNet32Large( num_channels=channel_num, num_filters=128, train=True) elif FLAGS.larger_model: model = ResNet32Larger( num_channels=channel_num, num_filters=128) model_dis = ResNet32Larger( num_channels=channel_num, num_filters=128) elif FLAGS.wider_model: model = ResNet32Wider( num_channels=channel_num, num_filters=256) model_dis = ResNet32Wider( num_channels=channel_num, num_filters=256) else: model = ResNet32( num_channels=channel_num, num_filters=128) model_dis = ResNet32( num_channels=channel_num, num_filters=128) print("Done loading...") grad_exp, conjugate_grad_exp = get_divergence_funcs(FLAGS.divergence) data_loader = DataLoader( dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=True, shuffle=True) weights = [model.construct_weights('context_energy'), model_dis.construct_weights('context_dis')] Y = tf.placeholder(shape=(None), dtype=tf.int32) # Varibles to run in training X_SPLIT = tf.split(X, FLAGS.num_gpus) X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus) LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus) LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus) LABEL_SPLIT_INIT = list(LABEL_SPLIT) tower_grads = [] tower_grads_dis = [] tower_grads_l2 = [] tower_grads_dis_l2 = [] optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999) optimizer = hvd.DistributedOptimizer(optimizer) optimizer_dis = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999) optimizer_dis = hvd.DistributedOptimizer(optimizer_dis) for j in range(FLAGS.num_gpus): energy_pos = [ model.forward( X_SPLIT[j], weights[0], label=LABEL_POS_SPLIT[j], stop_at_grad=False)] energy_pos = tf.concat(energy_pos, axis=0) score_pos = [ model_dis.forward( X_SPLIT[j], weights[1], label=LABEL_POS_SPLIT[j], stop_at_grad=False)] score_pos = tf.concat(score_pos, axis=0) print("Building graph...") x_mod = x_orig = X_NOISE_SPLIT[j] x_grads = [] energy_negs = [] loss_energys = [] energy_negs.extend([model.forward(tf.stop_gradient( x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)]) eps_begin = tf.zeros(1) steps = tf.constant(0) c = lambda i, x: tf.less(i, FLAGS.num_steps) def langevin_step(counter, x_mod): x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale) energy_noise = energy_start = tf.concat( [model.forward( x_mod, weights[0], label=LABEL_SPLIT[j], reuse=True, stop_at_grad=False, stop_batch=True)], axis=0) x_grad, label_grad = tf.gradients(energy_noise, [x_mod, LABEL_SPLIT[j]]) energy_noise_old = energy_noise lr = FLAGS.step_lr if FLAGS.proj_norm != 0.0: if FLAGS.proj_norm_type == 'l2': x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm) elif FLAGS.proj_norm_type == 'li': x_grad = tf.clip_by_value( x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) else: print("Other types of projection are not supported!!!") assert False # Clip gradient norm for now if FLAGS.hmc: # Step size should be tuned to get around 65% acceptance def energy(x): return FLAGS.temperature * \ model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True) x_last = hmc(x_mod, 15., 10, energy) else: x_last = x_mod - (lr) * x_grad x_mod = x_last x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale) counter = counter + 1 return counter, x_mod steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod)) energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) x_grad = tf.gradients(energy_eval, [x_mod])[0] x_grads.append(x_grad) energy_negs.append( model.forward( tf.stop_gradient(x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)) score_neg = model_dis.forward( tf.stop_gradient(x_mod), weights[1], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) test_x_mod = x_mod temp = FLAGS.temperature energy_neg = energy_negs[-1] x_off = tf.reduce_mean( tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j])) loss_energy = model.forward( x_mod, weights[0], reuse=True, label=LABEL, stop_grad=True) print("Finished processing loop construction ...") target_vars = {} if FLAGS.cclass or FLAGS.model_cclass: label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0) label_prob = label_sum / tf.reduce_sum(label_sum) label_ent = -tf.reduce_sum(label_prob * tf.math.log(label_prob + 1e-7)) else: label_ent = tf.zeros(1) target_vars['label_ent'] = label_ent if FLAGS.train: loss_dis = - (tf.reduce_mean(grad_exp(score_pos + energy_pos)) - tf.reduce_mean(conjugate_grad_exp(score_neg + energy_neg))) loss_dis = loss_dis + FLAGS.l2_coeff * (tf.reduce_mean(tf.square(score_pos)) + tf.reduce_mean(tf.square(score_neg))) l2_dis = FLAGS.l2_coeff * (tf.reduce_mean(tf.square(score_pos)) + tf.reduce_mean(tf.square(score_neg))) loss_model = tf.reduce_mean(grad_exp(score_pos + energy_pos)) + \ tf.reduce_mean(energy_neg * tf.stop_gradient(conjugate_grad_exp(score_neg + energy_neg))) - \ tf.reduce_mean(energy_neg) * tf.stop_gradient(tf.reduce_mean(conjugate_grad_exp(score_neg + energy_neg))) - \ tf.reduce_mean(conjugate_grad_exp(score_neg + energy_neg)) loss_model = loss_model + FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg))) l2_model = FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg))) print("Started gradient computation...") model_vars = [var for var in tf.trainable_variables() if 'context_energy' in var.name] print("model var number", len(model_vars)) dis_vars = [var for var in tf.trainable_variables() if 'context_dis' in var.name] print("discriminator var number", len(dis_vars)) gvs = optimizer.compute_gradients(loss_model, model_vars) gvs = [(k, v) for (k, v) in gvs if k is not None] tower_grads.append(gvs) gvs = optimizer.compute_gradients(l2_model, model_vars) gvs = [(k, v) for (k, v) in gvs if k is not None] tower_grads_l2.append(gvs) gvs_dis = optimizer_dis.compute_gradients(loss_dis, dis_vars) gvs_dis = [(k, v) for (k, v) in gvs_dis if k is not None] tower_grads_dis.append(gvs_dis) gvs_dis = optimizer_dis.compute_gradients(l2_dis, dis_vars) gvs_dis = [(k, v) for (k, v) in gvs_dis if k is not None] tower_grads_dis_l2.append(gvs_dis) print("Finished applying gradients.") target_vars['total_loss'] = loss_model target_vars['loss_energy'] = loss_energy target_vars['weights'] = weights target_vars['gvs'] = gvs target_vars['X'] = X target_vars['Y'] = Y target_vars['LABEL'] = LABEL target_vars['LABEL_POS'] = LABEL_POS target_vars['X_NOISE'] = X_NOISE target_vars['energy_pos'] = energy_pos target_vars['energy_start'] = energy_negs[0] if len(x_grads) >= 1: target_vars['x_grad'] = x_grads[-1] target_vars['x_grad_first'] = x_grads[0] else: target_vars['x_grad'] = tf.zeros(1) target_vars['x_grad_first'] = tf.zeros(1) target_vars['x_mod'] = x_mod target_vars['x_off'] = x_off target_vars['temp'] = temp target_vars['energy_neg'] = energy_neg target_vars['test_x_mod'] = test_x_mod target_vars['eps_begin'] = eps_begin target_vars['score_neg'] = score_neg target_vars['score_pos'] = score_pos if FLAGS.train: grads_model = average_gradients(tower_grads) train_op_model = optimizer.apply_gradients(grads_model) target_vars['train_op_model'] = train_op_model grads_model_l2 = average_gradients(tower_grads_l2) train_op_model_l2 = optimizer.apply_gradients(grads_model_l2) target_vars['train_op_model_l2'] = train_op_model_l2 grads_model_dis = average_gradients(tower_grads_dis) train_op_dis = optimizer_dis.apply_gradients(grads_model_dis) target_vars['train_op_dis'] = train_op_dis grads_model_dis_l2 = average_gradients(tower_grads_dis_l2) train_op_dis_l2 = optimizer_dis.apply_gradients(grads_model_dis_l2) target_vars['train_op_dis_l2'] = train_op_dis_l2 config = tf.ConfigProto() if hvd.size() > 1: config.gpu_options.visible_device_list = str(hvd.local_rank()) sess = tf.Session(config=config) saver = loader = tf.train.Saver(max_to_keep=500) total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Model has a total of {} parameters".format(total_parameters)) sess.run(tf.global_variables_initializer()) resume_itr = 0 if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0: model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter saver.restore(sess, model_file) # optimistic_restore(sess, model_file) sess.run(hvd.broadcast_global_variables(0)) print("Initializing variables...") print("Start broadcast") print("End broadcast") if FLAGS.train: train(target_vars, saver, sess, logger, data_loader, resume_itr, logdir) test(target_vars, saver, sess, logger, data_loader)
def construct_steps_dual(weights_pos, weights_color, model_pos, model_color, target_vars): steps = tf.constant(0) STEPS = target_vars['STEPS'] c = lambda i, x: tf.less(i, STEPS) X, Y_first, Y_second = target_vars['X'], target_vars['Y_first'], target_vars['Y_second'] X_feed = target_vars['X_feed'] attention_mask = tf.zeros(1) def langevin_step(counter, x_mod): x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.001) # latent = latent + tf.random_normal(tf.shape(latent), mean=0.0, stddev=0.01) energy_noise = model_pos.forward( x_mod, weights_pos, label=Y_first, reuse=True, stop_at_grad=False, stop_batch=True, attention_mask=attention_mask) # energy_noise = tf.Print(energy_noise, [energy_noise, x_mod, latent]) x_grad = tf.gradients(energy_noise, [x_mod])[0] x_mod = x_mod - FLAGS.step_lr * x_grad x_mod = tf.clip_by_value(x_mod, 0, 1) x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.001) energy_noise = 1.0 * model_color.forward( x_mod, weights_color, label=Y_second, reuse=True, stop_at_grad=False, stop_batch=True, attention_mask=attention_mask) x_grad = tf.gradients(energy_noise, [x_mod])[0] x_mod = x_mod - FLAGS.step_lr * x_grad x_mod = tf.clip_by_value(x_mod, 0, 1) counter = counter + 1 # counter = tf.Print(counter, [counter], message="step") return counter, x_mod def langevin_merge_step(counter, x_mod): x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.001) # latent = latent + tf.random_normal(tf.shape(latent), mean=0.0, stddev=0.01) energy_noise = 1 * model_pos.forward( x_mod, weights_pos, label=Y_first, reuse=True, stop_at_grad=False, stop_batch=True, attention_mask=attention_mask) + \ 1.0 * model_color.forward( x_mod, weights_color, label=Y_second, reuse=True, stop_at_grad=False, stop_batch=True, attention_mask=attention_mask) # energy_noise = tf.Print(energy_noise, [energy_noise, x_mod, latent]) x_grad = tf.gradients(energy_noise, [x_mod])[0] x_mod = x_mod - FLAGS.step_lr * x_grad x_mod = tf.clip_by_value(x_mod, 0, 1) counter = counter + 1 return counter, x_mod steps, x_mod = tf.while_loop(c, langevin_merge_step, (steps, X)) energy_pos = model_pos.forward( tf.stop_gradient(x_mod), weights_pos, label=Y_first, reuse=True, stop_at_grad=False, stop_batch=True, attention_mask=attention_mask) energy_color = model_color.forward( tf.stop_gradient(x_mod), weights_color, label=Y_second, reuse=True, stop_at_grad=False, stop_batch=True, attention_mask=attention_mask) energy_plus_pos = model_pos.forward( X_feed, weights_pos, label=Y_first, reuse=True, stop_at_grad=False, stop_batch=True, attention_mask=attention_mask) energy_plus_color = model_color.forward( X_feed, weights_color, label=Y_second, reuse=True, stop_at_grad=False, stop_batch=True, attention_mask=attention_mask) energy_neg = -tf.reduce_mean(tf.reduce_mean(energy_pos) + tf.reduce_mean(energy_color)) energy_plus = tf.reduce_mean(tf.reduce_mean(energy_plus_pos) + tf.reduce_mean(energy_plus_color)) loss_l2 = tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_color)) + tf.reduce_mean(tf.square(energy_plus_pos)) + tf.reduce_mean(tf.square(energy_plus_color)) loss_total = energy_plus + energy_neg + loss_l2 optimizer = AdamOptimizer(1e-4, beta1=0.0, beta2=0.99) gvs = optimizer.compute_gradients(loss_total) train_op = optimizer.apply_gradients(gvs) x_off = tf.reduce_mean(tf.abs(X_feed - x_mod)) target_vars['X_final'] = x_mod target_vars['x_off'] = x_off target_vars['train_op'] = train_op target_vars['energy_neg'] = -energy_neg target_vars['energy_pos'] = energy_plus