def model(sess, image): global _inception_initialized network_fn = _get_model(reuse=_inception_initialized) size = network_fn.default_image_size preprocessed = _preprocess(image, size, size) logits, _ = network_fn(preprocessed) logits = logits[:,1:] # ignore background class predictions = tf.argmax(logits, 1) if not _inception_initialized: optimistic_restore(sess, INCEPTION_CHECKPOINT_PATH) _inception_initialized = True return logits, predictions
def model(sess, image): global _inception_initialized network_fn = _get_model(reuse=_inception_initialized) size = network_fn.default_image_size preprocessed = _preprocess(image, size, size) logits, _ = network_fn(preprocessed) logits = logits[:, 1:] # ignore background class predictions = tf.argmax(logits, 1) if not _inception_initialized: optimistic_restore(sess, INCEPTION_CHECKPOINT_PATH) _inception_initialized = True return logits, predictions
def _init_model(sess, checkpoint_name=None): global _model_func global _obs_shape global _model_opt if checkpoint_name is None: checkpoint_name = _PIXELCNN_CHECKPOINT_NAME checkpoint_path = os.path.join(DATA_DIR, checkpoint_name) x_init = tf.placeholder(tf.float32, (1,) + _obs_shape) model = _model_func(x_init, init=True, dropout_p=0.5, **_model_opt) # XXX need to add a scope argument to optimistic_restore and filter for # things that start with "{scope}/", so we can filter for "model/", because # the pixelcnn checkpoint has some random unscoped stuff like 'Variable' optimistic_restore(sess, checkpoint_path)
def main(): logdir = osp.join(FLAGS.logdir, FLAGS.exp) logger = TensorBoardOutputFormat(logdir) config = tf.ConfigProto() sess = tf.Session(config=config) LABEL = None print("Loading data...") if FLAGS.dataset == 'cubes': dataset = Cubes(cond_idx=FLAGS.cond_idx) test_dataset = dataset if FLAGS.cond_idx == 0: label_size = 2 elif FLAGS.cond_idx == 1: label_size = 1 elif FLAGS.cond_idx == 2: label_size = 3 elif FLAGS.cond_idx == 3: label_size = 20 LABEL = tf.placeholder(shape=(None, label_size), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, label_size), dtype=tf.float32) elif FLAGS.dataset == 'color': dataset = CubesColor() test_dataset = dataset LABEL = tf.placeholder(shape=(None, 301), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 301), dtype=tf.float32) label_size = 301 elif FLAGS.dataset == 'pos': dataset = CubesPos() test_dataset = dataset LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) label_size = 2 elif FLAGS.dataset == "pairs": dataset = Pairs(cond_idx=0) test_dataset = dataset LABEL = tf.placeholder(shape=(None, 6), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 6), dtype=tf.float32) label_size = 6 elif FLAGS.dataset == "continual": dataset = CubesContinual() test_dataset = dataset if FLAGS.prelearn_model_shape: LABEL = tf.placeholder(shape=(None, 20), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 20), dtype=tf.float32) label_size = 20 else: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) label_size = 2 elif FLAGS.dataset == "cross": dataset = CubesCrossProduct(FLAGS.ratio, cond_size=FLAGS.cond_size, cond_pos=FLAGS.cond_pos, joint_baseline=FLAGS.joint_baseline) test_dataset = dataset if FLAGS.cond_size: LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32) label_size = 1 elif FLAGS.cond_pos: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) label_size = 2 if FLAGS.joint_baseline: LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32) label_size = 3 elif FLAGS.dataset == 'celeba': dataset = CelebA(cond_idx=FLAGS.celeba_cond_idx) test_dataset = dataset channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) model = ResNet128( num_channels=channel_num, num_filters=64, classes=2) if FLAGS.joint_baseline: # Other stuff for joint model optimizer = AdamOptimizer(FLAGS.lr, beta1=0.99, beta2=0.999) X = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32) X_NOISE = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32) ATTENTION_MASK = tf.placeholder(shape=(None, 64, 64, FLAGS.cond_func), dtype=tf.float32) NOISE = tf.placeholder(shape=(None, 128), dtype=tf.float32) HIER_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) channel_num = 3 model = CubesNetGen(num_channels=channel_num, label_size=label_size) weights = model.construct_weights('context_0') output = model.forward(NOISE, weights, reuse=False, label=LABEL) print(output.get_shape()) mse_loss = tf.reduce_mean(tf.square(output - X)) gvs = optimizer.compute_gradients(mse_loss) train_op = optimizer.apply_gradients(gvs) gvs = [(k, v) for (k, v) in gvs if k is not None] target_vars = {} target_vars['train_op'] = train_op target_vars['X'] = X target_vars['X_NOISE'] = X_NOISE target_vars['ATTENTION_MASK'] = ATTENTION_MASK target_vars['eps_begin'] = tf.zeros(1) target_vars['gvs'] = gvs target_vars['energy_pos'] = tf.zeros(1) target_vars['energy_neg'] = tf.zeros(1) target_vars['loss_energy'] = tf.zeros(1) target_vars['loss_ml'] = tf.zeros(1) target_vars['total_loss'] = mse_loss target_vars['attention_mask'] = tf.zeros(1) target_vars['attention_grad'] = tf.zeros(1) target_vars['x_off'] = tf.reduce_mean(tf.abs(output - X)) target_vars['x_mod'] = tf.zeros(1) target_vars['x_grad'] = tf.zeros(1) target_vars['NOISE'] = NOISE target_vars['LABEL'] = LABEL target_vars['LABEL_POS'] = LABEL_POS target_vars['HIER_LABEL'] = HIER_LABEL data_loader = DataLoader( dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=True, shuffle=True) else: print("label size here ", label_size) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32) HEIR_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) ATTENTION_MASK = tf.placeholder(shape=(None, 64, 64, FLAGS.cond_func), dtype=tf.float32) if FLAGS.dataset != "celeba": model = CubesNet(num_channels=channel_num, label_size=label_size) heir_model = HeirNet(num_channels=FLAGS.cond_func) models_pretrain = [] if FLAGS.prelearn_model: model_prelearn = CubesNet(num_channels=channel_num, label_size=FLAGS.prelearn_label) weights = model_prelearn.construct_weights('context_1') LABEL_PRELEARN = tf.placeholder(shape=(None, FLAGS.prelearn_label), dtype=tf.float32) models_pretrain.append((model_prelearn, weights, LABEL_PRELEARN)) cubes_logdir = osp.join(FLAGS.logdir, FLAGS.prelearn_exp) if (FLAGS.prelearn_iter != -1 or not FLAGS.train): model_file = osp.join(cubes_logdir, 'model_{}'.format(FLAGS.prelearn_iter)) resume_itr = FLAGS.resume_iter # saver.restore(sess, model_file) v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1)) v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list} saver = tf.train.Saver(v_map) saver.restore(sess, model_file) if FLAGS.prelearn_model_shape: model_prelearn = CubesNet(num_channels=channel_num, label_size=FLAGS.prelearn_label_shape) weights = model_prelearn.construct_weights('context_2') LABEL_PRELEARN = tf.placeholder(shape=(None, FLAGS.prelearn_label_shape), dtype=tf.float32) models_pretrain.append((model_prelearn, weights, LABEL_PRELEARN)) cubes_logdir = osp.join(FLAGS.logdir, FLAGS.prelearn_exp_shape) if (FLAGS.prelearn_iter_shape != -1 or not FLAGS.train): model_file = osp.join(cubes_logdir, 'model_{}'.format(FLAGS.prelearn_iter_shape)) resume_itr = FLAGS.resume_iter # saver.restore(sess, model_file) v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2)) v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list} saver = tf.train.Saver(v_map) saver.restore(sess, model_file) print("Done loading...") data_loader = DataLoader( dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=True, shuffle=True) batch_size = FLAGS.batch_size weights = model.construct_weights('context_0') if FLAGS.heir_mask: weights = heir_model.construct_weights('heir_0', weights=weights) Y = tf.placeholder(shape=(None), dtype=tf.int32) # Varibles to run in training X_SPLIT = tf.split(X, FLAGS.num_gpus) X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus) LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus) LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus) LABEL_SPLIT_INIT = list(LABEL_SPLIT) attention_mask = ATTENTION_MASK tower_grads = [] tower_gen_grads = [] x_mod_list = [] optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.99) for j in range(FLAGS.num_gpus): x_mod = X_SPLIT[j] if FLAGS.comb_mask: steps = tf.constant(0) c = lambda i, x: tf.less(i, FLAGS.num_steps) def langevin_attention_step(counter, attention_mask): attention_mask = attention_mask + tf.random_normal(tf.shape(attention_mask), mean=0.0, stddev=0.01) energy_noise = energy_start = model.forward( x_mod, weights, attention_mask, label=LABEL_SPLIT[j], reuse=True, stop_at_grad=False, stop_batch=True) if FLAGS.heir_mask: energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL) energy_noise = energy_noise + energy_heir attention_grad = tf.gradients( FLAGS.temperature * energy_noise, [attention_mask])[0] energy_noise_old = energy_noise # Clip gradient norm for now attention_mask = attention_mask - (FLAGS.attention_lr) * attention_grad attention_mask = tf.layers.average_pooling2d(attention_mask, (3, 3), 1, padding='SAME') attention_mask = tf.stop_gradient(attention_mask) counter = counter + 1 return counter, attention_mask steps, attention_mask = tf.while_loop(c, langevin_attention_step, (steps, attention_mask)) # attention_mask = tf.Print(attention_mask, [attention_mask]) energy_pos = model.forward( X_SPLIT[j], weights, tf.stop_gradient(attention_mask), label=LABEL_POS_SPLIT[j], stop_at_grad=False) if FLAGS.heir_mask: energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL) energy_pos = energy_heir + energy_pos else: energy_pos = model.forward( X_SPLIT[j], weights, attention_mask, label=LABEL_POS_SPLIT[j], stop_at_grad=False) if FLAGS.heir_mask: energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL) energy_pos = energy_heir + energy_pos print("Building graph...") x_mod = x_orig = X_NOISE_SPLIT[j] x_grads = [] loss_energys = [] eps_begin = tf.zeros(1) steps = tf.constant(0) c_cond = lambda i, x, y: tf.less(i, FLAGS.num_steps) def langevin_step(counter, x_mod, attention_mask): lr = FLAGS.step_lr x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.001 * FLAGS.rescale * FLAGS.noise_scale) attention_mask = attention_mask + tf.random_normal(tf.shape(attention_mask), mean=0.0, stddev=0.01) energy_noise = model.forward( x_mod, weights, attention_mask, label=LABEL_SPLIT[j], reuse=True, stop_at_grad=False, stop_batch=True) if FLAGS.prelearn_model: for m_i, w_i, l_i in models_pretrain: energy_noise = energy_noise + m_i.forward( x_mod, w_i, attention_mask, label=l_i, reuse=True, stop_at_grad=False, stop_batch=True) if FLAGS.heir_mask: energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL) energy_noise = energy_heir + energy_noise x_grad, attention_grad = tf.gradients( FLAGS.temperature * energy_noise, [x_mod, attention_mask]) if not FLAGS.comb_mask: attention_grad = tf.zeros(1) energy_noise_old = energy_noise if FLAGS.proj_norm != 0.0: if FLAGS.proj_norm_type == 'l2': x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm) elif FLAGS.proj_norm_type == 'li': x_grad = tf.clip_by_value( x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) else: print("Other types of projection are not supported!!!") assert False # Clip gradient norm for now x_last = x_mod - (lr) * x_grad if FLAGS.comb_mask: attention_mask = attention_mask - FLAGS.attention_lr * attention_grad attention_mask = tf.layers.average_pooling2d(attention_mask, (3, 3), 1, padding='SAME') attention_mask = tf.stop_gradient(attention_mask) x_mod = x_last x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale) counter = counter + 1 return counter, x_mod, attention_mask steps, x_mod, attention_mask = tf.while_loop(c_cond, langevin_step, (steps, x_mod, attention_mask)) attention_mask = tf.stop_gradient(attention_mask) # attention_mask = tf.Print(attention_mask, [attention_mask]) energy_eval = model.forward(x_mod, weights, attention_mask, label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) x_grad, attention_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod, attention_mask]) x_grads.append(x_grad) energy_neg = model.forward( tf.stop_gradient(x_mod), weights, tf.stop_gradient(attention_mask), label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) if FLAGS.heir_mask: energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL) energy_neg = energy_heir + energy_neg temp = FLAGS.temperature x_off = tf.reduce_mean( tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j])) loss_energy = model.forward( x_mod, weights, attention_mask, reuse=True, label=LABEL, stop_grad=True) print("Finished processing loop construction ...") target_vars = {} if FLAGS.antialias: antialias = tf.tile(stride_3, (1, 1, tf.shape(x_mod)[3], tf.shape(x_mod)[3])) inp = tf.nn.conv2d(x_mod, antialias, [1, 2, 2, 1], padding='SAME') test_x_mod = x_mod if FLAGS.cclass or FLAGS.model_cclass: label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0) label_prob = label_sum / tf.reduce_sum(label_sum) label_ent = -tf.reduce_sum(label_prob * tf.math.log(label_prob + 1e-7)) else: label_ent = tf.zeros(1) target_vars['label_ent'] = label_ent if FLAGS.train: if FLAGS.objective == 'logsumexp': pos_term = temp * energy_pos energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg)) coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced)) norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4 pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'cd': pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = -tf.reduce_mean(temp * energy_neg) loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'softplus': loss_ml = FLAGS.ml_coeff * \ tf.nn.softplus(temp * (energy_pos - energy_neg)) loss_total = tf.reduce_mean(loss_ml) if not FLAGS.zero_kl: loss_total = loss_total + tf.reduce_mean(loss_energy) loss_total = loss_total + \ FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg)))) print("Started gradient computation...") gvs = optimizer.compute_gradients(loss_total) gvs = [(k, v) for (k, v) in gvs if k is not None] print("Applying gradients...") tower_grads.append(gvs) print("Finished applying gradients.") target_vars['loss_ml'] = loss_ml target_vars['total_loss'] = loss_total target_vars['loss_energy'] = loss_energy target_vars['weights'] = weights target_vars['gvs'] = gvs target_vars['X'] = X target_vars['Y'] = Y target_vars['LABEL'] = LABEL target_vars['HIER_LABEL'] = HEIR_LABEL target_vars['LABEL_POS'] = LABEL_POS target_vars['X_NOISE'] = X_NOISE target_vars['energy_pos'] = energy_pos target_vars['attention_grad'] = attention_grad if len(x_grads) >= 1: target_vars['x_grad'] = x_grads[-1] target_vars['x_grad_first'] = x_grads[0] else: target_vars['x_grad'] = tf.zeros(1) target_vars['x_grad_first'] = tf.zeros(1) target_vars['x_mod'] = x_mod target_vars['x_off'] = x_off target_vars['temp'] = temp target_vars['energy_neg'] = energy_neg target_vars['test_x_mod'] = test_x_mod target_vars['eps_begin'] = eps_begin target_vars['ATTENTION_MASK'] = ATTENTION_MASK target_vars['models_pretrain'] = models_pretrain if FLAGS.comb_mask: target_vars['attention_mask'] = tf.nn.softmax(attention_mask) else: target_vars['attention_mask'] = tf.zeros(1) if FLAGS.train: grads = average_gradients(tower_grads) train_op = optimizer.apply_gradients(grads) target_vars['train_op'] = train_op # sess = tf.Session(config=config) saver = loader = tf.train.Saver( max_to_keep=30, keep_checkpoint_every_n_hours=6) total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Model has a total of {} parameters".format(total_parameters)) sess.run(tf.global_variables_initializer()) resume_itr = 0 if (FLAGS.resume_iter != -1 or not FLAGS.train): model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter # saver.restore(sess, model_file) optimistic_restore(sess, model_file) print("Initializing variables...") print("Start broadcast") print("End broadcast") if FLAGS.train: train(target_vars, saver, sess, logger, data_loader, resume_itr, logdir) test(target_vars, saver, sess, logger, data_loader)
def main(): print("Local rank: ", hvd.local_rank(), hvd.size()) logdir = osp.join(FLAGS.logdir, FLAGS.exp) if hvd.rank() == 0: if not osp.exists(logdir): os.makedirs(logdir) logger = TensorBoardOutputFormat(logdir) else: logger = None LABEL = None print("Loading data...") if FLAGS.dataset == 'cifar10': dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale) test_dataset = Cifar10(train=False, rescale=FLAGS.rescale) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) if FLAGS.large_model: model = ResNet32Large(num_channels=channel_num, num_filters=128, train=True) elif FLAGS.larger_model: model = ResNet32Larger(num_channels=channel_num, num_filters=128) elif FLAGS.wider_model: model = ResNet32Wider(num_channels=channel_num, num_filters=192) else: model = ResNet32(num_channels=channel_num, num_filters=128) elif FLAGS.dataset == 'imagenet': dataset = Imagenet(train=True) test_dataset = Imagenet(train=False) channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32) model = ResNet32Wider(num_channels=channel_num, num_filters=256) elif FLAGS.dataset == 'imagenetfull': channel_num = 3 X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32) model = ResNet128(num_channels=channel_num, num_filters=64) elif FLAGS.dataset == 'mnist': dataset = Mnist(rescale=FLAGS.rescale) test_dataset = dataset channel_num = 1 X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32) X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) model = MnistNet(num_channels=channel_num, num_filters=FLAGS.num_filters) elif FLAGS.dataset == 'dsprites': dataset = DSprites(cond_shape=FLAGS.cond_shape, cond_size=FLAGS.cond_size, cond_pos=FLAGS.cond_pos, cond_rot=FLAGS.cond_rot) test_dataset = dataset channel_num = 1 X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) if FLAGS.dpos_only: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.dsize_only: LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32) elif FLAGS.drot_only: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.cond_size: LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32) elif FLAGS.cond_shape: LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32) elif FLAGS.cond_pos: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) elif FLAGS.cond_rot: LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) else: LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32) model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters, cond_size=FLAGS.cond_size, cond_shape=FLAGS.cond_shape, cond_pos=FLAGS.cond_pos, cond_rot=FLAGS.cond_rot) print("Done loading...") if FLAGS.dataset == "imagenetfull": # In the case of full imagenet, use custom_tensorflow dataloader data_loader = TFImagenetLoader('train', FLAGS.batch_size, hvd.rank(), hvd.size(), rescale=FLAGS.rescale) else: data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=True, shuffle=True) batch_size = FLAGS.batch_size weights = [model.construct_weights('context_0')] Y = tf.placeholder(shape=(None), dtype=tf.int32) # Varibles to run in training X_SPLIT = tf.split(X, FLAGS.num_gpus) X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus) LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus) LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus) LABEL_SPLIT_INIT = list(LABEL_SPLIT) tower_grads = [] tower_gen_grads = [] x_mod_list = [] optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999) optimizer = hvd.DistributedOptimizer(optimizer) for j in range(FLAGS.num_gpus): if FLAGS.model_cclass: ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus label_tensor = tf.Variable(tf.convert_to_tensor(np.reshape( np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)), (FLAGS.batch_size * 10, 10)), dtype=tf.float32), trainable=False, dtype=tf.float32) x_split = tf.tile( tf.reshape(X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)), (1, 10, 1, 1, 1)) x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3)) energy_pos = model.forward(x_split, weights[0], label=label_tensor, stop_at_grad=False) energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10)) energy_partition_est = tf.reduce_logsumexp(energy_pos_full, axis=1, keepdims=True) uniform = tf.random_uniform(tf.shape(energy_pos_full)) label_tensor = tf.argmax(-energy_pos_full - tf.log(-tf.log(uniform)) - energy_partition_est, axis=1) label = tf.one_hot(label_tensor, 10, dtype=tf.float32) label = tf.Print(label, [label_tensor, energy_pos_full]) LABEL_SPLIT[j] = label energy_pos = tf.concat(energy_pos, axis=0) else: energy_pos = [ model.forward(X_SPLIT[j], weights[0], label=LABEL_POS_SPLIT[j], stop_at_grad=False) ] energy_pos = tf.concat(energy_pos, axis=0) print("Building graph...") x_mod = x_orig = X_NOISE_SPLIT[j] x_grads = [] energy_negs = [] loss_energys = [] energy_negs.extend([ model.forward(tf.stop_gradient(x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) ]) eps_begin = tf.zeros(1) steps = tf.constant(0) c = lambda i, x: tf.less(i, FLAGS.num_steps) def langevin_step(counter, x_mod): x_mod = x_mod + tf.random_normal( tf.shape(x_mod), mean=0.0, stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale) energy_noise = energy_start = tf.concat([ model.forward(x_mod, weights[0], label=LABEL_SPLIT[j], reuse=True, stop_at_grad=False, stop_batch=True) ], axis=0) x_grad, label_grad = tf.gradients(FLAGS.temperature * energy_noise, [x_mod, LABEL_SPLIT[j]]) energy_noise_old = energy_noise lr = FLAGS.step_lr if FLAGS.proj_norm != 0.0: if FLAGS.proj_norm_type == 'l2': x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm) elif FLAGS.proj_norm_type == 'li': x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) else: print("Other types of projection are not supported!!!") assert False # Clip gradient norm for now if FLAGS.hmc: # Step size should be tuned to get around 65% acceptance def energy(x): return FLAGS.temperature * \ model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True) x_last = hmc(x_mod, 15., 10, energy) else: x_last = x_mod - (lr) * x_grad x_mod = x_last x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale) counter = counter + 1 return counter, x_mod steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod)) energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0] x_grads.append(x_grad) energy_negs.append( model.forward(tf.stop_gradient(x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)) test_x_mod = x_mod temp = FLAGS.temperature energy_neg = energy_negs[-1] x_off = tf.reduce_mean( tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j])) loss_energy = model.forward(x_mod, weights[0], reuse=True, label=LABEL, stop_grad=True) print("Finished processing loop construction ...") target_vars = {} if FLAGS.cclass or FLAGS.model_cclass: label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0) label_prob = label_sum / tf.reduce_sum(label_sum) label_ent = -tf.reduce_sum( label_prob * tf.math.log(label_prob + 1e-7)) else: label_ent = tf.zeros(1) target_vars['label_ent'] = label_ent if FLAGS.train: if FLAGS.objective == 'logsumexp': pos_term = temp * energy_pos energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg)) coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced)) norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4 pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'cd': pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = -tf.reduce_mean(temp * energy_neg) loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'softplus': loss_ml = FLAGS.ml_coeff * \ tf.nn.softplus(temp * (energy_pos - energy_neg)) loss_total = tf.reduce_mean(loss_ml) if not FLAGS.zero_kl: loss_total = loss_total + tf.reduce_mean(loss_energy) loss_total = loss_total + \ FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg)))) print("Started gradient computation...") gvs = optimizer.compute_gradients(loss_total) gvs = [(k, v) for (k, v) in gvs if k is not None] print("Applying gradients...") tower_grads.append(gvs) print("Finished applying gradients.") target_vars['loss_ml'] = loss_ml target_vars['total_loss'] = loss_total target_vars['loss_energy'] = loss_energy target_vars['weights'] = weights target_vars['gvs'] = gvs target_vars['X'] = X target_vars['Y'] = Y target_vars['LABEL'] = LABEL target_vars['LABEL_POS'] = LABEL_POS target_vars['X_NOISE'] = X_NOISE target_vars['energy_pos'] = energy_pos target_vars['energy_start'] = energy_negs[0] if len(x_grads) >= 1: target_vars['x_grad'] = x_grads[-1] target_vars['x_grad_first'] = x_grads[0] else: target_vars['x_grad'] = tf.zeros(1) target_vars['x_grad_first'] = tf.zeros(1) target_vars['x_mod'] = x_mod target_vars['x_off'] = x_off target_vars['temp'] = temp target_vars['energy_neg'] = energy_neg target_vars['test_x_mod'] = test_x_mod target_vars['eps_begin'] = eps_begin if FLAGS.train: grads = average_gradients(tower_grads) train_op = optimizer.apply_gradients(grads) target_vars['train_op'] = train_op config = tf.ConfigProto() if hvd.size() > 1: config.gpu_options.visible_device_list = str(hvd.local_rank()) sess = tf.Session(config=config) saver = loader = tf.train.Saver(max_to_keep=30, keep_checkpoint_every_n_hours=6) total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Model has a total of {} parameters".format(total_parameters)) sess.run(tf.global_variables_initializer()) resume_itr = 0 if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0: model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter # saver.restore(sess, model_file) optimistic_restore(sess, model_file) sess.run(hvd.broadcast_global_variables(0)) print("Initializing variables...") print("Start broadcast") print("End broadcast") if FLAGS.train: print("Training phase") train(target_vars, saver, sess, logger, data_loader, resume_itr, logdir) print("Testing phase") test(target_vars, saver, sess, logger, data_loader)
def main(): # Initialize dataset if FLAGS.dataset == 'cifar10': dataset = Cifar10(train=False, rescale=FLAGS.rescale) channel_num = 3 dim_input = 32 * 32 * 3 elif FLAGS.dataset == 'imagenet': dataset = ImagenetClass() channel_num = 3 dim_input = 64 * 64 * 3 elif FLAGS.dataset == 'mnist': dataset = Mnist(train=False, rescale=FLAGS.rescale) channel_num = 1 dim_input = 28 * 28 * 1 elif FLAGS.dataset == 'dsprites': dataset = DSprites() channel_num = 1 dim_input = 64 * 64 * 1 elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss': dataset = Box2D() dim_output = 1 data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=False, shuffle=True) if FLAGS.dataset == 'mnist': model = MnistNet(num_channels=channel_num) elif FLAGS.dataset == 'cifar10': if FLAGS.large_model: model = ResNet32Large(num_filters=128) elif FLAGS.wider_model: model = ResNet32Wider(num_filters=192) else: model = ResNet32(num_channels=channel_num, num_filters=128) elif FLAGS.dataset == 'dsprites': model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters) weights = model.construct_weights('context_{}'.format(0)) config = tf.ConfigProto() sess = tf.Session(config=config) saver = loader = tf.train.Saver(max_to_keep=10) sess.run(tf.global_variables_initializer()) logdir = osp.join(FLAGS.logdir, FLAGS.exp) model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter if FLAGS.resume_iter != "-1": optimistic_restore(sess, model_file) else: print("WARNING, YOU ARE NOT LOADING A SAVE FILE") # saver.restore(sess, model_file) chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample( model, weights, FLAGS.batch_size, temp=FLAGS.temperature) print("Finished constructing ancestral sample ...................") if FLAGS.dataset != "gauss": comb_weights_cum = [] batch_size = tf.shape(x_init)[0] label_tiled = tf.tile(label_default, (batch_size, 1)) e_compute = -FLAGS.temperature * model.forward( x_init, weights, label=label_tiled) e_pos_list = [] for data_corrupt, data, label_gt in tqdm(data_loader): e_pos = sess.run([e_compute], {x_init: data})[0] e_pos_list.extend(list(e_pos)) print(len(e_pos_list)) print("Positive sample probability ", np.mean(e_pos_list), np.std(e_pos_list)) if FLAGS.dataset == "2d": alr = 0.0045 elif FLAGS.dataset == "gauss": alr = 0.0085 elif FLAGS.dataset == "mnist": alr = 0.0065 #90 alr = 0.0035 else: # alr = 0.0125 if FLAGS.rescale == 8: alr = 0.0085 else: alr = 0.0045 # for i in range(1): tot_weight = 0 for j in tqdm(range(1, FLAGS.pdist + 1)): if j == 1: if FLAGS.dataset == "cifar10": x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3)) elif FLAGS.dataset == "gauss": x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim)) elif FLAGS.dataset == "mnist": x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28)) else: x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 2)) alpha_prev = (j - 1) / FLAGS.pdist alpha_new = j / FLAGS.pdist cweight, x_curr = sess.run( [chain_weights, x], { a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5**(2.5 * -alpha_prev)) }) tot_weight = tot_weight + cweight print("Total values of lower value based off forward sampling", np.mean(tot_weight), np.std(tot_weight)) tot_weight = 0 for j in tqdm(range(FLAGS.pdist, 0, -1)): alpha_new = (j - 1) / FLAGS.pdist alpha_prev = j / FLAGS.pdist cweight, x_curr = sess.run( [chain_weights, x], { a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5**(2.5 * -alpha_prev)) }) tot_weight = tot_weight - cweight print("Total values of upper value based off backward sampling", np.mean(tot_weight), np.std(tot_weight))
def main(): if FLAGS.dataset == "cifar10": dataset = Cifar10(train=True, noise=False) test_dataset = Cifar10(train=False, noise=False) else: dataset = Imagenet(train=True) test_dataset = Imagenet(train=False) if FLAGS.svhn: dataset = Svhn(train=True) test_dataset = Svhn(train=False) if FLAGS.task == 'latent': dataset = DSprites() test_dataset = dataset dataloader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=True) test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=True) hidden_dim = 128 if FLAGS.large_model: model = ResNet32Large(num_filters=hidden_dim) elif FLAGS.larger_model: model = ResNet32Larger(num_filters=hidden_dim) elif FLAGS.wider_model: if FLAGS.dataset == 'imagenet': model = ResNet32Wider(num_filters=196, train=False) else: model = ResNet32Wider(num_filters=256, train=False) else: model = ResNet32(num_filters=hidden_dim) if FLAGS.task == 'latent': model = DspritesNet() weights = model.construct_weights('context_{}'.format(0)) total_parameters = 0 for variable in tf.compat.v1.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Model has a total of {} parameters".format(total_parameters)) config = tf.compat.v1.ConfigProto() sess = tf.compat.v1.InteractiveSession() if FLAGS.task == 'latent': X = tf.compat.v1.placeholder(shape=(None, 64, 64), dtype=tf.float32) else: X = tf.compat.v1.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) if FLAGS.dataset == "cifar10": Y = tf.compat.v1.placeholder(shape=(None, 10), dtype=tf.float32) Y_GT = tf.compat.v1.placeholder(shape=(None, 10), dtype=tf.float32) elif FLAGS.dataset == "imagenet": Y = tf.compat.v1.placeholder(shape=(None, 1000), dtype=tf.float32) Y_GT = tf.compat.v1.placeholder(shape=(None, 1000), dtype=tf.float32) target_vars = {'X': X, 'Y': Y, 'Y_GT': Y_GT} if FLAGS.task == 'label': construct_label(weights, X, Y, Y_GT, model, target_vars) elif FLAGS.task == 'labelfinetune': construct_finetune_label( weights, X, Y, Y_GT, model, target_vars, ) elif FLAGS.task == 'energyeval' or FLAGS.task == 'mixenergy': construct_energy(weights, X, Y, Y_GT, model, target_vars) elif FLAGS.task == 'anticorrupt' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'crossclass' or FLAGS.task == 'cycleclass' or FLAGS.task == 'democlass' or FLAGS.task == 'nearestneighbor': construct_steps(weights, X, Y_GT, model, target_vars) elif FLAGS.task == 'latent': construct_latent(weights, X, Y_GT, model, target_vars) sess.run(tf.compat.v1.global_variables_initializer()) saver = loader = tf.compat.v1.train.Saver(max_to_keep=10) savedir = osp.join('cachedir', FLAGS.exp) logdir = osp.join(FLAGS.logdir, FLAGS.exp) if not osp.exists(logdir): os.makedirs(logdir) initialize() if FLAGS.resume_iter != -1: model_file = osp.join(savedir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter if FLAGS.task == 'label' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'labelfinetune' or FLAGS.task == "energyeval" or FLAGS.task == "crossclass" or FLAGS.task == "mixenergy": optimistic_restore(sess, model_file) # saver.restore(sess, model_file) else: # optimistic_restore(sess, model_file) saver.restore(sess, model_file) if FLAGS.task == 'label': if FLAGS.labelgrid: vals = [] if FLAGS.lnorm == -1: for i in range(31): accuracies = label(dataloader, test_dataloader, target_vars, sess, l1val=i) vals.append(accuracies) elif FLAGS.lnorm == 2: for i in range(0, 100, 5): accuracies = label(dataloader, test_dataloader, target_vars, sess, l2val=i) vals.append(accuracies) np.save("result_{}_{}.npy".format(FLAGS.lnorm, FLAGS.exp), vals) else: label(dataloader, test_dataloader, target_vars, sess) elif FLAGS.task == 'labelfinetune': labelfinetune(dataloader, test_dataloader, target_vars, sess, savedir, saver, l1val=FLAGS.lival, l2val=FLAGS.l2val) elif FLAGS.task == 'energyeval': energyeval(dataloader, test_dataloader, target_vars, sess) elif FLAGS.task == 'mixenergy': energyevalmix(dataloader, test_dataloader, target_vars, sess) elif FLAGS.task == 'anticorrupt': anticorrupt(test_dataloader, weights, model, target_vars, logdir, sess) elif FLAGS.task == 'boxcorrupt': # boxcorrupt(test_dataloader, weights, model, target_vars, logdir, sess) boxcorrupt(test_dataloader, dataloader, weights, model, target_vars, logdir, sess) elif FLAGS.task == 'crossclass': crossclass(test_dataloader, weights, model, target_vars, logdir, sess) elif FLAGS.task == 'cycleclass': cycleclass(test_dataloader, weights, model, target_vars, logdir, sess) elif FLAGS.task == 'democlass': democlass(test_dataloader, weights, model, target_vars, logdir, sess) elif FLAGS.task == 'nearestneighbor': # print(dir(dataset)) # print(type(dataset)) nearest_neighbor(dataset.data.train_data / 255, sess, target_vars, logdir) elif FLAGS.task == 'latent': latent(test_dataloader, weights, model, target_vars, sess)
def main(argv=()): del argv batch_size = FLAGS.batch_size_per_gpu * FLAGS.num_gpus data_stream_init = utils.setup_data_stream_genome( "train", batch_size=FLAGS.init_batch_size, image_res=FLAGS.image_res, ) (image_init_batch, class_init_batch, box_init_batch) = data_stream_init data_stream_train = utils.setup_data_stream_genome( "train", batch_size=batch_size, image_res=FLAGS.image_res) (image_train_batch, class_train_batch, box_train_batch) = data_stream_train data_stream_val = utils.setup_data_stream_genome("val", batch_size=batch_size, image_res=FLAGS.image_res) (image_val_batch, class_val_batch, box_val_batch) = data_stream_val def model_template(images, labels, boxes, stage): return models.model_detection(images, labels, boxes, stage) model_factory = tf.make_template("detection", model_template) tf.GLOBAL = {} # Init tf.GLOBAL["init"] = True tf.GLOBAL["dropout"] = 0.0 with tf.device("/cpu:0"): _ = model_factory(image_init_batch, [class_init_batch], box_init_batch, 0) ## Train tf.GLOBAL["init"] = False tf.GLOBAL["dropout"] = 0.5 imgs_train = tf.split(image_train_batch, FLAGS.num_gpus, 0) class_train = tf.split(class_train_batch, FLAGS.num_gpus, 0) boxes_train = tf.split(box_train_batch, FLAGS.num_gpus, 0) min_stage = tf.placeholder(shape=[], dtype=tf.int32) stage_train = tf.random_uniform([], min_stage, 5, dtype=tf.int32) loss_train = 0.0 for i in range(FLAGS.num_gpus): with tf.device("gpu:%i" % i if FLAGS.mode == "gpu" else "/cpu:0"): _, loss = model_factory(imgs_train[i], [class_train[i]], boxes_train[i], stage_train) loss_train = loss_train + loss loss_train /= FLAGS.num_gpus # Optimization learning_rate = tf.Variable(0.0001) update_lr = learning_rate.assign(FLAGS.decay * learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate, 0.95, 0.9995) train_step = optimizer.minimize(loss_train, colocate_gradients_with_ops=True) train_bpd_ph = tf.placeholder(shape=[], dtype=tf.float32) summary_train = { i: tf.summary.scalar("train_bpd_stage%i" % i, train_bpd_ph) for i in range(5) } ## Val tf.GLOBAL["init"] = False tf.GLOBAL["dropout"] = 0.0 imgs_val = tf.split(image_val_batch, FLAGS.num_gpus, 0) class_val = tf.split(class_val_batch, FLAGS.num_gpus, 0) boxes_val = tf.split(box_val_batch, FLAGS.num_gpus, 0) stage_val = tf.random_uniform([], 0, 5, dtype=tf.int32) loss_val = 0.0 label_p_val, point_p_val = [], [] for i in range(FLAGS.num_gpus): with tf.device("gpu:%i" % i if FLAGS.mode == "gpu" else "/cpu:0"): [label_p_v, point_p_v], loss = model_factory(imgs_val[i], [class_val[i]], boxes_val[i], stage_val) loss_val = loss_val + loss label_p_val.append(label_p_v) point_p_val.append(point_p_v) loss_val /= FLAGS.num_gpus label_p_val = [tf.concat(l, axis=0) for l in zip(*label_p_val)] point_p_val = [tf.concat(l, axis=0) for l in zip(*point_p_val)] val_bpd_ph = tf.placeholder(shape=[], dtype=tf.float32) summary_val = { i: tf.summary.scalar("val_bpd_stage%i" % i, val_bpd_ph) for i in range(5) } # Counters global_step, val_step = tf.Variable(1), tf.Variable(1) update_global_step = global_step.assign_add(1) update_val_step = val_step.assign_add(1) ## Inits var_init_1 = [ v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if v.name.find("image_parser") >= 0 ] var_init_2 = [ v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if v.name.find("detector") >= 0 ] var_rest = list( set(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) - set(var_init_1 + var_init_2)) init_ops = [ tf.initialize_variables(v_l) for v_l in [var_init_1, var_init_2, var_rest] ] #### image_summary_placeholder = tf.placeholder(dtype=tf.float32) image_summary_sample_val = tf.summary.image("validation_samples", image_summary_placeholder, max_outputs=256) saver = tf.train.Saver() # tf.get_default_graph().finalize() with tf.Session() as sess: with queues.QueueRunners(sess): default_model_meta = os.path.join(FLAGS.tb_log_dir, "main", "model.ckpt.meta") default_model_file = os.path.join(FLAGS.tb_log_dir, "main", "model.ckpt") rerun = False if tf.gfile.Exists(default_model_meta): print("Model is loading...") saver.restore(sess, default_model_file) rerun = True else: # Initialization (Due to the bug in tensorflow it is split # into multiple steps) _ = [sess.run(init_op) for init_op in init_ops] if FLAGS.use_pretrained: utils.optimistic_restore(sess, "") sess.run(global_step.assign(1)) sess.run(val_step.assign(1)) sess.run(learning_rate.assign(0.0001)) # Summary writers summary_writer_main = tf.summary.FileWriter( "%s/%s" % (FLAGS.tb_log_dir, "main"), sess.graph) # Visalize validation GT if not rerun: (imgs_sample, box_cls_sample, boxes_sample) = sess.run( [image_val_batch, class_val_batch, box_val_batch]) boxes_sample = np.concatenate([ boxes_sample[..., :2][:, None], boxes_sample[..., 2:][:, None] ], 1) imgs_with_box = utils.visualize(imgs_sample, box_cls_sample, boxes_sample, utils.LABEL_MAP) s = sess.run( image_summary_sample_val, {image_summary_placeholder: np.array(imgs_with_box)}) summary_writer_main.add_summary(s, 0) # Run training n_iter_train = (utils.SST_COUNTS["train"] // batch_size if FLAGS.iter_cap <= 0 else FLAGS.iter_cap) n_iter_val = (utils.SST_COUNTS["val"] // batch_size if FLAGS.iter_cap <= 0 else FLAGS.iter_cap) max_iter = FLAGS.num_epochs * n_iter_train buf_loss = defaultdict(list) val_i = 0 while True and (not FLAGS.run_test): # Training step (_, loss_v, stage_v, train_i, val_i) = sess.run([ train_step, loss_train, stage_train, global_step, val_step ], {min_stage: 0}) buf_loss[stage_v].append(loss_v) # Update global counter and learning rate sess.run([update_global_step, update_lr]) # Log training error if train_i % FLAGS.log_training_loss == 0: for i in range(5): s = sess.run(summary_train[i], {train_bpd_ph: np.mean(buf_loss[i])}) summary_writer_main.add_summary(s, train_i) buf_loss = defaultdict(list) # Log val error and visualize samples if train_i % FLAGS.log_val_loss == 0: buf_loss = defaultdict(list) for i in range(n_iter_val): loss_v, stage_v = sess.run([loss_val, stage_val]) buf_loss[stage_v].append(loss_v) for i in range(5): s = sess.run(summary_val[i], {val_bpd_ph: np.mean(buf_loss[i])}) summary_writer_main.add_summary(s, val_i) buf_loss = defaultdict(list) # Sample detections label_np = np.zeros((batch_size, 41)) boxes_np = np.zeros((batch_size, 56, 56, 4)) # stage 0 l = sess.run( label_p_val, { image_val_batch: imgs_sample, class_val_batch: label_np, box_val_batch: boxes_np, stage_val: 0 })[0] l = np.argmax(l, axis=1) label_np[range(batch_size), l] = 1 # stage 1 for ii in range(4): l = sess.run( point_p_val, { image_val_batch: imgs_sample, class_val_batch: label_np, box_val_batch: boxes_np, stage_val: ii + 1 })[ii] l = (l == np.amax(l, axis=(1, 2), keepdims=True)).astype("int32") boxes_np[:, :, :, ii:ii + 1] = l # vis boxes_np = np.concatenate([ boxes_np[..., :2][:, None], boxes_np[..., 2:][:, None] ], 1) imgs_with_box = utils.visualize(imgs_sample, label_np, boxes_np, utils.LABEL_MAP) image_summary_det = tf.summary.image( "detection_samples%i" % val_i, image_summary_placeholder, max_outputs=256) s = sess.run( image_summary_det, {image_summary_placeholder: np.array(imgs_with_box)}) summary_writer_main.add_summary(s, 0) # Save model saver.save( sess, os.path.join(FLAGS.tb_log_dir, "main", "model.ckpt")) saver.save( sess, os.path.join(FLAGS.tb_log_dir, "main", "model%i.ckpt" % val_i)) sess.run([update_val_step]) # Terminate if train_i > max_iter: break if FLAGS.run_test: pass
def train(args): global num_gpu global batch_size num_gpu = 1 batch_size = per_gpu_batch_size * num_gpu lr_init = config.TRAIN.lr_init pwc_lr_init = config.TRAIN.pwc_lr_init record_reader = RecordReader(config.TRAIN.tf_records_path) with tf.variable_scope('learning_rate'): lr_v = tf.Variable(lr_init, trainable=False) opt = tf.train.AdamOptimizer(lr_v, beta1=beta1, beta2=beta2) pwc_lr_v = tf.Variable(pwc_lr_init, trainable=False) pwcnet_opt = tf.train.AdamOptimizer(pwc_lr_v, beta1=beta1, beta2=beta2) vgg_data_dict = np.load(config.TRAIN.vgg19_npy_path, encoding='latin1').item() first_img_t, mid_img_t, end_img_t, s_img_t = record_reader.read_and_decode( ) first_img_t_batch, mid_img_t_batch, end_img_t_batch, s_img_t_batch = tf.train.shuffle_batch( [first_img_t, mid_img_t, end_img_t, s_img_t], batch_size=batch_size, capacity=12000, min_after_dequeue=160, num_threads=4) reuse_all = False tower_grads, tower_pwc_grads = [], [] tower_loss = [] for d in range(0, num_gpu): print("dealing {}th gpu".format(d)) with tf.device('/gpu:%s' % d): with tf.name_scope('%s_%s' % ('tower', d)): print("build model!!!") tot_loss_gpu, summary \ = build_model(first_img_t_batch, mid_img_t_batch, end_img_t_batch, s_img_t_batch, vgg_data_dict, reuse_all=reuse_all) if not reuse_all: vars_trainable = get_variables_with_name( name='stabnet', exclude_name='pwcnet', train_only=True) grads = opt.compute_gradients(tot_loss_gpu, var_list=vars_trainable) pwc_vars_trainable = get_variables_with_name( name='pwcnet', exclude_name='stabnet', train_only=True) pwc_grads = opt.compute_gradients( tot_loss_gpu, var_list=pwc_vars_trainable) for i, (g, v) in enumerate(grads): if g is not None: grads[i] = (tf.clip_by_norm(g, 5), v) for i, (g, v) in enumerate(pwc_grads): if g is not None: pwc_grads[i] = (tf.clip_by_norm(g, 5), v) tower_grads.append(grads) tower_pwc_grads.append(pwc_grads) tower_loss.append(tot_loss_gpu) reuse_all = True if num_gpu == 1: with tf.device('/gpu:0'): mse_loss = tf.reduce_mean(tf.stack(tower_loss, 0), 0) mean_grads = average_gradients(tower_grads) mean_pwc_grads = average_gradients(tower_pwc_grads) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='.*?stabnet') update_pwc_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='.*?pwcnet') with tf.control_dependencies(update_ops): minimize_op = opt.apply_gradients(mean_grads) with tf.control_dependencies(update_pwc_ops): minimize_pwc_op = pwcnet_opt.apply_gradients( mean_pwc_grads) else: mse_loss = tf.reduce_mean(tf.stack(tower_loss, 0), 0) mean_grads = average_gradients(tower_grads) mean_pwc_grads = average_gradients(tower_pwc_grads) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='.*?stabnet') update_pwc_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='.*?pwcnet') with tf.control_dependencies(update_ops): minimize_op = opt.apply_gradients(mean_grads) with tf.control_dependencies(update_pwc_ops): minimize_pwc_op = pwcnet_opt.apply_gradients(mean_pwc_grads) print('trainable variables:') print(vars_trainable) print('pwc trainable variables:') print(pwc_vars_trainable) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) if debug: sess = tf_debug.LocalCLIDebugWrapperSession(sess) # sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(max_to_keep=200) lr_str = timestamp + ' ' + get_config(config) + ',gn:{}'.format(num_gpu) if not os.path.exists(checkpoint_path + lr_str): os.makedirs(checkpoint_path + lr_str) if args.pretrained: print('restore path from : ', checkpoint_path + args.lr_str + '/stab.ckpt-' + str(args.modeli)) saver.restore( sess, checkpoint_path + args.lr_str + '/stab.ckpt-' + str(args.modeli)) summary_ops = tf.summary.merge(summary) summary_writer = tf.summary.FileWriter( checkpoint_path + lr_str + '/summary', sess.graph) len_train = config.TRAIN.len_train n_epoch = 250 init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) optimistic_restore(sess, pwc_opt.ckpt_path) for epoch in range(0, n_epoch): if epoch < pwc_freeze_epoch: pwc_lr_init = 0.0 sess.run(tf.assign(pwc_lr_v, pwc_lr_init)) # freeze the optical flow net log = ' ** pwc net new learning rate: %f ' % (pwc_lr_init) print(log) if epoch >= pwc_freeze_epoch: pwc_lr_init = config.TRAIN.pwc_lr_init cur_lr = pwc_lr_init sess.run(tf.assign(pwc_lr_v, cur_lr)) log = ' ** pwc net new learning rate: %f ' % (cur_lr) print(log) if epoch >= pwc_lr_stable_epoch: pwc_lr_init = config.TRAIN.pwc_lr_init cur_lr = linear_lr(pwc_lr_init, pwc_decay_ratio, epoch - pwc_lr_stable_epoch) sess.run(tf.assign(pwc_lr_v, cur_lr)) log = ' ** pwc net new learning rate: %f ' % (cur_lr) print(log) if epoch >= lr_stable_epoch: lr_init = config.TRAIN.lr_init cur_lr = linear_lr(lr_init, decay_ratio, epoch - lr_stable_epoch) sess.run(tf.assign(lr_v, cur_lr)) log = ' ** stab net new learning rate: %f' % (cur_lr) print(log) sys.stdout.flush() epoch_time = time.time() for it in range(int(len_train / batch_size)): errM, _, _, summary = sess.run( [mse_loss, minimize_op, minimize_pwc_op, summary_ops]) if (it + int(len_train / batch_size) * epoch) % 10 == 0: summary_writer.add_summary( summary, it + int(len_train / batch_size) * epoch) print("Epoch [%2d/%2d] %4d time: %4.4fs, loss: %5.5f" % (epoch, n_epoch, it, time.time() - epoch_time, errM)) sys.stdout.flush() epoch_time = time.time() if (it + int(len_train / batch_size) * epoch) % 1000 == 0: saver.save(sess, checkpoint_path + lr_str + '/stab.ckpt', global_step=(it + int(len_train / batch_size) * epoch)) coord.request_stop() coord.join(threads) sess.close()
def main(): parser = argparse.ArgumentParser() parser.add_argument('name', nargs='*') parser.add_argument('--eval', dest='eval_only', action='store_true') parser.add_argument('--test', action='store_true') parser.add_argument('--resume', nargs='*') args = parser.parse_args() if args.test: args.eval_only = True src = open('model.py').read() if args.name: name = ' '.join(args.name) else: from datetime import datetime name = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") target_name = os.path.join('logs', '{}'.format(name)) writer.add_text('Log Name:', name) if not args.test: # target_name won't be used in test mode print('will save to {}'.format(target_name)) if args.resume: logs = torch.load(' '.join(args.resume)) # hacky way to tell the VQA classes that they should use the vocab without passing more params around #data.preloaded_vocab = logs['vocab'] cudnn.benchmark = True if not args.eval_only: train_loader = data.get_loader(train=True) if not args.test: val_loader = data.get_loader(val=True) else: val_loader = data.get_loader(test=True) net = model.Net(val_loader.dataset.num_tokens).cuda() # restore transfer learning # 'data/vgrel-29.tar' for 36 # 'data/vgrel-19.tar' for 10-100 if config.output_size == 36: print("load data/vgrel-29(transfer36).tar") ckpt = torch.load('data/vgrel-29(transfer36).tar') else: print("load data/vgrel-19(transfer110).tar") ckpt = torch.load('data/vgrel-19(transfer110).tar') utils.optimistic_restore(net.tree_lstm.gen_tree_net, ckpt['state_dict']) if config.use_rl: for p in net.parameters(): p.requires_grad = False for p in net.tree_lstm.gen_tree_net.parameters(): p.requires_grad = True optimizer = optim.Adam([p for p in net.parameters() if p.requires_grad], lr=config.initial_lr) scheduler = lr_scheduler.ExponentialLR(optimizer, 0.5**(1 / config.lr_halflife)) start_epoch = 0 if args.resume: net.load_state_dict(logs['weights']) #optimizer.load_state_dict(logs['optimizer']) start_epoch = int(logs['epoch']) + 1 tracker = utils.Tracker() config_as_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')} print(config_as_dict) best_accuracy = -1 for i in range(start_epoch, config.epochs): if not args.eval_only: run(net, train_loader, optimizer, scheduler, tracker, train=True, prefix='train', epoch=i) if i % 1 != 0 or (i > 0 and i <20): r = [[-1], [-1], [-1]] else: r = run(net, val_loader, optimizer, scheduler, tracker, train=False, prefix='val', epoch=i, has_answers=not args.test) if not args.test: results = { 'name': name, 'tracker': tracker.to_dict(), 'config': config_as_dict, 'weights': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': i, 'eval': { 'answers': r[0], 'accuracies': r[1], 'idx': r[2], }, 'vocab': val_loader.dataset.vocab, 'src': src, 'setting': exp_setting, } current_ac = sum(r[1]) / len(r[1]) if current_ac > best_accuracy: best_accuracy = current_ac print('update best model, current: ', current_ac) torch.save(results, target_name + '_best.pth') if i % 1 == 0: torch.save(results, target_name + '_' + str(i) + '.pth') else: # in test mode, save a results file in the format accepted by the submission server answer_index_to_string = {a: s for s, a in val_loader.dataset.answer_to_index.items()} results = [] for answer, index in zip(r[0], r[2]): answer = answer_index_to_string[answer.item()] qid = val_loader.dataset.question_ids[index] entry = { 'question_id': qid, 'answer': answer, } results.append(entry) with open('results.json', 'w') as fd: json.dump(results, fd) if args.eval_only: break
img_mid = np.expand_dims(img_mid, 0) img_end = np.expand_dims(img_end, 0) # warped = test_flow_warp(img_first, img_s) img_int, img_out, [warped_first, warped_end, warped_mid] = training_stab_model(img_first, img_s, img_end, img_mid, training=False) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) sess.run(tf.global_variables_initializer()) optimistic_restore(sess, pwc_opt.ckpt_path) #必须放在下面,否则会被覆盖 [warped_first, warped_end] = sess.run([warped_first, warped_end]) # import pdb; pdb.set_trace(); warped_first = warped_first[0][:, :, ::-1] warped_end = warped_end[0][:, :, ::-1] # warped_np = np.clip(warped_np, 0, 1.) cv2.imwrite('warped_first.png', np.array(warped_first * 255).astype(np.uint8)) cv2.imwrite('warped_end.png', np.array(warped_end * 255).astype(np.uint8)) # test_training_model() # test_testing_model() # input_tensor_batch = tf.random_uniform(shape=[2, 16, 16, 3]) # # out = resnet(input_tensor_batch, 5) # out = make_unet(input_tensor_batch)