예제 #1
0
def main():
    print(' '.join(sys.argv))

    args = get_args()
    print(args)

    print('Hostname: {}'.format(socket.gethostname()))
    print('GPU: {}'.format(get_gpuname()))

    if not args.start_from_ckpt:
        if tf.gfile.Exists(args.log_dir):
            tf.gfile.DeleteRecursively(args.log_dir)
        tf.gfile.MakeDirs(args.log_dir)

    tf.get_variable_scope()._reuse = None

    _seed = args.base_seed + args.add_seed
    tf.set_random_seed(_seed)
    np.random.seed(_seed)

    prefix_name = os.path.join(args.log_dir, 'model')
    file_name = '%s.npz' % prefix_name

    eval_summary = OrderedDict()

    tg, sg = build_graph(args)
    tg_ml_cost = tf.reduce_mean(tg.ml_cost)

    global_step = tf.Variable(0, trainable=False, name="global_step")

    tvars = tf.trainable_variables()
    print([tvar.name for tvar in tvars])

    ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate,
                                         beta1=0.9,
                                         beta2=0.99)
    rl_opt_func = tf.train.AdamOptimizer(learning_rate=args.rl_learning_rate,
                                         beta1=0.9,
                                         beta2=0.99)

    if args.grad_clip:
        ml_grads, _ = tf.clip_by_global_norm(tf.gradients(tg_ml_cost, tvars),
                                             clip_norm=1.0)
    else:
        ml_grads = tf.gradients(tg_ml_cost, tvars)
    ml_op = ml_opt_func.apply_gradients(zip(ml_grads, tvars),
                                        global_step=global_step)

    tg_rl_cost = tf.reduce_mean(tg.rl_cost)
    rl_grads = tf.gradients(tg_rl_cost, tvars)
    # do not increase global step -- ml op increases it
    rl_op = rl_opt_func.apply_gradients(zip(rl_grads, tvars))

    tf.add_to_collection('n_fast_action', args.n_fast_action)

    sync_data(args)
    datasets = [args.train_dataset, args.valid_dataset, args.test_dataset]
    train_set, valid_set, test_set = [
        create_ivector_datastream(path=args.data_path,
                                  which_set=dataset,
                                  batch_size=args.n_batch,
                                  min_after_cache=args.min_after_cache,
                                  length_sort=not args.no_length_sort)
        for dataset in datasets
    ]

    init_op = tf.global_variables_initializer()
    save_op = tf.train.Saver(max_to_keep=5)
    best_save_op = tf.train.Saver(max_to_keep=5)

    with tf.name_scope("per_step_eval"):
        tr_ce = tf.placeholder(tf.float32)
        tr_ce_summary = tf.summary.scalar("tr_ce", tr_ce)
        tr_image = tf.placeholder(tf.float32)
        tr_image_summary = tf.summary.image("tr_image", tr_image)
        tr_fer = tf.placeholder(tf.float32)
        tr_fer_summary = tf.summary.scalar("tr_fer", tr_fer)
        tr_rl = tf.placeholder(tf.float32)
        tr_rl_summary = tf.summary.scalar("tr_rl", tr_rl)
        tr_rw_hist = tf.placeholder(tf.float32)
        tr_rw_hist_summary = tf.summary.histogram("tr_reward_hist", tr_rw_hist)

    with tf.name_scope("per_epoch_eval"):
        best_val_ce = tf.placeholder(tf.float32)
        val_ce = tf.placeholder(tf.float32)
        best_val_ce_summary = tf.summary.scalar("best_valid_ce", best_val_ce)
        val_ce_summary = tf.summary.scalar("valid_ce", val_ce)

    vf = LinearVF()

    with tf.Session() as sess:
        sess.run(init_op)

        if args.start_from_ckpt:
            save_op = tf.train.import_meta_graph(
                os.path.join(args.log_dir, 'model.ckpt.meta'))
            save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt'))
            print(
                "Restore from the last checkpoint. Restarting from %d step." %
                global_step.eval())

        summary_writer = tf.summary.FileWriter(args.log_dir,
                                               sess.graph,
                                               flush_secs=5.0)

        tr_ce_sum = 0.
        tr_ce_count = 0
        tr_acc_sum = 0
        tr_acc_count = 0
        tr_rl_costs = []
        tr_action_entropies = []
        tr_rewards = []

        _best_score = np.iinfo(np.int32).max

        epoch_sw = StopWatch()
        disp_sw = StopWatch()
        eval_sw = StopWatch()
        per_sw = StopWatch()

        # For each epoch
        for _epoch in xrange(args.n_epoch):
            _n_exp = 0

            epoch_sw.reset()
            disp_sw.reset()

            print('Epoch {} training'.format(_epoch + 1))

            # For each batch
            for batch in train_set.get_epoch_iterator():
                x, x_mask, _, _, y, _ = batch
                n_batch = x.shape[0]
                _n_exp += n_batch

                new_x, new_y, actions_1hot, rewards, action_entropies, new_x_mask, new_reward_mask, output_image = \
                        gen_episode_with_seg_reward(x, x_mask, y, sess, sg, args)

                advantages, _ = compute_advantage(new_x, new_x_mask, rewards,
                                                  new_reward_mask, vf, args)

                zero_state = gen_zero_state(n_batch, args.n_hidden)

                feed_dict = {
                    tg.seq_x_data: new_x,
                    tg.seq_x_mask: new_x_mask,
                    tg.seq_y_data: new_y,
                    tg.seq_action: actions_1hot,
                    tg.seq_advantage: advantages,
                    tg.seq_action_mask: new_reward_mask,
                    tg.seq_y_data_for_action: new_y
                }
                feed_init_state(feed_dict, tg.init_state, zero_state)

                _tr_ml_cost, _tr_rl_cost, _, _, pred_idx = \
                    sess.run([tg.ml_cost, tg.rl_cost, ml_op, rl_op, tg.pred_idx], feed_dict=feed_dict)

                tr_ce_sum += _tr_ml_cost.sum()
                tr_ce_count += new_x_mask.sum()

                pred_idx = expand_output(actions_1hot, x_mask, new_x_mask,
                                         pred_idx.reshape([n_batch, -1]),
                                         args.n_fast_action)
                tr_acc_sum += ((pred_idx == y) * x_mask).sum()
                tr_acc_count += x_mask.sum()

                _tr_ce_summary, _tr_fer_summary, _tr_rl_summary, _tr_image_summary, _tr_rw_hist_summary = \
                    sess.run([tr_ce_summary, tr_fer_summary, tr_rl_summary, tr_image_summary, tr_rw_hist_summary],
                        feed_dict={tr_ce: _tr_ml_cost.sum() / new_x_mask.sum(),
                            tr_fer: ((pred_idx == y) * x_mask).sum() / x_mask.sum(),
                            tr_rl: _tr_rl_cost.sum() / new_reward_mask.sum(),
                            tr_image: output_image, tr_rw_hist: rewards})
                summary_writer.add_summary(_tr_ce_summary, global_step.eval())
                summary_writer.add_summary(_tr_fer_summary, global_step.eval())
                summary_writer.add_summary(_tr_rl_summary, global_step.eval())
                summary_writer.add_summary(_tr_image_summary,
                                           global_step.eval())
                summary_writer.add_summary(_tr_rw_hist_summary,
                                           global_step.eval())

                tr_rl_costs.append(_tr_rl_cost.sum() / new_reward_mask.sum())
                tr_action_entropies.append(action_entropies.sum() /
                                           new_reward_mask.sum())
                tr_rewards.append(rewards.sum() / new_reward_mask.sum())

                if global_step.eval() % args.display_freq == 0:
                    avg_tr_ce = tr_ce_sum / tr_ce_count
                    avg_tr_fer = 1. - tr_acc_sum / tr_acc_count
                    avg_tr_rl_cost = np.asarray(tr_rl_costs).mean()
                    avg_tr_action_entropy = np.asarray(
                        tr_action_entropies).mean()
                    avg_tr_reward = np.asarray(tr_rewards).mean()

                    print(
                        "TRAIN: epoch={} iter={} ml_cost(ce/frame)={:.2f} fer={:.2f} rl_cost={:.4f} reward={:.4f} action_entropy={:.2f} time_taken={:.2f}"
                        .format(_epoch, global_step.eval(), avg_tr_ce,
                                avg_tr_fer, avg_tr_rl_cost, avg_tr_reward,
                                avg_tr_action_entropy, disp_sw.elapsed()))

                    tr_ce_sum = 0.
                    tr_ce_count = 0
                    tr_acc_sum = 0.
                    tr_acc_count = 0
                    tr_rl_costs = []
                    tr_action_entropies = []
                    tr_rewards = []

                    disp_sw.reset()

            print('--')
            print('End of epoch {}'.format(_epoch + 1))
            epoch_sw.print_elapsed()

            print('Testing')

            # Evaluate the model on the validation set
            val_ce_sum = 0.
            val_ce_count = 0
            val_acc_sum = 0
            val_acc_count = 0
            val_rl_costs = []
            val_action_entropies = []
            val_rewards = []

            eval_sw.reset()
            for batch in valid_set.get_epoch_iterator():
                x, x_mask, _, _, y, _ = batch
                n_batch = x.shape[0]

                new_x, new_y, actions_1hot, rewards, action_entropies, new_x_mask, new_reward_mask, output_image, new_y_sample = \
                        gen_episode_with_seg_reward(x, x_mask, y, sess, sg, args, sample_y=True)

                advantages, _ = compute_advantage(new_x, new_x_mask, rewards,
                                                  new_reward_mask, vf, args)

                zero_state = gen_zero_state(n_batch, args.n_hidden)

                feed_dict = {
                    tg.seq_x_data: new_x,
                    tg.seq_x_mask: new_x_mask,
                    tg.seq_y_data: new_y,
                    tg.seq_action: actions_1hot,
                    tg.seq_advantage: advantages,
                    tg.seq_action_mask: new_reward_mask,
                    tg.seq_y_data_for_action: new_y_sample
                }
                feed_init_state(feed_dict, tg.init_state, zero_state)

                _val_ml_cost, _val_rl_cost, pred_idx = sess.run(
                    [tg.ml_cost, tg.rl_cost, tg.pred_idx], feed_dict=feed_dict)

                val_ce_sum += _val_ml_cost.sum()
                val_ce_count += new_x_mask.sum()

                pred_idx = expand_output(actions_1hot, x_mask, new_x_mask,
                                         pred_idx.reshape([n_batch, -1]),
                                         args.n_fast_action)
                val_acc_sum += ((pred_idx == y) * x_mask).sum()
                val_acc_count += x_mask.sum()

                val_rl_costs.append(_val_rl_cost.sum() / new_reward_mask.sum())
                val_action_entropies.append(action_entropies.sum() /
                                            new_reward_mask.sum())
                val_rewards.append(rewards.sum() / new_reward_mask.sum())

            avg_val_ce = val_ce_sum / val_ce_count
            avg_val_fer = 1. - val_acc_sum / val_acc_count
            avg_val_rl_cost = np.asarray(val_rl_costs).mean()
            avg_val_action_entropy = np.asarray(val_action_entropies).mean()
            avg_val_reward = np.asarray(val_rewards).mean()

            print(
                "VALID: epoch={} ml_cost(ce/frame)={:.2f} fer={:.2f} rl_cost={:.4f} reward={:.4f} action_entropy={:.2f} time_taken={:.2f}"
                .format(_epoch, avg_val_ce, avg_val_fer, avg_val_rl_cost,
                        avg_val_reward, avg_val_action_entropy,
                        eval_sw.elapsed()))

            _val_ce_summary, = sess.run([val_ce_summary],
                                        feed_dict={val_ce: avg_val_ce})
            summary_writer.add_summary(_val_ce_summary, global_step.eval())

            insert_item2dict(eval_summary, 'val_ce', avg_val_ce)
            insert_item2dict(eval_summary, 'val_rl_cost', avg_val_rl_cost)
            insert_item2dict(eval_summary, 'val_reward', avg_val_reward)
            insert_item2dict(eval_summary, 'val_action_entropy',
                             avg_val_action_entropy)
            insert_item2dict(eval_summary, 'time', eval_sw.elapsed())
            save_npz2(file_name, eval_summary)

            # Save model
            if avg_val_ce < _best_score:
                _best_score = avg_val_ce
                best_ckpt = best_save_op.save(sess,
                                              os.path.join(
                                                  args.log_dir,
                                                  "best_model.ckpt"),
                                              global_step=global_step)
                print("Best checkpoint stored in: %s" % best_ckpt)
            ckpt = save_op.save(sess,
                                os.path.join(args.log_dir, "model.ckpt"),
                                global_step=global_step)
            print("Checkpoint stored in: %s" % ckpt)

            _best_val_ce_summary, = sess.run(
                [best_val_ce_summary], feed_dict={best_val_ce: _best_score})
            summary_writer.add_summary(_best_val_ce_summary,
                                       global_step.eval())
        summary_writer.close()

        print("Optimization Finished.")
