예제 #1
0
    ##############
    # print args #
    ##############
    print(args)

    sw = StopWatch()

    if not args.no_copy:
        print('Loading data streams from {}'.format(args.data_path))
        print('Copying data to local machine...')
        rsync = Rsync(args.tmpdir)
        rsync.sync(args.data_path)
        args.data_path = os.path.join(args.tmpdir,
                                      os.path.basename(args.data_path))
        sw.print_elapsed()

    ####################
    # load data stream #
    ####################
    train_datastream = get_datastream(path=args.data_path,
                                      which_set=args.train_dataset,
                                      batch_size=args.batch_size)
    valid_datastream = get_datastream(path=args.data_path,
                                      which_set=args.valid_dataset,
                                      batch_size=args.batch_size)
    test_datastream = get_datastream(path=args.data_path,
                                     which_set=args.test_dataset,
                                     batch_size=args.batch_size)

    #################
예제 #2
0
    ##############
    # print args #
    ##############
    print(args)

    sw = StopWatch()

    if not args.no_copy:
        print('Loading data streams from {}'.format(args.data_path))
        print('Copying data to local machine...')
        rsync = Rsync(args.tmpdir)
        rsync.sync(args.data_path)
        args.data_path = os.path.join(args.tmpdir,
                                      os.path.basename(args.data_path))
        sw.print_elapsed()

    ####################
    # load data stream #
    ####################
    train_datastream = get_datastream(path=args.data_path,
                                      which_set=args.train_dataset,
                                      batch_size=args.batch_size)
    valid_datastream = get_datastream(path=args.data_path,
                                      which_set=args.valid_dataset,
                                      batch_size=args.batch_size)
    test_datastream = get_datastream(path=args.data_path,
                                     which_set=args.test_dataset,
                                     batch_size=args.batch_size)

    #################
예제 #3
0
def main():
    print(' '.join(sys.argv))

    args = get_args()
    print(args)

    print('Hostname: {}'.format(socket.gethostname()))
    print('GPU: {}'.format(get_gpuname()))

    if not args.start_from_ckpt:
        if tf.gfile.Exists(args.log_dir):
            tf.gfile.DeleteRecursively(args.log_dir)
        tf.gfile.MakeDirs(args.log_dir)

    tf.get_variable_scope()._reuse = None

    _seed = args.base_seed + args.add_seed
    tf.set_random_seed(_seed)
    np.random.seed(_seed)

    prefix_name = os.path.join(args.log_dir, 'model')
    file_name = '%s.npz' % prefix_name

    eval_summary = OrderedDict()

    tg, sg = build_graph(args)
    tg_ml_cost = tf.reduce_mean(tg.ml_cost)

    global_step = tf.Variable(0, trainable=False, name="global_step")

    tvars = tf.trainable_variables()
    print([tvar.name for tvar in tvars])

    ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate,
                                         beta1=0.9,
                                         beta2=0.99)
    rl_opt_func = tf.train.AdamOptimizer(learning_rate=args.rl_learning_rate,
                                         beta1=0.9,
                                         beta2=0.99)

    if args.grad_clip:
        ml_grads, _ = tf.clip_by_global_norm(tf.gradients(tg_ml_cost, tvars),
                                             clip_norm=1.0)
    else:
        ml_grads = tf.gradients(tg_ml_cost, tvars)
    ml_op = ml_opt_func.apply_gradients(zip(ml_grads, tvars),
                                        global_step=global_step)

    tg_rl_cost = tf.reduce_mean(tg.rl_cost)
    rl_grads = tf.gradients(tg_rl_cost, tvars)
    # do not increase global step -- ml op increases it
    rl_op = rl_opt_func.apply_gradients(zip(rl_grads, tvars))

    tf.add_to_collection('n_fast_action', args.n_fast_action)

    sync_data(args)
    datasets = [args.train_dataset, args.valid_dataset, args.test_dataset]
    train_set, valid_set, test_set = [
        create_ivector_datastream(path=args.data_path,
                                  which_set=dataset,
                                  batch_size=args.n_batch,
                                  min_after_cache=args.min_after_cache,
                                  length_sort=not args.no_length_sort)
        for dataset in datasets
    ]

    init_op = tf.global_variables_initializer()
    save_op = tf.train.Saver(max_to_keep=5)
    best_save_op = tf.train.Saver(max_to_keep=5)

    with tf.name_scope("per_step_eval"):
        tr_ce = tf.placeholder(tf.float32)
        tr_ce_summary = tf.summary.scalar("tr_ce", tr_ce)
        tr_image = tf.placeholder(tf.float32)
        tr_image_summary = tf.summary.image("tr_image", tr_image)
        tr_fer = tf.placeholder(tf.float32)
        tr_fer_summary = tf.summary.scalar("tr_fer", tr_fer)
        tr_rl = tf.placeholder(tf.float32)
        tr_rl_summary = tf.summary.scalar("tr_rl", tr_rl)
        tr_rw_hist = tf.placeholder(tf.float32)
        tr_rw_hist_summary = tf.summary.histogram("tr_reward_hist", tr_rw_hist)

    with tf.name_scope("per_epoch_eval"):
        best_val_ce = tf.placeholder(tf.float32)
        val_ce = tf.placeholder(tf.float32)
        best_val_ce_summary = tf.summary.scalar("best_valid_ce", best_val_ce)
        val_ce_summary = tf.summary.scalar("valid_ce", val_ce)

    vf = LinearVF()

    with tf.Session() as sess:
        sess.run(init_op)

        if args.start_from_ckpt:
            save_op = tf.train.import_meta_graph(
                os.path.join(args.log_dir, 'model.ckpt.meta'))
            save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt'))
            print(
                "Restore from the last checkpoint. Restarting from %d step." %
                global_step.eval())

        summary_writer = tf.summary.FileWriter(args.log_dir,
                                               sess.graph,
                                               flush_secs=5.0)

        tr_ce_sum = 0.
        tr_ce_count = 0
        tr_acc_sum = 0
        tr_acc_count = 0
        tr_rl_costs = []
        tr_action_entropies = []
        tr_rewards = []

        _best_score = np.iinfo(np.int32).max

        epoch_sw = StopWatch()
        disp_sw = StopWatch()
        eval_sw = StopWatch()
        per_sw = StopWatch()

        # For each epoch
        for _epoch in xrange(args.n_epoch):
            _n_exp = 0

            epoch_sw.reset()
            disp_sw.reset()

            print('Epoch {} training'.format(_epoch + 1))

            # For each batch
            for batch in train_set.get_epoch_iterator():
                x, x_mask, _, _, y, _ = batch
                n_batch = x.shape[0]
                _n_exp += n_batch

                new_x, new_y, actions_1hot, rewards, action_entropies, new_x_mask, new_reward_mask, output_image = \
                        gen_episode_with_seg_reward(x, x_mask, y, sess, sg, args)

                advantages, _ = compute_advantage(new_x, new_x_mask, rewards,
                                                  new_reward_mask, vf, args)

                zero_state = gen_zero_state(n_batch, args.n_hidden)

                feed_dict = {
                    tg.seq_x_data: new_x,
                    tg.seq_x_mask: new_x_mask,
                    tg.seq_y_data: new_y,
                    tg.seq_action: actions_1hot,
                    tg.seq_advantage: advantages,
                    tg.seq_action_mask: new_reward_mask,
                    tg.seq_y_data_for_action: new_y
                }
                feed_init_state(feed_dict, tg.init_state, zero_state)

                _tr_ml_cost, _tr_rl_cost, _, _, pred_idx = \
                    sess.run([tg.ml_cost, tg.rl_cost, ml_op, rl_op, tg.pred_idx], feed_dict=feed_dict)

                tr_ce_sum += _tr_ml_cost.sum()
                tr_ce_count += new_x_mask.sum()

                pred_idx = expand_output(actions_1hot, x_mask, new_x_mask,
                                         pred_idx.reshape([n_batch, -1]),
                                         args.n_fast_action)
                tr_acc_sum += ((pred_idx == y) * x_mask).sum()
                tr_acc_count += x_mask.sum()

                _tr_ce_summary, _tr_fer_summary, _tr_rl_summary, _tr_image_summary, _tr_rw_hist_summary = \
                    sess.run([tr_ce_summary, tr_fer_summary, tr_rl_summary, tr_image_summary, tr_rw_hist_summary],
                        feed_dict={tr_ce: _tr_ml_cost.sum() / new_x_mask.sum(),
                            tr_fer: ((pred_idx == y) * x_mask).sum() / x_mask.sum(),
                            tr_rl: _tr_rl_cost.sum() / new_reward_mask.sum(),
                            tr_image: output_image, tr_rw_hist: rewards})
                summary_writer.add_summary(_tr_ce_summary, global_step.eval())
                summary_writer.add_summary(_tr_fer_summary, global_step.eval())
                summary_writer.add_summary(_tr_rl_summary, global_step.eval())
                summary_writer.add_summary(_tr_image_summary,
                                           global_step.eval())
                summary_writer.add_summary(_tr_rw_hist_summary,
                                           global_step.eval())

                tr_rl_costs.append(_tr_rl_cost.sum() / new_reward_mask.sum())
                tr_action_entropies.append(action_entropies.sum() /
                                           new_reward_mask.sum())
                tr_rewards.append(rewards.sum() / new_reward_mask.sum())

                if global_step.eval() % args.display_freq == 0:
                    avg_tr_ce = tr_ce_sum / tr_ce_count
                    avg_tr_fer = 1. - tr_acc_sum / tr_acc_count
                    avg_tr_rl_cost = np.asarray(tr_rl_costs).mean()
                    avg_tr_action_entropy = np.asarray(
                        tr_action_entropies).mean()
                    avg_tr_reward = np.asarray(tr_rewards).mean()

                    print(
                        "TRAIN: epoch={} iter={} ml_cost(ce/frame)={:.2f} fer={:.2f} rl_cost={:.4f} reward={:.4f} action_entropy={:.2f} time_taken={:.2f}"
                        .format(_epoch, global_step.eval(), avg_tr_ce,
                                avg_tr_fer, avg_tr_rl_cost, avg_tr_reward,
                                avg_tr_action_entropy, disp_sw.elapsed()))

                    tr_ce_sum = 0.
                    tr_ce_count = 0
                    tr_acc_sum = 0.
                    tr_acc_count = 0
                    tr_rl_costs = []
                    tr_action_entropies = []
                    tr_rewards = []

                    disp_sw.reset()

            print('--')
            print('End of epoch {}'.format(_epoch + 1))
            epoch_sw.print_elapsed()

            print('Testing')

            # Evaluate the model on the validation set
            val_ce_sum = 0.
            val_ce_count = 0
            val_acc_sum = 0
            val_acc_count = 0
            val_rl_costs = []
            val_action_entropies = []
            val_rewards = []

            eval_sw.reset()
            for batch in valid_set.get_epoch_iterator():
                x, x_mask, _, _, y, _ = batch
                n_batch = x.shape[0]

                new_x, new_y, actions_1hot, rewards, action_entropies, new_x_mask, new_reward_mask, output_image, new_y_sample = \
                        gen_episode_with_seg_reward(x, x_mask, y, sess, sg, args, sample_y=True)

                advantages, _ = compute_advantage(new_x, new_x_mask, rewards,
                                                  new_reward_mask, vf, args)

                zero_state = gen_zero_state(n_batch, args.n_hidden)

                feed_dict = {
                    tg.seq_x_data: new_x,
                    tg.seq_x_mask: new_x_mask,
                    tg.seq_y_data: new_y,
                    tg.seq_action: actions_1hot,
                    tg.seq_advantage: advantages,
                    tg.seq_action_mask: new_reward_mask,
                    tg.seq_y_data_for_action: new_y_sample
                }
                feed_init_state(feed_dict, tg.init_state, zero_state)

                _val_ml_cost, _val_rl_cost, pred_idx = sess.run(
                    [tg.ml_cost, tg.rl_cost, tg.pred_idx], feed_dict=feed_dict)

                val_ce_sum += _val_ml_cost.sum()
                val_ce_count += new_x_mask.sum()

                pred_idx = expand_output(actions_1hot, x_mask, new_x_mask,
                                         pred_idx.reshape([n_batch, -1]),
                                         args.n_fast_action)
                val_acc_sum += ((pred_idx == y) * x_mask).sum()
                val_acc_count += x_mask.sum()

                val_rl_costs.append(_val_rl_cost.sum() / new_reward_mask.sum())
                val_action_entropies.append(action_entropies.sum() /
                                            new_reward_mask.sum())
                val_rewards.append(rewards.sum() / new_reward_mask.sum())

            avg_val_ce = val_ce_sum / val_ce_count
            avg_val_fer = 1. - val_acc_sum / val_acc_count
            avg_val_rl_cost = np.asarray(val_rl_costs).mean()
            avg_val_action_entropy = np.asarray(val_action_entropies).mean()
            avg_val_reward = np.asarray(val_rewards).mean()

            print(
                "VALID: epoch={} ml_cost(ce/frame)={:.2f} fer={:.2f} rl_cost={:.4f} reward={:.4f} action_entropy={:.2f} time_taken={:.2f}"
                .format(_epoch, avg_val_ce, avg_val_fer, avg_val_rl_cost,
                        avg_val_reward, avg_val_action_entropy,
                        eval_sw.elapsed()))

            _val_ce_summary, = sess.run([val_ce_summary],
                                        feed_dict={val_ce: avg_val_ce})
            summary_writer.add_summary(_val_ce_summary, global_step.eval())

            insert_item2dict(eval_summary, 'val_ce', avg_val_ce)
            insert_item2dict(eval_summary, 'val_rl_cost', avg_val_rl_cost)
            insert_item2dict(eval_summary, 'val_reward', avg_val_reward)
            insert_item2dict(eval_summary, 'val_action_entropy',
                             avg_val_action_entropy)
            insert_item2dict(eval_summary, 'time', eval_sw.elapsed())
            save_npz2(file_name, eval_summary)

            # Save model
            if avg_val_ce < _best_score:
                _best_score = avg_val_ce
                best_ckpt = best_save_op.save(sess,
                                              os.path.join(
                                                  args.log_dir,
                                                  "best_model.ckpt"),
                                              global_step=global_step)
                print("Best checkpoint stored in: %s" % best_ckpt)
            ckpt = save_op.save(sess,
                                os.path.join(args.log_dir, "model.ckpt"),
                                global_step=global_step)
            print("Checkpoint stored in: %s" % ckpt)

            _best_val_ce_summary, = sess.run(
                [best_val_ce_summary], feed_dict={best_val_ce: _best_score})
            summary_writer.add_summary(_best_val_ce_summary,
                                       global_step.eval())
        summary_writer.close()

        print("Optimization Finished.")
