Exemple #1
0
def main(unused_arg):
    env = catch.Catch(seed=FLAGS.seed)
    epsilon_cfg = dict(init_value=FLAGS.epsilon_begin,
                       end_value=FLAGS.epsilon_end,
                       transition_steps=FLAGS.epsilon_steps,
                       power=1.)
    agent = DQN(
        observation_spec=env.observation_spec(),
        action_spec=env.action_spec(),
        epsilon_cfg=epsilon_cfg,
        target_period=FLAGS.target_period,
        learning_rate=FLAGS.learning_rate,
    )

    accumulator = ReplayBuffer(FLAGS.replay_capacity)
    experiment.run_loop(
        agent=agent,
        environment=env,
        accumulator=accumulator,
        seed=FLAGS.seed,
        batch_size=FLAGS.batch_size,
        train_episodes=FLAGS.train_episodes,
        evaluate_every=FLAGS.evaluate_every,
        eval_episodes=FLAGS.eval_episodes,
    )
Exemple #2
0
 def test_integration(self):
     env = catch.Catch()
     action_spec = env.action_spec()
     num_actions = action_spec.num_values
     obs_spec = env.observation_spec()
     agent = agent_lib.Agent(
         num_actions=num_actions,
         obs_spec=obs_spec,
         net_factory=haiku_nets.CatchNet,
     )
     unroll_length = 20
     learner = learner_lib.Learner(
         agent=agent,
         rng_key=jax.random.PRNGKey(42),
         opt=optix.sgd(1e-2),
         batch_size=1,
         discount_factor=0.99,
         frames_per_iter=unroll_length,
     )
     actor = actor_lib.Actor(
         agent=agent,
         env=env,
         learner=learner,
         unroll_length=unroll_length,
     )
     frame_count, params = actor.pull_params()
     actor.unroll_and_push(frame_count=frame_count, params=params)
     learner.run(max_iterations=1)
Exemple #3
0
def load(noise_scale, seed):
    """Load a catch experiment with the prescribed settings."""
    env = wrappers.RewardNoise(env=catch.Catch(seed=seed),
                               noise_scale=noise_scale,
                               seed=seed)
    env.bsuite_num_episodes = sweep.NUM_EPISODES
    return env
Exemple #4
0
    def test_checkpointing(self):
        """Tests whether checkpointing restores the state correctly."""
        # Given an environment, and a model based on this environment.
        model = simulator.Simulator(catch.Catch())
        num_actions = model.action_spec().num_values

        model.reset()

        # Now, we save a checkpoint.
        model.save_checkpoint()

        ts = model.step(1)

        # Step the model once and load the checkpoint.
        timestep = model.step(np.random.randint(num_actions))
        model.load_checkpoint()
        self._check_equal(ts, model.step(1))

        while not timestep.last():
            timestep = model.step(np.random.randint(num_actions))

        # The model should require a reset.
        self.assertTrue(model.needs_reset)

        # Once we load checkpoint, the model should no longer require reset.
        model.load_checkpoint()
        self.assertFalse(model.needs_reset)

        # Further steps should agree with the original environment state.
        self._check_equal(ts, model.step(1))
Exemple #5
0
    def test_catch(self, policy_type: Text):
        env = catch.Catch(rows=2, seed=1)
        num_actions = env.action_spec().num_values
        model = simulator.Simulator(env)
        eval_fn = lambda _: (np.ones(num_actions) / num_actions, 0.)

        timestep = env.reset()
        model.reset()

        search_policy = search.bfs if policy_type == 'bfs' else search.puct

        root = search.mcts(observation=timestep.observation,
                           model=model,
                           search_policy=search_policy,
                           evaluation=eval_fn,
                           num_simulations=100,
                           num_actions=num_actions)

        values = np.array([c.value for c in root.children.values()])
        best_action = search.argmax(values)

        if env._paddle_x > env._ball_x:
            self.assertEqual(best_action, 0)
        if env._paddle_x == env._ball_x:
            self.assertEqual(best_action, 1)
        if env._paddle_x < env._ball_x:
            self.assertEqual(best_action, 2)
Exemple #6
0
def run_actor(
    agent: Agent,
    rng_key: jnp.ndarray,
    get_params: Callable[[], hk.Params],
    enqueue_traj: Callable[[Transition], None],
    unroll_len: int,
    num_trajectories: int,
):
    """Runs an actor to produce num_trajectories trajectories."""
    env = catch.Catch()
    state = env.reset()
    traj = []

    for i in range(num_trajectories):
        params = get_params()
        # The first rollout is one step longer.
        for _ in range(unroll_len + int(i == 0)):
            rng_key, step_key = jax.random.split(rng_key)
            state = preprocess_step(state)
            action, logits = agent.step(params, step_key, state)
            transition = Transition(state, action, logits)
            traj.append(transition)
            state = env.step(action)
            if state.step_type == dm_env.StepType.LAST:
                logging.log_every_n(logging.INFO,
                                    'Episode ended with reward: %s', 5,
                                    state.reward)

        # Stack and send the trajectory.
        stacked_traj = jax.tree_map(lambda *ts: np.stack(ts), *traj)
        enqueue_traj(stacked_traj)
        # Reset the trajectory, keeping the last timestep.
        traj = traj[-1:]
