def test_integration(self): env = catch.Catch() action_spec = env.action_spec() num_actions = action_spec.num_values obs_spec = env.observation_spec() agent = agent_lib.Agent( num_actions=num_actions, obs_spec=obs_spec, net_factory=haiku_nets.CatchNet, ) unroll_length = 20 learner = learner_lib.Learner( agent=agent, rng_key=jax.random.PRNGKey(42), opt=optix.sgd(1e-2), batch_size=1, discount_factor=0.99, frames_per_iter=unroll_length, ) actor = actor_lib.Actor( agent=agent, env=env, learner=learner, unroll_length=unroll_length, ) frame_count, params = actor.pull_params() actor.unroll_and_push(frame_count=frame_count, params=params) learner.run(max_iterations=1)
def main(_): # A thunk that builds a new environment. # Substitute your environment here! build_env = catch.Catch # Construct the agent. We need a sample environment for its spec. env_for_spec = build_env() num_actions = env_for_spec.action_spec().num_values agent = agent_lib.Agent(num_actions, env_for_spec.observation_spec(), haiku_nets.CatchNet) # Construct the optimizer. max_updates = MAX_ENV_FRAMES / FRAMES_PER_ITER opt = optax.rmsprop(5e-3, decay=0.99, eps=1e-7) # Construct the learner. learner = learner_lib.Learner( agent, jax.random.PRNGKey(428), opt, BATCH_SIZE, DISCOUNT_FACTOR, FRAMES_PER_ITER, max_abs_reward=1., logger=util.AbslLogger(), # Provide your own logger here. ) # Construct the actors on different threads. # stop_signal in a list so the reference is shared. actor_threads = [] stop_signal = [False] for i in range(NUM_ACTORS): actor = actor_lib.Actor( agent, build_env(), UNROLL_LENGTH, learner, rng_seed=i, logger=util.AbslLogger(), # Provide your own logger here. ) args = (actor, stop_signal) actor_threads.append(threading.Thread(target=run_actor, args=args)) # Start the actors and learner. for t in actor_threads: t.start() learner.run(int(max_updates)) # Stop. stop_signal[0] = True for t in actor_threads: t.join()
def test_sync_params(self): mock_learner = mock.MagicMock() frame_count = 428 params = self.initial_params mock_learner.params_for_actor.return_value = frame_count, params traj_len = 10 actor = actor_lib.Actor( agent=self.agent, env=self.env, learner=mock_learner, unroll_length=traj_len, ) received_frame_count, received_params = actor.pull_params() self.assertEqual(received_frame_count, frame_count) tree.assert_same_structure(received_params, params) tree.map_structure(np.testing.assert_array_almost_equal, received_params, params)
def test_unroll(self): mock_learner = mock.MagicMock() traj_len = 10 actor = actor_lib.Actor( agent=self.agent, env=self.env, learner=mock_learner, unroll_length=traj_len, ) self.key, subkey = jax.random.split(self.key) act_out = actor.unroll( rng_key=subkey, frame_count=0, params=self.initial_params, unroll_length=traj_len) self.assertIsInstance(act_out, util.Transition) self.assertIsInstance(act_out.timestep, dm_env.TimeStep) self.assertLen(act_out.timestep.reward.shape, 1) self.assertEqual(act_out.timestep.reward.shape, (traj_len + 1,)) self.assertLen(act_out.timestep.discount.shape, 1) self.assertEqual(act_out.timestep.discount.shape, (traj_len + 1,)) self.assertLen(act_out.timestep.step_type.shape, 1) self.assertEqual(act_out.timestep.step_type.shape, (traj_len + 1,)) self.assertLen(act_out.timestep.observation.shape, 3) self.assertEqual(act_out.timestep.observation.shape, (traj_len + 1,) + self.obs_spec.shape) self.assertIsInstance(act_out.agent_out, agent_lib.AgentOutput) self.assertLen(act_out.agent_out.action.shape, 1) self.assertEqual(act_out.agent_out.action.shape, (traj_len + 1,)) self.assertLen(act_out.agent_out.policy_logits.shape, 2) self.assertEqual(act_out.agent_out.policy_logits.shape, (traj_len + 1, self.num_actions)) self.assertLen(act_out.agent_out.values.shape, 1) self.assertEqual(act_out.agent_out.values.shape, (traj_len + 1,)) self.assertEqual(act_out.agent_state.shape, (traj_len + 1,))
def test_unroll_and_push(self): traj_len = 3 mock_learner = mock.create_autospec(learner_lib.Learner, instance=True) actor = actor_lib.Actor( agent=self.agent, env=self.env, learner=mock_learner, unroll_length=traj_len, ) actor.unroll_and_push(0, self.initial_params) mock_learner.enqueue_traj.assert_called_once() act_out = mock_learner.enqueue_traj.call_args[0][0] self.assertIsInstance(act_out, util.Transition) self.assertIsInstance(act_out.timestep, dm_env.TimeStep) self.assertLen(act_out.timestep.reward.shape, 1) self.assertEqual(act_out.timestep.reward.shape, (traj_len + 1, )) self.assertLen(act_out.timestep.discount.shape, 1) self.assertEqual(act_out.timestep.discount.shape, (traj_len + 1, )) self.assertLen(act_out.timestep.step_type.shape, 1) self.assertEqual(act_out.timestep.step_type.shape, (traj_len + 1, )) self.assertLen(act_out.timestep.observation.shape, 3) self.assertEqual(act_out.timestep.observation.shape, (traj_len + 1, ) + self.obs_spec.shape) self.assertIsInstance(act_out.agent_out, agent_lib.AgentOutput) self.assertLen(act_out.agent_out.action.shape, 1) self.assertEqual(act_out.agent_out.action.shape, (traj_len + 1, )) self.assertLen(act_out.agent_out.policy_logits.shape, 2) self.assertEqual(act_out.agent_out.policy_logits.shape, (traj_len + 1, self.num_actions)) self.assertLen(act_out.agent_out.values.shape, 1) self.assertEqual(act_out.agent_out.values.shape, (traj_len + 1, )) self.assertEqual(act_out.agent_state.shape, (traj_len + 1, ))