예제 #4
0
def main(_):
  print(' '.join(sys.argv))
  args = FLAGS
  print(args.__flags)
  if not args.start_from_ckpt:
    if tf.gfile.Exists(args.log_dir):
      tf.gfile.DeleteRecursively(args.log_dir)
    tf.gfile.MakeDirs(args.log_dir)

  tf.get_variable_scope()._reuse = None

  _seed = args.base_seed + args.add_seed
  tf.set_random_seed(_seed)
  np.random.seed(_seed)

  prefix_name = os.path.join(args.log_dir, 'model')
  file_name = '%s.npz' % prefix_name

  eval_summary = OrderedDict() #

  tg = build_graph(args)
  tg_ml_cost = tf.reduce_mean(tg.ml_cost)

  global_step = tf.Variable(0, trainable=False, name="global_step")

  tvars = tf.trainable_variables()

  ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate,
                                       beta1=0.9, beta2=0.99)

  if args.grad_clip:
    ml_grads, _ = tf.clip_by_global_norm(tf.gradients(tg_ml_cost, tvars),
                                      clip_norm=1.0)
  else:
    ml_grads = tf.gradients(tg_ml_cost, tvars)
  ml_op = ml_opt_func.apply_gradients(zip(ml_grads, tvars), global_step=global_step)

  sync_data(args)
  datasets = [args.train_dataset, args.valid_dataset, args.test_dataset]
  train_set, valid_set, test_set = [create_ivector_datastream(path=args.data_path, which_set=dataset,
      batch_size=args.batch_size, min_after_cache=args.min_after_cache, length_sort=not args.no_length_sort) for dataset in datasets]

  init_op = tf.global_variables_initializer()
  save_op = tf.train.Saver(max_to_keep=5)
  best_save_op = tf.train.Saver(max_to_keep=5)

  with tf.name_scope("per_step_eval"):
    tr_ce = tf.placeholder(tf.float32)
    tr_ce_summary = tf.summary.scalar("tr_ce", tr_ce)

  with tf.name_scope("per_epoch_eval"):
    best_val_ce = tf.placeholder(tf.float32)
    val_ce = tf.placeholder(tf.float32)
    best_val_ce_summary = tf.summary.scalar("best_valid_ce", best_val_ce)
    val_ce_summary = tf.summary.scalar("valid_ce", val_ce)

  with tf.Session() as sess:
    sess.run(init_op)

    if args.start_from_ckpt:
      save_op = tf.train.import_meta_graph(os.path.join(args.log_dir,
                                                        'model.ckpt.meta'))
      save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt'))
      print("Restore from the last checkpoint. "
            "Restarting from %d step." % global_step.eval())

    summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph, flush_secs=5.0)

    tr_ce_sum = 0.
    tr_ce_count = 0
    tr_acc_sum = 0
    tr_acc_count = 0
    _best_score = np.iinfo(np.int32).max

    epoch_sw = StopWatch()
    disp_sw = StopWatch()
    eval_sw = StopWatch()
    per_sw = StopWatch()
    # For each epoch
    for _epoch in xrange(args.n_epoch):
      _n_exp = 0

      epoch_sw.reset()
      disp_sw.reset()

      print('--')
      print('Epoch {} training'.format(_epoch+1))

      # For each batch
      for batch in train_set.get_epoch_iterator():
        orig_x, orig_x_mask, _, _, orig_y, _ = batch

        # Get skipped frames
        for sub_batch in skip_frames_fixed([orig_x, orig_x_mask, orig_y], args.n_skip+1):
            x, x_mask, y = sub_batch
            n_batch, _, _ = x.shape

            _feed_states = initial_states(n_batch, args.n_hidden)

            _tr_ml_cost, _seq_logit,  _ = sess.run([tg.ml_cost,
                                                    tg.seq_logit,
                                                    ml_op],
                                                    feed_dict={tg.seq_x_data: x,
                                                               tg.seq_x_mask: x_mask,
                                                               tg.seq_y_data: y,
                                                               tg.init_state: _feed_states})

            tr_ce_sum += _tr_ml_cost.sum()
            tr_ce_count += x_mask.sum()
            _tr_ce_summary, = sess.run([tr_ce_summary], feed_dict={tr_ce: _tr_ml_cost.sum() / x_mask.sum()})
            summary_writer.add_summary(_tr_ce_summary, global_step.eval())

            _, n_seq = orig_y.shape
            _expand_seq_logit = interpolate_feat(_seq_logit, num_skips=args.n_skip+1, axis=1, use_bidir=True)
            _pred_idx = _expand_seq_logit.argmax(axis=2)

            tr_acc_sum += ((_pred_idx == orig_y) * orig_x_mask).sum()
            tr_acc_count += orig_x_mask.sum()

        if global_step.eval() % args.display_freq == 0:
          avg_tr_ce = tr_ce_sum / tr_ce_count
          avg_tr_fer = 1. - float(tr_acc_sum) / tr_acc_count

          print("TRAIN: epoch={} iter={} ml_cost(ce/frame)={:.2f} fer={:.2f} time_taken={:.2f}".format(
              _epoch, global_step.eval(), avg_tr_ce, avg_tr_fer, disp_sw.elapsed()))

          tr_ce_sum = 0.
          tr_ce_count = 0
          tr_acc_sum = 0
          tr_acc_count = 0
          disp_sw.reset()

      print('--')
      print('End of epoch {}'.format(_epoch+1))
      epoch_sw.print_elapsed()

      print('Testing')

      # Evaluate the model on the validation set
      val_ce_sum = 0.
      val_ce_count = 0
      val_acc_sum = 0
      val_acc_count = 0

      eval_sw.reset()
      for batch in valid_set.get_epoch_iterator():
        orig_x, orig_x_mask, _, _, orig_y, _ = batch

        for sub_batch in skip_frames_fixed([orig_x, orig_x_mask, orig_y], args.n_skip+1, return_first=True):
            x, x_mask, y = sub_batch
            n_batch, _, _ = x.shape

            _feed_states = initial_states(n_batch, args.n_hidden)

            _val_ml_cost, _seq_logit = sess.run([tg.ml_cost,
                                                 tg.seq_logit,],
                                                feed_dict={tg.seq_x_data: x,
                                                           tg.seq_x_mask: x_mask,
                                                           tg.seq_y_data: y,
                                                           tg.init_state: _feed_states})

            val_ce_sum += _val_ml_cost.sum()
            val_ce_count += x_mask.sum()

            _, n_seq = orig_y.shape
            _expand_seq_logit = interpolate_feat(_seq_logit, num_skips=args.n_skip+1, axis=1, use_bidir=True)
            _pred_idx = _expand_seq_logit.argmax(axis=2)

            val_acc_sum += ((_pred_idx == orig_y) * orig_x_mask).sum()
            val_acc_count += orig_x_mask.sum()

      avg_val_ce = val_ce_sum / val_ce_count
      avg_val_fer = 1. - float(val_acc_sum) / val_acc_count

      print("VALID: epoch={} ml_cost(ce/frame)={:.2f} fer={:.2f} time_taken={:.2f}".format(
          _epoch, avg_val_ce, avg_val_fer, eval_sw.elapsed()))

      _val_ce_summary, = sess.run([val_ce_summary], feed_dict={val_ce: avg_val_ce})
      summary_writer.add_summary(_val_ce_summary, global_step.eval())

      insert_item2dict(eval_summary, 'val_ce', avg_val_ce)
      insert_item2dict(eval_summary, 'time', eval_sw.elapsed())
      save_npz2(file_name, eval_summary)

      # Save model
      if avg_val_ce < _best_score:
        _best_score = avg_val_ce
        best_ckpt = best_save_op.save(sess, os.path.join(args.log_dir,
                                                         "best_model.ckpt"),
                                      global_step=global_step)
        print("Best checkpoint stored in: %s" % best_ckpt)
      ckpt = save_op.save(sess, os.path.join(args.log_dir, "model.ckpt"),
                          global_step=global_step)
      print("Checkpoint stored in: %s" % ckpt)

      _best_val_ce_summary, = sess.run([best_val_ce_summary],
                                feed_dict={best_val_ce: _best_score})
      summary_writer.add_summary(_best_val_ce_summary, global_step.eval())
    summary_writer.close()

    print("Optimization Finished.")
