Beispiel #1
0
def main(unused_argv):
    del unused_argv  # Unused

    tf.logging.set_verbosity(tf.logging.INFO)

    #### Tokenizer
    tokenizer = get_tokenizer()

    #### Get corpus info
    n_token = tokenizer.get_vocab_size()
    tf.logging.info("n_token %d", n_token)

    # test data
    inputs_np = [
        3933, 7752, 15179, 893, 24249, 703, 19119, 4, 2919, 335, 8511, 1094,
        43, 1661, 669, 5481, 1106, 7029, 891, 891
    ]
    type_id_np = [0] * len(inputs_np)
    inputs_np = np.array(inputs_np)[None]
    type_id_np = np.array(type_id_np)[None]

    # tensorflow graph
    inputs = tf.placeholder(tf.int64, [1, None])
    type_id = tf.placeholder(tf.int64, [1, None])
    hiddens = model_func_builder.extract_hiddens(inputs,
                                                 type_id,
                                                 n_token,
                                                 is_training=False)

    # run session
    saver = tf.train.Saver()
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=False)) as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, FLAGS.init_checkpoint)

        feed_dict = {
            inputs: inputs_np,
            type_id: type_id_np,
        }

        hiddens_np = sess.run(hiddens, feed_dict=feed_dict)
        tf.logging.info(len(hiddens_np))
