Example #1
0
def create_train_step(loss,
                      optimizer,
                      global_step=_USE_GLOBAL_STEP,
                      total_loss_fn=None,
                      update_ops=None,
                      variables_to_train=None,
                      transform_grads_fn=None,
                      summarize_gradients=False,
                      gate_gradients=tf.compat.v1.train.Optimizer.GATE_OP,
                      aggregation_method=None,
                      check_numerics=True):
  """Creates a train_step that evaluates the gradients and returns the loss.

  Args:
    loss: A (possibly nested tuple of) `Tensor` or function representing
      the loss.
    optimizer: A tf.Optimizer to use for computing the gradients.
    global_step: A `Tensor` representing the global step variable. If left as
      `_USE_GLOBAL_STEP`, then tf.train.get_or_create_global_step() is used.
    total_loss_fn: Function to call on loss value to access the final
     item to minimize.
    update_ops: An optional list of updates to execute. If `update_ops` is
      `None`, then the update ops are set to the contents of the
      `tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but
      it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`,
      a warning will be displayed.
    variables_to_train: an optional list of variables to train. If None, it will
      default to all tf.trainable_variables().
    transform_grads_fn: A function which takes a single argument, a list of
      gradient to variable pairs (tuples), performs any requested gradient
      updates, such as gradient clipping or multipliers, and returns the updated
      list.
    summarize_gradients: Whether or not add summaries for each gradient.
    gate_gradients: How to gate the computation of gradients. See tf.Optimizer.
    aggregation_method: Specifies the method used to combine gradient terms.
      Valid values are defined in the class `AggregationMethod`.
    check_numerics: Whether or not we apply check_numerics.

  Returns:
    In graph mode: A (possibly nested tuple of) `Tensor` that when evaluated,
      calculates the current loss, computes the gradients, applies the
      optimizer, and returns the current loss.
    In eager mode: A lambda function that when is called, calculates the loss,
      then computes and applies the gradients and returns the original
      loss values.
  Raises:
    ValueError: if loss is not callable.
    RuntimeError: if resource variables are not enabled.
  """
  if total_loss_fn is None:
    total_loss_fn = lambda x: x
  if not callable(total_loss_fn):
    raise ValueError('`total_loss_fn` should be a function.')
  if not common.resource_variables_enabled():
    raise RuntimeError(common.MISSING_RESOURCE_VARIABLES_ERROR)
  if not tf.executing_eagerly():
    if callable(loss):
      loss = loss()
    if callable(variables_to_train):
      variables_to_train = variables_to_train()
    # Calculate loss first, then calculate train op, then return the original
    # loss conditioned on executing the train op.
    with tf.control_dependencies(tf.nest.flatten(loss)):
      loss = tf.nest.map_structure(
          lambda t: tf.identity(t, 'loss_pre_train'), loss)
    train_op = create_train_op(
        total_loss_fn(loss),
        optimizer,
        global_step=global_step,
        update_ops=update_ops,
        variables_to_train=variables_to_train,
        transform_grads_fn=transform_grads_fn,
        summarize_gradients=summarize_gradients,
        gate_gradients=gate_gradients,
        aggregation_method=aggregation_method,
        check_numerics=check_numerics)

    with tf.control_dependencies([train_op]):
      return tf.nest.map_structure(
          lambda t: tf.identity(t, 'loss_post_train'), loss)

  if global_step is _USE_GLOBAL_STEP:
    global_step = tf.compat.v1.train.get_or_create_global_step()

  if not callable(loss):
    raise ValueError('`loss` should be a function in eager mode.')

  if not isinstance(loss, Future):
    logging.warning('loss should be an instance of eager_utils.Future')

  with tf.GradientTape() as tape:
    loss_value = loss()
    total_loss_value = total_loss_fn(loss_value)
  if variables_to_train is None:
    variables_to_train = tape.watched_variables()
  elif callable(variables_to_train):
    variables_to_train = variables_to_train()
  variables_to_train = tf.nest.flatten(variables_to_train)
  grads = tape.gradient(total_loss_value, variables_to_train)
  grads_and_vars = zip(grads, variables_to_train)

  if transform_grads_fn:
    grads_and_vars = transform_grads_fn(grads_and_vars)

  if summarize_gradients:
    with tf.name_scope('summarize_grads'):
      add_gradients_summaries(grads_and_vars, global_step)

  if check_numerics:
    with tf.name_scope('train_op'):
      tf.debugging.check_numerics(total_loss_value, 'Loss is inf or nan')

  optimizer.apply_gradients(grads_and_vars, global_step=global_step)

  return loss_value
