Example #1
0
    def test_endian_encodings(self):
        spec = {
            "a": array_spec.ArraySpec((2, ), np.int16),
            "b": array_spec.ArraySpec((2, ), np.int32),
            "c": array_spec.ArraySpec((2, ), np.float32),
        }

        serializer = example_encoding.get_example_serializer(spec)
        decoder = example_encoding.get_example_decoder(spec)

        # Little endian encoding.
        le_sample = {
            "a": np.array([100, 25000]).astype("<i2"),
            "b": np.array([-5, 80000000]).astype("<i4"),
            "c": np.array([12.5, np.pi]).astype("<f4")
        }

        example_proto = serializer(le_sample)
        recovered = self.evaluate(decoder(example_proto))
        tf.nest.map_structure(np.testing.assert_almost_equal, le_sample,
                              recovered)

        # Big endian encoding.
        be_sample = {
            "a": np.array([100, 25000]).astype(">i2"),
            "b": np.array([-5, 80000000]).astype(">i4"),
            "c": np.array([12.5, np.pi]).astype(">f4")
        }

        example_proto = serializer(be_sample)
        recovered = self.evaluate(decoder(example_proto))
        tf.nest.map_structure(np.testing.assert_almost_equal, be_sample,
                              recovered)
def load_tfrecord_dataset(dataset_files,
                          buffer_size=1000,
                          as_experience=False,
                          as_trajectories=False):
    """Loads a TFRecord dataset from file, sequencing samples as Trajectories.

  Args:
    dataset_files: List of paths to one or more datasets
    buffer_size: (int) number of bytes in the read buffer. 0 means no buffering.
    as_experience: (bool) Returns dataset as a pair of Trajectories. Samples
      will be shaped as if they had been pulled from a replay buffer with
      `num_steps=2`. These samples can be fed directly to agent's `train`
      method.
    as_trajectories: (bool) Remaps the data into trajectory objects. This should
      be enabled when the resulting types must be trajectories as expected by
      agents.

  Returns:
    A dataset of type tf.data.Dataset. Samples follow the dataset's spec nested
    structure. Samples are generated with a leading batch dim of 1
    (or 2 if as_experience is enabled).
  Raises:
    IOError: One or more of the dataset files does not exist.
  """

    specs = []
    for dataset_file in dataset_files:
        spec_path = dataset_file + _SPEC_FILE_EXTENSION
        dataset_spec = parse_encoded_spec_from_file(spec_path)
        specs.append(dataset_spec)
        if not all([dataset_spec == spec for spec in specs]):
            raise IOError('One or more of the encoding specs do not match.')
    decoder = example_encoding.get_example_decoder(specs[0])
    logging.info('Loading TFRecord dataset...')
    dataset = tf.data.TFRecordDataset(dataset_files,
                                      buffer_size=buffer_size,
                                      num_parallel_reads=len(dataset_files))

    def decode_fn(proto):
        """Decodes a proto object."""
        return decoder(proto)

    def decode_and_batch_fn(proto):
        """Decodes a proto object, and batch output tensors."""
        sample = decoder(proto)
        return nest_utils.batch_nested_tensors(sample)

    if as_experience:
        dataset = dataset.map(decode_fn).batch(2)
    else:
        dataset = dataset.map(decode_and_batch_fn)

    if as_trajectories:
        as_trajectories_fn = lambda sample: trajectory.Trajectory(*sample)
        dataset = dataset.map(as_trajectories_fn)
    return dataset
Example #3
0
    def test_serialize_deserialize(self, dtype):
        spec = example_nested_spec(dtype)
        serializer = example_encoding.get_example_serializer(spec)
        decoder = example_encoding.get_example_decoder(spec)

        sample = array_spec.sample_spec_nest(spec, np.random.RandomState(0))
        example_proto = serializer(sample)

        recovered = self.evaluate(decoder(example_proto))
        tf.nest.map_structure(np.testing.assert_almost_equal, sample,
                              recovered)
