def test_feedforward(self): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) def policy(inputs: jnp.ndarray): return hk.Sequential([ hk.Flatten(), hk.Linear(env_spec.actions.num_values), lambda x: jnp.argmax(x, axis=-1), ])( inputs) policy = hk.transform(policy, apply_rng=True) rng = hk.PRNGSequence(1) dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) params = policy.init(next(rng), dummy_obs) variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient(variable_source, 'policy') actor = actors.FeedForwardActor( policy.apply, rng=hk.PRNGSequence(1), variable_client=variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def test_update(self): # Create two instances of the same model. actor_model = snt.nets.MLP([50, 30]) learner_model = snt.nets.MLP([50, 30]) # Create variables first. input_spec = tf.TensorSpec(shape=(28, ), dtype=tf.float32) tf2_utils.create_variables(actor_model, [input_spec]) tf2_utils.create_variables(learner_model, [input_spec]) # Register them as client and source variables, respectively. actor_variables = actor_model.variables np_learner_variables = [ tf2_utils.to_numpy(v) for v in learner_model.variables ] variable_source = fakes.VariableSource(np_learner_variables) variable_client = tf2_variable_utils.VariableClient( variable_source, {'policy': actor_variables}) # Now, given some random batch of test input: x = tf.random.normal(shape=(8, 28)) # Before copying variables, the models have different outputs. actor_output = actor_model(x).numpy() learner_output = learner_model(x).numpy() self.assertFalse(np.allclose(actor_output, learner_output)) # Update the variable client. variable_client.update_and_wait() # After copying variables (by updating the client), the models are the same. actor_output = actor_model(x).numpy() learner_output = learner_model(x).numpy() self.assertTrue(np.allclose(actor_output, learner_output))
def test_update(self): init_fn, _ = hk.without_apply_rng(hk.transform(dummy_network)) params = init_fn(jax.random.PRNGKey(1), jnp.zeros(shape=(1, 32))) variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient(variable_source, key='policy') variable_client.update_and_wait() tree.map_structure(np.testing.assert_array_equal, variable_client.params, params)
def test_multiple_keys(self): init_fn, _ = hk.without_apply_rng(hk.transform(dummy_network)) params = init_fn(jax.random.PRNGKey(1), jnp.zeros(shape=(1, 32))) steps = jnp.zeros(shape=1) variables = {'network': params, 'steps': steps} variable_source = fakes.VariableSource(variables, use_default_key=False) variable_client = variable_utils.VariableClient( variable_source, key=['network', 'steps']) variable_client.update_and_wait() tree.map_structure(np.testing.assert_array_equal, variable_client.params[0], params) tree.map_structure(np.testing.assert_array_equal, variable_client.params[1], steps)
def test_recurrent(self, has_extras): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) output_size = env_spec.actions.num_values obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) rng = hk.PRNGSequence(1) @_transform_without_rng def network(inputs: jnp.ndarray, state: hk.LSTMState): return hk.DeepRNN( [hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)])(inputs, state) @_transform_without_rng def initial_state(batch_size: Optional[int] = None): network = hk.DeepRNN( [hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)]) return network.initial_state(batch_size) initial_state = initial_state.apply(initial_state.init(next(rng)), 1) params = network.init(next(rng), obs, initial_state) def policy( params: jnp.ndarray, key: jnp.ndarray, observation: jnp.ndarray, core_state: hk.LSTMState) -> Tuple[jnp.ndarray, hk.LSTMState]: del key # Unused for test-case deterministic policy. action_values, core_state = network.apply(params, observation, core_state) actions = jnp.argmax(action_values, axis=-1) if has_extras: return (actions, (action_values, )), core_state else: return actions, core_state variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient( variable_source, 'policy') actor = actors.RecurrentActor(policy, jax.random.PRNGKey(1), initial_state, variable_client, has_extras=has_extras) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def test_update_and_wait(self): # Create a variable source (emulating the learner). np_learner_variables = tf2_utils.to_numpy(self._learner_model.variables) variable_source = fakes.VariableSource(np_learner_variables) # Create a variable client (emulating the actor). variable_client = tf2_variable_utils.VariableClient( variable_source, {'policy': self._actor_model.variables}) # Create some random batch of test input: x = tf.random.normal(shape=(_BATCH_SIZE, _INPUT_SIZE)) # Before copying variables, the models have different outputs. self.assertNotAllClose(self._actor_model(x), self._learner_model(x)) # Update the variable client. variable_client.update_and_wait() # After copying variables (by updating the client), the models are the same. self.assertAllClose(self._actor_model(x), self._learner_model(x))
def test_feedforward(self, has_extras): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) def policy(inputs: jnp.ndarray): action_values = hk.Sequential([ hk.Flatten(), hk.Linear(env_spec.actions.num_values), ])(inputs) action = jnp.argmax(action_values, axis=-1) if has_extras: return action, (action_values, ) else: return action policy = hk.transform(policy) rng = hk.PRNGSequence(1) dummy_obs = utils.add_batch_dim(utils.zeros_like( env_spec.observations)) params = policy.init(next(rng), dummy_obs) variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient( variable_source, 'policy') if has_extras: actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core( policy.apply) else: actor_core = actor_core_lib.batched_feed_forward_to_actor_core( policy.apply) actor = actors.GenericActor(actor_core, random_key=jax.random.PRNGKey(1), variable_client=variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def test_recurrent(self): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) output_size = env_spec.actions.num_values obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) rng = hk.PRNGSequence(1) @hk.transform def network(inputs: jnp.ndarray, state: hk.LSTMState): return hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])(inputs, state) @hk.transform def initial_state(batch_size: int): network = hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)]) return network.initial_state(batch_size) initial_state = initial_state.apply(initial_state.init(next(rng), 1), 1) params = network.init(next(rng), obs, initial_state) def policy( params: jnp.ndarray, key: jnp.ndarray, observation: jnp.ndarray, core_state: hk.LSTMState ) -> Tuple[jnp.ndarray, hk.LSTMState]: del key # Unused for test-case deterministic policy. action_values, core_state = network.apply(params, observation, core_state) return jnp.argmax(action_values, axis=-1), core_state variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient(variable_source, 'policy') actor = actors.RecurrentActor( policy, hk.PRNGSequence(1), initial_state, variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def test_update(self): # Create a barrier to be shared between the test body and the variable # source. The barrier will block until, in this case, two threads call # wait(). Note that the (fake) variable source will call it within its # get_variables() call. barrier = threading.Barrier(2) # Create a variable source (emulating the learner). np_learner_variables = tf2_utils.to_numpy( self._learner_model.variables) variable_source = fakes.VariableSource(np_learner_variables, barrier) # Create a variable client (emulating the actor). variable_client = tf2_variable_utils.VariableClient( variable_source, {'policy': self._actor_model.variables}, update_period=_UPDATE_PERIOD) # Create some random batch of test input: x = tf.random.normal(shape=(_BATCH_SIZE, _INPUT_SIZE)) # Create variables by doing the computation once. learner_output = self._learner_model(x) actor_output = self._actor_model(x) del learner_output, actor_output for _ in range(_UPDATE_PERIOD): # Before the update period is reached, the models have different outputs. self.assertNotAllClose(self._actor_model.variables, self._learner_model.variables) # Before the update period is reached, the variable client should not make # any requests for variables. self.assertIsNone(variable_client._future) variable_client.update() # Make sure the last call created a request for variables and reset the # internal call counter. self.assertIsNotNone(variable_client._future) self.assertEqual(variable_client._call_counter, 0) future = variable_client._future for _ in range(_UPDATE_PERIOD): # Before the barrier allows the variables to be released, the models have # different outputs. self.assertNotAllClose(self._actor_model.variables, self._learner_model.variables) variable_client.update() # Make sure no new requests are made. self.assertEqual(variable_client._future, future) # Calling wait() on the barrier will now allow the variables to be copied # over from source to client. barrier.wait() # Update once more to ensure the variables are copied over. while variable_client._future is not None: variable_client.update() # After a number of update calls, the variables should be the same. self.assertAllClose(self._actor_model.variables, self._learner_model.variables)