def test_conflicting_specs(self):
   # If two different specs are encountered an error should be thrown
   self.other_data_spec = SimpleSpec(
       step_type=tf.TensorSpec(shape=(1), dtype=tf.int32, name="step_type"),
       value=tf.TensorSpec(shape=(1, 5), dtype=tf.float64, name="value"))
   self.other_dataset_path = os.path.join(
       self.get_temp_dir(), "other_test_tfrecord_dataset.tfrecord")
   example_encoding_dataset.encode_spec_to_file(self.other_dataset_path,
                                                self.other_data_spec)
   with self.assertRaises(IOError):
     example_encoding_dataset.load_tfrecord_dataset(
         [self.dataset_path, self.other_dataset_path])
Example #2
0
def create_dataset_for_given_tfrecord(
    record_path,
    load_buffer_size=None,
    shuffle_buffer_size=1,
    num_parallel_reads=100,
    decoder=None,
):
    """Creates TFDataset based on a given TFRecord file.

  TFRecords contain agent experience (as saved by the tf_agents endpoints).

  Args:
    record_path: path to TFRecord
    load_buffer_size: Buffer size
    shuffle_buffer_size: Size of shuffle
    num_parallel_reads: How many workers to use for loading
    decoder: decoder
  Returns:
    dataset: dataset
  """
    dataset = example_encoding_dataset.load_tfrecord_dataset(
        record_path,
        buffer_size=load_buffer_size,
        as_trajectories=True,
        as_experience=True,
        add_batch_dim=False,
        compress_image=True,
        num_parallel_reads=num_parallel_reads,
        decoder=decoder,
    )
    # Add dummy info field for experience.
    if shuffle_buffer_size > 1:
        dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return dataset
    def test_with_dynamic_step_driver(self):
        env = driver_test_utils.PyEnvironmentMock()
        tf_env = tf_py_environment.TFPyEnvironment(env)
        policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(),
                                                tf_env.action_spec())

        trajectory_spec = trajectory.from_transition(tf_env.time_step_spec(),
                                                     policy.policy_step_spec,
                                                     tf_env.time_step_spec())

        tfrecord_observer = example_encoding_dataset.TFRecordObserver(
            self.dataset_path, trajectory_spec)
        driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            policy,
            observers=[common.function(tfrecord_observer)],
            num_steps=10)
        self.evaluate(tf.compat.v1.global_variables_initializer())

        time_step = self.evaluate(tf_env.reset())
        initial_policy_state = policy.get_initial_state(batch_size=1)
        self.evaluate(
            common.function(driver.run)(time_step, initial_policy_state))
        tfrecord_observer.flush()
        tfrecord_observer.close()

        dataset = example_encoding_dataset.load_tfrecord_dataset(
            [self.dataset_path], buffer_size=2, as_trajectories=True)
        iterator = eager_utils.dataset_iterator(dataset)
        sample = self.evaluate(eager_utils.get_next(iterator))
        self.assertIsInstance(sample, trajectory.Trajectory)
    def test_with_py_driver(self):
        env = driver_test_utils.PyEnvironmentMock()
        policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        trajectory_spec = trajectory.from_transition(env.time_step_spec(),
                                                     policy.policy_step_spec,
                                                     env.time_step_spec())
        trajectory_spec = tensor_spec.from_spec(trajectory_spec)

        tfrecord_observer = example_encoding_dataset.TFRecordObserver(
            self.dataset_path, trajectory_spec, py_mode=True)

        driver = py_driver.PyDriver(env,
                                    policy, [tfrecord_observer],
                                    max_steps=10)
        time_step = env.reset()
        driver.run(time_step)
        tfrecord_observer.flush()
        tfrecord_observer.close()

        dataset = example_encoding_dataset.load_tfrecord_dataset(
            [self.dataset_path], buffer_size=2, as_trajectories=True)

        iterator = eager_utils.dataset_iterator(dataset)
        sample = self.evaluate(eager_utils.get_next(iterator))
        self.assertIsInstance(sample, trajectory.Trajectory)