Example #4
0
    def test_compress_image(self):
        if not common.has_eager_been_enabled():
            self.skipTest("Image compression only supported in TF2.x")

        gin.parse_config_files_and_bindings([], """
    _get_feature_encoder.compress_image=True
    _get_feature_parser.compress_image=True
    """)
        spec = {"image": array_spec.ArraySpec((128, 128, 3), np.uint8)}
        serializer = example_encoding.get_example_serializer(spec)
        decoder = example_encoding.get_example_decoder(spec)

        sample = {"image": 128 * np.ones([128, 128, 3], dtype=np.uint8)}
        example_proto = serializer(sample)

        recovered = self.evaluate(decoder(example_proto))
        tf.nest.map_structure(np.testing.assert_almost_equal, sample,
                              recovered)
Example #5
0
def create_tf_record_dataset(
    filenames: MutableSequence[Text],
    batch_size: int,
    shuffle_buffer_size_per_record: int = 100,
    shuffle_buffer_size: int = 100,
    load_buffer_size: int = 100000000,
    num_shards: int = 50,
    cycle_length: int = tf.data.experimental.AUTOTUNE,
    block_length: int = 10,
    num_parallel_reads: Optional[int] = None,
    num_parallel_calls: int = tf.data.experimental.AUTOTUNE,
    num_prefetch: int = 10,
    strategy: Optional[tf.distribute.Strategy] = None,
    reward_shift: float = 0.0,
    action_clipping: Optional[Tuple[float, float]] = None,
    use_trajectories: bool = True,
):
    """Create a TF dataset from a list of filenames.

  A dataset is created for each record file and these are interleaved together
  to create the final dataset.

  Args:
    filenames: List of filenames of a TFRecord dataset containing TF Examples.
    batch_size: The batch size of tensors in the returned dataset.
    shuffle_buffer_size_per_record: The buffer size used for shuffling within a
      Record file.
    shuffle_buffer_size: The shuffle buffer size for the interleaved dataset.
    load_buffer_size: Number of bytes in the read buffer. 0 means no buffering.
    num_shards: The number of shards, each consisting of 1 or more record file
      datasets, that are then interleaved together.
    cycle_length: The number of input elements processed concurrently while
      interleaving.
    block_length: The number of consecutive elements to produce from each input
      element before cycling to another input element.
    num_parallel_reads: Optional, number of parallel reads in the TFRecord
      dataset. If not specified, len(filenames) will be used.
    num_parallel_calls: Number of parallel calls for interleave.
    num_prefetch: Number of batches to prefetch.
    strategy: Optional, `tf.distribute.Strategy` being used in training.
    reward_shift: Value to add to the reward.
    action_clipping: Optional, (minimum, maximum) values to clip actions.
    use_trajectories: Whether to use trajectories. If false, use transitions.

  Returns:
    A TF.Dataset containing a batch of nested Tensors.
  """
    initial_len = len(filenames)
    remainder = initial_len % num_shards
    for _ in range(num_shards - remainder):
        filenames.append(filenames[np.random.randint(low=0, high=initial_len)])
    filenames = np.array(filenames)
    np.random.shuffle(filenames)
    filenames = np.array_split(filenames, num_shards)

    record_file_ds = tf.data.Dataset.from_tensor_slices(filenames)
    record_file_ds = record_file_ds.repeat().shuffle(len(filenames))

    spec_path = filenames[0][0] + '.spec'
    record_spec = example_encoding_dataset.parse_encoded_spec_from_file(
        spec_path)
    decoder = example_encoding.get_example_decoder(record_spec)

    example_ds = record_file_ds.interleave(
        functools.partial(create_single_tf_record_dataset,
                          load_buffer_size=load_buffer_size,
                          shuffle_buffer_size=shuffle_buffer_size_per_record,
                          num_parallel_reads=num_parallel_reads,
                          decoder=decoder,
                          reward_shift=reward_shift,
                          action_clipping=action_clipping,
                          use_trajectories=use_trajectories),
        cycle_length=cycle_length,
        block_length=block_length,
        num_parallel_calls=num_parallel_calls,
    )
    example_ds = example_ds.shuffle(shuffle_buffer_size)

    use_tpu = isinstance(
        strategy,
        (tf.distribute.experimental.TPUStrategy, tf.distribute.TPUStrategy))
    example_ds = example_ds.batch(
        batch_size, drop_remainder=use_tpu).prefetch(num_prefetch)
    return example_ds
