def test_make_dataset_nested_specs(self): environment_spec = specs.EnvironmentSpec(observations={ 'obs_1': specs.Array((3, 64, 64), 'uint8'), 'obs_2': specs.Array((10, ), 'int32') }, actions=specs.BoundedArray( (), 'float32', minimum=-1., maximum=1.), rewards=specs.Array( (), 'float32'), discounts=specs.BoundedArray( (), 'float32', minimum=0., maximum=1.)) dataset = reverb_dataset.make_dataset( client=self.tf_client, environment_spec=environment_spec) self.assertTrue( _check_specs(tuple(environment_spec), dataset.element_spec.data))
def _numeric_to_spec(x: Union[float, int, np.ndarray]): if isinstance(x, np.ndarray): return specs.Array(shape=x.shape, dtype=x.dtype) elif isinstance(x, (float, int)): return specs.Array(shape=(), dtype=type(x)) else: raise ValueError(f'Unsupported numeric: {type(x)}')
def test_pmap_update_nested(self): local_device_count = jax.local_device_count() state = running_statistics.init_state({ 'a': specs.Array((5,), jnp.float32), 'b': specs.Array((2,), jnp.float32) }) x = { 'a': (jnp.arange(15 * local_device_count, dtype=jnp.float32)).reshape(local_device_count, 3, 5), 'b': (jnp.arange(6 * local_device_count, dtype=jnp.float32)).reshape(local_device_count, 3, 2), } devices = jax.local_devices() state = jax.device_put_replicated(state, devices) pmap_axis_name = 'i' state = jax.pmap( functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name), pmap_axis_name)(state, x) state = jax.pmap( functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name), pmap_axis_name)(state, x) normalized = jax.pmap(running_statistics.normalize)(x, state) mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized) std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized) tree.map_structure( lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean) tree.map_structure( lambda x: self.assert_allclose(x, jnp.ones_like(x)), std)
def test_nested_normalize(self): state = running_statistics.init_state({ 'a': specs.Array((5,), jnp.float32), 'b': specs.Array((2,), jnp.float32) }) x1 = { 'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5), 'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) } x2 = { 'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5) + 20, 'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) + 8 } x3 = { 'a': jnp.arange(40, dtype=jnp.float32).reshape(4, 2, 5), 'b': jnp.arange(16, dtype=jnp.float32).reshape(4, 2, 2) } state = update_and_validate(state, x1) state = update_and_validate(state, x2) state = update_and_validate(state, x3) normalized = running_statistics.normalize(x3, state) mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized) std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized) tree.map_structure( lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean) tree.map_structure( lambda x: self.assert_allclose(x, jnp.ones_like(x)), std)
def test_make_dataset_nested_specs(self): environment_spec = specs.EnvironmentSpec(observations={ 'obs_1': specs.Array((3, 64, 64), 'uint8'), 'obs_2': specs.Array((10, ), 'int32') }, actions=specs.BoundedArray( (), 'float32', minimum=-1., maximum=1.), rewards=specs.Array( (), 'float32'), discounts=specs.BoundedArray( (), 'float32', minimum=0., maximum=1.)) dataset = reverb_dataset.make_dataset( client=self.tf_client, environment_spec=environment_spec) expected_spec = adders.Step(observation=environment_spec.observations, action=environment_spec.actions, reward=environment_spec.rewards, discount=environment_spec.discounts, start_of_episode=specs.Array(shape=(), dtype=bool), extras=()) self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
def _make_fake_env() -> dm_env.Environment: env_spec = specs.EnvironmentSpec( observations=specs.Array(shape=(10, 5), dtype=np.float32), actions=specs.DiscreteArray(num_values=3), rewards=specs.Array(shape=(), dtype=np.float32), discounts=specs.BoundedArray( shape=(), dtype=np.float32, minimum=0., maximum=1.), ) return fakes.Environment(env_spec, episode_length=10)
def observation_spec(self) -> types.Observation: observation_specs = {} for agent in self.possible_agents: spec = self._environment.observation_spec() observation_specs[agent] = types.OLT( observation=specs.Array(spec["info_state"], np.float32), legal_actions=specs.Array(spec["legal_actions"], np.float32), terminal=specs.Array((1, ), np.float32), ) return observation_specs
def test_different_structure_normalize(self): spec = TestNestedSpec( a=specs.Array((5,), jnp.float32), b=specs.Array((2,), jnp.float32)) state = running_statistics.init_state(spec) x = { 'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5), 'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) } with self.assertRaises(TypeError): state = update_and_validate(state, x)
def test_normalize_config(self): x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5) x_split = jnp.split(x, 5, axis=0) y = jnp.arange(160, dtype=jnp.float32).reshape(20, 2, 4) y_split = jnp.split(y, 5, axis=0) z = {'a': x, 'b': y} z_split = [{'a': xx, 'b': yy} for xx, yy in zip(x_split, y_split)] update = jax.jit(running_statistics.update, static_argnames=('config', )) config = running_statistics.NestStatisticsConfig((('a', ), )) state = running_statistics.init_state({ 'a': specs.Array((5, ), jnp.float32), 'b': specs.Array((4, ), jnp.float32) }) # Test initialization from the first element. state = update(state, z_split[0], config=config) state = update(state, z_split[1], config=config) state = update(state, z_split[2], config=config) state = update(state, z_split[3], config=config) state = update(state, z_split[4], config=config) normalize = jax.jit(running_statistics.normalize) normalized = normalize(z, state) for key in normalized: mean = jnp.mean(normalized[key], axis=(0, 1)) std = jnp.std(normalized[key], axis=(0, 1)) if key == 'a': self.assert_allclose( mean, jnp.zeros_like(mean), err_msg= f'key:{key} mean:{mean} normalized:{normalized[key]}') self.assert_allclose( std, jnp.ones_like(std), err_msg=f'key:{key} std:{std} normalized:{normalized[key]}' ) else: assert key == 'b' np.testing.assert_array_equal( normalized[key], z[key], err_msg=f'z:{z[key]} normalized:{normalized[key]}')
def test_multiple_inputs_and_outputs(self): def transformation(aa, bb, cc): return (tf.concat([aa, bb, cc], axis=-1), tf.concat([bb, cc], axis=-1)) model = tf2_utils.to_sonnet_module(transformation) dtype = np.float32 input_spec = [specs.Array(shape=(2,), dtype=dtype), specs.Array(shape=(3,), dtype=dtype), specs.Array(shape=(4,), dtype=dtype)] expected_output_spec = (tf.TensorSpec(shape=(9,), dtype=dtype), tf.TensorSpec(shape=(7,), dtype=dtype)) output_spec = tf2_utils.create_variables(model, input_spec) self.assertEqual(model.variables, ()) self.assertEqual(output_spec, expected_output_spec)
def test_make_dataset_with_variable_length_instances(self): """Dataset with variable length instances should have shapes with None.""" environment_spec = specs.EnvironmentSpec( observations=specs.Array((0, 64, 64), 'uint8'), actions=specs.BoundedArray((), 'float32', minimum=-1., maximum=1.), rewards=specs.Array((), 'float32'), discounts=specs.BoundedArray((), 'float32', minimum=0., maximum=1.)) dataset = reverb_dataset.make_dataset( server_address=self.server_address, environment_spec=environment_spec, convert_zero_size_to_none=True) self.assertSequenceEqual(dataset.element_spec.data[0].shape.as_list(), [None, 64, 64])
def test_none_output(self): model = tf2_utils.to_sonnet_module(lambda x: None) input_spec = specs.Array(shape=(10,), dtype=np.float32) expected_spec = None output_spec = tf2_utils.create_variables(model, [input_spec]) self.assertEqual(model.variables, ()) self.assertEqual(output_spec, expected_spec)
def test_scalar_output(self): model = tf2_utils.to_sonnet_module(tf.reduce_sum) input_spec = specs.Array(shape=(10,), dtype=np.float32) expected_spec = tf.TensorSpec(shape=(), dtype=tf.float32) output_spec = tf2_utils.create_variables(model, [input_spec]) self.assertEqual(model.variables, ()) self.assertEqual(output_spec, expected_spec)
def test_rnn_snapshot(self): """Test that snapshotter correctly calls saves/restores snapshots on RNNs.""" # Create a test network. net = snt.LSTM(10) spec = specs.Array([10], dtype=np.float32) tf2_utils.create_variables(net, [spec]) # Test that if you add some postprocessing without rerunning # create_variables, it still works. wrapped_net = snt.DeepRNN([net, lambda x: x]) for net1 in [net, wrapped_net]: # Save the test network. directory = self.get_tempdir() objects_to_save = {'net': net1} snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) snapshotter.save() # Reload the test network. net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) with tf.GradientTape() as tape: outputs1, next_state1 = net1(inputs, net1.initial_state(1)) loss1 = tf.math.reduce_sum(outputs1) grads1 = tape.gradient(loss1, net1.trainable_variables) with tf.GradientTape() as tape: outputs2, next_state2 = net2(inputs, net2.initial_state(1)) loss2 = tf.math.reduce_sum(outputs2) grads2 = tape.gradient(loss2, net2.trainable_variables) assert np.allclose(outputs1, outputs2) assert np.allclose(tree.flatten(next_state1), tree.flatten(next_state2)) assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
def __init__(self, *, num_actions: int = 1, num_observations: int = 1, action_dtype=np.int32, obs_dtype=np.int32, obs_shape: Sequence[int] = (), discount_spec: Optional[types.NestedSpec] = None, reward_spec: Optional[types.NestedSpec] = None, **kwargs): """Initialize the environment.""" if reward_spec is None: reward_spec = specs.Array((), np.float32) if discount_spec is None: discount_spec = specs.BoundedArray((), np.float32, 0.0, 1.0) actions = specs.DiscreteArray(num_actions, dtype=action_dtype) observations = specs.BoundedArray(shape=obs_shape, dtype=obs_dtype, minimum=obs_dtype(0), maximum=obs_dtype(num_observations - 1)) super().__init__(spec=specs.EnvironmentSpec(observations=observations, actions=actions, rewards=reward_spec, discounts=discount_spec), **kwargs)
def signature(cls, environment_spec: specs.EnvironmentSpec, extras_spec: types.NestedSpec = ()): """This is a helper method for generating signatures for Reverb tables. Signatures are useful for validating data types and shapes, see Reverb's documentation for details on how they are used. Args: environment_spec: A `specs.EnvironmentSpec` whose fields are nested structures with leaf nodes that have `.shape` and `.dtype` attributes. This should come from the environment that will be used to generate the data inserted into the Reverb table. extras_spec: A nested structure with leaf nodes that have `.shape` and `.dtype` attributes. The structure (and shapes/dtypes) of this must be the same as the `extras` passed into `ReverbAdder.add`. Returns: A `Step` whose leaf nodes are `tf.TensorSpec` objects. """ spec_step = Step(observation=environment_spec.observations, action=environment_spec.actions, reward=environment_spec.rewards, discount=environment_spec.discounts, start_of_episode=specs.Array(shape=(), dtype=bool), extras=extras_spec) return tree.map_structure_with_path(spec_like_to_tensor_spec, spec_step)
def test_make_dataset_with_sequence_length_and_batch_size(self): sequence_length = 6 batch_size = 4 environment = fakes.ContinuousEnvironment() environment_spec = specs.make_environment_spec(environment) dataset = reverb_dataset.make_dataset( client=self.tf_client, environment_spec=environment_spec, batch_size=batch_size, sequence_length=sequence_length) def make_tensor_spec(spec): return tf.TensorSpec(shape=( batch_size, sequence_length, ) + spec.shape, dtype=spec.dtype) expected_spec = tree.map_structure(make_tensor_spec, environment_spec) expected_spec = adders.Step(observation=expected_spec.observations, action=expected_spec.actions, reward=expected_spec.rewards, discount=expected_spec.discounts, start_of_episode=specs.Array( shape=(batch_size, sequence_length), dtype=bool), extras=()) self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
def test_snapshot_distribution(self): """Test that snapshotter correctly calls saves/restores snapshots.""" # Create a test network. net1 = snt.Sequential([ networks.LayerNormMLP([10, 10]), networks.MultivariateNormalDiagHead(1) ]) spec = specs.Array([10], dtype=np.float32) tf2_utils.create_variables(net1, [spec]) # Save the test network. directory = self.get_tempdir() objects_to_save = {'net': net1} snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory) snapshotter.save() # Reload the test network. net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net')) inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec)) with tf.GradientTape() as tape: dist1 = net1(inputs) loss1 = tf.math.reduce_sum(dist1.mean() + dist1.variance()) grads1 = tape.gradient(loss1, net1.trainable_variables) with tf.GradientTape() as tape: dist2 = net2(inputs) loss2 = tf.math.reduce_sum(dist2.mean() + dist2.variance()) grads2 = tape.gradient(loss2, net2.trainable_variables) assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
def __init__( self, rank: int, config_factory: List[Callable[[], WobConfig]], keep_pristine: bool, verbose_level: Optional[int] = 0, ): # Set environment configurations self.rank = rank # Set task-info. (For multi-task env, set this in reset_task()) self._configs = config_factory self._keep_pristine = keep_pristine assert all([callable(config) for config in self._configs]) # Sample at least one valid task config (default seed = 0). self.reset_task(task_index=0) # params self.env_done = False self.step_count = 0 # TODO: Do we need action mask? self.action_mask = False self._verbose_level = verbose_level # Set observation specs. observation = self.reset() self._observation_specs = { k: specs.Array(shape=v.shape, dtype=v.dtype, name=k) \ for k, v in observation.items() }
def test_init_normalize(self): state = running_statistics.init_state(specs.Array((5,), jnp.float32)) x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5) normalized = running_statistics.normalize(x, state) self.assert_allclose(normalized, x)
def setUp(self): super().setUp() self.state_dims = 8 self.action_dims = 4 self.params = { 'world': jnp.ones((3, )), 'policy': jnp.ones((3, )), 'value': jnp.ones((3, )) } self.env_spec = specs.EnvironmentSpec( observations=specs.Array(shape=(self.state_dims, ), dtype=float), actions=specs.Array(shape=(self.action_dims, ), dtype=float), rewards=specs.Array(shape=(1, ), dtype=float, name='reward'), discounts=specs.BoundedArray(shape=(), dtype=float, minimum=0., maximum=1., name='discount'))
def observation_spec(self) -> types.NestedSpec: observation_spec = self._environment.observation_spec() observation_spec.update({ 'trial_remaining_steps': specs.Array( shape=(), dtype=np.int32, name='trial_remaining_steps') }) return observation_spec
def test_feedforward(self, recurrent: bool): model = snt.Linear(42) if recurrent: model = snt.DeepRNN([model]) input_spec = specs.Array(shape=(10,), dtype=np.float32) tf2_utils.create_variables(model, [input_spec]) variables: Sequence[tf.Variable] = model.variables shapes = [v.shape.as_list() for v in variables] self.assertSequenceEqual(shapes, [[42], [10, 42]])
def test_int_not_normalized(self): state = running_statistics.init_state(specs.Array((), jnp.int32)) x = jnp.arange(5, dtype=jnp.int32) state = update_and_validate(state, x) normalized = running_statistics.normalize(x, state) np.testing.assert_array_equal(normalized, x)
def _broadcast_specs(*args: acme_specs.Array) -> acme_specs.Array: """Like np.broadcast, but for specs.Array. Args: *args: one or more specs.Array instances. Returns: A specs.Array with the broadcasted shape and dtype of the specs in *args. """ bc_info = np.broadcast(*tuple(a.generate_value() for a in args)) dtype = np.result_type(*tuple(a.dtype for a in args)) return acme_specs.Array(shape=bc_info.shape, dtype=dtype)
def test_validation(self): state = running_statistics.init_state(specs.Array((1, 2, 3), jnp.float32)) x = jnp.arange(12, dtype=jnp.float32).reshape(2, 2, 3) with self.assertRaises(AssertionError): update_and_validate(state, x) x = jnp.arange(3, dtype=jnp.float32).reshape(1, 1, 3) with self.assertRaises(AssertionError): update_and_validate(state, x)
def test_output_spec_feedforward(self, recurrent: bool): input_spec = specs.Array(shape=(10,), dtype=np.float32) model = snt.Linear(42) expected_spec = tf.TensorSpec(shape=(42,), dtype=tf.float32) if recurrent: model = snt.DeepRNN([model]) expected_spec = (expected_spec, ()) output_spec = tf2_utils.create_variables(model, [input_spec]) self.assertEqual(output_spec, expected_spec)
def test_multiple_outputs(self): model = PolicyValueHead(42) input_spec = specs.Array(shape=(10,), dtype=np.float32) expected_spec = (tf.TensorSpec(shape=(42,), dtype=tf.float32), tf.TensorSpec(shape=(), dtype=tf.float32)) output_spec = tf2_utils.create_variables(model, [input_spec]) variables: Sequence[tf.Variable] = model.variables shapes = [v.shape.as_list() for v in variables] self.assertSequenceEqual(shapes, [[42], [10, 42], [1], [10, 1]]) self.assertSequenceEqual(output_spec, expected_spec)
def observation_spec(self) -> types.Observation: return { agent: types.OLT( observation=_convert_to_spec( self.observation_space["observation"]), legal_actions=_convert_to_spec( self.observation_space["action_mask"]), terminal=specs.Array((1, ), np.float32), ) for agent in self._possible_agents }
def observation_spec(self) -> types.Observation: observation_specs = {} for agent in self.possible_agents: observation_specs[agent] = types.OLT( observation=_convert_to_spec( self._environment.observation_spaces[agent]), legal_actions=_convert_to_spec( self._environment.action_spaces[agent]), terminal=specs.Array((1, ), np.float32), ) return observation_specs