예제 #1
0
 def test_integration(self):
     env = catch.Catch()
     action_spec = env.action_spec()
     num_actions = action_spec.num_values
     obs_spec = env.observation_spec()
     agent = agent_lib.Agent(
         num_actions=num_actions,
         obs_spec=obs_spec,
         net_factory=haiku_nets.CatchNet,
     )
     unroll_length = 20
     learner = learner_lib.Learner(
         agent=agent,
         rng_key=jax.random.PRNGKey(42),
         opt=optix.sgd(1e-2),
         batch_size=1,
         discount_factor=0.99,
         frames_per_iter=unroll_length,
     )
     actor = actor_lib.Actor(
         agent=agent,
         env=env,
         unroll_length=unroll_length,
     )
     frame_count, params = learner.params_for_actor()
     act_out = actor.unroll_and_push(frame_count=frame_count, params=params)
     learner.enqueue_traj(act_out)
     learner.run(max_iterations=1)
예제 #2
0
def run_actor(
    agent: Agent,
    rng_key: jnp.ndarray,
    get_params: Callable[[], hk.Params],
    enqueue_traj: Callable[[Transition], None],
    unroll_len: int,
    num_trajectories: int,
):
  """Runs an actor to produce num_trajectories trajectories."""
  env = catch.Catch()
  state = env.reset()
  traj = []

  for i in range(num_trajectories):
    params = get_params()
    # The first rollout is one step longer.
    for _ in range(unroll_len + int(i == 0)):
      rng_key, step_key = jax.random.split(rng_key)
      state = preprocess_step(state)
      action, logits = agent.step(params, step_key, state)
      transition = Transition(state, action, logits)
      traj.append(transition)
      state = env.step(action)
      if state.step_type == dm_env.StepType.LAST:
        logging.log_every_n(logging.INFO, 'Episode ended with reward: %s', 5,
                            state.reward)

    # Stack and send the trajectory.
    stacked_traj = jax.tree_multimap(lambda *ts: np.stack(ts), *traj)
    enqueue_traj(stacked_traj)
    # Reset the trajectory, keeping the last timestep.
    traj = traj[-1:]
예제 #3
0
 def test_encode_weights(self):
     env = catch.Catch()
     action_spec = env.action_spec()
     num_actions = action_spec.num_values
     obs_spec = env.observation_spec()
     agent = agent_lib.Agent(
         num_actions=num_actions,
         obs_spec=obs_spec,
         net_factory=haiku_nets.CatchNet,
     )
     unroll_length = 20
     learner = learner_lib.Learner(
         agent=agent,
         rng_key=jax.random.PRNGKey(42),
         opt=optix.sgd(1e-2),
         batch_size=1,
         discount_factor=0.99,
         frames_per_iter=unroll_length,
     )
     actor = actor_lib.Actor(
         agent=agent,
         env=env,
         unroll_length=unroll_length,
     )
     frame_count, params = learner.params_for_actor()
     proto_weight = util.proto3_weight_encoder(frame_count, params)
     decoded_frame_count, decoded_params = \
       util.proto3_weight_decoder(proto_weight)
     self.assertEqual(frame_count, decoded_frame_count)
     np.testing.assert_almost_equal(decoded_params["catch_net/linear"]["w"],
                                    params["catch_net/linear"]["w"])
     act_out = actor.unroll_and_push(frame_count, params)
예제 #4
0
def load(reward_scale, seed):
    """Load a catch experiment with the prescribed settings."""
    env = wrappers.RewardScale(env=catch.Catch(seed=seed),
                               reward_scale=reward_scale,
                               seed=seed)
    env.bsuite_num_episodes = sweep.NUM_EPISODES
    return env
예제 #5
0
    def setUp(self):
        super(CatchTest, self).setUp()
        self.env = catch.Catch()
        self.action_spec = self.env.action_spec()
        self.num_actions = self.action_spec.num_values
        self.obs_spec = self.env.observation_spec()
        self.agent = agent_lib.Agent(
            num_actions=self.num_actions,
            obs_spec=self.obs_spec,
            net_factory=haiku_nets.CatchNet,
        )

        self.key = jax.random.PRNGKey(42)
        self.key, subkey = jax.random.split(self.key)
        self.initial_params = self.agent.initial_params(subkey)