def main(unused_arg):
    env = catch.Catch(seed=FLAGS.seed)
    agent = OnlineQ(env.observation_spec(), env.action_spec(),
                    FLAGS.learning_rate, FLAGS.epsilon)

    accumulator = TransitionAccumulator()
    experiment.run_loop(agent, env, accumulator, FLAGS.seed, 1,
                        FLAGS.train_episodes, FLAGS.evaluate_every,
                        FLAGS.eval_episodes)
Exemple #8
0
    def setUp(self):
        super(CatchTest, self).setUp()
        self.env = catch.Catch()
        self.action_spec = self.env.action_spec()
        self.num_actions = self.action_spec.num_values
        self.obs_spec = self.env.observation_spec()
        self.agent = agent_lib.Agent(
            num_actions=self.num_actions,
            obs_spec=self.obs_spec,
            net_factory=haiku_nets.CatchNet,
        )

        self.key = jax.random.PRNGKey(42)
        self.key, subkey = jax.random.split(self.key)
        self.initial_params = self.agent.initial_params(subkey)
Exemple #9
0
    def test_simulator_fidelity(self):
        """Tests whether the simulator match the ground truth."""
        env = catch.Catch()
        num_actions = env.action_spec().num_values
        model = simulator.Simulator(env)

        for _ in range(10):
            true_timestep = env.reset()
            model_timestep = model.reset()
            self._check_equal(true_timestep, model_timestep)

            while not true_timestep.last():
                action = np.random.randint(num_actions)
                true_timestep = env.step(action)
                model_timestep = model.step(action)
                self._check_equal(true_timestep, model_timestep)
Exemple #10
0
def run(*, trajectories_per_actor, num_actors, unroll_len):
    """Runs the example."""

    # Construct the agent network. We need a sample environment for its spec.
    env = catch.Catch()
    num_actions = env.action_spec().num_values
    net = hk.without_apply_rng(
        hk.transform(lambda ts: SimpleNet(num_actions)(ts)))  # pylint: disable=unnecessary-lambda

    # Construct the agent and learner.
    agent = Agent(net.apply)
    opt = optax.rmsprop(5e-3, decay=0.99, eps=1e-7)
    learner = Learner(agent, opt.update)

    # Initialize the optimizer state.
    sample_ts = env.reset()
    sample_ts = preprocess_step(sample_ts)
    ts_with_batch = jax.tree_map(lambda t: np.expand_dims(t, 0), sample_ts)
    params = jax.jit(net.init)(jax.random.PRNGKey(428), ts_with_batch)
    opt_state = opt.init(params)

    # Create accessor and queueing functions.
    current_params = lambda: params
    batch_size = 2
    q = queue.Queue(maxsize=batch_size)

    def dequeue():
        batch = []
        for _ in range(batch_size):
            batch.append(q.get())
        batch = jax.tree_map(lambda *ts: np.stack(ts, axis=1), *batch)
        return jax.device_put(batch)

    # Start the actors.
    for i in range(num_actors):
        key = jax.random.PRNGKey(i)
        args = (agent, key, current_params, q.put, unroll_len,
                trajectories_per_actor)
        threading.Thread(target=run_actor, args=args).start()

    # Run the learner.
    num_steps = num_actors * trajectories_per_actor // batch_size
    for i in range(num_steps):
        traj = dequeue()
        params, opt_state = learner.update(params, opt_state, traj)
Exemple #11
0
    def test_checkpointing(self):
        env = catch.Catch()
        num_actions = env.action_spec().num_values
        model = simulator.Simulator(env)

        env.reset()
        model.reset()

        env.step(2)
        timestep = model.step(2)
        model.save_checkpoint()
        while not timestep.last():
            timestep = model.step(np.random.randint(num_actions))
        model.load_checkpoint()

        true_timestep = env.step(2)
        model_timestep = model.step(2)
        self._check_equal(true_timestep, model_timestep)
Exemple #12
0
def main(unused_arg):
    env = catch.Catch(seed=FLAGS.seed)
    agent = OnlineQLambda(observation_spec=env.observation_spec(),
                          action_spec=env.action_spec(),
                          num_hidden_units=FLAGS.num_hidden_units,
                          epsilon=FLAGS.epsilon,
                          lambda_=FLAGS.lambda_,
                          learning_rate=FLAGS.learning_rate)

    accumulator = SequenceAccumulator(length=FLAGS.sequence_length)
    experiment.run_loop(
        agent=agent,
        environment=env,
        accumulator=accumulator,
        seed=FLAGS.seed,
        batch_size=1,
        train_episodes=FLAGS.train_episodes,
        evaluate_every=FLAGS.evaluate_every,
        eval_episodes=FLAGS.eval_episodes,
    )
