Example #1
0
 def test_integration(self):
     env = catch.Catch()
     action_spec = env.action_spec()
     num_actions = action_spec.num_values
     obs_spec = env.observation_spec()
     agent = agent_lib.Agent(
         num_actions=num_actions,
         obs_spec=obs_spec,
         net_factory=haiku_nets.CatchNet,
     )
     unroll_length = 20
     learner = learner_lib.Learner(
         agent=agent,
         rng_key=jax.random.PRNGKey(42),
         opt=optix.sgd(1e-2),
         batch_size=1,
         discount_factor=0.99,
         frames_per_iter=unroll_length,
     )
     actor = actor_lib.Actor(
         agent=agent,
         env=env,
         learner=learner,
         unroll_length=unroll_length,
     )
     frame_count, params = actor.pull_params()
     actor.unroll_and_push(frame_count=frame_count, params=params)
     learner.run(max_iterations=1)
Example #2
0
def main(_):
    # A thunk that builds a new environment.
    # Substitute your environment here!
    build_env = catch.Catch

    # Construct the agent. We need a sample environment for its spec.
    env_for_spec = build_env()
    num_actions = env_for_spec.action_spec().num_values
    agent = agent_lib.Agent(num_actions, env_for_spec.observation_spec(),
                            haiku_nets.CatchNet)

    # Construct the optimizer.
    max_updates = MAX_ENV_FRAMES / FRAMES_PER_ITER
    opt = optax.rmsprop(5e-3, decay=0.99, eps=1e-7)

    # Construct the learner.
    learner = learner_lib.Learner(
        agent,
        jax.random.PRNGKey(428),
        opt,
        BATCH_SIZE,
        DISCOUNT_FACTOR,
        FRAMES_PER_ITER,
        max_abs_reward=1.,
        logger=util.AbslLogger(),  # Provide your own logger here.
    )

    # Construct the actors on different threads.
    # stop_signal in a list so the reference is shared.
    actor_threads = []
    stop_signal = [False]
    for i in range(NUM_ACTORS):
        actor = actor_lib.Actor(
            agent,
            build_env(),
            UNROLL_LENGTH,
            learner,
            rng_seed=i,
            logger=util.AbslLogger(),  # Provide your own logger here.
        )
        args = (actor, stop_signal)
        actor_threads.append(threading.Thread(target=run_actor, args=args))

    # Start the actors and learner.
    for t in actor_threads:
        t.start()
    learner.run(int(max_updates))

    # Stop.
    stop_signal[0] = True
    for t in actor_threads:
        t.join()
Example #3
0
    def setUp(self):
        super(CatchTest, self).setUp()
        self.env = catch.Catch()
        self.action_spec = self.env.action_spec()
        self.num_actions = self.action_spec.num_values
        self.obs_spec = self.env.observation_spec()
        self.agent = agent_lib.Agent(
            num_actions=self.num_actions,
            obs_spec=self.obs_spec,
            net_factory=haiku_nets.CatchNet,
        )

        self.key = jax.random.PRNGKey(42)
        self.key, subkey = jax.random.split(self.key)
        self.initial_params = self.agent.initial_params(subkey)