Ejemplo n.º 1
0
def main(_):
  print(' '.join(sys.argv))
  args = FLAGS
  print(args.__flags)
  if not args.start_from_ckpt:
    if tf.gfile.Exists(args.log_dir):
      tf.gfile.DeleteRecursively(args.log_dir)
    tf.gfile.MakeDirs(args.log_dir)

  tf.get_variable_scope()._reuse = None

  _seed = args.base_seed + args.add_seed
  tf.set_random_seed(_seed)
  np.random.seed(_seed)

  prefix_name = os.path.join(args.log_dir, 'model')
  file_name = '%s.npz' % prefix_name

  eval_summary = OrderedDict() #

  tg = build_graph(args)
  tg_ml_cost = tf.reduce_mean(tg.ml_cost)

  global_step = tf.Variable(0, trainable=False, name="global_step")

  tvars = tf.trainable_variables()

  ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate,
                                       beta1=0.9, beta2=0.99)

  if args.grad_clip:
    ml_grads, _ = tf.clip_by_global_norm(tf.gradients(tg_ml_cost, tvars),
                                      clip_norm=1.0)
  else:
    ml_grads = tf.gradients(tg_ml_cost, tvars)
  ml_op = ml_opt_func.apply_gradients(zip(ml_grads, tvars), global_step=global_step)

  sync_data(args)
  datasets = [args.train_dataset, args.valid_dataset, args.test_dataset]
  train_set, valid_set, test_set = [create_ivector_datastream(path=args.data_path, which_set=dataset,
      batch_size=args.batch_size, min_after_cache=args.min_after_cache, length_sort=not args.no_length_sort) for dataset in datasets]

  init_op = tf.global_variables_initializer()
  save_op = tf.train.Saver(max_to_keep=5)
  best_save_op = tf.train.Saver(max_to_keep=5)

  with tf.name_scope("per_step_eval"):
    tr_ce = tf.placeholder(tf.float32)
    tr_ce_summary = tf.summary.scalar("tr_ce", tr_ce)

  with tf.name_scope("per_epoch_eval"):
    best_val_ce = tf.placeholder(tf.float32)
    val_ce = tf.placeholder(tf.float32)
    best_val_ce_summary = tf.summary.scalar("best_valid_ce", best_val_ce)
    val_ce_summary = tf.summary.scalar("valid_ce", val_ce)

  with tf.Session() as sess:
    sess.run(init_op)

    if args.start_from_ckpt:
      save_op = tf.train.import_meta_graph(os.path.join(args.log_dir,
                                                        'model.ckpt.meta'))
      save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt'))
      print("Restore from the last checkpoint. "
            "Restarting from %d step." % global_step.eval())

    summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph, flush_secs=5.0)

    tr_ce_sum = 0.
    tr_ce_count = 0
    tr_acc_sum = 0
    tr_acc_count = 0
    _best_score = np.iinfo(np.int32).max

    epoch_sw = StopWatch()
    disp_sw = StopWatch()
    eval_sw = StopWatch()
    per_sw = StopWatch()
    # For each epoch
    for _epoch in xrange(args.n_epoch):
      _n_exp = 0

      epoch_sw.reset()
      disp_sw.reset()

      print('--')
      print('Epoch {} training'.format(_epoch+1))

      # For each batch
      for batch in train_set.get_epoch_iterator():
        orig_x, orig_x_mask, _, _, orig_y, _ = batch

        # Get skipped frames
        for sub_batch in skip_frames_fixed([orig_x, orig_x_mask, orig_y], args.n_skip+1):
            x, x_mask, y = sub_batch
            n_batch, _, _ = x.shape

            _feed_states = initial_states(n_batch, args.n_hidden)

            _tr_ml_cost, _seq_logit,  _ = sess.run([tg.ml_cost,
                                                    tg.seq_logit,
                                                    ml_op],
                                                    feed_dict={tg.seq_x_data: x,
                                                               tg.seq_x_mask: x_mask,
                                                               tg.seq_y_data: y,
                                                               tg.init_state: _feed_states})

            tr_ce_sum += _tr_ml_cost.sum()
            tr_ce_count += x_mask.sum()
            _tr_ce_summary, = sess.run([tr_ce_summary], feed_dict={tr_ce: _tr_ml_cost.sum() / x_mask.sum()})
            summary_writer.add_summary(_tr_ce_summary, global_step.eval())

            _, n_seq = orig_y.shape
            _expand_seq_logit = interpolate_feat(_seq_logit, num_skips=args.n_skip+1, axis=1, use_bidir=True)
            _pred_idx = _expand_seq_logit.argmax(axis=2)

            tr_acc_sum += ((_pred_idx == orig_y) * orig_x_mask).sum()
            tr_acc_count += orig_x_mask.sum()

        if global_step.eval() % args.display_freq == 0:
          avg_tr_ce = tr_ce_sum / tr_ce_count
          avg_tr_fer = 1. - float(tr_acc_sum) / tr_acc_count

          print("TRAIN: epoch={} iter={} ml_cost(ce/frame)={:.2f} fer={:.2f} time_taken={:.2f}".format(
              _epoch, global_step.eval(), avg_tr_ce, avg_tr_fer, disp_sw.elapsed()))

          tr_ce_sum = 0.
          tr_ce_count = 0
          tr_acc_sum = 0
          tr_acc_count = 0
          disp_sw.reset()

      print('--')
      print('End of epoch {}'.format(_epoch+1))
      epoch_sw.print_elapsed()

      print('Testing')

      # Evaluate the model on the validation set
      val_ce_sum = 0.
      val_ce_count = 0
      val_acc_sum = 0
      val_acc_count = 0

      eval_sw.reset()
      for batch in valid_set.get_epoch_iterator():
        orig_x, orig_x_mask, _, _, orig_y, _ = batch

        for sub_batch in skip_frames_fixed([orig_x, orig_x_mask, orig_y], args.n_skip+1, return_first=True):
            x, x_mask, y = sub_batch
            n_batch, _, _ = x.shape

            _feed_states = initial_states(n_batch, args.n_hidden)

            _val_ml_cost, _seq_logit = sess.run([tg.ml_cost,
                                                 tg.seq_logit,],
                                                feed_dict={tg.seq_x_data: x,
                                                           tg.seq_x_mask: x_mask,
                                                           tg.seq_y_data: y,
                                                           tg.init_state: _feed_states})

            val_ce_sum += _val_ml_cost.sum()
            val_ce_count += x_mask.sum()

            _, n_seq = orig_y.shape
            _expand_seq_logit = interpolate_feat(_seq_logit, num_skips=args.n_skip+1, axis=1, use_bidir=True)
            _pred_idx = _expand_seq_logit.argmax(axis=2)

            val_acc_sum += ((_pred_idx == orig_y) * orig_x_mask).sum()
            val_acc_count += orig_x_mask.sum()

      avg_val_ce = val_ce_sum / val_ce_count
      avg_val_fer = 1. - float(val_acc_sum) / val_acc_count

      print("VALID: epoch={} ml_cost(ce/frame)={:.2f} fer={:.2f} time_taken={:.2f}".format(
          _epoch, avg_val_ce, avg_val_fer, eval_sw.elapsed()))

      _val_ce_summary, = sess.run([val_ce_summary], feed_dict={val_ce: avg_val_ce})
      summary_writer.add_summary(_val_ce_summary, global_step.eval())

      insert_item2dict(eval_summary, 'val_ce', avg_val_ce)
      insert_item2dict(eval_summary, 'time', eval_sw.elapsed())
      save_npz2(file_name, eval_summary)

      # Save model
      if avg_val_ce < _best_score:
        _best_score = avg_val_ce
        best_ckpt = best_save_op.save(sess, os.path.join(args.log_dir,
                                                         "best_model.ckpt"),
                                      global_step=global_step)
        print("Best checkpoint stored in: %s" % best_ckpt)
      ckpt = save_op.save(sess, os.path.join(args.log_dir, "model.ckpt"),
                          global_step=global_step)
      print("Checkpoint stored in: %s" % ckpt)

      _best_val_ce_summary, = sess.run([best_val_ce_summary],
                                feed_dict={best_val_ce: _best_score})
      summary_writer.add_summary(_best_val_ce_summary, global_step.eval())
    summary_writer.close()

    print("Optimization Finished.")
