コード例 #1
0
    def test_with_dynamic_step_driver(self):
        env = driver_test_utils.PyEnvironmentMock()
        tf_env = tf_py_environment.TFPyEnvironment(env)
        policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(),
                                                tf_env.action_spec())

        trajectory_spec = trajectory.from_transition(tf_env.time_step_spec(),
                                                     policy.policy_step_spec,
                                                     tf_env.time_step_spec())

        tfrecord_observer = example_encoding_dataset.TFRecordObserver(
            self.dataset_path, trajectory_spec)
        driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            policy,
            observers=[common.function(tfrecord_observer)],
            num_steps=10)
        self.evaluate(tf.compat.v1.global_variables_initializer())

        time_step = self.evaluate(tf_env.reset())
        initial_policy_state = policy.get_initial_state(batch_size=1)
        self.evaluate(
            common.function(driver.run)(time_step, initial_policy_state))
        tfrecord_observer.flush()
        tfrecord_observer.close()

        dataset = example_encoding_dataset.load_tfrecord_dataset(
            [self.dataset_path], buffer_size=2, as_trajectories=True)
        iterator = eager_utils.dataset_iterator(dataset)
        sample = self.evaluate(eager_utils.get_next(iterator))
        self.assertIsInstance(sample, trajectory.Trajectory)
コード例 #2
0
    def test_with_py_driver(self):
        env = driver_test_utils.PyEnvironmentMock()
        policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        trajectory_spec = trajectory.from_transition(env.time_step_spec(),
                                                     policy.policy_step_spec,
                                                     env.time_step_spec())
        trajectory_spec = tensor_spec.from_spec(trajectory_spec)

        tfrecord_observer = example_encoding_dataset.TFRecordObserver(
            self.dataset_path, trajectory_spec, py_mode=True)

        driver = py_driver.PyDriver(env,
                                    policy, [tfrecord_observer],
                                    max_steps=10)
        time_step = env.reset()
        driver.run(time_step)
        tfrecord_observer.flush()
        tfrecord_observer.close()

        dataset = example_encoding_dataset.load_tfrecord_dataset(
            [self.dataset_path], buffer_size=2, as_trajectories=True)

        iterator = eager_utils.dataset_iterator(dataset)
        sample = self.evaluate(eager_utils.get_next(iterator))
        self.assertIsInstance(sample, trajectory.Trajectory)
コード例 #3
0
 def test_tfrecord_observer(self):
     tfrecord_observer = example_encoding_dataset.TFRecordObserver(
         self.dataset_path, self.simple_data_spec)
     # Draw a random sample from the simple spec
     sample = tensor_spec.sample_spec_nest(self.simple_data_spec,
                                           np.random.RandomState(0),
                                           outer_dims=(1, ))
     # Write to file using __call__() function
     for _ in range(3):
         tfrecord_observer(sample)
     # Manually flush
     tfrecord_observer.flush()
     tfrecord_observer.close()
コード例 #4
0
def collect_and_save_data(env_name,
                          model_dir,
                          trial_suffix,
                          max_episode_len,
                          root_dir,
                          total_episodes,
                          episodes_per_seed):

  saved_model_dir = utils.get_expanded_dir(
      model_dir, env_name, trial_suffix)
  saved_model_dir = os.path.join(saved_model_dir, 'policies/greedy_policy')
  max_train_step = 500 * max_episode_len  # 500k / action_repeat
  policy = utils.load_policy(saved_model_dir, max_train_step)
  trajectory_spec, episode_spec = utils.create_tensor_specs(
      policy.collect_data_spec, max_episode_len)
  episode2_spec = process_data.get_episode_spec(
      trajectory_spec, max_episode_len)

  root_dir = utils.get_expanded_dir(
      root_dir, env_name, trial_suffix, check=False)
  tf_episode_observer = example_encoding_dataset.TFRecordObserver(
      os.path.join(root_dir, 'episodes'), episode_spec, py_mode=True)
  tf_episode2_observer = example_encoding_dataset.TFRecordObserver(
      os.path.join(root_dir, 'episodes2'), episode2_spec, py_mode=True)

  num_seeds = total_episodes // episodes_per_seed
  max_steps = (max_episode_len + 1) * episodes_per_seed
  for seed in range(INIT_DATA_SEED, INIT_DATA_SEED + num_seeds):
    logging.info('Collection and saving for seed %d ..', seed)
    episodes, paired_episodes = utils.collect_pair_episodes(
        policy, env_name, random_seed=seed, max_steps=max_steps,
        max_episodes=episodes_per_seed)
    for episode_tuple in zip(episodes, paired_episodes):
      tf_episode_observer.write(episode_tuple)
      # Write (obs1, obs2, metric) tuples
      processed_episode_tuple = process_data.process_episode(
          episode_tuple[0], episode_tuple[1], gamma=process_data.GAMMA)
      tf_episode2_observer.write(processed_episode_tuple)
  tf_episode_observer.close()
コード例 #5
0
ファイル: file_utils.py プロジェクト: morgandu/agents
def write_samples_to_tfrecord(dataset_dict: Dict[str, types.Array],
                              collect_data_spec: trajectory.Transition,
                              dataset_path: str,
                              start_episode: int,
                              end_episode: int,
                              use_trajectories: bool = True) -> None:
    """Creates and writes samples to a TFRecord file."""
    tfrecord_observer = example_encoding_dataset.TFRecordObserver(
        dataset_path, collect_data_spec, py_mode=True)
    states = dataset_dict['states']
    actions = dataset_dict['actions']
    discounts = dataset_dict['discounts']
    rewards = dataset_dict['rewards']
    num_episodes = len(dataset_dict['episode_start_index'])

    for episode_i in range(start_episode, end_episode):
        episode_start_index = dataset_dict['episode_start_index'][episode_i]
        # If this is the last episode, end at the final step.
        if episode_i == (num_episodes - 1):
            episode_end_index = len(states)
        else:
            # Otherwise, end before the next episode.
            episode_end_index = dataset_dict['episode_start_index'][episode_i +
                                                                    1]

        for step_i in range(int(episode_start_index), int(episode_end_index)):
            # Set step type.
            if step_i == episode_end_index - 1:
                step_type = ts.StepType.LAST
            elif step_i == episode_start_index:
                step_type = ts.StepType.FIRST
            else:
                step_type = ts.StepType.MID

            # Set next state.
            # If at the last step in the episode, create a dummy next step.
            if step_type == ts.StepType.LAST:
                next_state = np.zeros_like(states[step_i])
                next_step_type = ts.StepType.FIRST
            else:
                next_state = states[step_i + 1]
                next_step_type = (ts.StepType.LAST
                                  if step_i == episode_end_index -
                                  2 else ts.StepType.MID)

            if use_trajectories:
                sample = create_trajectory(state=states[step_i],
                                           action=actions[step_i],
                                           discount=discounts[step_i],
                                           reward=rewards[step_i],
                                           step_type=step_type,
                                           next_step_type=next_step_type)
            else:
                sample = create_transition(state=states[step_i],
                                           action=actions[step_i],
                                           next_state=next_state,
                                           discount=discounts[step_i],
                                           reward=rewards[step_i],
                                           step_type=step_type,
                                           next_step_type=next_step_type)
            tfrecord_observer(sample)

    tfrecord_observer.close()
    logging.info('Wrote episodes [%d-%d] to %s', start_episode, end_episode,
                 dataset_path)