Example #6
0
    remainder = initial_len % num_shards
    for _ in range(num_shards - remainder):
        record_paths.append(record_paths[np.random.randint(low=0,
                                                           high=initial_len)])
    record_paths = np.array(record_paths)
    np.random.shuffle(record_paths)
    record_paths = np.array_split(record_paths, num_shards)

    record_file_ds = tf.data.Dataset.from_tensor_slices(record_paths)
    record_file_ds = record_file_ds.repeat().shuffle(len(record_paths))

    spec_path = record_paths[0][
        0] + example_encoding_dataset._SPEC_FILE_EXTENSION  # pylint: disable=protected-access
    record_spec = example_encoding_dataset.parse_encoded_spec_from_file(
        spec_path)
    decoder = example_encoding.get_example_decoder(record_spec,
                                                   compress_image=True)

    example_ds = record_file_ds.interleave(
        partial(
            create_dataset_for_given_tfrecord,
            load_buffer_size=100000000,
            shuffle_buffer_size=shuffle_buffer_size_per_tfrecord,
            num_parallel_reads=1,
            decoder=decoder,
        ),
        cycle_length=10,
        block_length=block_length,
        num_parallel_calls=10,
    )
    example_ds = example_ds.shuffle(shuffle_buffer_size)
    example_ds = example_ds.batch(batch_size)
def load_tfrecord_dataset(dataset_files,
                          buffer_size=1000,
                          as_experience=False,
                          as_trajectories=False,
                          add_batch_dim=True,
                          decoder=None,
                          num_parallel_reads=None,
                          compress_image=False,
                          spec=None):
    """Loads a TFRecord dataset from file, sequencing samples as Trajectories.

  Args:
    dataset_files: List of paths to one or more datasets
    buffer_size: (int) number of bytes in the read buffer. 0 means no buffering.
    as_experience: (bool) Returns dataset as a pair of Trajectories. Samples
      will be shaped as if they had been pulled from a replay buffer with
      `num_steps=2`. These samples can be fed directly to agent's `train`
      method.
    as_trajectories: (bool) Remaps the data into trajectory objects. This should
      be enabled when the resulting types must be trajectories as expected by
      agents.
    add_batch_dim: (bool) If True the data will have a batch dim of 1 to conform
      with the expected tensor batch convention. Set to false if you want to
      batch the data on your own.
    decoder: Optional, a custom decoder to use rather than using the default
      spec path.
    num_parallel_reads: Optional, number of parallel reads in the TFRecord
      dataset. If not specified, len(dataset_files) will be used.
    compress_image: Whether to decompress image. It is assumed that any uint8
      tensor of rank 3 with shape (w,h,c) is an image.
      If the tensor was compressed in the encoder, it needs to be decompressed.
    spec: Optional, the dataspec of the TFRecord dataset to be parsed. If not
      provided, parses the dataspec of the TFRecord directly.

  Returns:
    A dataset of type tf.data.Dataset. Samples follow the dataset's spec nested
    structure. Samples are generated with a leading batch dim of 1
    (or 2 if as_experience is enabled).
  Raises:
    IOError: One or more of the dataset files does not exist.
  """

    if not decoder:
        if spec is None:
            specs = []
            for dataset_file in dataset_files:
                spec_path = dataset_file + _SPEC_FILE_EXTENSION
                dataset_spec = parse_encoded_spec_from_file(spec_path)
                specs.append(dataset_spec)
                if not all([dataset_spec == spec for spec in specs]):
                    raise IOError(
                        'One or more of the encoding specs do not match.')
            spec = specs[0]

        decoder = example_encoding.get_example_decoder(
            spec, compress_image=compress_image)

    logging.info('Loading TFRecord dataset...')
    if not num_parallel_reads:
        num_parallel_reads = len(dataset_files)
    dataset = tf.data.TFRecordDataset(dataset_files,
                                      buffer_size=buffer_size,
                                      num_parallel_reads=num_parallel_reads)

    def decode_fn(proto):
        """Decodes a proto object."""
        return decoder(proto)

    def decode_and_batch_fn(proto):
        """Decodes a proto object, and batch output tensors."""
        sample = decoder(proto)
        return nest_utils.batch_nested_tensors(sample)

    if as_experience:
        dataset = dataset.map(decode_fn).batch(2, drop_remainder=True)
    elif add_batch_dim:
        dataset = dataset.map(decode_and_batch_fn)
    else:
        dataset = dataset.map(decode_fn)

    if as_trajectories:
        as_trajectories_fn = lambda sample: trajectory.Trajectory(*sample)
        dataset = dataset.map(as_trajectories_fn)
    return dataset