Ejemplo n.º 2
0
    def __call__(self):
        # Reuse variables if needed
        tf.get_variable_scope()._reuse = None
        # Fix random seeds
        _seed = self.FLAGS.base_seed + self.FLAGS.add_seed
        tf.set_random_seed(_seed)
        np.random.seed(_seed)
        # Prefixed names for save files
        prefix_name = os.path.join(self.FLAGS.log_dir, self._file_name)
        file_name = '%s.npz' % prefix_name
        # Declare summary
        summary = OrderedDict()
        # Training objective should always come at the front
        G = self._build_graph(self.FLAGS)
        _cost = tf.reduce_mean(G[0])
        tf.add_to_collection('losses', _cost)
        if self.FLAGS.weight_decay > 0.0:
            with tf.variable_scope('weight_norm') as scope:
                weights_norm = tf.reduce_sum(
                    input_tensor=self.FLAGS.weight_decay * tf.stack([
                        2 * tf.nn.l2_loss(W)
                        for W in tf.get_collection('weights')
                    ]),
                    name='weights_norm')
            tf.add_to_collection('losses', weights_norm)
        _cost = tf.add_n(tf.get_collection('losses'), name='total_loss')
        # Define non-trainable variables to track the progress
        global_step = tf.Variable(0, trainable=False, name="global_step")
        # For Adam optimizer, in the original paper, we use 0.99 for beta2
        opt_func = tf.train.AdamOptimizer(
            learning_rate=self.FLAGS.learning_rate, beta1=0.9, beta2=0.99)
        if not self.FLAGS.grad_clip:
            # Without clipping
            t_step = opt_func.minimize(_cost, global_step=global_step)
        else:
            # Apply gradient clipping using global norm
            tvars = tf.trainable_variables()
            grads, _ = tf.clip_by_global_norm(tf.gradients(
                _cost, tvars, aggregation_method=2),
                                              clip_norm=1.0)
            t_step = opt_func.apply_gradients(zip(grads, tvars),
                                              global_step=global_step)

# Construct dataset objects
        sync_data(self.FLAGS)
        datasets = [
            self.FLAGS.train_dataset, self.FLAGS.valid_dataset,
            self.FLAGS.test_dataset
        ]
        train_set, valid_set, test_set = [
            create_ivector_datastream(path=self.FLAGS.data_path,
                                      which_set=dataset,
                                      batch_size=self.FLAGS.batch_size)
            for dataset in datasets
        ]

        init_op = tf.global_variables_initializer()
        save_op = tf.train.Saver(max_to_keep=5)
        best_save_op = tf.train.Saver(max_to_keep=5)
        with tf.name_scope("per_step_eval"):
            # Add batch level nats to summaries
            step_summary = tf.summary.scalar("tr_nats", tf.reduce_mean(G[0]))
        # Add monitor ops if exist
        if self._add_monitor_op is not None:
            monitor_op = self._add_monitor_op(self.FLAGS)
        else:
            monitor_op = None
        S = self._define_summary()

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            # Initializing the variables
            sess.run(init_op)
            if self.FLAGS.start_from_ckpt:
                save_op = tf.train.import_meta_graph(
                    os.path.join(self.FLAGS.log_dir, 'model.ckpt.meta'))
                save_op.restore(sess,
                                os.path.join(self.FLAGS.log_dir, 'model.ckpt'))
                print("Restore from the last checkpoint. "
                      "Restarting from %d step." % global_step.eval())
            # Declare summary writer
            summary_writer = tf.summary.FileWriter(self.FLAGS.log_dir,
                                                   flush_secs=5.0)
            tr_costs = []
            _best_score = np.iinfo(np.int32).max
            # Keep training until max iteration
            for _epoch in xrange(self.FLAGS.n_epoch):

                _n_exp = 0
                _time = time.time()
                __time = time.time()
                for batch in train_set.get_epoch_iterator():

                    x, x_mask, _, _, y, _ = batch
                    x = np.transpose(x, (1, 0, 2))
                    x_mask = np.transpose(x_mask, (1, 0))
                    y = np.transpose(y, (1, 0))

                    _, n_batch, _ = x.shape
                    _feed_states = self._initial_states(
                        n_batch, self.FLAGS.n_hidden)

                    _n_exp += self.FLAGS.batch_size
                    # Run optimization op (backprop)
                    _tr_cost, _step_summary, _feed_states = \
                            self._sess_wrapper(x, x_mask, y, t_step, step_summary, sess,
                                               G, _feed_states)
                    # Write step level logs
                    summary_writer.add_summary(_step_summary,
                                               global_step.eval())
                    tr_costs.append(_tr_cost.mean())
                    if global_step.eval() % self.FLAGS.display_freq == 0:
                        tr_cost = np.array(tr_costs).mean()
                        print("Epoch " + str(_epoch) + \
                              ", Iter " + str(global_step.eval()) + \
                              ", Average batch loss= " + "{:.6f}".format(tr_cost) + \
                              ", Elapsed time= " + "{:.5f}".format(time.time() - _time))
                        _time = time.time()
                        tr_costs = []
                # Monitor training/validation nats and bits
                _tr_nats, _tr_bits, _val_nats, _val_bits = \
                        self._monitor(G, sess, None, valid_set.get_epoch_iterator(),
                                      self.FLAGS, summary,
                                      S, summary_writer, global_step.eval(), monitor_op)
                _time_spent = time.time() - __time
                print("Train average nats= " + "{:.6f}".format(_tr_nats) + \
                      ", Train average bits= " + "{:.6f}".format(_tr_bits) + \
                      ", Valid average nats= " + "{:.6f}".format(_val_nats) + \
                      ", Valid average bits= " + "{:.6f}".format(_val_bits) + \
                      ", Elapsed time= " + "{:.5f}".format(_time_spent)) + \
                      ", Observed examples= " + "{:d}".format(_n_exp)
                insert_item2dict(summary, 'time', _time_spent)
                save_npz2(file_name, summary)
                # Save model
                if _val_bits < _best_score:
                    _best_score = _val_bits
                    best_ckpt = best_save_op.save(sess,
                                                  os.path.join(
                                                      self.FLAGS.log_dir,
                                                      "best_model.ckpt"),
                                                  global_step=global_step)
                    print("Best checkpoint stored in: %s" % best_ckpt)
                ckpt = save_op.save(sess,
                                    os.path.join(self.FLAGS.log_dir,
                                                 "model.ckpt"),
                                    global_step=global_step)
                print("Checkpointed in: %s" % ckpt)
                _epoch_summary = sess.run([S[0][0]],
                                          feed_dict={S[1][0]: _best_score})
                summary_writer.add_summary(_epoch_summary[0],
                                           global_step.eval())
            summary_writer.close()
            print("Optimization Finished.")