예제 #5
0
                        [s.s for s in tr_summary],
                        feed_dict={
                            tr_summary.ce.ph: ce.avg(),
                            tr_summary.cr.ph: cr.avg(),
                            tr_summary.image.ph: output_image
                        })
                    for s in summaries:
                        summary_writer.add_summary(s, global_step.eval())

                    for accu in accu_list:
                        accu.reset()
                    disp_sw.reset()

            print('--')
            print('End of epoch {}'.format(_epoch))
            epoch_sw.print_elapsed()

            print('Testing')

            # Evaluate the model on the validation set
            for accu in accu_list:
                accu.reset()
            eval_sw.reset()
            for batch in valid_set.get_epoch_iterator():
                orig_x, orig_x_mask, _, _, orig_y, _ = batch

                sub_batch, start_idx = utils.skip_frames_fixed(
                    [orig_x, orig_x_mask, orig_y],
                    args.n_skip + 1,
                    return_first=True,
                    return_start_idx=True)
예제 #6
0
파일: train.py 프로젝트: gunkisu/asr
def train_model():
    sw = StopWatch()

    # Fix random seeds
    rand_seed = FLAGS.base_seed + FLAGS.add_seed
    tf.set_random_seed(rand_seed)
    np.random.seed(rand_seed)

    # Get module graph
    model_graph = build_graph(FLAGS)

    # Get model parameter
    model_param = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

    # Set weight decay
    if FLAGS.weight_decay > 0.0:
        l2_cost = tf.add_n([
            0.5 * tf.nn.l2_loss(W) for W in model_param if 'W' in W.name
            and 'action' not in W.name and 'baseline' not in W.name
        ])
        l2_cost *= FLAGS.weight_decay
    else:
        l2_cost = 0.0

    # Set total cost
    model_total_cost = model_graph.ml_cost + model_graph.rl_cost + model_graph.bl_cost + l2_cost

    # Define global training step
    global_step = tf.contrib.framework.get_or_create_global_step()

    # Set ml optimizer (Adam optimizer, in the original paper, we use 0.99 for beta2
    model_opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate,
                                       name='model_optimizer')
    model_grad = tf.gradients(ys=model_total_cost,
                              xs=model_param,
                              aggregation_method=2)

    # Set gradient clipping
    if FLAGS.grad_clip > 0.0:
        model_grad, _ = tf.clip_by_global_norm(t_list=model_grad,
                                               clip_norm=FLAGS.grad_clip)

    model_update = model_opt.apply_gradients(grads_and_vars=zip(
        model_grad, model_param),
                                             global_step=global_step)

    # Set dataset (sync_data(FLAGS))
    datasets = [FLAGS.train_dataset, FLAGS.valid_dataset, FLAGS.test_dataset]
    train_set, valid_set, test_set = [
        create_ivector_datastream(path=FLAGS.data_path,
                                  which_set=dataset,
                                  batch_size=FLAGS.batch_size)
        for dataset in datasets
    ]

    # Set variable initializer
    init_op = tf.global_variables_initializer()

    # Set last model saver
    last_save_op = tf.train.Saver(max_to_keep=5)

    # Set best model saver
    best_save_op = tf.train.Saver(max_to_keep=5)

    # Get hardware config
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    # Set session
    with tf.Session(config=config) as sess:
        # Get summary
        merged_summary = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train',
                                             sess.graph)

        # Init model
        sess.run(init_op)

        # Load checkpoint
        if FLAGS.start_from_ckpt:
            last_save_op = tf.train.import_meta_graph(
                os.path.join(FLAGS.log_dir, 'last_model.ckpt.meta'))
            last_save_op.restore(
                sess, os.path.join(FLAGS.log_dir, 'last_model.ckpt'))
            print(
                "Restore from the last checkpoint. Restarting from %d step." %
                global_step.eval())

        # For each epoch
        accr_history = []
        loss_history = []
        ml_cost_history = []
        rl_cost_history = []
        bl_cost_history = []
        sum_cost_history = []

        best_accr = 0.0
        sw.reset()
        for e_idx in xrange(FLAGS.n_epoch):
            # for each batch (update)
            for b_idx, batch_data in enumerate(train_set.get_epoch_iterator()):
                # Get data x, y
                x_data, x_mask, _, _, y_data, _ = batch_data

                # Roll axis
                x_data = x_data.transpose((1, 0, 2))
                x_mask = x_mask.transpose((1, 0))
                y_data = y_data.transpose((1, 0))

                # Update model
                mean_accr, mean_loss, ml_cost, rl_cost, bl_cost, read_ratio, summary_output \
                    = updater(model_graph=model_graph,
                              model_updater=model_update,
                              x_data=x_data,
                              x_mask=x_mask,
                              y_data=y_data,
                              summary=merged_summary,
                              session=sess)

                # write summary
                train_writer.add_summary(summary_output, global_step.eval())

                accr_history.append(mean_accr)
                loss_history.append(mean_loss)
                ml_cost_history.append(ml_cost)
                rl_cost_history.append(rl_cost)
                bl_cost_history.append(bl_cost)
                sum_cost_history.append(ml_cost + rl_cost + bl_cost)

                # Display results
                if global_step.eval() % FLAGS.display_freq == 0:
                    mean_accr = np.array(accr_history).mean()
                    mean_loss = np.array(loss_history).mean()
                    mean_ml_cost = np.array(ml_cost_history).mean()
                    mean_rl_cost = np.array(rl_cost_history).mean()
                    mean_bl_cost = np.array(bl_cost_history).mean()
                    mean_sum_cost = np.array(sum_cost_history).mean()
                    print(
                        "====================================================")
                    print("Epoch " + str(e_idx) + ", Total Iter " +
                          str(global_step.eval()))
                    print(
                        "----------------------------------------------------")
                    print("Average FER: {:.2f}%".format(
                        (1.0 - mean_accr) * 100))
                    print("Average CCE: {:.6f}".format(mean_loss))
                    print("Average  ML: {:.6f}".format(mean_ml_cost))
                    if FLAGS.use_skim:
                        print("Average  RL: {:.6f}".format(mean_rl_cost))
                        print("Average  BL: {:.6f}".format(mean_bl_cost))
                        print("Average SUM: {:.6f}".format(mean_sum_cost))
                    print("Read ratio: ", read_ratio)
                    sw.print_elapsed()
                    sw.reset()
                    last_ckpt = last_save_op.save(sess,
                                                  os.path.join(
                                                      FLAGS.log_dir,
                                                      "last_model.ckpt"),
                                                  global_step=global_step)
                    print("Last checkpointed in: %s" % last_ckpt)
                    accr_history = []
                    loss_history = []
                    ml_cost_history = []
                    rl_cost_history = []
                    bl_cost_history = []
                    sum_cost_history = []

                # Evaluate model
                if global_step.eval() % FLAGS.evaluation_freq == 0:
                    # Monitor validation loss, accr
                    valid_accr, valid_cce = evaluation(model_graph=model_graph,
                                                       session=sess,
                                                       dataset=valid_set)
                    # Save model
                    if best_accr < valid_accr:
                        best_accr = valid_accr
                        best_ckpt = best_save_op.save(sess,
                                                      os.path.join(
                                                          FLAGS.log_dir,
                                                          "best_model.ckpt"),
                                                      global_step=global_step)
                        print("Best checkpoint stored in: %s" % best_ckpt)

                    print(
                        "----------------------------------------------------")
                    print("Validation evaluation")
                    print(
                        "----------------------------------------------------")
                    print("FER: {:.2f}%".format((1.0 - valid_accr) * 100.))
                    print("CCE: {:.6f}".format(valid_cce))
                    print(
                        "----------------------------------------------------")
                    print("Best FER: {:.2f}%".format((1.0 - best_accr) * 100.))

        print("Optimization Finished.")
