Esempio n. 1
0
    def test_make_dataset_with_sequence_length_and_batch_size(self):
        sequence_length = 6
        batch_size = 4
        environment = fakes.ContinuousEnvironment()
        environment_spec = specs.make_environment_spec(environment)
        dataset = reverb_dataset.make_dataset(
            client=self.tf_client,
            environment_spec=environment_spec,
            batch_size=batch_size,
            sequence_length=sequence_length)

        def make_tensor_spec(spec):
            return tf.TensorSpec(shape=(
                batch_size,
                sequence_length,
            ) + spec.shape,
                                 dtype=spec.dtype)

        expected_spec = tree.map_structure(make_tensor_spec, environment_spec)

        expected_spec = adders.Step(observation=expected_spec.observations,
                                    action=expected_spec.actions,
                                    reward=expected_spec.rewards,
                                    discount=expected_spec.discounts,
                                    start_of_episode=specs.Array(
                                        shape=(batch_size, sequence_length),
                                        dtype=bool),
                                    extras=())

        self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
Esempio n. 2
0
    def test_make_dataset_nested_specs(self):
        environment_spec = specs.EnvironmentSpec(observations={
            'obs_1':
            specs.Array((3, 64, 64), 'uint8'),
            'obs_2':
            specs.Array((10, ), 'int32')
        },
                                                 actions=specs.BoundedArray(
                                                     (),
                                                     'float32',
                                                     minimum=-1.,
                                                     maximum=1.),
                                                 rewards=specs.Array(
                                                     (), 'float32'),
                                                 discounts=specs.BoundedArray(
                                                     (),
                                                     'float32',
                                                     minimum=0.,
                                                     maximum=1.))

        dataset = reverb_dataset.make_dataset(
            client=self.tf_client, environment_spec=environment_spec)

        expected_spec = adders.Step(observation=environment_spec.observations,
                                    action=environment_spec.actions,
                                    reward=environment_spec.rewards,
                                    discount=environment_spec.discounts,
                                    start_of_episode=specs.Array(shape=(),
                                                                 dtype=bool),
                                    extras=())

        self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
Esempio n. 3
0
    def test_make_dataset_nested_specs(self):
        environment_spec = specs.EnvironmentSpec(observations={
            'obs_1':
            specs.Array((3, 64, 64), 'uint8'),
            'obs_2':
            specs.Array((10, ), 'int32')
        },
                                                 actions=specs.BoundedArray(
                                                     (),
                                                     'float32',
                                                     minimum=-1.,
                                                     maximum=1.),
                                                 rewards=specs.Array(
                                                     (), 'float32'),
                                                 discounts=specs.BoundedArray(
                                                     (),
                                                     'float32',
                                                     minimum=0.,
                                                     maximum=1.))

        dataset = reverb_dataset.make_dataset(
            client=self.tf_client, environment_spec=environment_spec)

        self.assertTrue(
            _check_specs(tuple(environment_spec), dataset.element_spec.data))
Esempio n. 4
0
    def test_make_dataset_simple(self):
        environment = fakes.ContinuousEnvironment()
        environment_spec = specs.make_environment_spec(environment)
        dataset = reverb_dataset.make_dataset(
            client=self.tf_client, environment_spec=environment_spec)

        self.assertTrue(
            _check_specs(tuple(environment_spec), dataset.element_spec.data))
Esempio n. 5
0
    def test_make_dataset_transition_adder(self):
        environment = fakes.ContinuousEnvironment()
        environment_spec = specs.make_environment_spec(environment)
        dataset = reverb_dataset.make_dataset(
            client=self.tf_client,
            environment_spec=environment_spec,
            transition_adder=True)

        environment_spec = tuple(environment_spec) + (
            environment_spec.observations, )

        self.assertTrue(
            _check_specs(tuple(environment_spec), dataset.element_spec.data))
Esempio n. 6
0
    def test_make_dataset_simple(self):
        environment = fakes.ContinuousEnvironment()
        environment_spec = specs.make_environment_spec(environment)
        dataset = reverb_dataset.make_dataset(
            client=self.tf_client, environment_spec=environment_spec)

        expected_spec = adders.Step(observation=environment_spec.observations,
                                    action=environment_spec.actions,
                                    reward=environment_spec.rewards,
                                    discount=environment_spec.discounts,
                                    start_of_episode=specs.Array(shape=(),
                                                                 dtype=bool),
                                    extras=())
        self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
Esempio n. 7
0
  def test_make_dataset_with_variable_length_instances(self):
    """Dataset with variable length instances should have shapes with None."""
    environment_spec = specs.EnvironmentSpec(
        observations=specs.Array((0, 64, 64), 'uint8'),
        actions=specs.BoundedArray((), 'float32', minimum=-1., maximum=1.),
        rewards=specs.Array((), 'float32'),
        discounts=specs.BoundedArray((), 'float32', minimum=0., maximum=1.))

    dataset = reverb_dataset.make_dataset(
        server_address=self.server_address,
        environment_spec=environment_spec,
        convert_zero_size_to_none=True)

    self.assertSequenceEqual(dataset.element_spec.data[0].shape.as_list(),
                             [None, 64, 64])
Esempio n. 8
0
    def test_make_dataset_with_batch_size(self):
        batch_size = 4
        environment = fakes.ContinuousEnvironment()
        environment_spec = specs.make_environment_spec(environment)
        dataset = reverb_dataset.make_dataset(
            client=self.tf_client,
            environment_spec=environment_spec,
            batch_size=batch_size)

        def make_tensor_spec(spec):
            return tf.TensorSpec(shape=(None, ) + spec.shape, dtype=spec.dtype)

        expected_spec = tree.map_structure(make_tensor_spec, environment_spec)

        self.assertTrue(
            _check_specs(tuple(expected_spec), dataset.element_spec.data))
Esempio n. 9
0
    def test_make_dataset_transition_adder(self):
        environment = fakes.ContinuousEnvironment()
        environment_spec = specs.make_environment_spec(environment)
        dataset = reverb_dataset.make_dataset(
            server_address=self.server_address,
            environment_spec=environment_spec,
            transition_adder=True)

        environment_spec = types.Transition(
            observation=environment_spec.observations,
            action=environment_spec.actions,
            reward=environment_spec.rewards,
            discount=environment_spec.discounts,
            next_observation=environment_spec.observations,
            extras=())

        self.assertTrue(
            _check_specs(environment_spec, dataset.element_spec.data))