Example #2
0
def train(
    root_dir,
    agent,
    environment,
    training_loops,
    steps_per_loop=1,
    additional_metrics=(),
    # Params for checkpoints, summaries, and logging
    train_checkpoint_interval=10,
    policy_checkpoint_interval=10,
    log_interval=10,
    summary_interval=10):
  """A training driver."""

  if not common.resource_variables_enabled():
    raise RuntimeError(common.MISSING_RESOURCE_VARIABLES_ERROR)

  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(train_dir)
  train_summary_writer.set_as_default()

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(batch_size=environment.batch_size),
        tf_metrics.AverageEpisodeLengthMetric(
            batch_size=environment.batch_size),
    ] + list(additional_metrics)

    # Add to replay buffer and other agent specific observers.
    replay_buffer = build_replay_buffer(agent, environment.batch_size,
                                        steps_per_loop)
    agent_observers = [replay_buffer.add_batch] + train_metrics

    driver = dynamic_step_driver.DynamicStepDriver(
        env=environment,
        policy=agent.policy,
        num_steps=steps_per_loop * environment.batch_size,
        observers=agent_observers)

    collect_op, _ = driver.run()
    batch_size = driver.env.batch_size
    dataset = replay_buffer.as_dataset(
        sample_batch_size=batch_size,
        num_steps=steps_per_loop,
        single_deterministic_pass=True)
    trajectories, unused_info = tf.data.experimental.get_single_element(dataset)
    train_op = agent.train(experience=trajectories)
    clear_replay_op = replay_buffer.clear()

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        max_to_keep=1,
        agent=agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        max_to_keep=None,
        policy=agent.policy,
        global_step=global_step)

    summary_ops = []
    for train_metric in train_metrics:
      summary_ops.append(
          train_metric.tf_summaries(
              train_step=global_step, step_metrics=train_metrics[:2]))

    init_agent_op = agent.initialize()

    config_saver = utils.GinConfigSaverHook(train_dir, summarize_config=True)
    config_saver.begin()

    with tf.compat.v1.Session() as sess:
      # Initialize the graph.
      train_checkpointer.initialize_or_restore(sess)
      common.initialize_uninitialized_variables(sess)

      config_saver.after_create_session(sess)

      global_step_call = sess.make_callable(global_step)
      global_step_val = global_step_call()

      sess.run(train_summary_writer.init())
      sess.run(collect_op)

      if global_step_val == 0:
        # Save an initial checkpoint so the evaluator runs for global_step=0.
        policy_checkpointer.save(global_step=global_step_val)
        sess.run(init_agent_op)

      collect_call = sess.make_callable(collect_op)
      train_step_call = sess.make_callable([train_op, summary_ops])
      clear_replay_call = sess.make_callable(clear_replay_op)

      timed_at_step = global_step_val
      time_acc = 0
      steps_per_second_ph = tf.compat.v1.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.compat.v2.summary.scalar(
          name='global_steps_per_sec',
          data=steps_per_second_ph,
          step=global_step)

      for _ in range(training_loops):
        # Collect and train.
        start_time = time.time()
        collect_call()
        total_loss, _ = train_step_call()
        clear_replay_call()
        global_step_val = global_step_call()

        time_acc += time.time() - start_time

        total_loss = total_loss.loss

        if global_step_val % log_interval == 0:
          logging.info('step = %d, loss = %f', global_step_val, total_loss)
          steps_per_sec = (global_step_val - timed_at_step) / time_acc
          logging.info('%.3f steps/sec', steps_per_sec)
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          timed_at_step = global_step_val
          time_acc = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)