def test_integration(self): env = catch.Catch() action_spec = env.action_spec() num_actions = action_spec.num_values obs_spec = env.observation_spec() agent = agent_lib.Agent( num_actions=num_actions, obs_spec=obs_spec, net_factory=haiku_nets.CatchNet, ) unroll_length = 20 learner = learner_lib.Learner( agent=agent, rng_key=jax.random.PRNGKey(42), opt=optix.sgd(1e-2), batch_size=1, discount_factor=0.99, frames_per_iter=unroll_length, ) actor = actor_lib.Actor( agent=agent, env=env, unroll_length=unroll_length, ) frame_count, params = learner.params_for_actor() act_out = actor.unroll_and_push(frame_count=frame_count, params=params) learner.enqueue_traj(act_out) learner.run(max_iterations=1)
def run_actor( agent: Agent, rng_key: jnp.ndarray, get_params: Callable[[], hk.Params], enqueue_traj: Callable[[Transition], None], unroll_len: int, num_trajectories: int, ): """Runs an actor to produce num_trajectories trajectories.""" env = catch.Catch() state = env.reset() traj = [] for i in range(num_trajectories): params = get_params() # The first rollout is one step longer. for _ in range(unroll_len + int(i == 0)): rng_key, step_key = jax.random.split(rng_key) state = preprocess_step(state) action, logits = agent.step(params, step_key, state) transition = Transition(state, action, logits) traj.append(transition) state = env.step(action) if state.step_type == dm_env.StepType.LAST: logging.log_every_n(logging.INFO, 'Episode ended with reward: %s', 5, state.reward) # Stack and send the trajectory. stacked_traj = jax.tree_multimap(lambda *ts: np.stack(ts), *traj) enqueue_traj(stacked_traj) # Reset the trajectory, keeping the last timestep. traj = traj[-1:]
def test_encode_weights(self): env = catch.Catch() action_spec = env.action_spec() num_actions = action_spec.num_values obs_spec = env.observation_spec() agent = agent_lib.Agent( num_actions=num_actions, obs_spec=obs_spec, net_factory=haiku_nets.CatchNet, ) unroll_length = 20 learner = learner_lib.Learner( agent=agent, rng_key=jax.random.PRNGKey(42), opt=optix.sgd(1e-2), batch_size=1, discount_factor=0.99, frames_per_iter=unroll_length, ) actor = actor_lib.Actor( agent=agent, env=env, unroll_length=unroll_length, ) frame_count, params = learner.params_for_actor() proto_weight = util.proto3_weight_encoder(frame_count, params) decoded_frame_count, decoded_params = \ util.proto3_weight_decoder(proto_weight) self.assertEqual(frame_count, decoded_frame_count) np.testing.assert_almost_equal(decoded_params["catch_net/linear"]["w"], params["catch_net/linear"]["w"]) act_out = actor.unroll_and_push(frame_count, params)
def load(reward_scale, seed): """Load a catch experiment with the prescribed settings.""" env = wrappers.RewardScale(env=catch.Catch(seed=seed), reward_scale=reward_scale, seed=seed) env.bsuite_num_episodes = sweep.NUM_EPISODES return env
def setUp(self): super(CatchTest, self).setUp() self.env = catch.Catch() self.action_spec = self.env.action_spec() self.num_actions = self.action_spec.num_values self.obs_spec = self.env.observation_spec() self.agent = agent_lib.Agent( num_actions=self.num_actions, obs_spec=self.obs_spec, net_factory=haiku_nets.CatchNet, ) self.key = jax.random.PRNGKey(42) self.key, subkey = jax.random.split(self.key) self.initial_params = self.agent.initial_params(subkey)
def main(_): # Construct the agent network. We need a sample environment for its spec. env = catch.Catch() num_actions = env.action_spec().num_values net = hk.transform(lambda ts: SimpleNet(num_actions)(ts)) # pylint: disable=unnecessary-lambda # Construct the agent and learner. agent = Agent(net.apply) opt = optix.rmsprop(1e-1, decay=0.99, eps=0.1) learner = Learner(agent, opt.update) # Initialize the optimizer state. sample_ts = env.reset() sample_ts = preprocess_step(sample_ts) ts_with_batch = jax.tree_map(lambda t: np.expand_dims(t, 0), sample_ts) params = jax.jit(net.init)(jax.random.PRNGKey(428), ts_with_batch) opt_state = opt.init(params) # Create accessor and queueing functions. current_params = lambda: params batch_size = 2 q = queue.Queue(maxsize=batch_size) def dequeue(): batch = [] for _ in range(batch_size): batch.append(q.get()) batch = jax.tree_multimap(lambda *ts: np.stack(ts, axis=1), *batch) return jax.device_put(batch) # Start the actors. num_actors = 2 trajectories_per_actor = 500 unroll_len = 20 for i in range(num_actors): key = jax.random.PRNGKey(i) args = (agent, key, current_params, q.put, unroll_len, trajectories_per_actor) threading.Thread(target=run_actor, args=args).start() # Run the learner. num_steps = num_actors * trajectories_per_actor // batch_size for i in range(num_steps): traj = dequeue() params, opt_state = learner.update(params, opt_state, traj)
def main_loop(unused_arg): env = catch.Catch(seed=FLAGS.seed) rng = hk.PRNGSequence(jax.random.PRNGKey(FLAGS.seed)) # Build and initialize Q-network. num_actions = env.action_spec().num_values network = build_network(num_actions) sample_input = env.observation_spec().generate_value() net_params = network.init(next(rng), sample_input) # Build and initialize optimizer. optimizer = optix.adam(FLAGS.learning_rate) opt_state = optimizer.init(net_params) @jax.jit def policy(net_params, key, obs): """Sample action from epsilon-greedy policy.""" q = network.apply(net_params, obs) a = rlax.epsilon_greedy(epsilon=FLAGS.epsilon).sample(key, q) return q, a @jax.jit def eval_policy(net_params, key, obs): """Sample action from greedy policy.""" q = network.apply(net_params, obs) return rlax.greedy().sample(key, q) @jax.jit def update(net_params, opt_state, obs_tm1, a_tm1, r_t, discount_t, q_t): """Update network weights wrt Q-learning loss.""" def q_learning_loss(net_params, obs_tm1, a_tm1, r_t, discount_t, q_t): q_tm1 = network.apply(net_params, obs_tm1) td_error = rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t) return rlax.l2_loss(td_error) dloss_dtheta = jax.grad(q_learning_loss)(net_params, obs_tm1, a_tm1, r_t, discount_t, q_t) updates, opt_state = optimizer.update(dloss_dtheta, opt_state) net_params = optix.apply_updates(net_params, updates) return net_params, opt_state print(f"Training agent for {FLAGS.train_episodes} episodes...") for _ in range(FLAGS.train_episodes): timestep = env.reset() obs_tm1 = timestep.observation _, a_tm1 = policy(net_params, next(rng), obs_tm1) while not timestep.last(): new_timestep = env.step(int(a_tm1)) obs_t = new_timestep.observation # Sample action from agent policy. q_t, a_t = policy(net_params, next(rng), obs_t) # Update Q-values. r_t = new_timestep.reward discount_t = FLAGS.discount_factor * new_timestep.discount net_params, opt_state = update(net_params, opt_state, obs_tm1, a_tm1, r_t, discount_t, q_t) timestep = new_timestep obs_tm1 = obs_t a_tm1 = a_t print(f"Evaluating agent for {FLAGS.eval_episodes} episodes...") returns = 0. for _ in range(FLAGS.eval_episodes): timestep = env.reset() obs = timestep.observation while not timestep.last(): action = eval_policy(net_params, next(rng), obs) timestep = env.step(int(action)) obs = timestep.observation returns += timestep.reward avg_returns = returns / FLAGS.eval_episodes print(f"Done! Average returns: {avg_returns} (range [-1.0, 1.0])")
def make_object_under_test(self): env = catch.Catch() return wrappers.ImageObservation(env, (84, 84, 4))
def make_object_under_test(self): return catch.Catch(rows=10, columns=5)