def train_ddpg(dataset,
               policy,
               actor_optimizer=None,
               critic_optimizer=None,
               pack_transition_fn=None,
               ddpg_graph_fn=None,
               log_dir=None,
               master='local',
               task=0,
               training_steps=None,
               max_training_steps=100000,
               reuse=False,
               init_checkpoint=None,
               update_target_every_n_steps=50,
               log_every_n_steps=None,
               save_checkpoint_steps=500,
               save_summaries_steps=500):
  """Self-contained learning loop for offline Q-learning.

  Code inspired by OpenAI Baselines' deepq.build_train. This function is
  compatible with discrete Q-learning graphs, continuous Q learning graphs, and
  SARSA.

  Args:
    dataset: tf.data.Dataset providing transitions.
    policy: Instance of TFDQNPolicy class that provides functor for building the
      critic function.
    actor_optimizer: Optional instance of an optimizer for the actor network.
      If not specified, creates an AdamOptimizer using the default constructor.
    critic_optimizer: Optional instance of an optimizer for the critic network.
      If not specified, creates an AdamOptimizer using the default constructor.
    pack_transition_fn: Optional function that performs additional processing
      of the transition. This is a convenience method for ad-hoc manipulation of
      transition data passed to the learning function after parsing.
    ddpg_graph_fn: Function used to construct training objectives w.r.t. critic
      outputs.
    log_dir: Where to save model checkpoints and tensorboard summaries.
    master: Optional address of master worker. Specify this when doing
      distributed training.
    task: Optional worker task for distributed training. Defaults to solo master
      task on a single machine.
    training_steps: Optional number of steps to run training before terminating
      early. Max_training_steps remains unchanged - training will terminate
      after max_training_steps whether or not training_steps is specified.
    max_training_steps: maximum number of training iters.
    reuse: If True, reuse existing variables for all declared variables by this
      function.
    init_checkpoint: Optional checkpoint to restore prior to training. If not
      provided, variables are initialized using global_variables_initializer().
    update_target_every_n_steps: How many global steps (training) between
      copying the Q network weights (scope='q_func') to target network
      (scope='target_q_func').
    log_every_n_steps: How many global steps between logging loss tensors.
    save_checkpoint_steps: How many global steps between saving TF variables
      to a checkpoint file.
    save_summaries_steps: How many global steps between saving TF summaries.

  Returns:
    (int) Current `global_step` reached after training for training_steps, or
    `max_training_steps` if `global_step` has reached `max_training_steps`.

  """
  data_iterator = dataset.make_one_shot_iterator()

  transition = data_iterator.get_next()
  if pack_transition_fn:
    transition = pack_transition_fn(transition)

  if actor_optimizer is None:
    actor_optimizer = tf.train.AdamOptimizer()
  if critic_optimizer is None:
    critic_optimizer = tf.train.AdamOptimizer()

  a_func = policy.get_a_func(is_training=True, reuse=reuse)
  q_func = policy.get_q_func(is_training=True, reuse=reuse)
  actor_loss, critic_loss, all_summaries = ddpg_graph_fn(
      a_func, q_func, transition)

  a_func_vars = tf.contrib.framework.get_trainable_variables(scope='a_func')
  q_func_vars = framework.get_trainable_variables(scope='q_func')
  target_q_func_vars = framework.get_trainable_variables(scope='target_q_func')

  # with tf.variable_scope('ddpg', use_resource=True):
  global_step = tf.train.get_or_create_global_step()

  # CRITIC OPTIMIZATION
  # Only optimize q_func and update its batchnorm params.
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='q_func')
  critic_train_op = tf.contrib.training.create_train_op(
      critic_loss,
      critic_optimizer,
      global_step=global_step,
      update_ops=update_ops,
      summarize_gradients=True,
      variables_to_train=q_func_vars,
  )

  # ACTOR OPTIMIZATION
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='a_func')
  actor_train_op = tf.contrib.training.create_train_op(
      actor_loss,
      actor_optimizer,
      global_step=None,
      summarize_gradients=True,
      variables_to_train=a_func_vars,
  )
  # Combine losses to train both actor and critic simultaneously.
  train_op = critic_train_op + actor_train_op

  chief_hooks = []
  hooks = []
  # Save summaries periodically.
  if save_summaries_steps is not None:
    chief_hooks.append(tf.train.SummarySaverHook(
        save_steps=save_summaries_steps,
        output_dir=log_dir, summary_op=all_summaries))

  # Stop after training_steps
  if max_training_steps:
    hooks.append(tf.train.StopAtStepHook(last_step=max_training_steps))

  # Report if loss tensor is NaN.
  hooks.append(tf.train.NanTensorHook(actor_loss))
  hooks.append(tf.train.NanTensorHook(critic_loss))

  if log_every_n_steps is not None:
    tensor_dict = {
        'global_step': global_step,
        'actor loss': actor_loss,
        'critic_loss': critic_loss
    }
    chief_hooks.append(
        tf.train.LoggingTensorHook(tensor_dict, every_n_iter=log_every_n_steps))

    # Measure how fast we are training per sec and save to summary.
    chief_hooks.append(tf.train.StepCounterHook(
        every_n_steps=log_every_n_steps, output_dir=log_dir))

  # If target network exists, periodically update target Q network with new
  # weights (frozen target network). We hack this by
  # abusing a LoggingTensorHook for this.
  if target_q_func_vars and update_target_every_n_steps is not None:
    update_target_expr = []
    for var, var_t in zip(sorted(q_func_vars, key=lambda v: v.name),
                          sorted(target_q_func_vars, key=lambda v: v.name)):
      update_target_expr.append(var_t.assign(var))
    update_target_expr = tf.group(*update_target_expr)

    with tf.control_dependencies([update_target_expr]):
      update_target = tf.constant(0)
    chief_hooks.append(
        tf.train.LoggingTensorHook({'update_target': update_target},
                                   every_n_iter=update_target_every_n_steps))

  # Save checkpoints periodically, save all of them.
  saver = tf.train.Saver(max_to_keep=None)
  chief_hooks.append(tf.train.CheckpointSaverHook(
      log_dir, save_steps=save_checkpoint_steps, saver=saver,
      checkpoint_basename='model.ckpt'))

  # Save our experiment params to checkpoint dir.
  chief_hooks.append(gin.tf.GinConfigSaverHook(log_dir, summarize_config=True))

  session_config = tf.ConfigProto(log_device_placement=True)

  init_fn = None
  if init_checkpoint:
    assign_fn = tf.contrib.framework.assign_from_checkpoint_fn(
        init_checkpoint, framework.get_model_variables())
    init_fn = lambda _, sess: assign_fn(sess)
  scaffold = tf.train.Scaffold(saver=saver, init_fn=init_fn)
  with tf.train.MonitoredTrainingSession(
      master=master,
      is_chief=(task == 0),
      config=session_config,
      checkpoint_dir=log_dir,
      scaffold=scaffold,
      hooks=hooks,
      chief_only_hooks=chief_hooks) as sess:
    np_step = 0
    while not sess.should_stop():
      np_step, _ = sess.run([global_step, train_op])
      if training_steps and np_step % training_steps == 0:
        break
    done = np_step >= max_training_steps
  return np_step, done
