def testAverageReturnMultiMetricTimeMisalignment( self, run_mode, num_trajectories, reward_spec, expected_result): with run_mode(): trajectories = self._create_misaligned_trajectories() multi_trajectories = [] for traj in trajectories: if isinstance(reward_spec, list): new_reward = [traj.reward, traj.reward] else: new_reward = tf.stack([traj.reward, traj.reward], axis=1) new_traj = trajectory.Trajectory( step_type=traj.step_type, observation=traj.observation, action=traj.action, policy_info=traj.policy_info, next_step_type=traj.next_step_type, reward=new_reward, discount=traj.discount) multi_trajectories.append(new_traj) metric = tf_metrics.AverageReturnMultiMetric(reward_spec, batch_size=2) self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(metric.init_variables()) for i in range(num_trajectories): self.evaluate(metric(multi_trajectories[i])) self.assertAllEqual(expected_result, self.evaluate(metric.result())) self.evaluate(metric.reset()) self.assertAllEqual([0.0, 0.0], self.evaluate(metric.result()))
def train(root_dir, agent, environment, training_loops, steps_per_loop, additional_metrics=(), training_data_spec_transformation_fn=None): """Perform `training_loops` iterations of training. Checkpoint results. If one or more baseline_reward_fns are provided, the regret is computed against each one of them. Here is example baseline_reward_fn: def baseline_reward_fn(observation, per_action_reward_fns): rewards = ... # compute reward for each arm optimal_action_reward = ... # take the maximum reward return optimal_action_reward Args: root_dir: path to the directory where checkpoints and metrics will be written. agent: an instance of `TFAgent`. environment: an instance of `TFEnvironment`. training_loops: an integer indicating how many training loops should be run. steps_per_loop: an integer indicating how many driver steps should be executed and presented to the trainer during each training loop. additional_metrics: Tuple of metric objects to log, in addition to default metrics `NumberOfEpisodes`, `AverageReturnMetric`, and `AverageEpisodeLengthMetric`. training_data_spec_transformation_fn: Optional function that transforms the data items before they get to the replay buffer. """ # TODO(b/127641485): create evaluation loop with configurable metrics. if training_data_spec_transformation_fn is None: data_spec = agent.policy.trajectory_spec else: data_spec = training_data_spec_transformation_fn( agent.policy.trajectory_spec) replay_buffer = get_replay_buffer(data_spec, environment.batch_size, steps_per_loop) # `step_metric` records the number of individual rounds of bandit interaction; # that is, (number of trajectories) * batch_size. step_metric = tf_metrics.EnvironmentSteps() metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.AverageEpisodeLengthMetric( batch_size=environment.batch_size) ] + list(additional_metrics) if isinstance(environment.reward_spec(), dict): metrics += [ tf_metrics.AverageReturnMultiMetric( reward_spec=environment.reward_spec(), batch_size=environment.batch_size) ] else: metrics += [ tf_metrics.AverageReturnMetric(batch_size=environment.batch_size) ] if training_data_spec_transformation_fn is not None: add_batch_fn = lambda data: replay_buffer.add_batch( # pylint: disable=g-long-lambda training_data_spec_transformation_fn(data)) else: add_batch_fn = replay_buffer.add_batch observers = [add_batch_fn, step_metric] + metrics driver = dynamic_step_driver.DynamicStepDriver(env=environment, policy=agent.collect_policy, num_steps=steps_per_loop * environment.batch_size, observers=observers) training_loop = get_training_loop_fn(driver, replay_buffer, agent, steps_per_loop) checkpoint_manager = restore_and_get_checkpoint_manager( root_dir, agent, metrics, step_metric) saver = policy_saver.PolicySaver(agent.policy) summary_writer = tf.summary.create_file_writer(root_dir) summary_writer.set_as_default() for _ in range(training_loops): training_loop() metric_utils.log_metrics(metrics) for metric in metrics: metric.tf_summaries(train_step=step_metric.result()) checkpoint_manager.save() saver.save(os.path.join(root_dir, 'policy_%d' % step_metric.result()))
data_spec = training_data_spec_transformation_fn( agent.policy.trajectory_spec) replay_buffer = get_replay_buffer(data_spec, environment.batch_size, steps_per_loop) # `step_metric` records the number of individual rounds of bandit interaction; # that is, (number of trajectories) * batch_size. step_metric = tf_metrics.EnvironmentSteps() metrics = [ tf_metrics.NumberOfEpisodes(), tf_metrics.AverageEpisodeLengthMetric(batch_size=environment.batch_size) ] + list(additional_metrics) if isinstance(environment.reward_spec(), dict): metrics += [tf_metrics.AverageReturnMultiMetric( reward_spec=environment.reward_spec(), batch_size=environment.batch_size)] else: metrics += [ tf_metrics.AverageReturnMetric(batch_size=environment.batch_size)] if training_data_spec_transformation_fn is not None: add_batch_fn = lambda data: replay_buffer.add_batch( # pylint: disable=g-long-lambda training_data_spec_transformation_fn(data)) else: add_batch_fn = replay_buffer.add_batch observers = [add_batch_fn, step_metric] + metrics driver = dynamic_step_driver.DynamicStepDriver( env=environment,