예제 #2
0
def main(_):
    print(' '.join(sys.argv))
    args = FLAGS
    print(args.__flags)
    if not args.start_from_ckpt:
        if tf.gfile.Exists(args.log_dir):
            tf.gfile.DeleteRecursively(args.log_dir)
        tf.gfile.MakeDirs(args.log_dir)

    tf.get_variable_scope()._reuse = None

    _seed = args.base_seed + args.add_seed
    tf.set_random_seed(_seed)
    np.random.seed(_seed)

    prefix_name = os.path.join(args.log_dir, 'model')
    file_name = '%s.npz' % prefix_name

    eval_summary = OrderedDict()  #

    tg, test_graph = build_graph(args)
    tg_ml_cost = tf.reduce_mean(tg.ml_cost)

    global_step = tf.Variable(0, trainable=False, name="global_step")

    tvars = tf.trainable_variables()

    ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate,
                                         beta1=0.9,
                                         beta2=0.99)

    if args.grad_clip:
        ml_grads, _ = tf.clip_by_global_norm(tf.gradients(tg_ml_cost, tvars),
                                             clip_norm=1.0)
    else:
        ml_grads = tf.gradients(tg_ml_cost, tvars)
    ml_op = ml_opt_func.apply_gradients(zip(ml_grads, tvars),
                                        global_step=global_step)

    tf.add_to_collection('n_skip', args.n_skip)
    tf.add_to_collection('n_hidden', args.n_hidden)

    sync_data(args)
    datasets = [args.train_dataset, args.valid_dataset, args.test_dataset]
    train_set, valid_set, test_set = [
        create_ivector_datastream(path=args.data_path,
                                  which_set=dataset,
                                  batch_size=args.batch_size,
                                  min_after_cache=args.min_after_cache,
                                  length_sort=not args.no_length_sort)
        for dataset in datasets
    ]

    init_op = tf.global_variables_initializer()
    save_op = tf.train.Saver(max_to_keep=5)
    best_save_op = tf.train.Saver(max_to_keep=5)

    with tf.name_scope("per_step_eval"):
        tr_ce = tf.placeholder(tf.float32)
        tr_ce_summary = tf.summary.scalar("tr_ce", tr_ce)

    with tf.name_scope("per_epoch_eval"):
        best_val_ce = tf.placeholder(tf.float32)
        val_ce = tf.placeholder(tf.float32)
        best_val_ce_summary = tf.summary.scalar("best_valid_ce", best_val_ce)
        val_ce_summary = tf.summary.scalar("valid_ce", val_ce)

    with tf.Session() as sess:
        sess.run(init_op)

        if args.start_from_ckpt:
            save_op = tf.train.import_meta_graph(
                os.path.join(args.log_dir, 'model.ckpt.meta'))
            save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt'))
            print("Restore from the last checkpoint. "
                  "Restarting from %d step." % global_step.eval())

        summary_writer = tf.summary.FileWriter(args.log_dir,
                                               sess.graph,
                                               flush_secs=5.0)

        tr_ce_sum = 0.
        tr_ce_count = 0
        tr_acc_sum = 0
        tr_acc_count = 0
        _best_score = np.iinfo(np.int32).max

        epoch_sw = StopWatch()
        disp_sw = StopWatch()
        eval_sw = StopWatch()
        per_sw = StopWatch()
        # For each epoch
        for _epoch in xrange(args.n_epoch):
            _n_exp = 0

            epoch_sw.reset()
            disp_sw.reset()

            print('--')
            print('Epoch {} training'.format(_epoch + 1))

            # For each batch
            for batch in train_set.get_epoch_iterator():
                orig_x, orig_x_mask, _, _, orig_y, _ = batch

                for sub_batch in skip_frames_fixed(
                    [orig_x, orig_x_mask, orig_y], args.n_skip + 1):
                    x, x_mask, y = sub_batch
                    n_batch, _, _ = x.shape
                    _n_exp += n_batch

                    zero_state = gen_zero_state(n_batch, args.n_hidden)

                    feed_dict = {
                        tg.seq_x_data: x,
                        tg.seq_x_mask: x_mask,
                        tg.seq_y_data: y
                    }
                    feed_init_state(feed_dict, tg.init_state, zero_state)

                    _tr_ml_cost, _pred_idx, _ = sess.run(
                        [tg.ml_cost, tg.pred_idx, ml_op], feed_dict=feed_dict)

                    tr_ce_sum += _tr_ml_cost.sum()
                    tr_ce_count += x_mask.sum()
                    _tr_ce_summary, = sess.run(
                        [tr_ce_summary],
                        feed_dict={tr_ce: _tr_ml_cost.sum() / x_mask.sum()})
                    summary_writer.add_summary(_tr_ce_summary,
                                               global_step.eval())

                    _, n_seq = orig_y.shape
                    _pred_idx = _pred_idx.reshape([n_batch,
                                                   -1]).repeat(args.n_skip + 1,
                                                               axis=1)
                    _pred_idx = _pred_idx[:, :n_seq]

                    tr_acc_sum += ((_pred_idx == orig_y) * orig_x_mask).sum()
                    tr_acc_count += orig_x_mask.sum()

                if global_step.eval() % args.display_freq == 0:
                    avg_tr_ce = tr_ce_sum / tr_ce_count
                    avg_tr_fer = 1. - float(tr_acc_sum) / tr_acc_count

                    print(
                        "TRAIN: epoch={} iter={} ml_cost(ce/frame)={:.2f} fer={:.2f} time_taken={:.2f}"
                        .format(_epoch, global_step.eval(), avg_tr_ce,
                                avg_tr_fer, disp_sw.elapsed()))

                    tr_ce_sum = 0.
                    tr_ce_count = 0
                    tr_acc_sum = 0
                    tr_acc_count = 0
                    disp_sw.reset()

            print('--')
            print('End of epoch {}'.format(_epoch + 1))
            epoch_sw.print_elapsed()

            print('Testing')

            # Evaluate the model on the validation set
            val_ce_sum = 0.
            val_ce_count = 0
            val_acc_sum = 0
            val_acc_count = 0

            eval_sw.reset()
            for batch in valid_set.get_epoch_iterator():
                orig_x, orig_x_mask, _, _, orig_y, _ = batch

                for sub_batch in skip_frames_fixed(
                    [orig_x, orig_x_mask, orig_y],
                        args.n_skip + 1,
                        return_first=True):
                    x, x_mask, y = sub_batch
                    n_batch, _, _ = x.shape

                    zero_state = gen_zero_state(n_batch, args.n_hidden)

                    feed_dict = {
                        tg.seq_x_data: x,
                        tg.seq_x_mask: x_mask,
                        tg.seq_y_data: y
                    }
                    feed_init_state(feed_dict, tg.init_state, zero_state)

                    _val_ml_cost, _pred_idx = sess.run(
                        [tg.ml_cost, tg.pred_idx], feed_dict=feed_dict)

                    val_ce_sum += _val_ml_cost.sum()
                    val_ce_count += x_mask.sum()

                    _, n_seq = orig_y.shape
                    _pred_idx = _pred_idx.reshape([n_batch,
                                                   -1]).repeat(args.n_skip + 1,
                                                               axis=1)
                    _pred_idx = _pred_idx[:, :n_seq]

                    val_acc_sum += ((_pred_idx == orig_y) * orig_x_mask).sum()
                    val_acc_count += orig_x_mask.sum()

            avg_val_ce = val_ce_sum / val_ce_count
            avg_val_fer = 1. - float(val_acc_sum) / val_acc_count

            print(
                "VALID: epoch={} ml_cost(ce/frame)={:.2f} fer={:.2f} time_taken={:.2f}"
                .format(_epoch, avg_val_ce, avg_val_fer, eval_sw.elapsed()))

            _val_ce_summary, = sess.run([val_ce_summary],
                                        feed_dict={val_ce: avg_val_ce})
            summary_writer.add_summary(_val_ce_summary, global_step.eval())

            insert_item2dict(eval_summary, 'val_ce', avg_val_ce)
            insert_item2dict(eval_summary, 'time', eval_sw.elapsed())
            save_npz2(file_name, eval_summary)

            # Save model
            if avg_val_ce < _best_score:
                _best_score = avg_val_ce
                best_ckpt = best_save_op.save(sess,
                                              os.path.join(
                                                  args.log_dir,
                                                  "best_model.ckpt"),
                                              global_step=global_step)
                print("Best checkpoint stored in: %s" % best_ckpt)
            ckpt = save_op.save(sess,
                                os.path.join(args.log_dir, "model.ckpt"),
                                global_step=global_step)
            print("Checkpoint stored in: %s" % ckpt)

            _best_val_ce_summary, = sess.run(
                [best_val_ce_summary], feed_dict={best_val_ce: _best_score})
            summary_writer.add_summary(_best_val_ce_summary,
                                       global_step.eval())
        summary_writer.close()

        print("Optimization Finished.")