Ejemplo n.º 3
0
def main():
    print(' '.join(sys.argv))

    args = get_args()
    print(args)

    print('Hostname: {}'.format(socket.gethostname()))
    print('GPU: {}'.format(get_gpuname()))

    if not args.start_from_ckpt:
        if tf.gfile.Exists(args.log_dir):
            tf.gfile.DeleteRecursively(args.log_dir)
        tf.gfile.MakeDirs(args.log_dir)

    tf.get_variable_scope()._reuse = None

    _seed = args.base_seed + args.add_seed
    tf.set_random_seed(_seed)
    np.random.seed(_seed)

    prefix_name = os.path.join(args.log_dir, 'model')
    file_name = '%s.npz' % prefix_name

    eval_summary = OrderedDict()

    tg, sg = build_graph(args)
    tg_ml_cost = tf.reduce_mean(tg.ml_cost)

    global_step = tf.Variable(0, trainable=False, name="global_step")

    tvars = tf.trainable_variables()
    print([tvar.name for tvar in tvars])

    ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate,
                                         beta1=0.9,
                                         beta2=0.99)
    rl_opt_func = tf.train.AdamOptimizer(learning_rate=args.rl_learning_rate,
                                         beta1=0.9,
                                         beta2=0.99)

    if args.grad_clip:
        ml_grads, _ = tf.clip_by_global_norm(tf.gradients(tg_ml_cost, tvars),
                                             clip_norm=1.0)
    else:
        ml_grads = tf.gradients(tg_ml_cost, tvars)
    ml_op = ml_opt_func.apply_gradients(zip(ml_grads, tvars),
                                        global_step=global_step)

    tg_rl_cost = tf.reduce_mean(tg.rl_cost)
    rl_grads = tf.gradients(tg_rl_cost, tvars)
    # do not increase global step -- ml op increases it
    rl_op = rl_opt_func.apply_gradients(zip(rl_grads, tvars))

    tf.add_to_collection('n_fast_action', args.n_fast_action)

    sync_data(args)
    datasets = [args.train_dataset, args.valid_dataset, args.test_dataset]
    train_set, valid_set, test_set = [
        create_ivector_datastream(path=args.data_path,
                                  which_set=dataset,
                                  batch_size=args.n_batch,
                                  min_after_cache=args.min_after_cache,
                                  length_sort=not args.no_length_sort)
        for dataset in datasets
    ]

    init_op = tf.global_variables_initializer()
    save_op = tf.train.Saver(max_to_keep=5)
    best_save_op = tf.train.Saver(max_to_keep=5)

    with tf.name_scope("per_step_eval"):
        tr_ce = tf.placeholder(tf.float32)
        tr_ce_summary = tf.summary.scalar("tr_ce", tr_ce)
        tr_image = tf.placeholder(tf.float32)
        tr_image_summary = tf.summary.image("tr_image", tr_image)
        tr_fer = tf.placeholder(tf.float32)
        tr_fer_summary = tf.summary.scalar("tr_fer", tr_fer)
        tr_rl = tf.placeholder(tf.float32)
        tr_rl_summary = tf.summary.scalar("tr_rl", tr_rl)
        tr_rw_hist = tf.placeholder(tf.float32)
        tr_rw_hist_summary = tf.summary.histogram("tr_reward_hist", tr_rw_hist)

    with tf.name_scope("per_epoch_eval"):
        best_val_ce = tf.placeholder(tf.float32)
        val_ce = tf.placeholder(tf.float32)
        best_val_ce_summary = tf.summary.scalar("best_valid_ce", best_val_ce)
        val_ce_summary = tf.summary.scalar("valid_ce", val_ce)

    vf = LinearVF()

    with tf.Session() as sess:
        sess.run(init_op)

        if args.start_from_ckpt:
            save_op = tf.train.import_meta_graph(
                os.path.join(args.log_dir, 'model.ckpt.meta'))
            save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt'))
            print(
                "Restore from the last checkpoint. Restarting from %d step." %
                global_step.eval())

        summary_writer = tf.summary.FileWriter(args.log_dir,
                                               sess.graph,
                                               flush_secs=5.0)

        tr_ce_sum = 0.
        tr_ce_count = 0
        tr_acc_sum = 0
        tr_acc_count = 0
        tr_rl_costs = []
        tr_action_entropies = []
        tr_rewards = []

        _best_score = np.iinfo(np.int32).max

        epoch_sw = StopWatch()
        disp_sw = StopWatch()
        eval_sw = StopWatch()
        per_sw = StopWatch()

        # For each epoch
        for _epoch in xrange(args.n_epoch):
            _n_exp = 0

            epoch_sw.reset()
            disp_sw.reset()

            print('Epoch {} training'.format(_epoch + 1))

            # For each batch
            for batch in train_set.get_epoch_iterator():
                x, x_mask, _, _, y, _ = batch
                n_batch = x.shape[0]
                _n_exp += n_batch

                new_x, new_y, actions_1hot, rewards, action_entropies, new_x_mask, new_reward_mask, output_image = \
                        gen_episode_with_seg_reward(x, x_mask, y, sess, sg, args)

                advantages, _ = compute_advantage(new_x, new_x_mask, rewards,
                                                  new_reward_mask, vf, args)

                zero_state = gen_zero_state(n_batch, args.n_hidden)

                feed_dict = {
                    tg.seq_x_data: new_x,
                    tg.seq_x_mask: new_x_mask,
                    tg.seq_y_data: new_y,
                    tg.seq_action: actions_1hot,
                    tg.seq_advantage: advantages,
                    tg.seq_action_mask: new_reward_mask,
                    tg.seq_y_data_for_action: new_y
                }
                feed_init_state(feed_dict, tg.init_state, zero_state)

                _tr_ml_cost, _tr_rl_cost, _, _, pred_idx = \
                    sess.run([tg.ml_cost, tg.rl_cost, ml_op, rl_op, tg.pred_idx], feed_dict=feed_dict)

                tr_ce_sum += _tr_ml_cost.sum()
                tr_ce_count += new_x_mask.sum()

                pred_idx = expand_output(actions_1hot, x_mask, new_x_mask,
                                         pred_idx.reshape([n_batch, -1]),
                                         args.n_fast_action)
                tr_acc_sum += ((pred_idx == y) * x_mask).sum()
                tr_acc_count += x_mask.sum()

                _tr_ce_summary, _tr_fer_summary, _tr_rl_summary, _tr_image_summary, _tr_rw_hist_summary = \
                    sess.run([tr_ce_summary, tr_fer_summary, tr_rl_summary, tr_image_summary, tr_rw_hist_summary],
                        feed_dict={tr_ce: _tr_ml_cost.sum() / new_x_mask.sum(),
                            tr_fer: ((pred_idx == y) * x_mask).sum() / x_mask.sum(),
                            tr_rl: _tr_rl_cost.sum() / new_reward_mask.sum(),
                            tr_image: output_image, tr_rw_hist: rewards})
                summary_writer.add_summary(_tr_ce_summary, global_step.eval())
                summary_writer.add_summary(_tr_fer_summary, global_step.eval())
                summary_writer.add_summary(_tr_rl_summary, global_step.eval())
                summary_writer.add_summary(_tr_image_summary,
                                           global_step.eval())
                summary_writer.add_summary(_tr_rw_hist_summary,
                                           global_step.eval())

                tr_rl_costs.append(_tr_rl_cost.sum() / new_reward_mask.sum())
                tr_action_entropies.append(action_entropies.sum() /
                                           new_reward_mask.sum())
                tr_rewards.append(rewards.sum() / new_reward_mask.sum())

                if global_step.eval() % args.display_freq == 0:
                    avg_tr_ce = tr_ce_sum / tr_ce_count
                    avg_tr_fer = 1. - tr_acc_sum / tr_acc_count
                    avg_tr_rl_cost = np.asarray(tr_rl_costs).mean()
                    avg_tr_action_entropy = np.asarray(
                        tr_action_entropies).mean()
                    avg_tr_reward = np.asarray(tr_rewards).mean()

                    print(
                        "TRAIN: epoch={} iter={} ml_cost(ce/frame)={:.2f} fer={:.2f} rl_cost={:.4f} reward={:.4f} action_entropy={:.2f} time_taken={:.2f}"
                        .format(_epoch, global_step.eval(), avg_tr_ce,
                                avg_tr_fer, avg_tr_rl_cost, avg_tr_reward,
                                avg_tr_action_entropy, disp_sw.elapsed()))

                    tr_ce_sum = 0.
                    tr_ce_count = 0
                    tr_acc_sum = 0.
                    tr_acc_count = 0
                    tr_rl_costs = []
                    tr_action_entropies = []
                    tr_rewards = []

                    disp_sw.reset()

            print('--')
            print('End of epoch {}'.format(_epoch + 1))
            epoch_sw.print_elapsed()

            print('Testing')

            # Evaluate the model on the validation set
            val_ce_sum = 0.
            val_ce_count = 0
            val_acc_sum = 0
            val_acc_count = 0
            val_rl_costs = []
            val_action_entropies = []
            val_rewards = []

            eval_sw.reset()
            for batch in valid_set.get_epoch_iterator():
                x, x_mask, _, _, y, _ = batch
                n_batch = x.shape[0]

                new_x, new_y, actions_1hot, rewards, action_entropies, new_x_mask, new_reward_mask, output_image, new_y_sample = \
                        gen_episode_with_seg_reward(x, x_mask, y, sess, sg, args, sample_y=True)

                advantages, _ = compute_advantage(new_x, new_x_mask, rewards,
                                                  new_reward_mask, vf, args)

                zero_state = gen_zero_state(n_batch, args.n_hidden)

                feed_dict = {
                    tg.seq_x_data: new_x,
                    tg.seq_x_mask: new_x_mask,
                    tg.seq_y_data: new_y,
                    tg.seq_action: actions_1hot,
                    tg.seq_advantage: advantages,
                    tg.seq_action_mask: new_reward_mask,
                    tg.seq_y_data_for_action: new_y_sample
                }
                feed_init_state(feed_dict, tg.init_state, zero_state)

                _val_ml_cost, _val_rl_cost, pred_idx = sess.run(
                    [tg.ml_cost, tg.rl_cost, tg.pred_idx], feed_dict=feed_dict)

                val_ce_sum += _val_ml_cost.sum()
                val_ce_count += new_x_mask.sum()

                pred_idx = expand_output(actions_1hot, x_mask, new_x_mask,
                                         pred_idx.reshape([n_batch, -1]),
                                         args.n_fast_action)
                val_acc_sum += ((pred_idx == y) * x_mask).sum()
                val_acc_count += x_mask.sum()

                val_rl_costs.append(_val_rl_cost.sum() / new_reward_mask.sum())
                val_action_entropies.append(action_entropies.sum() /
                                            new_reward_mask.sum())
                val_rewards.append(rewards.sum() / new_reward_mask.sum())

            avg_val_ce = val_ce_sum / val_ce_count
            avg_val_fer = 1. - val_acc_sum / val_acc_count
            avg_val_rl_cost = np.asarray(val_rl_costs).mean()
            avg_val_action_entropy = np.asarray(val_action_entropies).mean()
            avg_val_reward = np.asarray(val_rewards).mean()

            print(
                "VALID: epoch={} ml_cost(ce/frame)={:.2f} fer={:.2f} rl_cost={:.4f} reward={:.4f} action_entropy={:.2f} time_taken={:.2f}"
                .format(_epoch, avg_val_ce, avg_val_fer, avg_val_rl_cost,
                        avg_val_reward, avg_val_action_entropy,
                        eval_sw.elapsed()))

            _val_ce_summary, = sess.run([val_ce_summary],
                                        feed_dict={val_ce: avg_val_ce})
            summary_writer.add_summary(_val_ce_summary, global_step.eval())

            insert_item2dict(eval_summary, 'val_ce', avg_val_ce)
            insert_item2dict(eval_summary, 'val_rl_cost', avg_val_rl_cost)
            insert_item2dict(eval_summary, 'val_reward', avg_val_reward)
            insert_item2dict(eval_summary, 'val_action_entropy',
                             avg_val_action_entropy)
            insert_item2dict(eval_summary, 'time', eval_sw.elapsed())
            save_npz2(file_name, eval_summary)

            # Save model
            if avg_val_ce < _best_score:
                _best_score = avg_val_ce
                best_ckpt = best_save_op.save(sess,
                                              os.path.join(
                                                  args.log_dir,
                                                  "best_model.ckpt"),
                                              global_step=global_step)
                print("Best checkpoint stored in: %s" % best_ckpt)
            ckpt = save_op.save(sess,
                                os.path.join(args.log_dir, "model.ckpt"),
                                global_step=global_step)
            print("Checkpoint stored in: %s" % ckpt)

            _best_val_ce_summary, = sess.run(
                [best_val_ce_summary], feed_dict={best_val_ce: _best_score})
            summary_writer.add_summary(_best_val_ce_summary,
                                       global_step.eval())
        summary_writer.close()

        print("Optimization Finished.")
