예제 #1
0
 def test_integration(self):
     env = catch.Catch()
     action_spec = env.action_spec()
     num_actions = action_spec.num_values
     obs_spec = env.observation_spec()
     agent = agent_lib.Agent(
         num_actions=num_actions,
         obs_spec=obs_spec,
         net_factory=haiku_nets.CatchNet,
     )
     unroll_length = 20
     learner = learner_lib.Learner(
         agent=agent,
         rng_key=jax.random.PRNGKey(42),
         opt=optix.sgd(1e-2),
         batch_size=1,
         discount_factor=0.99,
         frames_per_iter=unroll_length,
     )
     actor = actor_lib.Actor(
         agent=agent,
         env=env,
         learner=learner,
         unroll_length=unroll_length,
     )
     frame_count, params = actor.pull_params()
     actor.unroll_and_push(frame_count=frame_count, params=params)
     learner.run(max_iterations=1)
예제 #2
0
def main(_):
    # A thunk that builds a new environment.
    # Substitute your environment here!
    build_env = catch.Catch

    # Construct the agent. We need a sample environment for its spec.
    env_for_spec = build_env()
    num_actions = env_for_spec.action_spec().num_values
    agent = agent_lib.Agent(num_actions, env_for_spec.observation_spec(),
                            haiku_nets.CatchNet)

    # Construct the optimizer.
    max_updates = MAX_ENV_FRAMES / FRAMES_PER_ITER
    opt = optax.rmsprop(5e-3, decay=0.99, eps=1e-7)

    # Construct the learner.
    learner = learner_lib.Learner(
        agent,
        jax.random.PRNGKey(428),
        opt,
        BATCH_SIZE,
        DISCOUNT_FACTOR,
        FRAMES_PER_ITER,
        max_abs_reward=1.,
        logger=util.AbslLogger(),  # Provide your own logger here.
    )

    # Construct the actors on different threads.
    # stop_signal in a list so the reference is shared.
    actor_threads = []
    stop_signal = [False]
    for i in range(NUM_ACTORS):
        actor = actor_lib.Actor(
            agent,
            build_env(),
            UNROLL_LENGTH,
            learner,
            rng_seed=i,
            logger=util.AbslLogger(),  # Provide your own logger here.
        )
        args = (actor, stop_signal)
        actor_threads.append(threading.Thread(target=run_actor, args=args))

    # Start the actors and learner.
    for t in actor_threads:
        t.start()
    learner.run(int(max_updates))

    # Stop.
    stop_signal[0] = True
    for t in actor_threads:
        t.join()
예제 #3
0
 def test_sync_params(self):
     mock_learner = mock.MagicMock()
     frame_count = 428
     params = self.initial_params
     mock_learner.params_for_actor.return_value = frame_count, params
     traj_len = 10
     actor = actor_lib.Actor(
         agent=self.agent,
         env=self.env,
         learner=mock_learner,
         unroll_length=traj_len,
     )
     received_frame_count, received_params = actor.pull_params()
     self.assertEqual(received_frame_count, frame_count)
     tree.assert_same_structure(received_params, params)
     tree.map_structure(np.testing.assert_array_almost_equal,
                        received_params, params)
예제 #4
0
  def test_unroll(self):
    mock_learner = mock.MagicMock()
    traj_len = 10
    actor = actor_lib.Actor(
        agent=self.agent,
        env=self.env,
        learner=mock_learner,
        unroll_length=traj_len,
    )
    self.key, subkey = jax.random.split(self.key)
    act_out = actor.unroll(
        rng_key=subkey,
        frame_count=0,
        params=self.initial_params,
        unroll_length=traj_len)

    self.assertIsInstance(act_out, util.Transition)
    self.assertIsInstance(act_out.timestep, dm_env.TimeStep)
    self.assertLen(act_out.timestep.reward.shape, 1)
    self.assertEqual(act_out.timestep.reward.shape, (traj_len + 1,))
    self.assertLen(act_out.timestep.discount.shape, 1)
    self.assertEqual(act_out.timestep.discount.shape, (traj_len + 1,))
    self.assertLen(act_out.timestep.step_type.shape, 1)
    self.assertEqual(act_out.timestep.step_type.shape, (traj_len + 1,))

    self.assertLen(act_out.timestep.observation.shape, 3)
    self.assertEqual(act_out.timestep.observation.shape,
                     (traj_len + 1,) + self.obs_spec.shape)

    self.assertIsInstance(act_out.agent_out, agent_lib.AgentOutput)
    self.assertLen(act_out.agent_out.action.shape, 1)
    self.assertEqual(act_out.agent_out.action.shape, (traj_len + 1,))

    self.assertLen(act_out.agent_out.policy_logits.shape, 2)
    self.assertEqual(act_out.agent_out.policy_logits.shape,
                     (traj_len + 1, self.num_actions))

    self.assertLen(act_out.agent_out.values.shape, 1)
    self.assertEqual(act_out.agent_out.values.shape, (traj_len + 1,))

    self.assertEqual(act_out.agent_state.shape, (traj_len + 1,))
예제 #5
0
    def test_unroll_and_push(self):
        traj_len = 3
        mock_learner = mock.create_autospec(learner_lib.Learner, instance=True)
        actor = actor_lib.Actor(
            agent=self.agent,
            env=self.env,
            learner=mock_learner,
            unroll_length=traj_len,
        )
        actor.unroll_and_push(0, self.initial_params)

        mock_learner.enqueue_traj.assert_called_once()
        act_out = mock_learner.enqueue_traj.call_args[0][0]

        self.assertIsInstance(act_out, util.Transition)
        self.assertIsInstance(act_out.timestep, dm_env.TimeStep)
        self.assertLen(act_out.timestep.reward.shape, 1)
        self.assertEqual(act_out.timestep.reward.shape, (traj_len + 1, ))
        self.assertLen(act_out.timestep.discount.shape, 1)
        self.assertEqual(act_out.timestep.discount.shape, (traj_len + 1, ))
        self.assertLen(act_out.timestep.step_type.shape, 1)
        self.assertEqual(act_out.timestep.step_type.shape, (traj_len + 1, ))

        self.assertLen(act_out.timestep.observation.shape, 3)
        self.assertEqual(act_out.timestep.observation.shape,
                         (traj_len + 1, ) + self.obs_spec.shape)

        self.assertIsInstance(act_out.agent_out, agent_lib.AgentOutput)
        self.assertLen(act_out.agent_out.action.shape, 1)
        self.assertEqual(act_out.agent_out.action.shape, (traj_len + 1, ))

        self.assertLen(act_out.agent_out.policy_logits.shape, 2)
        self.assertEqual(act_out.agent_out.policy_logits.shape,
                         (traj_len + 1, self.num_actions))

        self.assertLen(act_out.agent_out.values.shape, 1)
        self.assertEqual(act_out.agent_out.values.shape, (traj_len + 1, ))

        self.assertEqual(act_out.agent_state.shape, (traj_len + 1, ))