Beispiel #3
0
def main(unused_argv):
    del unused_argv  # Unused

    tf.logging.set_verbosity(tf.logging.INFO)

    #### Validate FLAGS
    if FLAGS.save_steps == 0:
        FLAGS.save_steps = None

    assert FLAGS.seq_len > 0

    #### Tokenizer
    tokenizer = get_tokenizer()

    #### Get corpus info
    n_token = tokenizer.get_vocab_size()
    tf.logging.info("n_token %d", n_token)

    if FLAGS.do_train:
        # Get train input function
        train_input_fn = get_input_fn("train")

        # Get train cache function
        train_cache_fn = get_cache_fn(FLAGS.mem_len)
    else:
        train_cache_fn = None

    if FLAGS.do_eval:
        assert FLAGS.num_hosts == 1
        # Get eval input function
        eval_input_fn = get_input_fn(FLAGS.eval_split)
        tf.logging.info("num of eval batches %d", FLAGS.eval_steps)

        # Get eval cache function
        eval_cache_fn = get_cache_fn(FLAGS.mem_len)
    else:
        eval_cache_fn = None

    ##### Get model function
    model_fn = get_model_fn(n_token)

    ##### Create TPUEstimator
    # TPU Configuration
    if not run_internal and FLAGS.use_tpu:
        tpu_cluster = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    else:
        tpu_cluster = None

    per_host_input = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster,
        master=FLAGS.master,
        model_dir=FLAGS.model_dir,
        session_config=tf.ConfigProto(allow_soft_placement=True),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations,
            per_host_input_for_training=per_host_input),
        keep_checkpoint_max=FLAGS.max_save,
        save_checkpoints_secs=None,
        save_checkpoints_steps=FLAGS.save_steps)

    # warm start
    warm_start_from = None
    if FLAGS.warm_start_path is not None:
        warm_start_from = tf.estimator.WarmStartSettings(
            ckpt_to_initialize_from=FLAGS.warm_start_path)

    # TPU Estimator
    estimator = tpu_estimator.TPUEstimator(
        model_fn=model_fn,
        train_cache_fn=train_cache_fn,
        eval_cache_fn=eval_cache_fn,
        use_tpu=FLAGS.use_tpu,
        config=run_config,
        params={},
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        eval_on_tpu=FLAGS.use_tpu,
        warm_start_from=warm_start_from)

    #### Training
    if FLAGS.do_train:
        estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps)

    #### Evaluation
    if FLAGS.do_eval:
        if FLAGS.eval_ckpt_path is not None:
            if FLAGS.eval_ckpt_path.endswith("latest"):
                ckpt_dir = os.path.dirname(FLAGS.eval_ckpt_path)
                FLAGS.eval_ckpt_path = tf.train.latest_checkpoint(ckpt_dir)

            ret = estimator.evaluate(input_fn=eval_input_fn,
                                     steps=FLAGS.eval_steps,
                                     checkpoint_path=FLAGS.eval_ckpt_path)
            tf.logging.info("=" * 200)
            log_str = "Eval results | "
            for key, val in ret.items():
                log_str += "{} {} | ".format(key, val)
            tf.logging.info(log_str)
            tf.logging.info("=" * 200)
        else:
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.model_dir)
            eval_results = []
            for eval_checkpoint in ckpt_state.all_model_checkpoint_paths:
                if not tf.gfile.Exists(eval_checkpoint + ".index"):
                    continue
                global_step = int(eval_checkpoint.split("-")[-1])
                if (global_step < FLAGS.start_eval_steps
                        or global_step > FLAGS.train_steps):
                    continue
                tf.logging.info("Evaluate ckpt %d", global_step)
                ret = estimator.evaluate(input_fn=eval_input_fn,
                                         steps=FLAGS.eval_steps,
                                         checkpoint_path=eval_checkpoint)
                eval_results.append(ret)

            eval_results.sort(key=lambda x: x["perplexity"])

            tf.logging.info("=" * 200)
            log_str = "Best results | "
            for key, val in eval_results[0].items():
                log_str += "{} {} | ".format(key, val)
            tf.logging.info(log_str)
            tf.logging.info("=" * 200)
Beispiel #4
0
def main(unused_argv):
    del unused_argv  # Unused

    tf.logging.set_verbosity(tf.logging.INFO)

    tokenizer = tokenization.get_tokenizer()

    ##### Get train cache function
    train_cache_fn = get_cache_fn(FLAGS.mem_len)
    eval_cache_fn = get_cache_fn(FLAGS.mem_len)

    ##### Get model function
    model_fn = get_model_fn(tokenizer.get_vocab_size())

    ##### Create TPUEstimator
    # TPU Configuration
    if not run_internal and FLAGS.use_tpu:
        tpu_cluster = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
    else:
        tpu_cluster = None

    per_host_input = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster,
        master=FLAGS.master,
        model_dir=FLAGS.model_dir,
        session_config=tf.ConfigProto(allow_soft_placement=True),
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations,
            per_host_input_for_training=per_host_input),
        keep_checkpoint_max=FLAGS.max_save,
        save_checkpoints_secs=None,
        save_checkpoints_steps=FLAGS.save_steps)

    # warm start
    warm_start_from = None
    if FLAGS.warm_start_path is not None:
        warm_start_from = tf.estimator.WarmStartSettings(
            ckpt_to_initialize_from=FLAGS.warm_start_path)

    # TPU Estimator
    estimator = tpu_estimator.TPUEstimator(
        model_fn=model_fn,
        train_cache_fn=train_cache_fn,
        eval_cache_fn=eval_cache_fn,
        use_tpu=FLAGS.use_tpu,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        eval_on_tpu=FLAGS.use_tpu,
        warm_start_from=warm_start_from)

    ##### Training
    if FLAGS.do_train:
        # Get train input function
        train_input_fn = get_input_fn("train")

        estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps)

    #### Evaluation
    if FLAGS.do_eval:
        # Get eval input function
        eval_input_fn = get_input_fn(FLAGS.eval_split)

        estimator.evaluate(input_fn=eval_input_fn,
                           steps=FLAGS.eval_steps,
                           checkpoint_path=FLAGS.eval_ckpt_path)