예제 #7
0
def main(_):
    # Print settings
    print(' '.join(sys.argv))
    args = FLAGS
    for k, v in args.__flags.iteritems():
        print(k, v)

    # Load checkpoint
    if not args.start_from_ckpt:
        if tf.gfile.Exists(args.log_dir):
            tf.gfile.DeleteRecursively(args.log_dir)
        tf.gfile.MakeDirs(args.log_dir)

    # ???
    tf.get_variable_scope()._reuse = None

    # Set random seed
    _seed = args.base_seed + args.add_seed
    tf.set_random_seed(_seed)
    np.random.seed(_seed)

    # Set save file name
    prefix_name = os.path.join(args.log_dir, 'model')
    file_name = '%s.npz' % prefix_name

    # Set evaluation summary
    eval_summary = OrderedDict()

    # Build model graph
    tg, sg = build_graph(args)

    # Set linear regressor for baseline
    vf = LinearVF()

    # Set global step
    global_step = tf.Variable(0, trainable=False, name="global_step")

    # Get ml/rl related parameters
    tvars = tf.trainable_variables()
    ml_vars = [tvar for tvar in tvars if "action" not in tvar.name]
    rl_vars = [tvar for tvar in tvars if "action" in tvar.name]

    # Set optimizer
    ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    rl_opt_func = tf.train.AdamOptimizer(learning_rate=args.rl_learning_rate)

    # Set model ml cost (sum over all and divide it by batch_size)
    ml_cost = tf.reduce_sum(tg.seq_ml_cost)
    ml_cost /= tf.to_float(tf.shape(tg.seq_x_data)[0])
    ml_cost += args.ml_l2 * 0.5 * tf.add_n(
        [tf.reduce_sum(tf.square(var)) for var in ml_vars])

    # Set model rl cost (sum over all and divide it by batch_size, also entropy cost)
    rl_cost = tg.seq_rl_cost - args.ent_weight * tg.seq_a_ent
    rl_cost = tf.reduce_sum(rl_cost)
    rl_cost /= tf.to_float(tf.shape(tg.seq_x_data)[0])

    # Set model rl cost (sum over all and divide it by batch_size, also entropy cost)
    real_rl_cost = tf.reduce_sum(tg.seq_real_rl_cost)
    real_rl_cost /= tf.reduce_sum(tg.seq_a_mask)

    # Gradient clipping for ML
    ml_grads = tf.gradients(ml_cost, ml_vars)
    if args.grad_clip > 0.0:
        ml_grads, _ = tf.clip_by_global_norm(t_list=ml_grads,
                                             clip_norm=args.grad_clip)

    # Gradient for RL
    rl_grads = tf.gradients(rl_cost, rl_vars)

    # ML optimization
    ml_op = ml_opt_func.apply_gradients(grads_and_vars=zip(ml_grads, ml_vars),
                                        global_step=global_step,
                                        name='ml_op')

    # RL optimization
    rl_op = rl_opt_func.apply_gradients(grads_and_vars=zip(rl_grads, rl_vars),
                                        global_step=global_step,
                                        name='rl_op')

    # Sync dataset
    sync_data(args)

    # Get dataset
    train_set = create_ivector_datastream(path=args.data_path,
                                          which_set=args.train_dataset,
                                          batch_size=args.batch_size,
                                          min_after_cache=args.min_after_cache,
                                          length_sort=not args.no_length_sort)
    valid_set = create_ivector_datastream(path=args.data_path,
                                          which_set=args.valid_dataset,
                                          batch_size=args.batch_size,
                                          min_after_cache=args.min_after_cache,
                                          length_sort=not args.no_length_sort)

    # Set param init op
    init_op = tf.global_variables_initializer()

    # Set save op
    save_op = tf.train.Saver(max_to_keep=5)
    best_save_op = tf.train.Saver(max_to_keep=5)

    # Set per-step logging
    with tf.name_scope("per_step_eval"):
        # For ML cost (ce)
        tr_ce = tf.placeholder(tf.float32)
        tr_ce_summary = tf.summary.scalar("train_ml_cost", tr_ce)

        # For output visualization
        tr_image = tf.placeholder(tf.float32)
        tr_image_summary = tf.summary.image("train_image", tr_image)

        # For ML FER
        tr_fer = tf.placeholder(tf.float32)
        tr_fer_summary = tf.summary.scalar("train_fer", tr_fer)

        # For RL cost
        tr_rl = tf.placeholder(tf.float32)
        tr_rl_summary = tf.summary.scalar("train_rl", tr_rl)

        # For RL reward
        tr_reward = tf.placeholder(tf.float32)
        tr_reward_summary = tf.summary.scalar("train_reward", tr_reward)

        # For RL entropy
        tr_ent = tf.placeholder(tf.float32)
        tr_ent_summary = tf.summary.scalar("train_entropy", tr_ent)

        # For RL reward histogram
        tr_rw_hist = tf.placeholder(tf.float32)
        tr_rw_hist_summary = tf.summary.histogram("train_reward_hist",
                                                  tr_rw_hist)

        # For RL skip count
        tr_skip_cnt = tf.placeholder(tf.float32)
        tr_skip_cnt_summary = tf.summary.scalar("train_skip_cnt", tr_skip_cnt)

    # Set per-epoch logging
    with tf.name_scope("per_epoch_eval"):
        # For best valid ML cost (full)
        best_val_ce = tf.placeholder(tf.float32)
        best_val_ce_summary = tf.summary.scalar("best_valid_ce", best_val_ce)

        # For best valid FER
        best_val_fer = tf.placeholder(tf.float32)
        best_val_fer_summary = tf.summary.scalar("best_valid_fer",
                                                 best_val_fer)

        # For valid ML cost (full)
        val_ce = tf.placeholder(tf.float32)
        val_ce_summary = tf.summary.scalar("valid_ce", val_ce)

        # For valid FER
        val_fer = tf.placeholder(tf.float32)
        val_fer_summary = tf.summary.scalar("valid_fer", val_fer)

        # For output visualization
        val_image = tf.placeholder(tf.float32)
        val_image_summary = tf.summary.image("valid_image", val_image)

        # For RL skip count
        val_skip_cnt = tf.placeholder(tf.float32)
        val_skip_cnt_summary = tf.summary.scalar("valid_skip_cnt",
                                                 val_skip_cnt)

    # Set module
    gen_episodes = improve_skip_rnn_act_parallel

    # Init session
    with tf.Session() as sess:
        # Init model
        sess.run(init_op)

        # Load from checkpoint
        if args.start_from_ckpt:
            save_op = tf.train.import_meta_graph(
                os.path.join(args.log_dir, 'model.ckpt.meta'))
            save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt'))
            print(
                "Restore from the last checkpoint. Restarting from %d step." %
                global_step.eval())

        # Summary writer
        summary_writer = tf.summary.FileWriter(args.log_dir,
                                               sess.graph,
                                               flush_secs=5.0)

        # For train tracking
        tr_ce_sum = 0.
        tr_ce_count = 0
        tr_acc_sum = 0.
        tr_acc_count = 0
        tr_rl_sum = 0.
        tr_rl_count = 0
        tr_ent_sum = 0.
        tr_ent_count = 0
        tr_reward_sum = 0.
        tr_reward_count = 0
        tr_skip_sum = 0.
        tr_skip_count = 0

        _best_ce = np.iinfo(np.int32).max
        _best_fer = 1.00

        # For time measure
        epoch_sw = StopWatch()
        disp_sw = StopWatch()
        eval_sw = StopWatch()

        # For each epoch
        for _epoch in xrange(args.n_epoch):
            # Reset timer
            epoch_sw.reset()
            disp_sw.reset()
            print('Epoch {} training'.format(_epoch + 1))

            # Set rl skipping flag
            use_rl_skipping = True

            # For each batch (update)
            for batch_data in train_set.get_epoch_iterator():
                ##################
                # Sampling Phase #
                ##################
                # Get batch data
                seq_x_data, seq_x_mask, _, _, seq_y_data, _ = batch_data

                # Use skipping
                if use_rl_skipping:
                    # Transpose axis
                    seq_x_data = np.transpose(seq_x_data, (1, 0, 2))
                    seq_x_mask = np.transpose(seq_x_mask, (1, 0))
                    seq_y_data = np.transpose(seq_y_data, (1, 0))

                    # Number of samples
                    batch_size = seq_x_data.shape[1]
                    # Sample actions (episode generation)
                    [
                        skip_x_data, skip_h_data, skip_x_mask, skip_y_data,
                        skip_a_data, skip_a_mask, skip_rewards, result_image
                    ] = gen_episodes(seq_x_data=seq_x_data,
                                     seq_x_mask=seq_x_mask,
                                     seq_y_data=seq_y_data,
                                     sess=sess,
                                     sample_graph=sg,
                                     args=args,
                                     use_sampling=True)

                    # Compute skip ratio
                    tr_skip_sum += skip_x_mask.sum() / seq_x_mask.sum()
                    tr_skip_count += 1.0

                    # Compute baseline and refine reward
                    skip_advantage, skip_disc_rewards = compute_advantage(
                        seq_h_data=skip_h_data,
                        seq_r_data=skip_rewards,
                        seq_r_mask=skip_a_mask,
                        vf=vf,
                        args=args,
                        final_cost=args.use_final_reward)

                    if args.use_baseline is False:
                        skip_advantage = skip_disc_rewards

                    ##################
                    # Training Phase #
                    ##################
                    # Update model
                    [
                        _tr_ml_cost, _tr_rl_cost, _, _, _tr_act_ent,
                        _tr_pred_logit
                    ] = sess.run(
                        [
                            ml_cost, real_rl_cost, ml_op, rl_op, tg.seq_a_ent,
                            tg.seq_label_logits
                        ],
                        feed_dict={
                            tg.seq_x_data: skip_x_data,
                            tg.seq_x_mask: skip_x_mask,
                            tg.seq_y_data: skip_y_data,
                            tg.seq_a_data: skip_a_data,
                            tg.seq_a_mask: skip_a_mask,
                            tg.seq_advantage: skip_advantage,
                            tg.seq_reward: skip_disc_rewards
                        })

                    seq_x_mask = np.transpose(seq_x_mask, (1, 0))
                    seq_y_data = np.transpose(seq_y_data, (1, 0))

                    # Get full sequence prediction
                    _tr_pred_full = expand_pred_idx(
                        seq_skip_1hot=skip_a_data,
                        seq_skip_mask=skip_a_mask,
                        seq_prd_idx=_tr_pred_logit.argmax(axis=-1).reshape(
                            [batch_size, -1]),
                        seq_x_mask=seq_y_data)

                    # Update history
                    tr_ce_sum += _tr_ml_cost.sum() * batch_size
                    tr_ce_count += skip_x_mask.sum()

                    tr_acc_sum += ((_tr_pred_full == seq_y_data) *
                                   seq_x_mask).sum()
                    tr_acc_count += seq_x_mask.sum()

                    tr_rl_sum += _tr_rl_cost.sum()
                    tr_rl_count += 1.0

                    tr_ent_sum += _tr_act_ent.sum()
                    tr_ent_count += skip_a_mask.sum()

                    tr_reward_sum += (skip_rewards * skip_a_mask).sum()
                    tr_reward_count += skip_a_mask.sum()

                    ################
                    # Write result #
                    ################
                    [
                        _tr_rl_summary, _tr_image_summary, _tr_ent_summary,
                        _tr_reward_summary, _tr_rw_hist_summary,
                        _tr_skip_cnt_summary
                    ] = sess.run(
                        [
                            tr_rl_summary, tr_image_summary, tr_ent_summary,
                            tr_reward_summary, tr_rw_hist_summary,
                            tr_skip_cnt_summary
                        ],
                        feed_dict={
                            tr_rl:
                            _tr_rl_cost.sum(),
                            tr_image:
                            result_image,
                            tr_ent: (_tr_act_ent.sum() / skip_a_mask.sum()),
                            tr_reward: ((skip_rewards * skip_a_mask).sum() /
                                        skip_a_mask.sum()),
                            tr_rw_hist:
                            skip_rewards,
                            tr_skip_cnt:
                            skip_x_mask.sum() / seq_x_mask.sum()
                        })

                    summary_writer.add_summary(_tr_rl_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_image_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_ent_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_reward_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_rw_hist_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_skip_cnt_summary,
                                               global_step.eval())
                else:
                    # Number of samples
                    batch_size = seq_x_data.shape[0]

                    ##################
                    # Training Phase #
                    ##################
                    [_tr_ml_cost, _, _tr_pred_full] = sess.run(
                        [ml_cost, ml_op, tg.seq_label_logits],
                        feed_dict={
                            tg.seq_x_data: seq_x_data,
                            tg.seq_x_mask: seq_x_mask,
                            tg.seq_y_data: seq_y_data
                        })
                    _tr_pred_full = np.reshape(_tr_pred_full.argmax(axis=1),
                                               seq_y_data.shape)

                    # Update history
                    tr_ce_sum += _tr_ml_cost.sum() * batch_size
                    tr_ce_count += seq_x_mask.sum()

                    tr_acc_sum += ((_tr_pred_full == seq_y_data) *
                                   seq_x_mask).sum()
                    tr_acc_count += seq_x_mask.sum()

                    skip_x_mask = seq_x_mask

                ################
                # Write result #
                ################
                [_tr_ce_summary, _tr_fer_summary] = sess.run(
                    [tr_ce_summary, tr_fer_summary],
                    feed_dict={
                        tr_ce:
                        (_tr_ml_cost.sum() * batch_size) / skip_x_mask.sum(),
                        tr_fer:
                        ((_tr_pred_full == seq_y_data) * seq_x_mask).sum() /
                        seq_x_mask.sum()
                    })
                summary_writer.add_summary(_tr_ce_summary, global_step.eval())
                summary_writer.add_summary(_tr_fer_summary, global_step.eval())

                # Display results
                if global_step.eval() % args.display_freq == 0:
                    # Get average results
                    avg_tr_ce = tr_ce_sum / tr_ce_count
                    avg_tr_fer = 1. - tr_acc_sum / tr_acc_count
                    if use_rl_skipping:
                        avg_tr_rl = tr_rl_sum / tr_rl_count
                        avg_tr_ent = tr_ent_sum / tr_ent_count
                        avg_tr_reward = tr_reward_sum / tr_reward_count
                        avg_tr_skip = tr_skip_sum / tr_skip_count
                        print(
                            "TRAIN: epoch={} iter={} "
                            "ml_cost(ce/frame)={:.2f} fer={:.2f} "
                            "rl_cost={:.4f} reward={:.4f} action_entropy={:.2f} "
                            "skip_ratio={:.2f} "
                            "time_taken={:.2f}".format(_epoch,
                                                       global_step.eval(),
                                                       avg_tr_ce, avg_tr_fer,
                                                       avg_tr_rl,
                                                       avg_tr_reward,
                                                       avg_tr_ent, avg_tr_skip,
                                                       disp_sw.elapsed()))
                    else:
                        print("TRAIN: epoch={} iter={} "
                              "ml_cost(ce/frame)={:.2f} fer={:.2f} "
                              "time_taken={:.2f}".format(
                                  _epoch, global_step.eval(), avg_tr_ce,
                                  avg_tr_fer, disp_sw.elapsed()))

                    # Reset average results
                    tr_ce_sum = 0.
                    tr_ce_count = 0
                    tr_acc_sum = 0.
                    tr_acc_count = 0
                    tr_rl_sum = 0.
                    tr_rl_count = 0
                    tr_ent_sum = 0.
                    tr_ent_count = 0
                    tr_reward_sum = 0.
                    tr_reward_count = 0
                    tr_skip_sum = 0.
                    tr_skip_count = 0

                    disp_sw.reset()

            # End of epoch
            print('--')
            print('End of epoch {}'.format(_epoch + 1))
            epoch_sw.print_elapsed()

            # Evaluation
            print('Testing')

            # Evaluate the model on the validation set
            val_ce_sum = 0.
            val_ce_count = 0
            val_acc_sum = 0.
            val_acc_count = 0
            val_rl_sum = 0.
            val_rl_count = 0
            val_ent_sum = 0.
            val_ent_count = 0
            val_reward_sum = 0.
            val_reward_count = 0
            val_skip_sum = 0.
            val_skip_count = 0
            eval_sw.reset()

            # For each batch in Valid
            for batch_data in valid_set.get_epoch_iterator():
                ##################
                # Sampling Phase #
                ##################
                # Get batch data
                seq_x_data, seq_x_mask, _, _, seq_y_data, _ = batch_data

                if use_rl_skipping:
                    # Transpose axis
                    seq_x_data = np.transpose(seq_x_data, (1, 0, 2))
                    seq_x_mask = np.transpose(seq_x_mask, (1, 0))
                    seq_y_data = np.transpose(seq_y_data, (1, 0))

                    # Number of samples
                    batch_size = seq_x_data.shape[1]

                    # Sample actions (episode generation)
                    [
                        skip_x_data, skip_h_data, skip_x_mask, skip_y_data,
                        skip_a_data, skip_a_mask, skip_rewards, result_image
                    ] = gen_episodes(seq_x_data=seq_x_data,
                                     seq_x_mask=seq_x_mask,
                                     seq_y_data=seq_y_data,
                                     sess=sess,
                                     sample_graph=sg,
                                     args=args,
                                     use_sampling=False)

                    # Compute skip ratio
                    val_skip_sum += skip_x_mask.sum() / seq_x_mask.sum()
                    val_skip_count += 1.0

                    # Compute baseline and refine reward
                    skip_advantage, skip_disc_rewards = compute_advantage(
                        seq_h_data=skip_h_data,
                        seq_r_data=skip_rewards,
                        seq_r_mask=skip_a_mask,
                        vf=vf,
                        args=args,
                        final_cost=args.use_final_reward)

                    if args.use_baseline is False:
                        skip_advantage = skip_disc_rewards

                    #################
                    # Forward Phase #
                    #################
                    [
                        _val_ml_cost, _val_rl_cost, _val_pred_logit,
                        _val_action_ent
                    ] = sess.run(
                        [
                            ml_cost, real_rl_cost, tg.seq_label_logits,
                            tg.seq_a_ent
                        ],
                        feed_dict={
                            tg.seq_x_data: skip_x_data,
                            tg.seq_x_mask: skip_x_mask,
                            tg.seq_y_data: skip_y_data,
                            tg.seq_a_data: skip_a_data,
                            tg.seq_a_mask: skip_a_mask,
                            tg.seq_advantage: skip_advantage,
                            tg.seq_reward: skip_disc_rewards
                        })

                    seq_x_mask = np.transpose(seq_x_mask, (1, 0))
                    seq_y_data = np.transpose(seq_y_data, (1, 0))

                    # Get full sequence prediction
                    _val_pred_full = expand_pred_idx(
                        seq_skip_1hot=skip_a_data,
                        seq_skip_mask=skip_a_mask,
                        seq_prd_idx=_val_pred_logit.argmax(axis=-1).reshape(
                            [batch_size, -1]),
                        seq_x_mask=seq_y_data)

                    # Update history
                    val_ce_sum += _val_ml_cost.sum() * batch_size
                    val_ce_count += skip_x_mask.sum()

                    val_acc_sum += ((_val_pred_full == seq_y_data) *
                                    seq_x_mask).sum()
                    val_acc_count += seq_x_mask.sum()

                    val_rl_sum += _val_rl_cost.sum()
                    val_rl_count += 1.0

                    val_ent_sum += _val_action_ent.sum()
                    val_ent_count += skip_a_mask.sum()

                    val_reward_sum += (skip_rewards * skip_a_mask).sum()
                    val_reward_count += skip_a_mask.sum()
                else:
                    # Number of samples
                    batch_size = seq_x_data.shape[0]

                    #################
                    # Forward Phase #
                    #################
                    # Update model
                    [_val_ml_cost, _val_pred_full] = sess.run(
                        [ml_cost, tg.seq_label_logits],
                        feed_dict={
                            tg.seq_x_data: seq_x_data,
                            tg.seq_x_mask: seq_x_mask,
                            tg.seq_y_data: seq_y_data
                        })
                    _val_pred_full = np.reshape(_val_pred_full.argmax(axis=1),
                                                seq_y_data.shape)

                    # Update history
                    val_ce_sum += _val_ml_cost.sum() * batch_size
                    val_ce_count += seq_x_mask.sum()

                    val_acc_sum += ((_val_pred_full == seq_y_data) *
                                    seq_x_mask).sum()
                    val_acc_count += seq_x_mask.sum()

            # Aggregate over all valid data
            avg_val_ce = val_ce_sum / val_ce_count
            avg_val_fer = 1. - val_acc_sum / val_acc_count

            if use_rl_skipping:
                avg_val_rl = val_rl_sum / val_rl_count
                avg_val_ent = val_ent_sum / val_ent_count
                avg_val_reward = val_reward_sum / val_reward_count
                avg_val_skip = val_skip_sum / val_skip_count

                print("VALID: epoch={} "
                      "ml_cost(ce/frame)={:.2f} fer={:.2f} "
                      "rl_cost={:.4f} reward={:.4f} action_entropy={:.2f} "
                      "skip_ratio={:.2f} "
                      "time_taken={:.2f}".format(_epoch, avg_val_ce,
                                                 avg_val_fer, avg_val_rl,
                                                 avg_val_reward,
                                                 avg_val_ent, avg_val_skip,
                                                 eval_sw.elapsed()))
            else:
                print("VALID: epoch={} "
                      "ml_cost(ce/frame)={:.2f} fer={:.2f} "
                      "time_taken={:.2f}".format(_epoch,
                                                 avg_val_ce, avg_val_fer,
                                                 eval_sw.elapsed()))

            ################
            # Write result #
            ################
            [
                _val_ce_summary, _val_fer_summary, _val_skip_cnt_summary,
                _val_img_summary
            ] = sess.run(
                [
                    val_ce_summary, val_fer_summary, val_skip_cnt_summary,
                    val_image_summary
                ],
                feed_dict={
                    val_ce: avg_val_ce,
                    val_fer: avg_val_fer,
                    val_skip_cnt: avg_val_skip,
                    val_image: result_image
                })

            summary_writer.add_summary(_val_skip_cnt_summary,
                                       global_step.eval())
            summary_writer.add_summary(_val_img_summary, global_step.eval())
            summary_writer.add_summary(_val_ce_summary, global_step.eval())
            summary_writer.add_summary(_val_fer_summary, global_step.eval())

            insert_item2dict(eval_summary, 'val_ce', avg_val_ce)
            insert_item2dict(eval_summary, 'val_fer', avg_val_fer)
            # insert_item2dict(eval_summary, 'val_rl', avg_val_rl)
            # insert_item2dict(eval_summary, 'val_reward', avg_val_reward)
            # insert_item2dict(eval_summary, 'val_ent', avg_val_ent)
            insert_item2dict(eval_summary, 'time', eval_sw.elapsed())
            save_npz2(file_name, eval_summary)

            # Save best model
            if avg_val_ce < _best_ce:
                _best_ce = avg_val_ce
                best_ckpt = best_save_op.save(sess=sess,
                                              save_path=os.path.join(
                                                  args.log_dir,
                                                  "best_model(ce).ckpt"),
                                              global_step=global_step)
                print("Best checkpoint based on CE stored in: %s" % best_ckpt)

            if avg_val_fer < _best_fer:
                _best_fer = avg_val_fer
                best_ckpt = best_save_op.save(sess=sess,
                                              save_path=os.path.join(
                                                  args.log_dir,
                                                  "best_model(fer).ckpt"),
                                              global_step=global_step)
                print("Best checkpoint based on FER stored in: %s" % best_ckpt)

            # Save model
            ckpt = save_op.save(sess=sess,
                                save_path=os.path.join(args.log_dir,
                                                       "model.ckpt"),
                                global_step=global_step)
            print("Checkpoint stored in: %s" % ckpt)

            # Write result
            [_best_val_ce_summary, _best_val_fer_summary
             ] = sess.run([best_val_ce_summary, best_val_fer_summary],
                          feed_dict={
                              best_val_ce: _best_ce,
                              best_val_fer: _best_fer
                          })
            summary_writer.add_summary(_best_val_ce_summary,
                                       global_step.eval())
            summary_writer.add_summary(_best_val_fer_summary,
                                       global_step.eval())

        # Done of training
        summary_writer.close()
        print("Optimization Finished.")
