def test_make_dataset_with_sequence_length_and_batch_size(self): sequence_length = 6 batch_size = 4 environment = fakes.ContinuousEnvironment() environment_spec = specs.make_environment_spec(environment) dataset = reverb_dataset.make_dataset( client=self.tf_client, environment_spec=environment_spec, batch_size=batch_size, sequence_length=sequence_length) def make_tensor_spec(spec): return tf.TensorSpec(shape=( batch_size, sequence_length, ) + spec.shape, dtype=spec.dtype) expected_spec = tree.map_structure(make_tensor_spec, environment_spec) expected_spec = adders.Step(observation=expected_spec.observations, action=expected_spec.actions, reward=expected_spec.rewards, discount=expected_spec.discounts, start_of_episode=specs.Array( shape=(batch_size, sequence_length), dtype=bool), extras=()) self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
def test_make_dataset_nested_specs(self): environment_spec = specs.EnvironmentSpec(observations={ 'obs_1': specs.Array((3, 64, 64), 'uint8'), 'obs_2': specs.Array((10, ), 'int32') }, actions=specs.BoundedArray( (), 'float32', minimum=-1., maximum=1.), rewards=specs.Array( (), 'float32'), discounts=specs.BoundedArray( (), 'float32', minimum=0., maximum=1.)) dataset = reverb_dataset.make_dataset( client=self.tf_client, environment_spec=environment_spec) expected_spec = adders.Step(observation=environment_spec.observations, action=environment_spec.actions, reward=environment_spec.rewards, discount=environment_spec.discounts, start_of_episode=specs.Array(shape=(), dtype=bool), extras=()) self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
def test_make_dataset_nested_specs(self): environment_spec = specs.EnvironmentSpec(observations={ 'obs_1': specs.Array((3, 64, 64), 'uint8'), 'obs_2': specs.Array((10, ), 'int32') }, actions=specs.BoundedArray( (), 'float32', minimum=-1., maximum=1.), rewards=specs.Array( (), 'float32'), discounts=specs.BoundedArray( (), 'float32', minimum=0., maximum=1.)) dataset = reverb_dataset.make_dataset( client=self.tf_client, environment_spec=environment_spec) self.assertTrue( _check_specs(tuple(environment_spec), dataset.element_spec.data))
def test_make_dataset_simple(self): environment = fakes.ContinuousEnvironment() environment_spec = specs.make_environment_spec(environment) dataset = reverb_dataset.make_dataset( client=self.tf_client, environment_spec=environment_spec) self.assertTrue( _check_specs(tuple(environment_spec), dataset.element_spec.data))
def test_make_dataset_transition_adder(self): environment = fakes.ContinuousEnvironment() environment_spec = specs.make_environment_spec(environment) dataset = reverb_dataset.make_dataset( client=self.tf_client, environment_spec=environment_spec, transition_adder=True) environment_spec = tuple(environment_spec) + ( environment_spec.observations, ) self.assertTrue( _check_specs(tuple(environment_spec), dataset.element_spec.data))
def test_make_dataset_simple(self): environment = fakes.ContinuousEnvironment() environment_spec = specs.make_environment_spec(environment) dataset = reverb_dataset.make_dataset( client=self.tf_client, environment_spec=environment_spec) expected_spec = adders.Step(observation=environment_spec.observations, action=environment_spec.actions, reward=environment_spec.rewards, discount=environment_spec.discounts, start_of_episode=specs.Array(shape=(), dtype=bool), extras=()) self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
def test_make_dataset_with_variable_length_instances(self): """Dataset with variable length instances should have shapes with None.""" environment_spec = specs.EnvironmentSpec( observations=specs.Array((0, 64, 64), 'uint8'), actions=specs.BoundedArray((), 'float32', minimum=-1., maximum=1.), rewards=specs.Array((), 'float32'), discounts=specs.BoundedArray((), 'float32', minimum=0., maximum=1.)) dataset = reverb_dataset.make_dataset( server_address=self.server_address, environment_spec=environment_spec, convert_zero_size_to_none=True) self.assertSequenceEqual(dataset.element_spec.data[0].shape.as_list(), [None, 64, 64])
def test_make_dataset_with_batch_size(self): batch_size = 4 environment = fakes.ContinuousEnvironment() environment_spec = specs.make_environment_spec(environment) dataset = reverb_dataset.make_dataset( client=self.tf_client, environment_spec=environment_spec, batch_size=batch_size) def make_tensor_spec(spec): return tf.TensorSpec(shape=(None, ) + spec.shape, dtype=spec.dtype) expected_spec = tree.map_structure(make_tensor_spec, environment_spec) self.assertTrue( _check_specs(tuple(expected_spec), dataset.element_spec.data))
def test_make_dataset_transition_adder(self): environment = fakes.ContinuousEnvironment() environment_spec = specs.make_environment_spec(environment) dataset = reverb_dataset.make_dataset( server_address=self.server_address, environment_spec=environment_spec, transition_adder=True) environment_spec = types.Transition( observation=environment_spec.observations, action=environment_spec.actions, reward=environment_spec.rewards, discount=environment_spec.discounts, next_observation=environment_spec.observations, extras=()) self.assertTrue( _check_specs(environment_spec, dataset.element_spec.data))