Ejemplo n.º 4
0
 def __call__(self):
     # Fix random seeds
     _seed = self.FLAGS.base_seed + self.FLAGS.add_seed
     np.random.seed(seed=_seed)
     # Prefixed names for save files
     prefix_name = os.path.join(self.FLAGS.log_dir, self._file_name)
     file_name = '%s.npz' % prefix_name
     best_file_name = '%s.best.npz' % prefix_name
     opt_file_name = '%s.grads.npz' % prefix_name
     best_opt_file_name = '%s.best.grads.npz' % prefix_name
     if self.FLAGS.start_from_ckpt and os.path.exists(file_name):
         self._ckpt_file_name = file_name
     # Declare summary
     summary = OrderedDict()
     # Initialize the variables
     f_prop, f_update, f_log_prob, f_debug, tparams, opt_tparams, \
         states, st_slope = self._build_graph(self.FLAGS)
     # Restore from checkpoint if FLAGS.start_from_ckpt is on
     if self.FLAGS.start_from_ckpt and os.path.exists(file_name):
         tparams = init_tparams_with_restored_value(tparams, file_name)
         model = np.load(file_name)
         for k, v in model.items():
             if 'summary' in k:
                 summary[k] = list(v)
             if 'time' in k:
                 summary[k] = list(v)
         global_step = model['global_step']
         epoch_step = model['epoch_step']
         batch_step = model['batch_step']
         print("Restore from the last checkpoint. "
               "Restarting from %d step." % global_step)
     else:
         global_step = 0
         epoch_step = 0
         batch_step = 0
     # Construct dataset objects
     train_set = TextIterator(which_set='train',
                              max_seq_len=self.FLAGS.max_seq_len,
                              batch_size=self.FLAGS.batch_size,
                              shuffle_every_epoch=1)
     if self.FLAGS.eval_train:
         train_infer_set = TextIterator(which_set='train',
                                        max_seq_len=self.FLAGS.max_seq_len,
                                        batch_size=self.FLAGS.batch_size,
                                        shuffle_every_epoch=0)
     else:
         train_infer_set = None
     valid_set = TextIterator(which_set='valid',
                              max_seq_len=self.FLAGS.max_seq_len,
                              batch_size=self.FLAGS.batch_size,
                              shuffle_every_epoch=0)
     if self.FLAGS.start_from_ckpt:
         _summary = self._monitor(f_log_prob, self.FLAGS, valid_set, None,
                                  states)
         _val_bits = _summary['val_bits']
         if _val_bits != summary['val_bits'][-1]:
             raise ValueError(
                 "Sanity check failed, check values do not match.")
         try:
             for cc in xrange(batch_step + 1):
                 train_set.next()
         except:
             batch_step = 0
     best_params = None
     tr_costs = []
     _best_score = np.iinfo(np.int32).max
     # Keep training until max iteration
     print("Starting the optimization")
     for _epoch in xrange(self.FLAGS.n_epoch):
         reset_state(states)
         _n_exp = 0
         _time = time.time()
         __time = time.time()
         if self.FLAGS.start_from_ckpt and batch_step is not 0:
             pass
         else:
             batch_step = 0
         if self.FLAGS.use_slope_anneal:
             if _epoch <= self.FLAGS.n_anneal_epoch:
                 new_slope = float(1. + (self.FLAGS.n_slope - 1) /
                                   float(self.FLAGS.n_anneal_epoch) *
                                   _epoch)
                 st_slope.set_value(new_slope)
                 print("Changed the ST slope to : %f" %
                       st_slope.get_value())
         for x in train_set:
             x, x_mask = gen_mask(x, max_seq_len=self.FLAGS.max_seq_len)
             _n_exp += self.FLAGS.batch_size
             # Run f-prop and optimization functions (backprop)
             cost = f_prop(x, x_mask)
             f_update(self.FLAGS.learning_rate)
             tr_costs.append(cost)
             if np.mod(global_step, self.FLAGS.display_freq) == 0:
                 _time_spent = time.time() - _time
                 tr_cost = np.array(tr_costs).mean()
                 print("Epoch " + str(_epoch) + \
                       ", Iter " + str(global_step) + \
                       ", Average batch loss= " + "{:.6f}".format(tr_cost) + \
                       ", Elapsed time= " + "{:.5f}".format(_time_spent))
                 _time = time.time()
                 tr_costs = []
             batch_step += 1
             global_step += 1
         # Monitor training/validation nats and bits
         _summary = self._monitor(f_log_prob, self.FLAGS, valid_set,
                                  train_infer_set, states)
         feed_dict(summary, _summary)
         print("Train average nats= " + "{:.6f}".format(_summary['tr_nats']) + \
               ", Train average bits= " + "{:.6f}".format(_summary['tr_bits']) + \
               ", Valid average nats= " + "{:.6f}".format(_summary['val_nats']) + \
               ", Valid average bits= " + "{:.6f}".format(_summary['val_bits']) + \
               ", Elapsed time= " + "{:.5f}".format(time.time() - __time)) + \
               ", Observed examples= " + "{:d}".format(_n_exp)
         insert_item2dict(summary, 'time', _time_spent)
         # Save model
         _val_bits = summary['val_bits'][-1]
         if _val_bits < _best_score:
             _best_score = _val_bits
             # Save the best model
             best_params = unzip(tparams)
             if self.FLAGS.use_slope_anneal:
                 best_params['st_slope'] = st_slope.get_value()
             save_npz(best_file_name, global_step, epoch_step, batch_step,
                      best_params, summary)
             # Save the gradients of best model
             best_opt_params = unzip(opt_tparams)
             save_npz2(best_opt_file_name, best_opt_params)
             print("Best checkpoint stored in: %s" % best_file_name)
         # Save the latest model
         params = unzip(tparams)
         if self.FLAGS.use_slope_anneal:
             params['st_slope'] = st_slope.get_value()
         save_npz(file_name, global_step, epoch_step, batch_step, params,
                  summary)
         # Save the gradients of latest model
         opt_params = unzip(opt_tparams)
         save_npz2(opt_file_name, opt_params)
         print("Checkpointed in: %s" % file_name)
     print("Optimization Finished.")