예제 #6
0
def main(_):

  # Construct the agent network. We need a sample environment for its spec.
  env = catch.Catch()
  num_actions = env.action_spec().num_values
  net = hk.transform(lambda ts: SimpleNet(num_actions)(ts))  # pylint: disable=unnecessary-lambda

  # Construct the agent and learner.
  agent = Agent(net.apply)
  opt = optix.rmsprop(1e-1, decay=0.99, eps=0.1)
  learner = Learner(agent, opt.update)

  # Initialize the optimizer state.
  sample_ts = env.reset()
  sample_ts = preprocess_step(sample_ts)
  ts_with_batch = jax.tree_map(lambda t: np.expand_dims(t, 0), sample_ts)
  params = jax.jit(net.init)(jax.random.PRNGKey(428), ts_with_batch)
  opt_state = opt.init(params)

  # Create accessor and queueing functions.
  current_params = lambda: params
  batch_size = 2
  q = queue.Queue(maxsize=batch_size)

  def dequeue():
    batch = []
    for _ in range(batch_size):
      batch.append(q.get())
    batch = jax.tree_multimap(lambda *ts: np.stack(ts, axis=1), *batch)
    return jax.device_put(batch)

  # Start the actors.
  num_actors = 2
  trajectories_per_actor = 500
  unroll_len = 20
  for i in range(num_actors):
    key = jax.random.PRNGKey(i)
    args = (agent, key, current_params, q.put, unroll_len,
            trajectories_per_actor)
    threading.Thread(target=run_actor, args=args).start()

  # Run the learner.
  num_steps = num_actors * trajectories_per_actor // batch_size
  for i in range(num_steps):
    traj = dequeue()
    params, opt_state = learner.update(params, opt_state, traj)
예제 #7
0
def main_loop(unused_arg):
  env = catch.Catch(seed=FLAGS.seed)
  rng = hk.PRNGSequence(jax.random.PRNGKey(FLAGS.seed))

  # Build and initialize Q-network.
  num_actions = env.action_spec().num_values
  network = build_network(num_actions)
  sample_input = env.observation_spec().generate_value()
  net_params = network.init(next(rng), sample_input)

  # Build and initialize optimizer.
  optimizer = optix.adam(FLAGS.learning_rate)
  opt_state = optimizer.init(net_params)

  @jax.jit
  def policy(net_params, key, obs):
    """Sample action from epsilon-greedy policy."""
    q = network.apply(net_params, obs)
    a = rlax.epsilon_greedy(epsilon=FLAGS.epsilon).sample(key, q)
    return q, a

  @jax.jit
  def eval_policy(net_params, key, obs):
    """Sample action from greedy policy."""
    q = network.apply(net_params, obs)
    return rlax.greedy().sample(key, q)

  @jax.jit
  def update(net_params, opt_state, obs_tm1, a_tm1, r_t, discount_t, q_t):
    """Update network weights wrt Q-learning loss."""

    def q_learning_loss(net_params, obs_tm1, a_tm1, r_t, discount_t, q_t):
      q_tm1 = network.apply(net_params, obs_tm1)
      td_error = rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t)
      return rlax.l2_loss(td_error)

    dloss_dtheta = jax.grad(q_learning_loss)(net_params, obs_tm1, a_tm1, r_t,
                                             discount_t, q_t)
    updates, opt_state = optimizer.update(dloss_dtheta, opt_state)
    net_params = optix.apply_updates(net_params, updates)
    return net_params, opt_state

  print(f"Training agent for {FLAGS.train_episodes} episodes...")
  for _ in range(FLAGS.train_episodes):
    timestep = env.reset()
    obs_tm1 = timestep.observation

    _, a_tm1 = policy(net_params, next(rng), obs_tm1)

    while not timestep.last():
      new_timestep = env.step(int(a_tm1))
      obs_t = new_timestep.observation

      # Sample action from agent policy.
      q_t, a_t = policy(net_params, next(rng), obs_t)

      # Update Q-values.
      r_t = new_timestep.reward
      discount_t = FLAGS.discount_factor * new_timestep.discount
      net_params, opt_state = update(net_params, opt_state, obs_tm1, a_tm1, r_t,
                                     discount_t, q_t)

      timestep = new_timestep
      obs_tm1 = obs_t
      a_tm1 = a_t

  print(f"Evaluating agent for {FLAGS.eval_episodes} episodes...")
  returns = 0.
  for _ in range(FLAGS.eval_episodes):
    timestep = env.reset()
    obs = timestep.observation

    while not timestep.last():
      action = eval_policy(net_params, next(rng), obs)
      timestep = env.step(int(action))
      obs = timestep.observation
      returns += timestep.reward

  avg_returns = returns / FLAGS.eval_episodes
  print(f"Done! Average returns: {avg_returns} (range [-1.0, 1.0])")
예제 #8
0
 def make_object_under_test(self):
     env = catch.Catch()
     return wrappers.ImageObservation(env, (84, 84, 4))
예제 #9
0
 def make_object_under_test(self):
   return catch.Catch(rows=10, columns=5)