Exemple #13
0
def main(unused_arg):
    env = catch.Catch(seed=FLAGS.seed)
    env = wrappers.RewardScale(env, reward_scale=FLAGS.reward_scale)
    agent = PopArtAgent(
        observation_spec=env.observation_spec(),
        action_spec=env.action_spec(),
        num_hidden_units=FLAGS.num_hidden_units,
        epsilon=FLAGS.epsilon,
        learning_rate=FLAGS.learning_rate,
        pop_art_step_size=FLAGS.pop_art_step_size,
    )

    accumulator = TransitionAccumulator()
    experiment.run_loop(
        agent=agent,
        environment=env,
        accumulator=accumulator,
        seed=FLAGS.seed,
        batch_size=1,
        train_episodes=FLAGS.train_episodes,
        evaluate_every=FLAGS.evaluate_every,
        eval_episodes=FLAGS.eval_episodes,
    )
Exemple #14
0
    def test_simulator_fidelity(self):
        """Tests whether the simulator match the ground truth."""

        # Given an environment.
        env = catch.Catch()

        # If we instantiate a simulator 'model' of this environment.
        model = simulator.Simulator(env)

        # Then the model and environment should always agree as we step them.
        num_actions = env.action_spec().num_values
        for _ in range(10):
            true_timestep = env.reset()
            self.assertTrue(model.needs_reset)
            model_timestep = model.reset()
            self.assertFalse(model.needs_reset)
            self._check_equal(true_timestep, model_timestep)

            while not true_timestep.last():
                action = np.random.randint(num_actions)
                true_timestep = env.step(action)
                model_timestep = model.step(action)
                self._check_equal(true_timestep, model_timestep)
Exemple #15
0
def main_loop(unused_arg):
    env = catch.Catch(seed=FLAGS.seed)
    rng = hk.PRNGSequence(jax.random.PRNGKey(FLAGS.seed))

    # Build and initialize Q-network.
    num_actions = env.action_spec().num_values
    network = build_network(num_actions)
    sample_input = env.observation_spec().generate_value()
    net_params = network.init(next(rng), sample_input)

    # Build and initialize optimizer.
    optimizer = optix.adam(FLAGS.learning_rate)
    opt_state = optimizer.init(net_params)

    @jax.jit
    def policy(net_params, key, obs):
        """Sample action from epsilon-greedy policy."""
        q = network.apply(net_params, obs)
        a = rlax.epsilon_greedy(epsilon=FLAGS.epsilon).sample(key, q)
        return q, a

    @jax.jit
    def eval_policy(net_params, key, obs):
        """Sample action from greedy policy."""
        q = network.apply(net_params, obs)
        return rlax.greedy().sample(key, q)

    @jax.jit
    def update(net_params, opt_state, obs_tm1, a_tm1, r_t, discount_t, q_t):
        """Update network weights wrt Q-learning loss."""
        def q_learning_loss(net_params, obs_tm1, a_tm1, r_t, discount_t, q_t):
            q_tm1 = network.apply(net_params, obs_tm1)
            td_error = rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t)
            return rlax.l2_loss(td_error)

        dloss_dtheta = jax.grad(q_learning_loss)(net_params, obs_tm1, a_tm1,
                                                 r_t, discount_t, q_t)
        updates, opt_state = optimizer.update(dloss_dtheta, opt_state)
        net_params = optix.apply_updates(net_params, updates)
        return net_params, opt_state

    print(f"Training agent for {FLAGS.train_episodes} episodes")
    print("Returns range [-1.0, 1.0]")
    for episode in range(FLAGS.train_episodes):
        timestep = env.reset()
        obs_tm1 = timestep.observation

        _, a_tm1 = policy(net_params, next(rng), obs_tm1)

        while not timestep.last():
            new_timestep = env.step(int(a_tm1))
            obs_t = new_timestep.observation

            # Sample action from stochastic policy.
            q_t, a_t = policy(net_params, next(rng), obs_t)

            # Update Q-values.
            r_t = new_timestep.reward
            discount_t = FLAGS.discount_factor * new_timestep.discount
            net_params, opt_state = update(net_params, opt_state, obs_tm1,
                                           a_tm1, r_t, discount_t, q_t)

            timestep = new_timestep
            obs_tm1 = obs_t
            a_tm1 = a_t

        if not episode % FLAGS.evaluate_every:
            # Evaluate agent with deterministic policy.
            returns = 0.
            for _ in range(FLAGS.eval_episodes):
                timestep = env.reset()
                obs = timestep.observation

                while not timestep.last():
                    action = eval_policy(net_params, next(rng), obs)
                    timestep = env.step(int(action))
                    obs = timestep.observation
                    returns += timestep.reward

            avg_returns = returns / FLAGS.eval_episodes
            print(f"Episode {episode:4d}: Average returns: {avg_returns:.2f}")
Exemple #16
0
 def make_object_under_test(self):
     return catch.Catch(rows=10, columns=5)
Exemple #17
0
 def make_object_under_test(self):
     env = catch.Catch()
     return wrappers.ImageObservation(env, (84, 84, 4))