예제 #1
0
  def testAverageReturnMultiMetricTimeMisalignment(
      self, run_mode, num_trajectories, reward_spec, expected_result):
    with run_mode():
      trajectories = self._create_misaligned_trajectories()
      multi_trajectories = []
      for traj in trajectories:
        if isinstance(reward_spec, list):
          new_reward = [traj.reward, traj.reward]
        else:
          new_reward = tf.stack([traj.reward, traj.reward], axis=1)
        new_traj = trajectory.Trajectory(
            step_type=traj.step_type,
            observation=traj.observation,
            action=traj.action,
            policy_info=traj.policy_info,
            next_step_type=traj.next_step_type,
            reward=new_reward,
            discount=traj.discount)
        multi_trajectories.append(new_traj)

      metric = tf_metrics.AverageReturnMultiMetric(reward_spec, batch_size=2)
      self.evaluate(tf.compat.v1.global_variables_initializer())
      self.evaluate(metric.init_variables())
      for i in range(num_trajectories):
        self.evaluate(metric(multi_trajectories[i]))

      self.assertAllEqual(expected_result, self.evaluate(metric.result()))
      self.evaluate(metric.reset())
      self.assertAllEqual([0.0, 0.0], self.evaluate(metric.result()))
예제 #2
0
def train(root_dir,
          agent,
          environment,
          training_loops,
          steps_per_loop,
          additional_metrics=(),
          training_data_spec_transformation_fn=None):
    """Perform `training_loops` iterations of training.

  Checkpoint results.

  If one or more baseline_reward_fns are provided, the regret is computed
  against each one of them. Here is example baseline_reward_fn:

  def baseline_reward_fn(observation, per_action_reward_fns):
   rewards = ... # compute reward for each arm
   optimal_action_reward = ... # take the maximum reward
   return optimal_action_reward

  Args:
    root_dir: path to the directory where checkpoints and metrics will be
      written.
    agent: an instance of `TFAgent`.
    environment: an instance of `TFEnvironment`.
    training_loops: an integer indicating how many training loops should be run.
    steps_per_loop: an integer indicating how many driver steps should be
      executed and presented to the trainer during each training loop.
    additional_metrics: Tuple of metric objects to log, in addition to default
      metrics `NumberOfEpisodes`, `AverageReturnMetric`, and
      `AverageEpisodeLengthMetric`.
    training_data_spec_transformation_fn: Optional function that transforms the
    data items before they get to the replay buffer.
  """

    # TODO(b/127641485): create evaluation loop with configurable metrics.
    if training_data_spec_transformation_fn is None:
        data_spec = agent.policy.trajectory_spec
    else:
        data_spec = training_data_spec_transformation_fn(
            agent.policy.trajectory_spec)
    replay_buffer = get_replay_buffer(data_spec, environment.batch_size,
                                      steps_per_loop)

    # `step_metric` records the number of individual rounds of bandit interaction;
    # that is, (number of trajectories) * batch_size.
    step_metric = tf_metrics.EnvironmentSteps()
    metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.AverageEpisodeLengthMetric(
            batch_size=environment.batch_size)
    ] + list(additional_metrics)

    if isinstance(environment.reward_spec(), dict):
        metrics += [
            tf_metrics.AverageReturnMultiMetric(
                reward_spec=environment.reward_spec(),
                batch_size=environment.batch_size)
        ]
    else:
        metrics += [
            tf_metrics.AverageReturnMetric(batch_size=environment.batch_size)
        ]

    if training_data_spec_transformation_fn is not None:
        add_batch_fn = lambda data: replay_buffer.add_batch(  # pylint: disable=g-long-lambda
            training_data_spec_transformation_fn(data))
    else:
        add_batch_fn = replay_buffer.add_batch

    observers = [add_batch_fn, step_metric] + metrics

    driver = dynamic_step_driver.DynamicStepDriver(env=environment,
                                                   policy=agent.collect_policy,
                                                   num_steps=steps_per_loop *
                                                   environment.batch_size,
                                                   observers=observers)

    training_loop = get_training_loop_fn(driver, replay_buffer, agent,
                                         steps_per_loop)
    checkpoint_manager = restore_and_get_checkpoint_manager(
        root_dir, agent, metrics, step_metric)
    saver = policy_saver.PolicySaver(agent.policy)

    summary_writer = tf.summary.create_file_writer(root_dir)
    summary_writer.set_as_default()
    for _ in range(training_loops):
        training_loop()
        metric_utils.log_metrics(metrics)
        for metric in metrics:
            metric.tf_summaries(train_step=step_metric.result())
        checkpoint_manager.save()
        saver.save(os.path.join(root_dir, 'policy_%d' % step_metric.result()))
예제 #3
0
    data_spec = training_data_spec_transformation_fn(
        agent.policy.trajectory_spec)
  replay_buffer = get_replay_buffer(data_spec, environment.batch_size,
                                    steps_per_loop)

  # `step_metric` records the number of individual rounds of bandit interaction;
  # that is, (number of trajectories) * batch_size.
  step_metric = tf_metrics.EnvironmentSteps()
  metrics = [
      tf_metrics.NumberOfEpisodes(),
      tf_metrics.AverageEpisodeLengthMetric(batch_size=environment.batch_size)
  ] + list(additional_metrics)

  if isinstance(environment.reward_spec(), dict):
    metrics += [tf_metrics.AverageReturnMultiMetric(
        reward_spec=environment.reward_spec(),
        batch_size=environment.batch_size)]
  else:
    metrics += [
        tf_metrics.AverageReturnMetric(batch_size=environment.batch_size)]

  if training_data_spec_transformation_fn is not None:
    add_batch_fn = lambda data: replay_buffer.add_batch(  # pylint: disable=g-long-lambda
        training_data_spec_transformation_fn(data))
  else:
    add_batch_fn = replay_buffer.add_batch

  observers = [add_batch_fn, step_metric] + metrics

  driver = dynamic_step_driver.DynamicStepDriver(
      env=environment,