Example #5
0
def create_single_tf_record_dataset(
    filename: Text,
    load_buffer_size: int = 0,
    shuffle_buffer_size: int = 10000,
    num_parallel_reads: Optional[int] = None,
    decoder: Optional[DecoderFnType] = None,
    reward_shift: float = 0.0,
    action_clipping: Optional[Tuple[float, float]] = None,
    use_trajectories: bool = True,
):
    """Create a TF dataset for a single TFRecord file.

  Args:
    filename: Path to a single TFRecord file.
    load_buffer_size: Number of bytes in the read buffer. 0 means no buffering.
    shuffle_buffer_size: Size of the buffer for shuffling items within a single
      TFRecord file.
    num_parallel_reads: Optional, number of parallel reads in the TFRecord
      dataset. If not specified, no parallelism.
    decoder: Optional, a custom decoder to use rather than using the default
      spec path.
    reward_shift: Value to add to the reward.
    action_clipping: Optional, (minimum, maximum) values to clip actions.
    use_trajectories: Whether to use trajectories. If false, use transitions.

  Returns:
    A TF.Dataset of experiences.
  """
    dataset = example_encoding_dataset.load_tfrecord_dataset(
        filename,
        buffer_size=load_buffer_size,
        as_experience=use_trajectories,
        as_trajectories=use_trajectories,
        add_batch_dim=False,
        num_parallel_reads=num_parallel_reads,
        decoder=decoder,
    )

    def _sample_to_experience(sample):
        dummy_info = ()
        return updated_sample(sample, reward_shift, action_clipping,
                              use_trajectories), dummy_info

    dataset = dataset.map(_sample_to_experience,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)

    dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return dataset
 def test_load_tfrecord_dataset(self):
   # Make sure an example tfrecord file exists before attempting to load
   self.test_tfrecord_observer()
   example_encoding_dataset.load_tfrecord_dataset([self.dataset_path],
                                                  buffer_size=2)
   example_encoding_dataset.load_tfrecord_dataset([self.dataset_path],
                                                  buffer_size=2,
                                                  as_experience=True)
   with self.assertRaises(IOError):
     example_encoding_dataset.load_tfrecord_dataset(["fake_file.tfrecord"])
def load_trajs(dataset_path, batch_size, img_pad):
  traj_dataset = example_encoding_dataset.load_tfrecord_dataset(
      [dataset_path],
      buffer_size=int(1e6),
      as_experience=False,
      as_trajectories=True,
      add_batch_dim=False)
  logging.info('Traj dataset loaded from %s', dataset_path)

  traj_dataset = traj_dataset.shuffle(10000).repeat()

  image_aug_fn = lambda img: image_aug(img, img_pad)
  traj_dataset = traj_dataset.map(
      image_aug_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  return traj_dataset.batch(batch_size).prefetch(
      tf.data.experimental.AUTOTUNE)
def load_episodes(dataset_path, img_pad, gamma=0.99):
  """Load episode data from a fixed dataset."""
  episode_dataset = example_encoding_dataset.load_tfrecord_dataset(
      [dataset_path],
      buffer_size=None,
      as_experience=False,
      add_batch_dim=False)
  logging.info('Episode dataset loaded from %s', dataset_path)

  if dataset_path.endswith('episodes'):
    process_episode_fn = lambda x, y: process_episode(x, y, gamma)
    episode_dataset = episode_dataset.map(
        process_episode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  episode_dataset = episode_dataset.shuffle(50).repeat()
  transform_episodes_fn = lambda x, y, z: transform_episodes(x, y, z, img_pad)
  episode_dataset = episode_dataset.map(transform_episodes_fn)
  episode_dataset = episode_dataset.prefetch(tf.data.experimental.AUTOTUNE)
  logging.info('Episode dataset processed.')
  return episode_dataset