def test_integration(self): env = catch.Catch() action_spec = env.action_spec() num_actions = action_spec.num_values obs_spec = env.observation_spec() agent = agent_lib.Agent( num_actions=num_actions, obs_spec=obs_spec, net_factory=haiku_nets.CatchNet, ) unroll_length = 20 learner = learner_lib.Learner( agent=agent, rng_key=jax.random.PRNGKey(42), opt=optix.sgd(1e-2), batch_size=1, discount_factor=0.99, frames_per_iter=unroll_length, ) actor = actor_lib.Actor( agent=agent, env=env, learner=learner, unroll_length=unroll_length, ) frame_count, params = actor.pull_params() actor.unroll_and_push(frame_count=frame_count, params=params) learner.run(max_iterations=1)
def main(_): # A thunk that builds a new environment. # Substitute your environment here! build_env = catch.Catch # Construct the agent. We need a sample environment for its spec. env_for_spec = build_env() num_actions = env_for_spec.action_spec().num_values agent = agent_lib.Agent(num_actions, env_for_spec.observation_spec(), haiku_nets.CatchNet) # Construct the optimizer. max_updates = MAX_ENV_FRAMES / FRAMES_PER_ITER opt = optax.rmsprop(5e-3, decay=0.99, eps=1e-7) # Construct the learner. learner = learner_lib.Learner( agent, jax.random.PRNGKey(428), opt, BATCH_SIZE, DISCOUNT_FACTOR, FRAMES_PER_ITER, max_abs_reward=1., logger=util.AbslLogger(), # Provide your own logger here. ) # Construct the actors on different threads. # stop_signal in a list so the reference is shared. actor_threads = [] stop_signal = [False] for i in range(NUM_ACTORS): actor = actor_lib.Actor( agent, build_env(), UNROLL_LENGTH, learner, rng_seed=i, logger=util.AbslLogger(), # Provide your own logger here. ) args = (actor, stop_signal) actor_threads.append(threading.Thread(target=run_actor, args=args)) # Start the actors and learner. for t in actor_threads: t.start() learner.run(int(max_updates)) # Stop. stop_signal[0] = True for t in actor_threads: t.join()
def setUp(self): super(CatchTest, self).setUp() self.env = catch.Catch() self.action_spec = self.env.action_spec() self.num_actions = self.action_spec.num_values self.obs_spec = self.env.observation_spec() self.agent = agent_lib.Agent( num_actions=self.num_actions, obs_spec=self.obs_spec, net_factory=haiku_nets.CatchNet, ) self.key = jax.random.PRNGKey(42) self.key, subkey = jax.random.split(self.key) self.initial_params = self.agent.initial_params(subkey)