Ejemplo n.º 5
0
def main(_):
    # Print settings
    print(' '.join(sys.argv))
    args = FLAGS
    for k, v in args.__flags.iteritems():
        print(k, v)

    # Load checkpoint
    if not args.start_from_ckpt:
        if tf.gfile.Exists(args.log_dir):
            tf.gfile.DeleteRecursively(args.log_dir)
        tf.gfile.MakeDirs(args.log_dir)

    # ???
    tf.get_variable_scope()._reuse = None

    # Set random seed
    _seed = args.base_seed + args.add_seed
    tf.set_random_seed(_seed)
    np.random.seed(_seed)

    # Set save file name
    prefix_name = os.path.join(args.log_dir, 'model')
    file_name = '%s.npz' % prefix_name

    # Set evaluation summary
    eval_summary = OrderedDict()

    # Build model graph
    tg, sg = build_graph(args)

    # Set linear regressor for baseline
    vf = LinearVF()

    # Set global step
    global_step = tf.Variable(0, trainable=False, name="global_step")

    # Get ml/rl related parameters
    tvars = tf.trainable_variables()
    ml_vars = [tvar for tvar in tvars if "action" not in tvar.name]
    rl_vars = [tvar for tvar in tvars if "action" in tvar.name]

    # Set optimizer
    ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    rl_opt_func = tf.train.AdamOptimizer(learning_rate=args.rl_learning_rate)

    # Set model ml cost (sum over all and divide it by batch_size)
    ml_cost = tf.reduce_sum(tg.seq_ml_cost)
    ml_cost /= tf.to_float(tf.shape(tg.seq_x_data)[0])
    ml_cost += args.ml_l2 * 0.5 * tf.add_n(
        [tf.reduce_sum(tf.square(var)) for var in ml_vars])

    # Set model rl cost (sum over all and divide it by batch_size, also entropy cost)
    rl_cost = tg.seq_rl_cost - args.ent_weight * tg.seq_a_ent
    rl_cost = tf.reduce_sum(rl_cost)
    rl_cost /= tf.to_float(tf.shape(tg.seq_x_data)[0])

    # Set model rl cost (sum over all and divide it by batch_size, also entropy cost)
    real_rl_cost = tf.reduce_sum(tg.seq_real_rl_cost)
    real_rl_cost /= tf.reduce_sum(tg.seq_a_mask)

    # Gradient clipping for ML
    ml_grads = tf.gradients(ml_cost, ml_vars)
    if args.grad_clip > 0.0:
        ml_grads, _ = tf.clip_by_global_norm(t_list=ml_grads,
                                             clip_norm=args.grad_clip)

    # Gradient for RL
    rl_grads = tf.gradients(rl_cost, rl_vars)

    # ML optimization
    ml_op = ml_opt_func.apply_gradients(grads_and_vars=zip(ml_grads, ml_vars),
                                        global_step=global_step,
                                        name='ml_op')

    # RL optimization
    rl_op = rl_opt_func.apply_gradients(grads_and_vars=zip(rl_grads, rl_vars),
                                        global_step=global_step,
                                        name='rl_op')

    # Sync dataset
    sync_data(args)

    # Get dataset
    train_set = create_ivector_datastream(path=args.data_path,
                                          which_set=args.train_dataset,
                                          batch_size=args.batch_size,
                                          min_after_cache=args.min_after_cache,
                                          length_sort=not args.no_length_sort)
    valid_set = create_ivector_datastream(path=args.data_path,
                                          which_set=args.valid_dataset,
                                          batch_size=args.batch_size,
                                          min_after_cache=args.min_after_cache,
                                          length_sort=not args.no_length_sort)

    # Set param init op
    init_op = tf.global_variables_initializer()

    # Set save op
    save_op = tf.train.Saver(max_to_keep=5)
    best_save_op = tf.train.Saver(max_to_keep=5)

    # Set per-step logging
    with tf.name_scope("per_step_eval"):
        # For ML cost (ce)
        tr_ce = tf.placeholder(tf.float32)
        tr_ce_summary = tf.summary.scalar("train_ml_cost", tr_ce)

        # For output visualization
        tr_image = tf.placeholder(tf.float32)
        tr_image_summary = tf.summary.image("train_image", tr_image)

        # For ML FER
        tr_fer = tf.placeholder(tf.float32)
        tr_fer_summary = tf.summary.scalar("train_fer", tr_fer)

        # For RL cost
        tr_rl = tf.placeholder(tf.float32)
        tr_rl_summary = tf.summary.scalar("train_rl", tr_rl)

        # For RL reward
        tr_reward = tf.placeholder(tf.float32)
        tr_reward_summary = tf.summary.scalar("train_reward", tr_reward)

        # For RL entropy
        tr_ent = tf.placeholder(tf.float32)
        tr_ent_summary = tf.summary.scalar("train_entropy", tr_ent)

        # For RL reward histogram
        tr_rw_hist = tf.placeholder(tf.float32)
        tr_rw_hist_summary = tf.summary.histogram("train_reward_hist",
                                                  tr_rw_hist)

        # For RL skip count
        tr_skip_cnt = tf.placeholder(tf.float32)
        tr_skip_cnt_summary = tf.summary.scalar("train_skip_cnt", tr_skip_cnt)

    # Set per-epoch logging
    with tf.name_scope("per_epoch_eval"):
        # For best valid ML cost (full)
        best_val_ce = tf.placeholder(tf.float32)
        best_val_ce_summary = tf.summary.scalar("best_valid_ce", best_val_ce)

        # For best valid FER
        best_val_fer = tf.placeholder(tf.float32)
        best_val_fer_summary = tf.summary.scalar("best_valid_fer",
                                                 best_val_fer)

        # For valid ML cost (full)
        val_ce = tf.placeholder(tf.float32)
        val_ce_summary = tf.summary.scalar("valid_ce", val_ce)

        # For valid FER
        val_fer = tf.placeholder(tf.float32)
        val_fer_summary = tf.summary.scalar("valid_fer", val_fer)

        # For output visualization
        val_image = tf.placeholder(tf.float32)
        val_image_summary = tf.summary.image("valid_image", val_image)

        # For RL skip count
        val_skip_cnt = tf.placeholder(tf.float32)
        val_skip_cnt_summary = tf.summary.scalar("valid_skip_cnt",
                                                 val_skip_cnt)

    # Set module
    gen_episodes = improve_skip_rnn_act_parallel

    # Init session
    with tf.Session() as sess:
        # Init model
        sess.run(init_op)

        # Load from checkpoint
        if args.start_from_ckpt:
            save_op = tf.train.import_meta_graph(
                os.path.join(args.log_dir, 'model.ckpt.meta'))
            save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt'))
            print(
                "Restore from the last checkpoint. Restarting from %d step." %
                global_step.eval())

        # Summary writer
        summary_writer = tf.summary.FileWriter(args.log_dir,
                                               sess.graph,
                                               flush_secs=5.0)

        # For train tracking
        tr_ce_sum = 0.
        tr_ce_count = 0
        tr_acc_sum = 0.
        tr_acc_count = 0
        tr_rl_sum = 0.
        tr_rl_count = 0
        tr_ent_sum = 0.
        tr_ent_count = 0
        tr_reward_sum = 0.
        tr_reward_count = 0
        tr_skip_sum = 0.
        tr_skip_count = 0

        _best_ce = np.iinfo(np.int32).max
        _best_fer = 1.00

        # For time measure
        epoch_sw = StopWatch()
        disp_sw = StopWatch()
        eval_sw = StopWatch()

        # For each epoch
        for _epoch in xrange(args.n_epoch):
            # Reset timer
            epoch_sw.reset()
            disp_sw.reset()
            print('Epoch {} training'.format(_epoch + 1))

            # Set rl skipping flag
            use_rl_skipping = True

            # For each batch (update)
            for batch_data in train_set.get_epoch_iterator():
                ##################
                # Sampling Phase #
                ##################
                # Get batch data
                seq_x_data, seq_x_mask, _, _, seq_y_data, _ = batch_data

                # Use skipping
                if use_rl_skipping:
                    # Transpose axis
                    seq_x_data = np.transpose(seq_x_data, (1, 0, 2))
                    seq_x_mask = np.transpose(seq_x_mask, (1, 0))
                    seq_y_data = np.transpose(seq_y_data, (1, 0))

                    # Number of samples
                    batch_size = seq_x_data.shape[1]
                    # Sample actions (episode generation)
                    [
                        skip_x_data, skip_h_data, skip_x_mask, skip_y_data,
                        skip_a_data, skip_a_mask, skip_rewards, result_image
                    ] = gen_episodes(seq_x_data=seq_x_data,
                                     seq_x_mask=seq_x_mask,
                                     seq_y_data=seq_y_data,
                                     sess=sess,
                                     sample_graph=sg,
                                     args=args,
                                     use_sampling=True)

                    # Compute skip ratio
                    tr_skip_sum += skip_x_mask.sum() / seq_x_mask.sum()
                    tr_skip_count += 1.0

                    # Compute baseline and refine reward
                    skip_advantage, skip_disc_rewards = compute_advantage(
                        seq_h_data=skip_h_data,
                        seq_r_data=skip_rewards,
                        seq_r_mask=skip_a_mask,
                        vf=vf,
                        args=args,
                        final_cost=args.use_final_reward)

                    if args.use_baseline is False:
                        skip_advantage = skip_disc_rewards

                    ##################
                    # Training Phase #
                    ##################
                    # Update model
                    [
                        _tr_ml_cost, _tr_rl_cost, _, _, _tr_act_ent,
                        _tr_pred_logit
                    ] = sess.run(
                        [
                            ml_cost, real_rl_cost, ml_op, rl_op, tg.seq_a_ent,
                            tg.seq_label_logits
                        ],
                        feed_dict={
                            tg.seq_x_data: skip_x_data,
                            tg.seq_x_mask: skip_x_mask,
                            tg.seq_y_data: skip_y_data,
                            tg.seq_a_data: skip_a_data,
                            tg.seq_a_mask: skip_a_mask,
                            tg.seq_advantage: skip_advantage,
                            tg.seq_reward: skip_disc_rewards
                        })

                    seq_x_mask = np.transpose(seq_x_mask, (1, 0))
                    seq_y_data = np.transpose(seq_y_data, (1, 0))

                    # Get full sequence prediction
                    _tr_pred_full = expand_pred_idx(
                        seq_skip_1hot=skip_a_data,
                        seq_skip_mask=skip_a_mask,
                        seq_prd_idx=_tr_pred_logit.argmax(axis=-1).reshape(
                            [batch_size, -1]),
                        seq_x_mask=seq_y_data)

                    # Update history
                    tr_ce_sum += _tr_ml_cost.sum() * batch_size
                    tr_ce_count += skip_x_mask.sum()

                    tr_acc_sum += ((_tr_pred_full == seq_y_data) *
                                   seq_x_mask).sum()
                    tr_acc_count += seq_x_mask.sum()

                    tr_rl_sum += _tr_rl_cost.sum()
                    tr_rl_count += 1.0

                    tr_ent_sum += _tr_act_ent.sum()
                    tr_ent_count += skip_a_mask.sum()

                    tr_reward_sum += (skip_rewards * skip_a_mask).sum()
                    tr_reward_count += skip_a_mask.sum()

                    ################
                    # Write result #
                    ################
                    [
                        _tr_rl_summary, _tr_image_summary, _tr_ent_summary,
                        _tr_reward_summary, _tr_rw_hist_summary,
                        _tr_skip_cnt_summary
                    ] = sess.run(
                        [
                            tr_rl_summary, tr_image_summary, tr_ent_summary,
                            tr_reward_summary, tr_rw_hist_summary,
                            tr_skip_cnt_summary
                        ],
                        feed_dict={
                            tr_rl:
                            _tr_rl_cost.sum(),
                            tr_image:
                            result_image,
                            tr_ent: (_tr_act_ent.sum() / skip_a_mask.sum()),
                            tr_reward: ((skip_rewards * skip_a_mask).sum() /
                                        skip_a_mask.sum()),
                            tr_rw_hist:
                            skip_rewards,
                            tr_skip_cnt:
                            skip_x_mask.sum() / seq_x_mask.sum()
                        })

                    summary_writer.add_summary(_tr_rl_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_image_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_ent_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_reward_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_rw_hist_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_skip_cnt_summary,
                                               global_step.eval())
                else:
                    # Number of samples
                    batch_size = seq_x_data.shape[0]

                    ##################
                    # Training Phase #
                    ##################
                    [_tr_ml_cost, _, _tr_pred_full] = sess.run(
                        [ml_cost, ml_op, tg.seq_label_logits],
                        feed_dict={
                            tg.seq_x_data: seq_x_data,
                            tg.seq_x_mask: seq_x_mask,
                            tg.seq_y_data: seq_y_data
                        })
                    _tr_pred_full = np.reshape(_tr_pred_full.argmax(axis=1),
                                               seq_y_data.shape)

                    # Update history
                    tr_ce_sum += _tr_ml_cost.sum() * batch_size
                    tr_ce_count += seq_x_mask.sum()

                    tr_acc_sum += ((_tr_pred_full == seq_y_data) *
                                   seq_x_mask).sum()
                    tr_acc_count += seq_x_mask.sum()

                    skip_x_mask = seq_x_mask

                ################
                # Write result #
                ################
                [_tr_ce_summary, _tr_fer_summary] = sess.run(
                    [tr_ce_summary, tr_fer_summary],
                    feed_dict={
                        tr_ce:
                        (_tr_ml_cost.sum() * batch_size) / skip_x_mask.sum(),
                        tr_fer:
                        ((_tr_pred_full == seq_y_data) * seq_x_mask).sum() /
                        seq_x_mask.sum()
                    })
                summary_writer.add_summary(_tr_ce_summary, global_step.eval())
                summary_writer.add_summary(_tr_fer_summary, global_step.eval())

                # Display results
                if global_step.eval() % args.display_freq == 0:
                    # Get average results
                    avg_tr_ce = tr_ce_sum / tr_ce_count
                    avg_tr_fer = 1. - tr_acc_sum / tr_acc_count
                    if use_rl_skipping:
                        avg_tr_rl = tr_rl_sum / tr_rl_count
                        avg_tr_ent = tr_ent_sum / tr_ent_count
                        avg_tr_reward = tr_reward_sum / tr_reward_count
                        avg_tr_skip = tr_skip_sum / tr_skip_count
                        print(
                            "TRAIN: epoch={} iter={} "
                            "ml_cost(ce/frame)={:.2f} fer={:.2f} "
                            "rl_cost={:.4f} reward={:.4f} action_entropy={:.2f} "
                            "skip_ratio={:.2f} "
                            "time_taken={:.2f}".format(_epoch,
                                                       global_step.eval(),
                                                       avg_tr_ce, avg_tr_fer,
                                                       avg_tr_rl,
                                                       avg_tr_reward,
                                                       avg_tr_ent, avg_tr_skip,
                                                       disp_sw.elapsed()))
                    else:
                        print("TRAIN: epoch={} iter={} "
                              "ml_cost(ce/frame)={:.2f} fer={:.2f} "
                              "time_taken={:.2f}".format(
                                  _epoch, global_step.eval(), avg_tr_ce,
                                  avg_tr_fer, disp_sw.elapsed()))

                    # Reset average results
                    tr_ce_sum = 0.
                    tr_ce_count = 0
                    tr_acc_sum = 0.
                    tr_acc_count = 0
                    tr_rl_sum = 0.
                    tr_rl_count = 0
                    tr_ent_sum = 0.
                    tr_ent_count = 0
                    tr_reward_sum = 0.
                    tr_reward_count = 0
                    tr_skip_sum = 0.
                    tr_skip_count = 0

                    disp_sw.reset()

            # End of epoch
            print('--')
            print('End of epoch {}'.format(_epoch + 1))
            epoch_sw.print_elapsed()

            # Evaluation
            print('Testing')

            # Evaluate the model on the validation set
            val_ce_sum = 0.
            val_ce_count = 0
            val_acc_sum = 0.
            val_acc_count = 0
            val_rl_sum = 0.
            val_rl_count = 0
            val_ent_sum = 0.
            val_ent_count = 0
            val_reward_sum = 0.
            val_reward_count = 0
            val_skip_sum = 0.
            val_skip_count = 0
            eval_sw.reset()

            # For each batch in Valid
            for batch_data in valid_set.get_epoch_iterator():
                ##################
                # Sampling Phase #
                ##################
                # Get batch data
                seq_x_data, seq_x_mask, _, _, seq_y_data, _ = batch_data

                if use_rl_skipping:
                    # Transpose axis
                    seq_x_data = np.transpose(seq_x_data, (1, 0, 2))
                    seq_x_mask = np.transpose(seq_x_mask, (1, 0))
                    seq_y_data = np.transpose(seq_y_data, (1, 0))

                    # Number of samples
                    batch_size = seq_x_data.shape[1]

                    # Sample actions (episode generation)
                    [
                        skip_x_data, skip_h_data, skip_x_mask, skip_y_data,
                        skip_a_data, skip_a_mask, skip_rewards, result_image
                    ] = gen_episodes(seq_x_data=seq_x_data,
                                     seq_x_mask=seq_x_mask,
                                     seq_y_data=seq_y_data,
                                     sess=sess,
                                     sample_graph=sg,
                                     args=args,
                                     use_sampling=False)

                    # Compute skip ratio
                    val_skip_sum += skip_x_mask.sum() / seq_x_mask.sum()
                    val_skip_count += 1.0

                    # Compute baseline and refine reward
                    skip_advantage, skip_disc_rewards = compute_advantage(
                        seq_h_data=skip_h_data,
                        seq_r_data=skip_rewards,
                        seq_r_mask=skip_a_mask,
                        vf=vf,
                        args=args,
                        final_cost=args.use_final_reward)

                    if args.use_baseline is False:
                        skip_advantage = skip_disc_rewards

                    #################
                    # Forward Phase #
                    #################
                    [
                        _val_ml_cost, _val_rl_cost, _val_pred_logit,
                        _val_action_ent
                    ] = sess.run(
                        [
                            ml_cost, real_rl_cost, tg.seq_label_logits,
                            tg.seq_a_ent
                        ],
                        feed_dict={
                            tg.seq_x_data: skip_x_data,
                            tg.seq_x_mask: skip_x_mask,
                            tg.seq_y_data: skip_y_data,
                            tg.seq_a_data: skip_a_data,
                            tg.seq_a_mask: skip_a_mask,
                            tg.seq_advantage: skip_advantage,
                            tg.seq_reward: skip_disc_rewards
                        })

                    seq_x_mask = np.transpose(seq_x_mask, (1, 0))
                    seq_y_data = np.transpose(seq_y_data, (1, 0))

                    # Get full sequence prediction
                    _val_pred_full = expand_pred_idx(
                        seq_skip_1hot=skip_a_data,
                        seq_skip_mask=skip_a_mask,
                        seq_prd_idx=_val_pred_logit.argmax(axis=-1).reshape(
                            [batch_size, -1]),
                        seq_x_mask=seq_y_data)

                    # Update history
                    val_ce_sum += _val_ml_cost.sum() * batch_size
                    val_ce_count += skip_x_mask.sum()

                    val_acc_sum += ((_val_pred_full == seq_y_data) *
                                    seq_x_mask).sum()
                    val_acc_count += seq_x_mask.sum()

                    val_rl_sum += _val_rl_cost.sum()
                    val_rl_count += 1.0

                    val_ent_sum += _val_action_ent.sum()
                    val_ent_count += skip_a_mask.sum()

                    val_reward_sum += (skip_rewards * skip_a_mask).sum()
                    val_reward_count += skip_a_mask.sum()
                else:
                    # Number of samples
                    batch_size = seq_x_data.shape[0]

                    #################
                    # Forward Phase #
                    #################
                    # Update model
                    [_val_ml_cost, _val_pred_full] = sess.run(
                        [ml_cost, tg.seq_label_logits],
                        feed_dict={
                            tg.seq_x_data: seq_x_data,
                            tg.seq_x_mask: seq_x_mask,
                            tg.seq_y_data: seq_y_data
                        })
                    _val_pred_full = np.reshape(_val_pred_full.argmax(axis=1),
                                                seq_y_data.shape)

                    # Update history
                    val_ce_sum += _val_ml_cost.sum() * batch_size
                    val_ce_count += seq_x_mask.sum()

                    val_acc_sum += ((_val_pred_full == seq_y_data) *
                                    seq_x_mask).sum()
                    val_acc_count += seq_x_mask.sum()

            # Aggregate over all valid data
            avg_val_ce = val_ce_sum / val_ce_count
            avg_val_fer = 1. - val_acc_sum / val_acc_count

            if use_rl_skipping:
                avg_val_rl = val_rl_sum / val_rl_count
                avg_val_ent = val_ent_sum / val_ent_count
                avg_val_reward = val_reward_sum / val_reward_count
                avg_val_skip = val_skip_sum / val_skip_count

                print("VALID: epoch={} "
                      "ml_cost(ce/frame)={:.2f} fer={:.2f} "
                      "rl_cost={:.4f} reward={:.4f} action_entropy={:.2f} "
                      "skip_ratio={:.2f} "
                      "time_taken={:.2f}".format(_epoch, avg_val_ce,
                                                 avg_val_fer, avg_val_rl,
                                                 avg_val_reward,
                                                 avg_val_ent, avg_val_skip,
                                                 eval_sw.elapsed()))
            else:
                print("VALID: epoch={} "
                      "ml_cost(ce/frame)={:.2f} fer={:.2f} "
                      "time_taken={:.2f}".format(_epoch,
                                                 avg_val_ce, avg_val_fer,
                                                 eval_sw.elapsed()))

            ################
            # Write result #
            ################
            [
                _val_ce_summary, _val_fer_summary, _val_skip_cnt_summary,
                _val_img_summary
            ] = sess.run(
                [
                    val_ce_summary, val_fer_summary, val_skip_cnt_summary,
                    val_image_summary
                ],
                feed_dict={
                    val_ce: avg_val_ce,
                    val_fer: avg_val_fer,
                    val_skip_cnt: avg_val_skip,
                    val_image: result_image
                })

            summary_writer.add_summary(_val_skip_cnt_summary,
                                       global_step.eval())
            summary_writer.add_summary(_val_img_summary, global_step.eval())
            summary_writer.add_summary(_val_ce_summary, global_step.eval())
            summary_writer.add_summary(_val_fer_summary, global_step.eval())

            insert_item2dict(eval_summary, 'val_ce', avg_val_ce)
            insert_item2dict(eval_summary, 'val_fer', avg_val_fer)
            # insert_item2dict(eval_summary, 'val_rl', avg_val_rl)
            # insert_item2dict(eval_summary, 'val_reward', avg_val_reward)
            # insert_item2dict(eval_summary, 'val_ent', avg_val_ent)
            insert_item2dict(eval_summary, 'time', eval_sw.elapsed())
            save_npz2(file_name, eval_summary)

            # Save best model
            if avg_val_ce < _best_ce:
                _best_ce = avg_val_ce
                best_ckpt = best_save_op.save(sess=sess,
                                              save_path=os.path.join(
                                                  args.log_dir,
                                                  "best_model(ce).ckpt"),
                                              global_step=global_step)
                print("Best checkpoint based on CE stored in: %s" % best_ckpt)

            if avg_val_fer < _best_fer:
                _best_fer = avg_val_fer
                best_ckpt = best_save_op.save(sess=sess,
                                              save_path=os.path.join(
                                                  args.log_dir,
                                                  "best_model(fer).ckpt"),
                                              global_step=global_step)
                print("Best checkpoint based on FER stored in: %s" % best_ckpt)

            # Save model
            ckpt = save_op.save(sess=sess,
                                save_path=os.path.join(args.log_dir,
                                                       "model.ckpt"),
                                global_step=global_step)
            print("Checkpoint stored in: %s" % ckpt)

            # Write result
            [_best_val_ce_summary, _best_val_fer_summary
             ] = sess.run([best_val_ce_summary, best_val_fer_summary],
                          feed_dict={
                              best_val_ce: _best_ce,
                              best_val_fer: _best_fer
                          })
            summary_writer.add_summary(_best_val_ce_summary,
                                       global_step.eval())
            summary_writer.add_summary(_best_val_fer_summary,
                                       global_step.eval())

        # Done of training
        summary_writer.close()
        print("Optimization Finished.")
