Exemplo n.º 1
0
 def test_integration(self):
     env = catch.Catch()
     action_spec = env.action_spec()
     num_actions = action_spec.num_values
     obs_spec = env.observation_spec()
     agent = agent_lib.Agent(
         num_actions=num_actions,
         obs_spec=obs_spec,
         net_factory=haiku_nets.CatchNet,
     )
     unroll_length = 20
     learner = learner_lib.Learner(
         agent=agent,
         rng_key=jax.random.PRNGKey(42),
         opt=optix.sgd(1e-2),
         batch_size=1,
         discount_factor=0.99,
         frames_per_iter=unroll_length,
     )
     actor = actor_lib.Actor(
         agent=agent,
         env=env,
         unroll_length=unroll_length,
     )
     frame_count, params = learner.params_for_actor()
     act_out = actor.unroll_and_push(frame_count=frame_count, params=params)
     learner.enqueue_traj(act_out)
     learner.run(max_iterations=1)
Exemplo n.º 2
0
def setup_actors(num_actors):
    """Setup actor threads for the execution."""
    # A thunk that builds a new environment.
    # Substitute your environment here!
    build_env = catch.Catch

    # Construct the agent. We need a sample environment for its spec.
    env_for_spec = build_env()
    num_actions = env_for_spec.action_spec().num_values
    agent = agent_lib.Agent(num_actions, env_for_spec.observation_spec(),
                            haiku_nets.CatchNet)

    # Construct the actors on different threads.
    # stop_signal in a list so the reference is shared.
    actor_threads = []
    for i in range(num_actors):
        actor = actor_lib.Actor(
            agent,
            build_env(),
            UNROLL_LENGTH,
            rng_seed=i,
            logger=util.AbslLogger(),  # Provide your own logger here.
        )
        args = (actor, )
        actor_threads.append(threading.Thread(target=run_actor, args=args))
    return actor_threads
Exemplo n.º 3
0
 def test_encode_weights(self):
     env = catch.Catch()
     action_spec = env.action_spec()
     num_actions = action_spec.num_values
     obs_spec = env.observation_spec()
     agent = agent_lib.Agent(
         num_actions=num_actions,
         obs_spec=obs_spec,
         net_factory=haiku_nets.CatchNet,
     )
     unroll_length = 20
     learner = learner_lib.Learner(
         agent=agent,
         rng_key=jax.random.PRNGKey(42),
         opt=optix.sgd(1e-2),
         batch_size=1,
         discount_factor=0.99,
         frames_per_iter=unroll_length,
     )
     actor = actor_lib.Actor(
         agent=agent,
         env=env,
         unroll_length=unroll_length,
     )
     frame_count, params = learner.params_for_actor()
     proto_weight = util.proto3_weight_encoder(frame_count, params)
     decoded_frame_count, decoded_params = \
       util.proto3_weight_decoder(proto_weight)
     self.assertEqual(frame_count, decoded_frame_count)
     np.testing.assert_almost_equal(decoded_params["catch_net/linear"]["w"],
                                    params["catch_net/linear"]["w"])
     act_out = actor.unroll_and_push(frame_count, params)
def setup_learner():
    """Setup learner for distributed setting"""
    # A thunk that builds a new environment.
    # Substitute your environment here!
    build_env = catch.Catch

    # Construct the agent. We need a sample environment for its spec.
    env_for_spec = build_env()
    num_actions = env_for_spec.action_spec().num_values
    agent = agent_lib.Agent(num_actions, env_for_spec.observation_spec(),
                            haiku_nets.CatchNet)

    # Construct the optimizer.
    opt = optix.rmsprop(1e-1, decay=0.99, eps=0.1)

    # Construct the learner.
    learner = learner_lib.Learner(
        agent,
        jax.random.PRNGKey(428),
        opt,
        BATCH_SIZE,
        DISCOUNT_FACTOR,
        FRAMES_PER_ITER,
        max_abs_reward=1.,
        logger=util.AbslLogger(),  # Provide your own logger here.
    )
    return learner
Exemplo n.º 5
0
    def setUp(self):
        super(CatchTest, self).setUp()
        self.env = catch.Catch()
        self.action_spec = self.env.action_spec()
        self.num_actions = self.action_spec.num_values
        self.obs_spec = self.env.observation_spec()
        self.agent = agent_lib.Agent(
            num_actions=self.num_actions,
            obs_spec=self.obs_spec,
            net_factory=haiku_nets.CatchNet,
        )

        self.key = jax.random.PRNGKey(42)
        self.key, subkey = jax.random.split(self.key)
        self.initial_params = self.agent.initial_params(subkey)