def test_integration(self): env = catch.Catch() action_spec = env.action_spec() num_actions = action_spec.num_values obs_spec = env.observation_spec() agent = agent_lib.Agent( num_actions=num_actions, obs_spec=obs_spec, net_factory=haiku_nets.CatchNet, ) unroll_length = 20 learner = learner_lib.Learner( agent=agent, rng_key=jax.random.PRNGKey(42), opt=optix.sgd(1e-2), batch_size=1, discount_factor=0.99, frames_per_iter=unroll_length, ) actor = actor_lib.Actor( agent=agent, env=env, unroll_length=unroll_length, ) frame_count, params = learner.params_for_actor() act_out = actor.unroll_and_push(frame_count=frame_count, params=params) learner.enqueue_traj(act_out) learner.run(max_iterations=1)
def setup_actors(num_actors): """Setup actor threads for the execution.""" # A thunk that builds a new environment. # Substitute your environment here! build_env = catch.Catch # Construct the agent. We need a sample environment for its spec. env_for_spec = build_env() num_actions = env_for_spec.action_spec().num_values agent = agent_lib.Agent(num_actions, env_for_spec.observation_spec(), haiku_nets.CatchNet) # Construct the actors on different threads. # stop_signal in a list so the reference is shared. actor_threads = [] for i in range(num_actors): actor = actor_lib.Actor( agent, build_env(), UNROLL_LENGTH, rng_seed=i, logger=util.AbslLogger(), # Provide your own logger here. ) args = (actor, ) actor_threads.append(threading.Thread(target=run_actor, args=args)) return actor_threads
def test_encode_weights(self): env = catch.Catch() action_spec = env.action_spec() num_actions = action_spec.num_values obs_spec = env.observation_spec() agent = agent_lib.Agent( num_actions=num_actions, obs_spec=obs_spec, net_factory=haiku_nets.CatchNet, ) unroll_length = 20 learner = learner_lib.Learner( agent=agent, rng_key=jax.random.PRNGKey(42), opt=optix.sgd(1e-2), batch_size=1, discount_factor=0.99, frames_per_iter=unroll_length, ) actor = actor_lib.Actor( agent=agent, env=env, unroll_length=unroll_length, ) frame_count, params = learner.params_for_actor() proto_weight = util.proto3_weight_encoder(frame_count, params) decoded_frame_count, decoded_params = \ util.proto3_weight_decoder(proto_weight) self.assertEqual(frame_count, decoded_frame_count) np.testing.assert_almost_equal(decoded_params["catch_net/linear"]["w"], params["catch_net/linear"]["w"]) act_out = actor.unroll_and_push(frame_count, params)
def setup_learner(): """Setup learner for distributed setting""" # A thunk that builds a new environment. # Substitute your environment here! build_env = catch.Catch # Construct the agent. We need a sample environment for its spec. env_for_spec = build_env() num_actions = env_for_spec.action_spec().num_values agent = agent_lib.Agent(num_actions, env_for_spec.observation_spec(), haiku_nets.CatchNet) # Construct the optimizer. opt = optix.rmsprop(1e-1, decay=0.99, eps=0.1) # Construct the learner. learner = learner_lib.Learner( agent, jax.random.PRNGKey(428), opt, BATCH_SIZE, DISCOUNT_FACTOR, FRAMES_PER_ITER, max_abs_reward=1., logger=util.AbslLogger(), # Provide your own logger here. ) return learner
def setUp(self): super(CatchTest, self).setUp() self.env = catch.Catch() self.action_spec = self.env.action_spec() self.num_actions = self.action_spec.num_values self.obs_spec = self.env.observation_spec() self.agent = agent_lib.Agent( num_actions=self.num_actions, obs_spec=self.obs_spec, net_factory=haiku_nets.CatchNet, ) self.key = jax.random.PRNGKey(42) self.key, subkey = jax.random.split(self.key) self.initial_params = self.agent.initial_params(subkey)