Example #8
0
def load_tfrecord_dataset_sequence(path_to_shards,
                                   buffer_size_per_shard=100,
                                   seq_len=1,
                                   deterministic=False,
                                   compress_image=True,
                                   for_rnn=False,
                                   check_all_specs=False):
    """A version of load_tfrecord_dataset that returns fixed length sequences.

  Note that we pad on the first frame to output seq_len. So a sequence of
  [0, 1, 2], with seq_len = 2 will produce samples of [0, 1], [1, 2], [0, 0],
  [0, 1], [1, 2], etc

  Args:
    path_to_shards: Path to TFRecord shards.
    buffer_size_per_shard: per-shard TFRecordReader buffer size.
    seq_len: fixed length output sequence.
    deterministic: If True, maintain deterministic sampling of shards (typically
      for testing).
    compress_image: Whether to decompress image. It is assumed that any uint8
      tensor of rank 3 with shape (w,h,c) is an image.
      If the tensor was compressed in the encoder, it needs to be decompressed.
    for_rnn: if True, see filter_episodes_rnn()
    check_all_specs: if True, check every spec.

  Returns:
    tf.data.Dataset object.
  """
    specs = []
    check_shards = path_to_shards if check_all_specs else path_to_shards[:1]
    for dataset_file in check_shards:
        spec_path = dataset_file + example_encoding_dataset._SPEC_FILE_EXTENSION  # pylint: disable=protected-access
        dataset_spec = example_encoding_dataset.parse_encoded_spec_from_file(
            spec_path)
        specs.append(dataset_spec)
        if not all([dataset_spec == spec for spec in specs]):
            raise ValueError('One or more of the encoding specs do not match.')
    decoder = example_encoding.get_example_decoder(
        specs[0], batched=True, compress_image=compress_image)

    # Note: window cannot be called on TFRecordDataset(shards) directly as it
    # interleaves samples across the shards. Instead, we'll sample windows on
    # shards independently using interleave.
    def interleave_func(shard):
        dataset = tf.data.TFRecordDataset(
            shard, buffer_size=buffer_size_per_shard).cache().repeat()
        dataset = dataset.window(seq_len,
                                 shift=1,
                                 stride=1,
                                 drop_remainder=True)
        return dataset.flat_map(
            lambda window: window.batch(seq_len, drop_remainder=True))

    dataset = tf.data.Dataset.from_tensor_slices(path_to_shards).repeat()
    num_parallel_calls = None if deterministic else len(path_to_shards)
    dataset = dataset.interleave(interleave_func,
                                 deterministic=deterministic,
                                 cycle_length=len(path_to_shards),
                                 block_length=1,
                                 num_parallel_calls=num_parallel_calls)

    # flat_map doesn't work with Dict[str, tf.Tensor], so for now decode after
    # the window sample (this causes unnecessary decode of protos).
    # TODO(tompson): It would be more efficient to decode before window.
    dataset = dataset.map(decoder, num_parallel_calls=num_parallel_calls)

    # We now have decoded sequences, each sample containing adjacent frames
    # within a single shard. However, the window may span multiple episodes, so
    # we need to filter these.

    if for_rnn:
        return dataset.map(filter_episodes_rnn,
                           num_parallel_calls=num_parallel_calls)
    else:
        dataset = dataset.map(filter_episodes,
                              num_parallel_calls=num_parallel_calls)

        # Set observation shape.
        def set_shape_obs(traj):
            def set_elem_shape(obs):
                obs_shape = obs.get_shape()
                return tf.ensure_shape(obs, [seq_len] + obs_shape[1:])

            observation = tf.nest.map_structure(set_elem_shape,
                                                traj.observation)
            return traj._replace(observation=observation)

        dataset = dataset.map(set_shape_obs,
                              num_parallel_calls=num_parallel_calls)
        return dataset