Ejemplo n.º 6
0
def main(_):
    print(' '.join(sys.argv))
    args = FLAGS
    print(args.__flags)
    print('Hostname: {}'.format(socket.gethostname()))
    print('GPU: {}'.format(get_gpuname()))

    if not args.start_from_ckpt:
        if tf.gfile.Exists(args.log_dir):
            tf.gfile.DeleteRecursively(args.log_dir)
        tf.gfile.MakeDirs(args.log_dir)

    tf.get_variable_scope()._reuse = None

    _seed = args.base_seed + args.add_seed
    tf.set_random_seed(_seed)
    np.random.seed(_seed)

    prefix_name = os.path.join(args.log_dir, 'model')
    file_name = '%s.npz' % prefix_name

    eval_summary = OrderedDict()

    tg, test_graph = build_graph(args)
    tg_ml_cost = tf.reduce_mean(tg.ml_cost)

    global_step = tf.Variable(0, trainable=False, name="global_step")

    tvars = tf.trainable_variables()
    print([tvar.name for tvar in tvars])
    ml_tvars = [tvar for tvar in tvars if "action_logit" not in tvar.name]
    rl_tvars = [tvar for tvar in tvars if "action_logit" in tvar.name]

    ml_opt_func = tf.train.AdamOptimizer(learning_rate=args.learning_rate,
                                         beta1=0.9,
                                         beta2=0.99)
    rl_opt_func = tf.train.AdamOptimizer(learning_rate=args.rl_learning_rate,
                                         beta1=0.9,
                                         beta2=0.99)

    if args.grad_clip:
        ml_grads, _ = tf.clip_by_global_norm(tf.gradients(
            tg_ml_cost, ml_tvars),
                                             clip_norm=1.0)
    else:
        ml_grads = tf.gradients(tg_ml_cost, ml_tvars)
    ml_op = ml_opt_func.apply_gradients(zip(ml_grads, ml_tvars),
                                        global_step=global_step)

    tg_rl_cost = tf.reduce_mean(tg.rl_cost)
    rl_grads = tf.gradients(tg_rl_cost, rl_tvars)
    rl_op = rl_opt_func.apply_gradients(zip(rl_grads, rl_tvars))

    tf.add_to_collection('fast_action', args.fast_action)
    tf.add_to_collection('fast_action', args.n_fast_action)

    sync_data(args)
    datasets = [args.train_dataset, args.valid_dataset, args.test_dataset]
    train_set, valid_set, test_set = [
        create_ivector_datastream(path=args.data_path,
                                  which_set=dataset,
                                  batch_size=args.n_batch,
                                  min_after_cache=args.min_after_cache,
                                  length_sort=not args.no_length_sort)
        for dataset in datasets
    ]

    init_op = tf.global_variables_initializer()
    save_op = tf.train.Saver(max_to_keep=5)
    best_save_op = tf.train.Saver(max_to_keep=5)

    with tf.name_scope("per_step_eval"):
        tr_ce = tf.placeholder(tf.float32)
        tr_ce_summary = tf.summary.scalar("tr_ce", tr_ce)
        tr_fer = tf.placeholder(tf.float32)
        tr_fer_summary = tf.summary.scalar("tr_fer", tr_fer)
        tr_ce2 = tf.placeholder(tf.float32)
        tr_ce2_summary = tf.summary.scalar("tr_rl", tr_ce2)

        tr_image = tf.placeholder(tf.float32)
        tr_image_summary = tf.summary.image("tr_image", tr_image)

    with tf.name_scope("per_epoch_eval"):
        val_fer = tf.placeholder(tf.float32)
        val_fer_summary = tf.summary.scalar("val_fer", val_fer)
        best_val_fer = tf.placeholder(tf.float32)
        best_val_fer_summary = tf.summary.scalar("best_valid_fer",
                                                 best_val_fer)
        val_image = tf.placeholder(tf.float32)
        val_image_summary = tf.summary.image("val_image", val_image)

    vf = LinearVF()

    with tf.Session() as sess:
        sess.run(init_op)

        if args.start_from_ckpt:
            save_op = tf.train.import_meta_graph(
                os.path.join(args.log_dir, 'model.ckpt.meta'))
            save_op.restore(sess, os.path.join(args.log_dir, 'model.ckpt'))
            print(
                "Restore from the last checkpoint. Restarting from %d step." %
                global_step.eval())

        summary_writer = tf.summary.FileWriter(args.log_dir,
                                               sess.graph,
                                               flush_secs=5.0)

        tr_ce_sum = 0.
        tr_ce_count = 0
        tr_acc_sum = 0
        tr_acc_count = 0
        tr_ce2_sum = 0.
        tr_ce2_count = 0

        _best_score = np.iinfo(np.int32).max

        epoch_sw = StopWatch()
        disp_sw = StopWatch()
        eval_sw = StopWatch()
        per_sw = StopWatch()

        # For each epoch
        for _epoch in xrange(args.n_epoch):
            _n_exp = 0

            epoch_sw.reset()
            disp_sw.reset()

            print('Epoch {} training'.format(_epoch + 1))

            # For each batch
            for batch in train_set.get_epoch_iterator():
                x, x_mask, _, _, y, _ = batch
                n_batch = x.shape[0]
                _n_exp += n_batch

                if args.no_sampling:
                    new_x, new_y, actions, actions_1hot, new_x_mask = gen_supervision(
                        x, x_mask, y, args)

                    zero_state = gen_zero_state(n_batch, args.n_hidden)

                    feed_dict = {
                        tg.seq_x_data: new_x,
                        tg.seq_x_mask: new_x_mask,
                        tg.seq_y_data: new_y,
                        tg.seq_jump_data: actions
                    }
                    feed_init_state(feed_dict, tg.init_state, zero_state)

                    _tr_ml_cost, _tr_rl_cost, _, _ = \
                        sess.run([tg.ml_cost, tg.rl_cost, ml_op, rl_op], feed_dict=feed_dict)

                    tr_ce_sum += _tr_ml_cost.sum()
                    tr_ce_count += new_x_mask.sum()
                    tr_ce2_sum += _tr_rl_cost.sum()
                    tr_ce2_count += new_x_mask[:, :-1].sum()

                    actions_1hot, label_probs, new_mask, output_image = \
                        skip_rnn_forward_supervised(x, x_mask, sess, test_graph,
                        args.fast_action, args.n_fast_action, y)

                    pred_idx = expand_output(actions_1hot, x_mask, new_mask,
                                             label_probs.argmax(axis=-1))
                    tr_acc_sum += ((pred_idx == y) * x_mask).sum()
                    tr_acc_count += x_mask.sum()

                    _tr_ce_summary, _tr_fer_summary, _tr_ce2_summary, _tr_image_summary = \
                        sess.run([tr_ce_summary, tr_fer_summary, tr_ce2_summary, tr_image_summary],
                            feed_dict={tr_ce: _tr_ml_cost.sum() / new_x_mask.sum(),
                                tr_fer: 1 - ((pred_idx == y) * x_mask).sum() / x_mask.sum(),
                                tr_ce2: _tr_rl_cost.sum() / new_x_mask[:,:-1].sum(),
                                tr_image: output_image})
                    summary_writer.add_summary(_tr_ce_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_fer_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_ce2_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_image_summary,
                                               global_step.eval())

                else:

                    # train jump prediction part
                    new_x, _, actions, _, new_x_mask = gen_supervision(
                        x, x_mask, y, args)
                    zero_state = gen_zero_state(n_batch, args.n_hidden)
                    feed_dict = {
                        tg.seq_x_data: new_x,
                        tg.seq_x_mask: new_x_mask,
                        tg.seq_jump_data: actions
                    }
                    feed_init_state(feed_dict, tg.init_state, zero_state)
                    _tr_rl_cost, _ = sess.run([tg.rl_cost, rl_op],
                                              feed_dict=feed_dict)

                    tr_ce2_sum += _tr_rl_cost.sum()
                    tr_ce2_count += new_x_mask[:, :-1].sum()

                    _tr_ce2_summary, = sess.run([tr_ce2_summary],
                                                feed_dict={
                                                    tr_ce2:
                                                    _tr_rl_cost.sum() /
                                                    new_x_mask[:, :-1].sum()
                                                })

                    # generate jumps from the model
                    new_x, new_y, actions_1hot, label_probs, new_x_mask, output_image = gen_episode_supervised(
                        x, y, x_mask, sess, test_graph, args.fast_action,
                        args.n_fast_action)

                    feed_dict = {
                        tg.seq_x_data: new_x,
                        tg.seq_x_mask: new_x_mask,
                        tg.seq_y_data: new_y
                    }
                    feed_init_state(feed_dict, tg.init_state, zero_state)

                    # train label prediction part
                    _tr_ml_cost, _ = sess.run([tg.ml_cost, ml_op],
                                              feed_dict=feed_dict)

                    tr_ce_sum += _tr_ml_cost.sum()
                    tr_ce_count += new_x_mask.sum()

                    actions_1hot, label_probs, new_mask, output_image = \
                        skip_rnn_forward_supervised(x, x_mask, sess, test_graph,
                        args.fast_action, args.n_fast_action, y)

                    pred_idx = expand_output(actions_1hot, x_mask, new_mask,
                                             label_probs.argmax(axis=-1))
                    tr_acc_sum += ((pred_idx == y) * x_mask).sum()
                    tr_acc_count += x_mask.sum()

                    _tr_ce_summary, _tr_fer_summary, _tr_image_summary = \
                        sess.run([tr_ce_summary, tr_fer_summary, tr_image_summary],
                            feed_dict={tr_ce: _tr_ml_cost.sum() / new_x_mask.sum(),
                                tr_fer: 1 - ((pred_idx == y) * x_mask).sum() / x_mask.sum(),
                                tr_image: output_image})
                    summary_writer.add_summary(_tr_ce_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_fer_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_ce2_summary,
                                               global_step.eval())
                    summary_writer.add_summary(_tr_image_summary,
                                               global_step.eval())

                if global_step.eval() % args.display_freq == 0:
                    avg_tr_ce = tr_ce_sum / tr_ce_count
                    avg_tr_fer = 1. - tr_acc_sum / tr_acc_count
                    avg_tr_ce2 = tr_ce2_sum / tr_ce2_count

                    print(
                        "TRAIN: epoch={} iter={} ml_cost(ce/frame)={:.2f} fer={:.2f} rl_cost={:.4f} time_taken={:.2f}"
                        .format(_epoch, global_step.eval(), avg_tr_ce,
                                avg_tr_fer, avg_tr_ce2, disp_sw.elapsed()))

                    tr_ce_sum = 0.
                    tr_ce_count = 0
                    tr_acc_sum = 0.
                    tr_acc_count = 0
                    tr_ce2_sum = 0.
                    tr_ce2_count = 0

                    disp_sw.reset()

            print('--')
            print('End of epoch {}'.format(_epoch + 1))
            epoch_sw.print_elapsed()

            print('Testing')

            # Evaluate the model on the validation set
            val_acc_sum = 0
            val_acc_count = 0

            eval_sw.reset()
            for batch in valid_set.get_epoch_iterator():
                x, x_mask, _, _, y, _ = batch
                n_batch = x.shape[0]

                actions_1hot, label_probs, new_mask, output_image = \
                    skip_rnn_forward_supervised(x, x_mask, sess, test_graph,
                    args.fast_action, args.n_fast_action, y)

                pred_idx = expand_output(actions_1hot, x_mask, new_mask,
                                         label_probs.argmax(axis=-1))
                val_acc_sum += ((pred_idx == y) * x_mask).sum()
                val_acc_count += x_mask.sum()

            avg_val_fer = 1. - val_acc_sum / val_acc_count

            print("VALID: epoch={} fer={:.2f} time_taken={:.2f}".format(
                _epoch, avg_val_fer, eval_sw.elapsed()))

            _val_fer_summary, _val_image_summary = sess.run(
                [val_fer_summary, val_image_summary],
                feed_dict={
                    val_fer: avg_val_fer,
                    val_image: output_image
                })
            summary_writer.add_summary(_val_fer_summary, global_step.eval())
            summary_writer.add_summary(_val_image_summary, global_step.eval())

            insert_item2dict(eval_summary, 'val_fer', avg_val_fer)
            insert_item2dict(eval_summary, 'time', eval_sw.elapsed())
            save_npz2(file_name, eval_summary)

            # Save model
            if avg_val_fer < _best_score:
                _best_score = avg_val_fer
                best_ckpt = best_save_op.save(sess,
                                              os.path.join(
                                                  args.log_dir,
                                                  "best_model.ckpt"),
                                              global_step=global_step)
                print("Best checkpoint stored in: %s" % best_ckpt)
            ckpt = save_op.save(sess,
                                os.path.join(args.log_dir, "model.ckpt"),
                                global_step=global_step)
            print("Checkpoint stored in: %s" % ckpt)

            _best_val_fer_summary, = sess.run(
                [best_val_fer_summary], feed_dict={best_val_fer: _best_score})
            summary_writer.add_summary(_best_val_fer_summary,
                                       global_step.eval())
        summary_writer.close()

        print("Optimization Finished.")