예제 #3
0
            for accu in accu_list:
                accu.reset()

            # For each batch
            for batch in train_set.get_epoch_iterator():
                orig_x, orig_x_mask, _, _, orig_y, _ = batch

                sub_batch, start_idx = utils.skip_frames_fixed(
                    [orig_x, orig_x_mask, orig_y],
                    args.n_skip + 1,
                    return_start_idx=True)
                x, x_mask, y = sub_batch
                n_batch, _, _ = x.shape

                zero_state = gen_zero_state(n_batch, args.n_hidden)

                feed_dict = {
                    tg.seq_x_data: x,
                    tg.seq_x_mask: x_mask,
                    tg.seq_y_data: y
                }
                feed_init_state(feed_dict, tg.init_state, zero_state)

                ml_cost, _ = sess.run([tg.ml_cost, ml_op], feed_dict=feed_dict)
                orig_count, comp_count = orig_x_mask.sum(), x_mask.sum()

                ce.add(ml_cost.sum(), comp_count)
                cr.add(float(comp_count) / orig_count, 1)

                if global_step.eval() % args.display_freq == 0:
예제 #4
0
def main(_):
    print(' '.join(sys.argv))
    args = FLAGS
    print(args.__flags)
    print('Hostname: {}'.format(socket.gethostname()))
    print('GPU: {}'.format(get_gpuname()))

    if not args.start_from_ckpt:
        if tf.gfile.Exists(args.log_dir):
            tf.gfile.DeleteRecursively(args.log_dir)
        tf.gfile.MakeDirs(args.log_dir)

    tf.get_variable_scope()._reuse = None

    _seed = args.base_seed + args.add_seed
    tf.set_random_seed(_seed)
    np.random.seed(_seed)

    prefix_name = os.path.join(args.log_dir, 'model')
    file_name = '%s.npz' % prefix_name

    eval_summary = OrderedDict()

    tg, test_graph = build_graph(args)
    tg_ml_cost = tf.reduce_mean(tg.ml_cost)

    global_step = tf.Variable(0, trainable=False, name="global_step")

    tvars = tf.trainable_variables()
    print([tvar.name for tvar in tvars])
    ml_tvars = [tvar for tvar in tvars if "action_logit" not in tvar.name]
    rl_tvars = [tvar for tvar in tvars if "action_logit" in tvar.name]

    ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate,
                                         beta1=0.9,
                                         beta2=0.99)
    rl_opt_func = tf.train.AdamOptimizer(learning_rate=args.rl_learning_rate,
                                         beta1=0.9,
                                         beta2=0.99)

    if args.grad_clip:
        ml_grads, _ = tf.clip_by_global_norm(tf.gradients(
            tg_ml_cost, ml_tvars),
                                             clip_norm=1.0)
    else:
        ml_grads = tf.gradients(tg_ml_cost, ml_tvars)
    ml_op = ml_opt_func.apply_gradients(zip(ml_grads, ml_tvars),
                                        global_step=global_step)

    tg_rl_cost = tf.reduce_mean(tg.rl_cost)
    rl_grads = tf.gradients(tg_rl_cost, rl_tvars)
    rl_op = rl_opt_func.apply_gradients(zip(rl_grads, rl_tvars))

    tf.add_to_collection('fast_action', args.fast_action)
    tf.add_to_collection('fast_action', args.n_fast_action)

    sync_data(args)
    datasets = [args.train_dataset, args.valid_dataset, args.test_dataset]
    train_set, valid_set, test_set = [
        create_ivector_datastream(path=args.data_path,
                                  which_set=dataset,
                                  batch_size=args.n_batch,
                                  min_after_cache=args.min_after_cache,
                                  length_sort=not args.no_length_sort)
        for dataset in datasets
    ]

    init_op = tf.global_variables_initializer()
    save_op = tf.train.Saver(max_to_keep=5)
    best_save_op = tf.train.Saver(max_to_keep=5)

    with tf.name_scope("per_step_eval"):
        tr_ce = tf.placeholder(tf.float32)
        tr_ce_summary = tf.summary.scalar("tr_ce", tr_ce)
        tr_fer = tf.placeholder(tf.float32)
        tr_fer_summary = tf.summary.scalar("tr_fer", tr_fer)
        tr_ce2 = tf.placeholder(tf.float32)
        tr_ce2_summary = tf.summary.scalar("tr_rl", tr_ce2)

        tr_image = tf.placeholder(tf.float32)
        tr_image_summary = tf.summary.image("tr_image", tr_image)

    with tf.name_scope("per_epoch_eval"):
        val_fer = tf.placeholder(tf.float32)
        val_fer_summary = tf.summary.scalar("val_fer", val_fer)
        best_val_fer = tf.placeholder(tf.float32)
        best_val_fer_summary = tf.summary.scalar("best_valid_fer",
                                                 best_val_fer)
        val_image = tf.placeholder(tf.float32)
        val_image_summary = tf.summary.image("val_image", val_image)

    vf = LinearVF()

    with tf.Session() as sess:
        sess.run(init_op)

        if args.start_from_ckpt:
            save_op = tf.train.import_meta_graph(
                os.path.join(args.log_dir, 'model.ckpt.meta'))
            save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt'))
            print(
                "Restore from the last checkpoint. Restarting from %d step." %
                global_step.eval())

        summary_writer = tf.summary.FileWriter(args.log_dir,
                                               sess.graph,
                                               flush_secs=5.0)

        tr_ce_sum = 0.
        tr_ce_count = 0
        tr_acc_sum = 0
        tr_acc_count = 0
        tr_ce2_sum = 0.
        tr_ce2_count = 0

        _best_score = np.iinfo(np.int32).max

        epoch_sw = StopWatch()
        disp_sw = StopWatch()
        eval_sw = StopWatch()
        per_sw = StopWatch()

        # For each epoch
        for _epoch in xrange(args.n_epoch):
            _n_exp = 0

            epoch_sw.reset()
            disp_sw.reset()

            print('Epoch {} training'.format(_epoch + 1))

            # For each batch
            for batch in train_set.get_epoch_iterator():
                x, x_mask, _, _, y, _ = batch
                n_batch = x.shape[0]
                _n_exp += n_batch

                if args.no_sampling:
                    new_x, new_y, actions, actions_1hot, new_x_mask = gen_supervision(
                        x, x_mask, y, args)

                    zero_state = gen_zero_state(n_batch, args.n_hidden)

                    feed_dict = {
                        tg.seq_x_data: new_x,
                        tg.seq_x_mask: new_x_mask,
                        tg.seq_y_data: new_y,
                        tg.seq_jump_data: actions
                    }
                    feed_init_state(feed_dict, tg.init_state, zero_state)

                    _tr_ml_cost, _tr_rl_cost, _, _ = \
                        sess.run([tg.ml_cost, tg.rl_cost, ml_op, rl_op], feed_dict=feed_dict)

                    tr_ce_sum += _tr_ml_cost.sum()
                    tr_ce_count += new_x_mask.sum()
                    tr_ce2_sum += _tr_rl_cost.sum()
                    tr_ce2_count += new_x_mask[:, :-1].sum()

                    actions_1hot, label_probs, new_mask, output_image = \
                        skip_rnn_forward_supervised(x, x_mask, sess, test_graph,
                        args.fast_action, args.n_fast_action, y)

                    pred_idx = expand_output(actions_1hot, x_mask, new_mask,
                                             label_probs.argmax(axis=-1))
                    tr_acc_sum += ((pred_idx == y) * x_mask).sum()
                    tr_acc_count += x_mask.sum()

                    _tr_ce_summary, _tr_fer_summary, _tr_ce2_summary, _tr_image_summary = \
                        sess.run([tr_ce_summary, tr_fer_summary, tr_ce2_summary, tr_image_summary],
                            feed_dict={tr_ce: _tr_ml_cost.sum() / new_x_mask.sum(),
                                tr_fer: 1 - ((pred_idx == y) * x_mask).sum() / x_mask.sum(),
                                tr_ce2: _tr_rl_cost.sum() / new_x_mask[:,:-1].sum(),
                                tr_image: output_image})
                    summary_writer.add_summary(_tr_ce_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_fer_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_ce2_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_image_summary,
                                               global_step.eval())

                else:

                    # train jump prediction part
                    new_x, _, actions, _, new_x_mask = gen_supervision(
                        x, x_mask, y, args)
                    zero_state = gen_zero_state(n_batch, args.n_hidden)
                    feed_dict = {
                        tg.seq_x_data: new_x,
                        tg.seq_x_mask: new_x_mask,
                        tg.seq_jump_data: actions
                    }
                    feed_init_state(feed_dict, tg.init_state, zero_state)
                    _tr_rl_cost, _ = sess.run([tg.rl_cost, rl_op],
                                              feed_dict=feed_dict)

                    tr_ce2_sum += _tr_rl_cost.sum()
                    tr_ce2_count += new_x_mask[:, :-1].sum()

                    _tr_ce2_summary, = sess.run([tr_ce2_summary],
                                                feed_dict={
                                                    tr_ce2:
                                                    _tr_rl_cost.sum() /
                                                    new_x_mask[:, :-1].sum()
                                                })

                    # generate jumps from the model
                    new_x, new_y, actions_1hot, label_probs, new_x_mask, output_image = gen_episode_supervised(
                        x, y, x_mask, sess, test_graph, args.fast_action,
                        args.n_fast_action)

                    feed_dict = {
                        tg.seq_x_data: new_x,
                        tg.seq_x_mask: new_x_mask,
                        tg.seq_y_data: new_y
                    }
                    feed_init_state(feed_dict, tg.init_state, zero_state)

                    # train label prediction part
                    _tr_ml_cost, _ = sess.run([tg.ml_cost, ml_op],
                                              feed_dict=feed_dict)

                    tr_ce_sum += _tr_ml_cost.sum()
                    tr_ce_count += new_x_mask.sum()

                    actions_1hot, label_probs, new_mask, output_image = \
                        skip_rnn_forward_supervised(x, x_mask, sess, test_graph,
                        args.fast_action, args.n_fast_action, y)

                    pred_idx = expand_output(actions_1hot, x_mask, new_mask,
                                             label_probs.argmax(axis=-1))
                    tr_acc_sum += ((pred_idx == y) * x_mask).sum()
                    tr_acc_count += x_mask.sum()

                    _tr_ce_summary, _tr_fer_summary, _tr_image_summary = \
                        sess.run([tr_ce_summary, tr_fer_summary, tr_image_summary],
                            feed_dict={tr_ce: _tr_ml_cost.sum() / new_x_mask.sum(),
                                tr_fer: 1 - ((pred_idx == y) * x_mask).sum() / x_mask.sum(),
                                tr_image: output_image})
                    summary_writer.add_summary(_tr_ce_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_fer_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_ce2_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_image_summary,
                                               global_step.eval())

                if global_step.eval() % args.display_freq == 0:
                    avg_tr_ce = tr_ce_sum / tr_ce_count
                    avg_tr_fer = 1. - tr_acc_sum / tr_acc_count
                    avg_tr_ce2 = tr_ce2_sum / tr_ce2_count

                    print(
                        "TRAIN: epoch={} iter={} ml_cost(ce/frame)={:.2f} fer={:.2f} rl_cost={:.4f} time_taken={:.2f}"
                        .format(_epoch, global_step.eval(), avg_tr_ce,
                                avg_tr_fer, avg_tr_ce2, disp_sw.elapsed()))

                    tr_ce_sum = 0.
                    tr_ce_count = 0
                    tr_acc_sum = 0.
                    tr_acc_count = 0
                    tr_ce2_sum = 0.
                    tr_ce2_count = 0

                    disp_sw.reset()

            print('--')
            print('End of epoch {}'.format(_epoch + 1))
            epoch_sw.print_elapsed()

            print('Testing')

            # Evaluate the model on the validation set
            val_acc_sum = 0
            val_acc_count = 0

            eval_sw.reset()
            for batch in valid_set.get_epoch_iterator():
                x, x_mask, _, _, y, _ = batch
                n_batch = x.shape[0]

                actions_1hot, label_probs, new_mask, output_image = \
                    skip_rnn_forward_supervised(x, x_mask, sess, test_graph,
                    args.fast_action, args.n_fast_action, y)

                pred_idx = expand_output(actions_1hot, x_mask, new_mask,
                                         label_probs.argmax(axis=-1))
                val_acc_sum += ((pred_idx == y) * x_mask).sum()
                val_acc_count += x_mask.sum()

            avg_val_fer = 1. - val_acc_sum / val_acc_count

            print("VALID: epoch={} fer={:.2f} time_taken={:.2f}".format(
                _epoch, avg_val_fer, eval_sw.elapsed()))

            _val_fer_summary, _val_image_summary = sess.run(
                [val_fer_summary, val_image_summary],
                feed_dict={
                    val_fer: avg_val_fer,
                    val_image: output_image
                })
            summary_writer.add_summary(_val_fer_summary, global_step.eval())
            summary_writer.add_summary(_val_image_summary, global_step.eval())

            insert_item2dict(eval_summary, 'val_fer', avg_val_fer)
            insert_item2dict(eval_summary, 'time', eval_sw.elapsed())
            save_npz2(file_name, eval_summary)

            # Save model
            if avg_val_fer < _best_score:
                _best_score = avg_val_fer
                best_ckpt = best_save_op.save(sess,
                                              os.path.join(
                                                  args.log_dir,
                                                  "best_model.ckpt"),
                                              global_step=global_step)
                print("Best checkpoint stored in: %s" % best_ckpt)
            ckpt = save_op.save(sess,
                                os.path.join(args.log_dir, "model.ckpt"),
                                global_step=global_step)
            print("Checkpoint stored in: %s" % ckpt)

            _best_val_fer_summary, = sess.run(
                [best_val_fer_summary], feed_dict={best_val_fer: _best_score})
            summary_writer.add_summary(_best_val_fer_summary,
                                       global_step.eval())
        summary_writer.close()

        print("Optimization Finished.")