예제 #8
0
    print('Building trainer')
    training_fn, trainer_params = trainer_tbptt(
              input_data=input_data,
              input_mask=input_mask,
              target_data=target_data,
              target_mask=target_mask,
              network=network,
              updater=adam,
              learning_rate=args.learn_rate,
              tbptt_layers=tbptt_layers, 
              is_first_win=is_first_win,
              delay=args.delay,
              context=args.num_tbptt_steps,
              load_updater_params=pretrain_update_params_val, 
              ivector_data=ivector_data)
    sw.print_elapsed()

    sw.reset()
    print('Building predictor')
    predict_fn = predictor_tbptt(
        input_data=input_data,
        input_mask=input_mask,
        target_data=target_data,
        target_mask=target_mask, 
        network=network, 
        tbptt_layers=tbptt_layers, 
        is_first_win=is_first_win,
        delay=args.delay,
        context=args.num_tbptt_steps,
        ivector_data=ivector_data)
    sw.print_elapsed()
예제 #9
0
def main(_):
    print(' '.join(sys.argv))
    args = FLAGS
    print(args.__flags)
    print('Hostname: {}'.format(socket.gethostname()))
    print('GPU: {}'.format(get_gpuname()))

    if not args.start_from_ckpt:
        if tf.gfile.Exists(args.log_dir):
            tf.gfile.DeleteRecursively(args.log_dir)
        tf.gfile.MakeDirs(args.log_dir)

    tf.get_variable_scope()._reuse = None

    _seed = args.base_seed + args.add_seed
    tf.set_random_seed(_seed)
    np.random.seed(_seed)

    prefix_name = os.path.join(args.log_dir, 'model')
    file_name = '%s.npz' % prefix_name

    eval_summary = OrderedDict()

    tg, test_graph = build_graph(args)
    tg_ml_cost = tf.reduce_mean(tg.ml_cost)

    global_step = tf.Variable(0, trainable=False, name="global_step")

    tvars = tf.trainable_variables()
    print([tvar.name for tvar in tvars])
    ml_tvars = [tvar for tvar in tvars if "action_logit" not in tvar.name]
    rl_tvars = [tvar for tvar in tvars if "action_logit" in tvar.name]

    ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate,
                                         beta1=0.9,
                                         beta2=0.99)
    rl_opt_func = tf.train.AdamOptimizer(learning_rate=args.rl_learning_rate,
                                         beta1=0.9,
                                         beta2=0.99)

    if args.grad_clip:
        ml_grads, _ = tf.clip_by_global_norm(tf.gradients(
            tg_ml_cost, ml_tvars),
                                             clip_norm=1.0)
    else:
        ml_grads = tf.gradients(tg_ml_cost, ml_tvars)
    ml_op = ml_opt_func.apply_gradients(zip(ml_grads, ml_tvars),
                                        global_step=global_step)

    tg_rl_cost = tf.reduce_mean(tg.rl_cost)
    rl_grads = tf.gradients(tg_rl_cost, rl_tvars)
    rl_op = rl_opt_func.apply_gradients(zip(rl_grads, rl_tvars))

    tf.add_to_collection('fast_action', args.fast_action)
    tf.add_to_collection('fast_action', args.n_fast_action)

    sync_data(args)
    datasets = [args.train_dataset, args.valid_dataset, args.test_dataset]
    train_set, valid_set, test_set = [
        create_ivector_datastream(path=args.data_path,
                                  which_set=dataset,
                                  batch_size=args.n_batch,
                                  min_after_cache=args.min_after_cache,
                                  length_sort=not args.no_length_sort)
        for dataset in datasets
    ]

    init_op = tf.global_variables_initializer()
    save_op = tf.train.Saver(max_to_keep=5)
    best_save_op = tf.train.Saver(max_to_keep=5)

    with tf.name_scope("per_step_eval"):
        tr_ce = tf.placeholder(tf.float32)
        tr_ce_summary = tf.summary.scalar("tr_ce", tr_ce)
        tr_fer = tf.placeholder(tf.float32)
        tr_fer_summary = tf.summary.scalar("tr_fer", tr_fer)
        tr_ce2 = tf.placeholder(tf.float32)
        tr_ce2_summary = tf.summary.scalar("tr_rl", tr_ce2)

        tr_image = tf.placeholder(tf.float32)
        tr_image_summary = tf.summary.image("tr_image", tr_image)

    with tf.name_scope("per_epoch_eval"):
        val_fer = tf.placeholder(tf.float32)
        val_fer_summary = tf.summary.scalar("val_fer", val_fer)
        best_val_fer = tf.placeholder(tf.float32)
        best_val_fer_summary = tf.summary.scalar("best_valid_fer",
                                                 best_val_fer)
        val_image = tf.placeholder(tf.float32)
        val_image_summary = tf.summary.image("val_image", val_image)

    vf = LinearVF()

    with tf.Session() as sess:
        sess.run(init_op)

        if args.start_from_ckpt:
            save_op = tf.train.import_meta_graph(
                os.path.join(args.log_dir, 'model.ckpt.meta'))
            save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt'))
            print(
                "Restore from the last checkpoint. Restarting from %d step." %
                global_step.eval())

        summary_writer = tf.summary.FileWriter(args.log_dir,
                                               sess.graph,
                                               flush_secs=5.0)

        tr_ce_sum = 0.
        tr_ce_count = 0
        tr_acc_sum = 0
        tr_acc_count = 0
        tr_ce2_sum = 0.
        tr_ce2_count = 0

        _best_score = np.iinfo(np.int32).max

        epoch_sw = StopWatch()
        disp_sw = StopWatch()
        eval_sw = StopWatch()
        per_sw = StopWatch()

        # For each epoch
        for _epoch in xrange(args.n_epoch):
            _n_exp = 0

            epoch_sw.reset()
            disp_sw.reset()

            print('Epoch {} training'.format(_epoch + 1))

            # For each batch
            for batch in train_set.get_epoch_iterator():
                x, x_mask, _, _, y, _ = batch
                n_batch = x.shape[0]
                _n_exp += n_batch

                if args.no_sampling:
                    new_x, new_y, actions, actions_1hot, new_x_mask = gen_supervision(
                        x, x_mask, y, args)

                    zero_state = gen_zero_state(n_batch, args.n_hidden)

                    feed_dict = {
                        tg.seq_x_data: new_x,
                        tg.seq_x_mask: new_x_mask,
                        tg.seq_y_data: new_y,
                        tg.seq_jump_data: actions
                    }
                    feed_init_state(feed_dict, tg.init_state, zero_state)

                    _tr_ml_cost, _tr_rl_cost, _, _ = \
                        sess.run([tg.ml_cost, tg.rl_cost, ml_op, rl_op], feed_dict=feed_dict)

                    tr_ce_sum += _tr_ml_cost.sum()
                    tr_ce_count += new_x_mask.sum()
                    tr_ce2_sum += _tr_rl_cost.sum()
                    tr_ce2_count += new_x_mask[:, :-1].sum()

                    actions_1hot, label_probs, new_mask, output_image = \
                        skip_rnn_forward_supervised(x, x_mask, sess, test_graph,
                        args.fast_action, args.n_fast_action, y)

                    pred_idx = expand_output(actions_1hot, x_mask, new_mask,
                                             label_probs.argmax(axis=-1))
                    tr_acc_sum += ((pred_idx == y) * x_mask).sum()
                    tr_acc_count += x_mask.sum()

                    _tr_ce_summary, _tr_fer_summary, _tr_ce2_summary, _tr_image_summary = \
                        sess.run([tr_ce_summary, tr_fer_summary, tr_ce2_summary, tr_image_summary],
                            feed_dict={tr_ce: _tr_ml_cost.sum() / new_x_mask.sum(),
                                tr_fer: 1 - ((pred_idx == y) * x_mask).sum() / x_mask.sum(),
                                tr_ce2: _tr_rl_cost.sum() / new_x_mask[:,:-1].sum(),
                                tr_image: output_image})
                    summary_writer.add_summary(_tr_ce_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_fer_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_ce2_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_image_summary,
                                               global_step.eval())

                else:

                    # train jump prediction part
                    new_x, _, actions, _, new_x_mask = gen_supervision(
                        x, x_mask, y, args)
                    zero_state = gen_zero_state(n_batch, args.n_hidden)
                    feed_dict = {
                        tg.seq_x_data: new_x,
                        tg.seq_x_mask: new_x_mask,
                        tg.seq_jump_data: actions
                    }
                    feed_init_state(feed_dict, tg.init_state, zero_state)
                    _tr_rl_cost, _ = sess.run([tg.rl_cost, rl_op],
                                              feed_dict=feed_dict)

                    tr_ce2_sum += _tr_rl_cost.sum()
                    tr_ce2_count += new_x_mask[:, :-1].sum()

                    _tr_ce2_summary, = sess.run([tr_ce2_summary],
                                                feed_dict={
                                                    tr_ce2:
                                                    _tr_rl_cost.sum() /
                                                    new_x_mask[:, :-1].sum()
                                                })

                    # generate jumps from the model
                    new_x, new_y, actions_1hot, label_probs, new_x_mask, output_image = gen_episode_supervised(
                        x, y, x_mask, sess, test_graph, args.fast_action,
                        args.n_fast_action)

                    feed_dict = {
                        tg.seq_x_data: new_x,
                        tg.seq_x_mask: new_x_mask,
                        tg.seq_y_data: new_y
                    }
                    feed_init_state(feed_dict, tg.init_state, zero_state)

                    # train label prediction part
                    _tr_ml_cost, _ = sess.run([tg.ml_cost, ml_op],
                                              feed_dict=feed_dict)

                    tr_ce_sum += _tr_ml_cost.sum()
                    tr_ce_count += new_x_mask.sum()

                    actions_1hot, label_probs, new_mask, output_image = \
                        skip_rnn_forward_supervised(x, x_mask, sess, test_graph,
                        args.fast_action, args.n_fast_action, y)

                    pred_idx = expand_output(actions_1hot, x_mask, new_mask,
                                             label_probs.argmax(axis=-1))
                    tr_acc_sum += ((pred_idx == y) * x_mask).sum()
                    tr_acc_count += x_mask.sum()

                    _tr_ce_summary, _tr_fer_summary, _tr_image_summary = \
                        sess.run([tr_ce_summary, tr_fer_summary, tr_image_summary],
                            feed_dict={tr_ce: _tr_ml_cost.sum() / new_x_mask.sum(),
                                tr_fer: 1 - ((pred_idx == y) * x_mask).sum() / x_mask.sum(),
                                tr_image: output_image})
                    summary_writer.add_summary(_tr_ce_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_fer_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_ce2_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_image_summary,
                                               global_step.eval())

                if global_step.eval() % args.display_freq == 0:
                    avg_tr_ce = tr_ce_sum / tr_ce_count
                    avg_tr_fer = 1. - tr_acc_sum / tr_acc_count
                    avg_tr_ce2 = tr_ce2_sum / tr_ce2_count

                    print(
                        "TRAIN: epoch={} iter={} ml_cost(ce/frame)={:.2f} fer={:.2f} rl_cost={:.4f} time_taken={:.2f}"
                        .format(_epoch, global_step.eval(), avg_tr_ce,
                                avg_tr_fer, avg_tr_ce2, disp_sw.elapsed()))

                    tr_ce_sum = 0.
                    tr_ce_count = 0
                    tr_acc_sum = 0.
                    tr_acc_count = 0
                    tr_ce2_sum = 0.
                    tr_ce2_count = 0

                    disp_sw.reset()

            print('--')
            print('End of epoch {}'.format(_epoch + 1))
            epoch_sw.print_elapsed()

            print('Testing')

            # Evaluate the model on the validation set
            val_acc_sum = 0
            val_acc_count = 0

            eval_sw.reset()
            for batch in valid_set.get_epoch_iterator():
                x, x_mask, _, _, y, _ = batch
                n_batch = x.shape[0]

                actions_1hot, label_probs, new_mask, output_image = \
                    skip_rnn_forward_supervised(x, x_mask, sess, test_graph,
                    args.fast_action, args.n_fast_action, y)

                pred_idx = expand_output(actions_1hot, x_mask, new_mask,
                                         label_probs.argmax(axis=-1))
                val_acc_sum += ((pred_idx == y) * x_mask).sum()
                val_acc_count += x_mask.sum()

            avg_val_fer = 1. - val_acc_sum / val_acc_count

            print("VALID: epoch={} fer={:.2f} time_taken={:.2f}".format(
                _epoch, avg_val_fer, eval_sw.elapsed()))

            _val_fer_summary, _val_image_summary = sess.run(
                [val_fer_summary, val_image_summary],
                feed_dict={
                    val_fer: avg_val_fer,
                    val_image: output_image
                })
            summary_writer.add_summary(_val_fer_summary, global_step.eval())
            summary_writer.add_summary(_val_image_summary, global_step.eval())

            insert_item2dict(eval_summary, 'val_fer', avg_val_fer)
            insert_item2dict(eval_summary, 'time', eval_sw.elapsed())
            save_npz2(file_name, eval_summary)

            # Save model
            if avg_val_fer < _best_score:
                _best_score = avg_val_fer
                best_ckpt = best_save_op.save(sess,
                                              os.path.join(
                                                  args.log_dir,
                                                  "best_model.ckpt"),
                                              global_step=global_step)
                print("Best checkpoint stored in: %s" % best_ckpt)
            ckpt = save_op.save(sess,
                                os.path.join(args.log_dir, "model.ckpt"),
                                global_step=global_step)
            print("Checkpoint stored in: %s" % ckpt)

            _best_val_fer_summary, = sess.run(
                [best_val_fer_summary], feed_dict={best_val_fer: _best_score})
            summary_writer.add_summary(_best_val_fer_summary,
                                       global_step.eval())
        summary_writer.close()

        print("Optimization Finished.")