def main(): print(' '.join(sys.argv)) args = get_args() print(args) print('Hostname: {}'.format(socket.gethostname())) print('GPU: {}'.format(get_gpuname())) if not args.start_from_ckpt: if tf.gfile.Exists(args.log_dir): tf.gfile.DeleteRecursively(args.log_dir) tf.gfile.MakeDirs(args.log_dir) tf.get_variable_scope()._reuse = None _seed = args.base_seed + args.add_seed tf.set_random_seed(_seed) np.random.seed(_seed) prefix_name = os.path.join(args.log_dir, 'model') file_name = '%s.npz' % prefix_name eval_summary = OrderedDict() tg, sg = build_graph(args) tg_ml_cost = tf.reduce_mean(tg.ml_cost) global_step = tf.Variable(0, trainable=False, name="global_step") tvars = tf.trainable_variables() print([tvar.name for tvar in tvars]) ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate, beta1=0.9, beta2=0.99) rl_opt_func = tf.train.AdamOptimizer(learning_rate=args.rl_learning_rate, beta1=0.9, beta2=0.99) if args.grad_clip: ml_grads, _ = tf.clip_by_global_norm(tf.gradients(tg_ml_cost, tvars), clip_norm=1.0) else: ml_grads = tf.gradients(tg_ml_cost, tvars) ml_op = ml_opt_func.apply_gradients(zip(ml_grads, tvars), global_step=global_step) tg_rl_cost = tf.reduce_mean(tg.rl_cost) rl_grads = tf.gradients(tg_rl_cost, tvars) # do not increase global step -- ml op increases it rl_op = rl_opt_func.apply_gradients(zip(rl_grads, tvars)) tf.add_to_collection('n_fast_action', args.n_fast_action) sync_data(args) datasets = [args.train_dataset, args.valid_dataset, args.test_dataset] train_set, valid_set, test_set = [ create_ivector_datastream(path=args.data_path, which_set=dataset, batch_size=args.n_batch, min_after_cache=args.min_after_cache, length_sort=not args.no_length_sort) for dataset in datasets ] init_op = tf.global_variables_initializer() save_op = tf.train.Saver(max_to_keep=5) best_save_op = tf.train.Saver(max_to_keep=5) with tf.name_scope("per_step_eval"): tr_ce = tf.placeholder(tf.float32) tr_ce_summary = tf.summary.scalar("tr_ce", tr_ce) tr_image = tf.placeholder(tf.float32) tr_image_summary = tf.summary.image("tr_image", tr_image) tr_fer = tf.placeholder(tf.float32) tr_fer_summary = tf.summary.scalar("tr_fer", tr_fer) tr_rl = tf.placeholder(tf.float32) tr_rl_summary = tf.summary.scalar("tr_rl", tr_rl) tr_rw_hist = tf.placeholder(tf.float32) tr_rw_hist_summary = tf.summary.histogram("tr_reward_hist", tr_rw_hist) with tf.name_scope("per_epoch_eval"): best_val_ce = tf.placeholder(tf.float32) val_ce = tf.placeholder(tf.float32) best_val_ce_summary = tf.summary.scalar("best_valid_ce", best_val_ce) val_ce_summary = tf.summary.scalar("valid_ce", val_ce) vf = LinearVF() with tf.Session() as sess: sess.run(init_op) if args.start_from_ckpt: save_op = tf.train.import_meta_graph( os.path.join(args.log_dir, 'model.ckpt.meta')) save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt')) print( "Restore from the last checkpoint. Restarting from %d step." % global_step.eval()) summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph, flush_secs=5.0) tr_ce_sum = 0. tr_ce_count = 0 tr_acc_sum = 0 tr_acc_count = 0 tr_rl_costs = [] tr_action_entropies = [] tr_rewards = [] _best_score = np.iinfo(np.int32).max epoch_sw = StopWatch() disp_sw = StopWatch() eval_sw = StopWatch() per_sw = StopWatch() # For each epoch for _epoch in xrange(args.n_epoch): _n_exp = 0 epoch_sw.reset() disp_sw.reset() print('Epoch {} training'.format(_epoch + 1)) # For each batch for batch in train_set.get_epoch_iterator(): x, x_mask, _, _, y, _ = batch n_batch = x.shape[0] _n_exp += n_batch new_x, new_y, actions_1hot, rewards, action_entropies, new_x_mask, new_reward_mask, output_image = \ gen_episode_with_seg_reward(x, x_mask, y, sess, sg, args) advantages, _ = compute_advantage(new_x, new_x_mask, rewards, new_reward_mask, vf, args) zero_state = gen_zero_state(n_batch, args.n_hidden) feed_dict = { tg.seq_x_data: new_x, tg.seq_x_mask: new_x_mask, tg.seq_y_data: new_y, tg.seq_action: actions_1hot, tg.seq_advantage: advantages, tg.seq_action_mask: new_reward_mask, tg.seq_y_data_for_action: new_y } feed_init_state(feed_dict, tg.init_state, zero_state) _tr_ml_cost, _tr_rl_cost, _, _, pred_idx = \ sess.run([tg.ml_cost, tg.rl_cost, ml_op, rl_op, tg.pred_idx], feed_dict=feed_dict) tr_ce_sum += _tr_ml_cost.sum() tr_ce_count += new_x_mask.sum() pred_idx = expand_output(actions_1hot, x_mask, new_x_mask, pred_idx.reshape([n_batch, -1]), args.n_fast_action) tr_acc_sum += ((pred_idx == y) * x_mask).sum() tr_acc_count += x_mask.sum() _tr_ce_summary, _tr_fer_summary, _tr_rl_summary, _tr_image_summary, _tr_rw_hist_summary = \ sess.run([tr_ce_summary, tr_fer_summary, tr_rl_summary, tr_image_summary, tr_rw_hist_summary], feed_dict={tr_ce: _tr_ml_cost.sum() / new_x_mask.sum(), tr_fer: ((pred_idx == y) * x_mask).sum() / x_mask.sum(), tr_rl: _tr_rl_cost.sum() / new_reward_mask.sum(), tr_image: output_image, tr_rw_hist: rewards}) summary_writer.add_summary(_tr_ce_summary, global_step.eval()) summary_writer.add_summary(_tr_fer_summary, global_step.eval()) summary_writer.add_summary(_tr_rl_summary, global_step.eval()) summary_writer.add_summary(_tr_image_summary, global_step.eval()) summary_writer.add_summary(_tr_rw_hist_summary, global_step.eval()) tr_rl_costs.append(_tr_rl_cost.sum() / new_reward_mask.sum()) tr_action_entropies.append(action_entropies.sum() / new_reward_mask.sum()) tr_rewards.append(rewards.sum() / new_reward_mask.sum()) if global_step.eval() % args.display_freq == 0: avg_tr_ce = tr_ce_sum / tr_ce_count avg_tr_fer = 1. - tr_acc_sum / tr_acc_count avg_tr_rl_cost = np.asarray(tr_rl_costs).mean() avg_tr_action_entropy = np.asarray( tr_action_entropies).mean() avg_tr_reward = np.asarray(tr_rewards).mean() print( "TRAIN: epoch={} iter={} ml_cost(ce/frame)={:.2f} fer={:.2f} rl_cost={:.4f} reward={:.4f} action_entropy={:.2f} time_taken={:.2f}" .format(_epoch, global_step.eval(), avg_tr_ce, avg_tr_fer, avg_tr_rl_cost, avg_tr_reward, avg_tr_action_entropy, disp_sw.elapsed())) tr_ce_sum = 0. tr_ce_count = 0 tr_acc_sum = 0. tr_acc_count = 0 tr_rl_costs = [] tr_action_entropies = [] tr_rewards = [] disp_sw.reset() print('--') print('End of epoch {}'.format(_epoch + 1)) epoch_sw.print_elapsed() print('Testing') # Evaluate the model on the validation set val_ce_sum = 0. val_ce_count = 0 val_acc_sum = 0 val_acc_count = 0 val_rl_costs = [] val_action_entropies = [] val_rewards = [] eval_sw.reset() for batch in valid_set.get_epoch_iterator(): x, x_mask, _, _, y, _ = batch n_batch = x.shape[0] new_x, new_y, actions_1hot, rewards, action_entropies, new_x_mask, new_reward_mask, output_image, new_y_sample = \ gen_episode_with_seg_reward(x, x_mask, y, sess, sg, args, sample_y=True) advantages, _ = compute_advantage(new_x, new_x_mask, rewards, new_reward_mask, vf, args) zero_state = gen_zero_state(n_batch, args.n_hidden) feed_dict = { tg.seq_x_data: new_x, tg.seq_x_mask: new_x_mask, tg.seq_y_data: new_y, tg.seq_action: actions_1hot, tg.seq_advantage: advantages, tg.seq_action_mask: new_reward_mask, tg.seq_y_data_for_action: new_y_sample } feed_init_state(feed_dict, tg.init_state, zero_state) _val_ml_cost, _val_rl_cost, pred_idx = sess.run( [tg.ml_cost, tg.rl_cost, tg.pred_idx], feed_dict=feed_dict) val_ce_sum += _val_ml_cost.sum() val_ce_count += new_x_mask.sum() pred_idx = expand_output(actions_1hot, x_mask, new_x_mask, pred_idx.reshape([n_batch, -1]), args.n_fast_action) val_acc_sum += ((pred_idx == y) * x_mask).sum() val_acc_count += x_mask.sum() val_rl_costs.append(_val_rl_cost.sum() / new_reward_mask.sum()) val_action_entropies.append(action_entropies.sum() / new_reward_mask.sum()) val_rewards.append(rewards.sum() / new_reward_mask.sum()) avg_val_ce = val_ce_sum / val_ce_count avg_val_fer = 1. - val_acc_sum / val_acc_count avg_val_rl_cost = np.asarray(val_rl_costs).mean() avg_val_action_entropy = np.asarray(val_action_entropies).mean() avg_val_reward = np.asarray(val_rewards).mean() print( "VALID: epoch={} ml_cost(ce/frame)={:.2f} fer={:.2f} rl_cost={:.4f} reward={:.4f} action_entropy={:.2f} time_taken={:.2f}" .format(_epoch, avg_val_ce, avg_val_fer, avg_val_rl_cost, avg_val_reward, avg_val_action_entropy, eval_sw.elapsed())) _val_ce_summary, = sess.run([val_ce_summary], feed_dict={val_ce: avg_val_ce}) summary_writer.add_summary(_val_ce_summary, global_step.eval()) insert_item2dict(eval_summary, 'val_ce', avg_val_ce) insert_item2dict(eval_summary, 'val_rl_cost', avg_val_rl_cost) insert_item2dict(eval_summary, 'val_reward', avg_val_reward) insert_item2dict(eval_summary, 'val_action_entropy', avg_val_action_entropy) insert_item2dict(eval_summary, 'time', eval_sw.elapsed()) save_npz2(file_name, eval_summary) # Save model if avg_val_ce < _best_score: _best_score = avg_val_ce best_ckpt = best_save_op.save(sess, os.path.join( args.log_dir, "best_model.ckpt"), global_step=global_step) print("Best checkpoint stored in: %s" % best_ckpt) ckpt = save_op.save(sess, os.path.join(args.log_dir, "model.ckpt"), global_step=global_step) print("Checkpoint stored in: %s" % ckpt) _best_val_ce_summary, = sess.run( [best_val_ce_summary], feed_dict={best_val_ce: _best_score}) summary_writer.add_summary(_best_val_ce_summary, global_step.eval()) summary_writer.close() print("Optimization Finished.")
def main(_): print(' '.join(sys.argv)) args = FLAGS print(args.__flags) if not args.start_from_ckpt: if tf.gfile.Exists(args.log_dir): tf.gfile.DeleteRecursively(args.log_dir) tf.gfile.MakeDirs(args.log_dir) tf.get_variable_scope()._reuse = None _seed = args.base_seed + args.add_seed tf.set_random_seed(_seed) np.random.seed(_seed) prefix_name = os.path.join(args.log_dir, 'model') file_name = '%s.npz' % prefix_name eval_summary = OrderedDict() # tg, test_graph = build_graph(args) tg_ml_cost = tf.reduce_mean(tg.ml_cost) global_step = tf.Variable(0, trainable=False, name="global_step") tvars = tf.trainable_variables() ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate, beta1=0.9, beta2=0.99) if args.grad_clip: ml_grads, _ = tf.clip_by_global_norm(tf.gradients(tg_ml_cost, tvars), clip_norm=1.0) else: ml_grads = tf.gradients(tg_ml_cost, tvars) ml_op = ml_opt_func.apply_gradients(zip(ml_grads, tvars), global_step=global_step) tf.add_to_collection('n_skip', args.n_skip) tf.add_to_collection('n_hidden', args.n_hidden) sync_data(args) datasets = [args.train_dataset, args.valid_dataset, args.test_dataset] train_set, valid_set, test_set = [ create_ivector_datastream(path=args.data_path, which_set=dataset, batch_size=args.batch_size, min_after_cache=args.min_after_cache, length_sort=not args.no_length_sort) for dataset in datasets ] init_op = tf.global_variables_initializer() save_op = tf.train.Saver(max_to_keep=5) best_save_op = tf.train.Saver(max_to_keep=5) with tf.name_scope("per_step_eval"): tr_ce = tf.placeholder(tf.float32) tr_ce_summary = tf.summary.scalar("tr_ce", tr_ce) with tf.name_scope("per_epoch_eval"): best_val_ce = tf.placeholder(tf.float32) val_ce = tf.placeholder(tf.float32) best_val_ce_summary = tf.summary.scalar("best_valid_ce", best_val_ce) val_ce_summary = tf.summary.scalar("valid_ce", val_ce) with tf.Session() as sess: sess.run(init_op) if args.start_from_ckpt: save_op = tf.train.import_meta_graph( os.path.join(args.log_dir, 'model.ckpt.meta')) save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt')) print("Restore from the last checkpoint. " "Restarting from %d step." % global_step.eval()) summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph, flush_secs=5.0) tr_ce_sum = 0. tr_ce_count = 0 tr_acc_sum = 0 tr_acc_count = 0 _best_score = np.iinfo(np.int32).max epoch_sw = StopWatch() disp_sw = StopWatch() eval_sw = StopWatch() per_sw = StopWatch() # For each epoch for _epoch in xrange(args.n_epoch): _n_exp = 0 epoch_sw.reset() disp_sw.reset() print('--') print('Epoch {} training'.format(_epoch + 1)) # For each batch for batch in train_set.get_epoch_iterator(): orig_x, orig_x_mask, _, _, orig_y, _ = batch for sub_batch in skip_frames_fixed( [orig_x, orig_x_mask, orig_y], args.n_skip + 1): x, x_mask, y = sub_batch n_batch, _, _ = x.shape _n_exp += n_batch zero_state = gen_zero_state(n_batch, args.n_hidden) feed_dict = { tg.seq_x_data: x, tg.seq_x_mask: x_mask, tg.seq_y_data: y } feed_init_state(feed_dict, tg.init_state, zero_state) _tr_ml_cost, _pred_idx, _ = sess.run( [tg.ml_cost, tg.pred_idx, ml_op], feed_dict=feed_dict) tr_ce_sum += _tr_ml_cost.sum() tr_ce_count += x_mask.sum() _tr_ce_summary, = sess.run( [tr_ce_summary], feed_dict={tr_ce: _tr_ml_cost.sum() / x_mask.sum()}) summary_writer.add_summary(_tr_ce_summary, global_step.eval()) _, n_seq = orig_y.shape _pred_idx = _pred_idx.reshape([n_batch, -1]).repeat(args.n_skip + 1, axis=1) _pred_idx = _pred_idx[:, :n_seq] tr_acc_sum += ((_pred_idx == orig_y) * orig_x_mask).sum() tr_acc_count += orig_x_mask.sum() if global_step.eval() % args.display_freq == 0: avg_tr_ce = tr_ce_sum / tr_ce_count avg_tr_fer = 1. - float(tr_acc_sum) / tr_acc_count print( "TRAIN: epoch={} iter={} ml_cost(ce/frame)={:.2f} fer={:.2f} time_taken={:.2f}" .format(_epoch, global_step.eval(), avg_tr_ce, avg_tr_fer, disp_sw.elapsed())) tr_ce_sum = 0. tr_ce_count = 0 tr_acc_sum = 0 tr_acc_count = 0 disp_sw.reset() print('--') print('End of epoch {}'.format(_epoch + 1)) epoch_sw.print_elapsed() print('Testing') # Evaluate the model on the validation set val_ce_sum = 0. val_ce_count = 0 val_acc_sum = 0 val_acc_count = 0 eval_sw.reset() for batch in valid_set.get_epoch_iterator(): orig_x, orig_x_mask, _, _, orig_y, _ = batch for sub_batch in skip_frames_fixed( [orig_x, orig_x_mask, orig_y], args.n_skip + 1, return_first=True): x, x_mask, y = sub_batch n_batch, _, _ = x.shape zero_state = gen_zero_state(n_batch, args.n_hidden) feed_dict = { tg.seq_x_data: x, tg.seq_x_mask: x_mask, tg.seq_y_data: y } feed_init_state(feed_dict, tg.init_state, zero_state) _val_ml_cost, _pred_idx = sess.run( [tg.ml_cost, tg.pred_idx], feed_dict=feed_dict) val_ce_sum += _val_ml_cost.sum() val_ce_count += x_mask.sum() _, n_seq = orig_y.shape _pred_idx = _pred_idx.reshape([n_batch, -1]).repeat(args.n_skip + 1, axis=1) _pred_idx = _pred_idx[:, :n_seq] val_acc_sum += ((_pred_idx == orig_y) * orig_x_mask).sum() val_acc_count += orig_x_mask.sum() avg_val_ce = val_ce_sum / val_ce_count avg_val_fer = 1. - float(val_acc_sum) / val_acc_count print( "VALID: epoch={} ml_cost(ce/frame)={:.2f} fer={:.2f} time_taken={:.2f}" .format(_epoch, avg_val_ce, avg_val_fer, eval_sw.elapsed())) _val_ce_summary, = sess.run([val_ce_summary], feed_dict={val_ce: avg_val_ce}) summary_writer.add_summary(_val_ce_summary, global_step.eval()) insert_item2dict(eval_summary, 'val_ce', avg_val_ce) insert_item2dict(eval_summary, 'time', eval_sw.elapsed()) save_npz2(file_name, eval_summary) # Save model if avg_val_ce < _best_score: _best_score = avg_val_ce best_ckpt = best_save_op.save(sess, os.path.join( args.log_dir, "best_model.ckpt"), global_step=global_step) print("Best checkpoint stored in: %s" % best_ckpt) ckpt = save_op.save(sess, os.path.join(args.log_dir, "model.ckpt"), global_step=global_step) print("Checkpoint stored in: %s" % ckpt) _best_val_ce_summary, = sess.run( [best_val_ce_summary], feed_dict={best_val_ce: _best_score}) summary_writer.add_summary(_best_val_ce_summary, global_step.eval()) summary_writer.close() print("Optimization Finished.")
sub_batch, start_idx = utils.skip_frames_fixed( [orig_x, orig_x_mask, orig_y], args.n_skip + 1, return_start_idx=True) x, x_mask, y = sub_batch n_batch, _, _ = x.shape zero_state = gen_zero_state(n_batch, args.n_hidden) feed_dict = { tg.seq_x_data: x, tg.seq_x_mask: x_mask, tg.seq_y_data: y } feed_init_state(feed_dict, tg.init_state, zero_state) ml_cost, _ = sess.run([tg.ml_cost, ml_op], feed_dict=feed_dict) orig_count, comp_count = orig_x_mask.sum(), x_mask.sum() ce.add(ml_cost.sum(), comp_count) cr.add(float(comp_count) / orig_count, 1) if global_step.eval() % args.display_freq == 0: print( "TRAIN: epoch={} iter={} ml_cost(ce/frame)={:.3f} compression={:.2f} time_taken={:.2f}" .format(_epoch, global_step.eval(), ce.avg(), cr.avg(), disp_sw.elapsed())) output_image = mixer.gen_output_image_subsample( orig_x, orig_y, args.n_skip, start_idx)
def main(_): print(' '.join(sys.argv)) args = FLAGS print(args.__flags) print('Hostname: {}'.format(socket.gethostname())) print('GPU: {}'.format(get_gpuname())) if not args.start_from_ckpt: if tf.gfile.Exists(args.log_dir): tf.gfile.DeleteRecursively(args.log_dir) tf.gfile.MakeDirs(args.log_dir) tf.get_variable_scope()._reuse = None _seed = args.base_seed + args.add_seed tf.set_random_seed(_seed) np.random.seed(_seed) prefix_name = os.path.join(args.log_dir, 'model') file_name = '%s.npz' % prefix_name eval_summary = OrderedDict() tg, test_graph = build_graph(args) tg_ml_cost = tf.reduce_mean(tg.ml_cost) global_step = tf.Variable(0, trainable=False, name="global_step") tvars = tf.trainable_variables() print([tvar.name for tvar in tvars]) ml_tvars = [tvar for tvar in tvars if "action_logit" not in tvar.name] rl_tvars = [tvar for tvar in tvars if "action_logit" in tvar.name] ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate, beta1=0.9, beta2=0.99) rl_opt_func = tf.train.AdamOptimizer(learning_rate=args.rl_learning_rate, beta1=0.9, beta2=0.99) if args.grad_clip: ml_grads, _ = tf.clip_by_global_norm(tf.gradients( tg_ml_cost, ml_tvars), clip_norm=1.0) else: ml_grads = tf.gradients(tg_ml_cost, ml_tvars) ml_op = ml_opt_func.apply_gradients(zip(ml_grads, ml_tvars), global_step=global_step) tg_rl_cost = tf.reduce_mean(tg.rl_cost) rl_grads = tf.gradients(tg_rl_cost, rl_tvars) rl_op = rl_opt_func.apply_gradients(zip(rl_grads, rl_tvars)) tf.add_to_collection('fast_action', args.fast_action) tf.add_to_collection('fast_action', args.n_fast_action) sync_data(args) datasets = [args.train_dataset, args.valid_dataset, args.test_dataset] train_set, valid_set, test_set = [ create_ivector_datastream(path=args.data_path, which_set=dataset, batch_size=args.n_batch, min_after_cache=args.min_after_cache, length_sort=not args.no_length_sort) for dataset in datasets ] init_op = tf.global_variables_initializer() save_op = tf.train.Saver(max_to_keep=5) best_save_op = tf.train.Saver(max_to_keep=5) with tf.name_scope("per_step_eval"): tr_ce = tf.placeholder(tf.float32) tr_ce_summary = tf.summary.scalar("tr_ce", tr_ce) tr_fer = tf.placeholder(tf.float32) tr_fer_summary = tf.summary.scalar("tr_fer", tr_fer) tr_ce2 = tf.placeholder(tf.float32) tr_ce2_summary = tf.summary.scalar("tr_rl", tr_ce2) tr_image = tf.placeholder(tf.float32) tr_image_summary = tf.summary.image("tr_image", tr_image) with tf.name_scope("per_epoch_eval"): val_fer = tf.placeholder(tf.float32) val_fer_summary = tf.summary.scalar("val_fer", val_fer) best_val_fer = tf.placeholder(tf.float32) best_val_fer_summary = tf.summary.scalar("best_valid_fer", best_val_fer) val_image = tf.placeholder(tf.float32) val_image_summary = tf.summary.image("val_image", val_image) vf = LinearVF() with tf.Session() as sess: sess.run(init_op) if args.start_from_ckpt: save_op = tf.train.import_meta_graph( os.path.join(args.log_dir, 'model.ckpt.meta')) save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt')) print( "Restore from the last checkpoint. Restarting from %d step." % global_step.eval()) summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph, flush_secs=5.0) tr_ce_sum = 0. tr_ce_count = 0 tr_acc_sum = 0 tr_acc_count = 0 tr_ce2_sum = 0. tr_ce2_count = 0 _best_score = np.iinfo(np.int32).max epoch_sw = StopWatch() disp_sw = StopWatch() eval_sw = StopWatch() per_sw = StopWatch() # For each epoch for _epoch in xrange(args.n_epoch): _n_exp = 0 epoch_sw.reset() disp_sw.reset() print('Epoch {} training'.format(_epoch + 1)) # For each batch for batch in train_set.get_epoch_iterator(): x, x_mask, _, _, y, _ = batch n_batch = x.shape[0] _n_exp += n_batch if args.no_sampling: new_x, new_y, actions, actions_1hot, new_x_mask = gen_supervision( x, x_mask, y, args) zero_state = gen_zero_state(n_batch, args.n_hidden) feed_dict = { tg.seq_x_data: new_x, tg.seq_x_mask: new_x_mask, tg.seq_y_data: new_y, tg.seq_jump_data: actions } feed_init_state(feed_dict, tg.init_state, zero_state) _tr_ml_cost, _tr_rl_cost, _, _ = \ sess.run([tg.ml_cost, tg.rl_cost, ml_op, rl_op], feed_dict=feed_dict) tr_ce_sum += _tr_ml_cost.sum() tr_ce_count += new_x_mask.sum() tr_ce2_sum += _tr_rl_cost.sum() tr_ce2_count += new_x_mask[:, :-1].sum() actions_1hot, label_probs, new_mask, output_image = \ skip_rnn_forward_supervised(x, x_mask, sess, test_graph, args.fast_action, args.n_fast_action, y) pred_idx = expand_output(actions_1hot, x_mask, new_mask, label_probs.argmax(axis=-1)) tr_acc_sum += ((pred_idx == y) * x_mask).sum() tr_acc_count += x_mask.sum() _tr_ce_summary, _tr_fer_summary, _tr_ce2_summary, _tr_image_summary = \ sess.run([tr_ce_summary, tr_fer_summary, tr_ce2_summary, tr_image_summary], feed_dict={tr_ce: _tr_ml_cost.sum() / new_x_mask.sum(), tr_fer: 1 - ((pred_idx == y) * x_mask).sum() / x_mask.sum(), tr_ce2: _tr_rl_cost.sum() / new_x_mask[:,:-1].sum(), tr_image: output_image}) summary_writer.add_summary(_tr_ce_summary, global_step.eval()) summary_writer.add_summary(_tr_fer_summary, global_step.eval()) summary_writer.add_summary(_tr_ce2_summary, global_step.eval()) summary_writer.add_summary(_tr_image_summary, global_step.eval()) else: # train jump prediction part new_x, _, actions, _, new_x_mask = gen_supervision( x, x_mask, y, args) zero_state = gen_zero_state(n_batch, args.n_hidden) feed_dict = { tg.seq_x_data: new_x, tg.seq_x_mask: new_x_mask, tg.seq_jump_data: actions } feed_init_state(feed_dict, tg.init_state, zero_state) _tr_rl_cost, _ = sess.run([tg.rl_cost, rl_op], feed_dict=feed_dict) tr_ce2_sum += _tr_rl_cost.sum() tr_ce2_count += new_x_mask[:, :-1].sum() _tr_ce2_summary, = sess.run([tr_ce2_summary], feed_dict={ tr_ce2: _tr_rl_cost.sum() / new_x_mask[:, :-1].sum() }) # generate jumps from the model new_x, new_y, actions_1hot, label_probs, new_x_mask, output_image = gen_episode_supervised( x, y, x_mask, sess, test_graph, args.fast_action, args.n_fast_action) feed_dict = { tg.seq_x_data: new_x, tg.seq_x_mask: new_x_mask, tg.seq_y_data: new_y } feed_init_state(feed_dict, tg.init_state, zero_state) # train label prediction part _tr_ml_cost, _ = sess.run([tg.ml_cost, ml_op], feed_dict=feed_dict) tr_ce_sum += _tr_ml_cost.sum() tr_ce_count += new_x_mask.sum() actions_1hot, label_probs, new_mask, output_image = \ skip_rnn_forward_supervised(x, x_mask, sess, test_graph, args.fast_action, args.n_fast_action, y) pred_idx = expand_output(actions_1hot, x_mask, new_mask, label_probs.argmax(axis=-1)) tr_acc_sum += ((pred_idx == y) * x_mask).sum() tr_acc_count += x_mask.sum() _tr_ce_summary, _tr_fer_summary, _tr_image_summary = \ sess.run([tr_ce_summary, tr_fer_summary, tr_image_summary], feed_dict={tr_ce: _tr_ml_cost.sum() / new_x_mask.sum(), tr_fer: 1 - ((pred_idx == y) * x_mask).sum() / x_mask.sum(), tr_image: output_image}) summary_writer.add_summary(_tr_ce_summary, global_step.eval()) summary_writer.add_summary(_tr_fer_summary, global_step.eval()) summary_writer.add_summary(_tr_ce2_summary, global_step.eval()) summary_writer.add_summary(_tr_image_summary, global_step.eval()) if global_step.eval() % args.display_freq == 0: avg_tr_ce = tr_ce_sum / tr_ce_count avg_tr_fer = 1. - tr_acc_sum / tr_acc_count avg_tr_ce2 = tr_ce2_sum / tr_ce2_count print( "TRAIN: epoch={} iter={} ml_cost(ce/frame)={:.2f} fer={:.2f} rl_cost={:.4f} time_taken={:.2f}" .format(_epoch, global_step.eval(), avg_tr_ce, avg_tr_fer, avg_tr_ce2, disp_sw.elapsed())) tr_ce_sum = 0. tr_ce_count = 0 tr_acc_sum = 0. tr_acc_count = 0 tr_ce2_sum = 0. tr_ce2_count = 0 disp_sw.reset() print('--') print('End of epoch {}'.format(_epoch + 1)) epoch_sw.print_elapsed() print('Testing') # Evaluate the model on the validation set val_acc_sum = 0 val_acc_count = 0 eval_sw.reset() for batch in valid_set.get_epoch_iterator(): x, x_mask, _, _, y, _ = batch n_batch = x.shape[0] actions_1hot, label_probs, new_mask, output_image = \ skip_rnn_forward_supervised(x, x_mask, sess, test_graph, args.fast_action, args.n_fast_action, y) pred_idx = expand_output(actions_1hot, x_mask, new_mask, label_probs.argmax(axis=-1)) val_acc_sum += ((pred_idx == y) * x_mask).sum() val_acc_count += x_mask.sum() avg_val_fer = 1. - val_acc_sum / val_acc_count print("VALID: epoch={} fer={:.2f} time_taken={:.2f}".format( _epoch, avg_val_fer, eval_sw.elapsed())) _val_fer_summary, _val_image_summary = sess.run( [val_fer_summary, val_image_summary], feed_dict={ val_fer: avg_val_fer, val_image: output_image }) summary_writer.add_summary(_val_fer_summary, global_step.eval()) summary_writer.add_summary(_val_image_summary, global_step.eval()) insert_item2dict(eval_summary, 'val_fer', avg_val_fer) insert_item2dict(eval_summary, 'time', eval_sw.elapsed()) save_npz2(file_name, eval_summary) # Save model if avg_val_fer < _best_score: _best_score = avg_val_fer best_ckpt = best_save_op.save(sess, os.path.join( args.log_dir, "best_model.ckpt"), global_step=global_step) print("Best checkpoint stored in: %s" % best_ckpt) ckpt = save_op.save(sess, os.path.join(args.log_dir, "model.ckpt"), global_step=global_step) print("Checkpoint stored in: %s" % ckpt) _best_val_fer_summary, = sess.run( [best_val_fer_summary], feed_dict={best_val_fer: _best_score}) summary_writer.add_summary(_best_val_fer_summary, global_step.eval()) summary_writer.close() print("Optimization Finished.")