def main(unused_arg): env = catch.Catch(seed=FLAGS.seed) epsilon_cfg = dict(init_value=FLAGS.epsilon_begin, end_value=FLAGS.epsilon_end, transition_steps=FLAGS.epsilon_steps, power=1.) agent = DQN( observation_spec=env.observation_spec(), action_spec=env.action_spec(), epsilon_cfg=epsilon_cfg, target_period=FLAGS.target_period, learning_rate=FLAGS.learning_rate, ) accumulator = ReplayBuffer(FLAGS.replay_capacity) experiment.run_loop( agent=agent, environment=env, accumulator=accumulator, seed=FLAGS.seed, batch_size=FLAGS.batch_size, train_episodes=FLAGS.train_episodes, evaluate_every=FLAGS.evaluate_every, eval_episodes=FLAGS.eval_episodes, )
def test_integration(self): env = catch.Catch() action_spec = env.action_spec() num_actions = action_spec.num_values obs_spec = env.observation_spec() agent = agent_lib.Agent( num_actions=num_actions, obs_spec=obs_spec, net_factory=haiku_nets.CatchNet, ) unroll_length = 20 learner = learner_lib.Learner( agent=agent, rng_key=jax.random.PRNGKey(42), opt=optix.sgd(1e-2), batch_size=1, discount_factor=0.99, frames_per_iter=unroll_length, ) actor = actor_lib.Actor( agent=agent, env=env, learner=learner, unroll_length=unroll_length, ) frame_count, params = actor.pull_params() actor.unroll_and_push(frame_count=frame_count, params=params) learner.run(max_iterations=1)
def load(noise_scale, seed): """Load a catch experiment with the prescribed settings.""" env = wrappers.RewardNoise(env=catch.Catch(seed=seed), noise_scale=noise_scale, seed=seed) env.bsuite_num_episodes = sweep.NUM_EPISODES return env
def test_checkpointing(self): """Tests whether checkpointing restores the state correctly.""" # Given an environment, and a model based on this environment. model = simulator.Simulator(catch.Catch()) num_actions = model.action_spec().num_values model.reset() # Now, we save a checkpoint. model.save_checkpoint() ts = model.step(1) # Step the model once and load the checkpoint. timestep = model.step(np.random.randint(num_actions)) model.load_checkpoint() self._check_equal(ts, model.step(1)) while not timestep.last(): timestep = model.step(np.random.randint(num_actions)) # The model should require a reset. self.assertTrue(model.needs_reset) # Once we load checkpoint, the model should no longer require reset. model.load_checkpoint() self.assertFalse(model.needs_reset) # Further steps should agree with the original environment state. self._check_equal(ts, model.step(1))
def test_catch(self, policy_type: Text): env = catch.Catch(rows=2, seed=1) num_actions = env.action_spec().num_values model = simulator.Simulator(env) eval_fn = lambda _: (np.ones(num_actions) / num_actions, 0.) timestep = env.reset() model.reset() search_policy = search.bfs if policy_type == 'bfs' else search.puct root = search.mcts(observation=timestep.observation, model=model, search_policy=search_policy, evaluation=eval_fn, num_simulations=100, num_actions=num_actions) values = np.array([c.value for c in root.children.values()]) best_action = search.argmax(values) if env._paddle_x > env._ball_x: self.assertEqual(best_action, 0) if env._paddle_x == env._ball_x: self.assertEqual(best_action, 1) if env._paddle_x < env._ball_x: self.assertEqual(best_action, 2)
def run_actor( agent: Agent, rng_key: jnp.ndarray, get_params: Callable[[], hk.Params], enqueue_traj: Callable[[Transition], None], unroll_len: int, num_trajectories: int, ): """Runs an actor to produce num_trajectories trajectories.""" env = catch.Catch() state = env.reset() traj = [] for i in range(num_trajectories): params = get_params() # The first rollout is one step longer. for _ in range(unroll_len + int(i == 0)): rng_key, step_key = jax.random.split(rng_key) state = preprocess_step(state) action, logits = agent.step(params, step_key, state) transition = Transition(state, action, logits) traj.append(transition) state = env.step(action) if state.step_type == dm_env.StepType.LAST: logging.log_every_n(logging.INFO, 'Episode ended with reward: %s', 5, state.reward) # Stack and send the trajectory. stacked_traj = jax.tree_map(lambda *ts: np.stack(ts), *traj) enqueue_traj(stacked_traj) # Reset the trajectory, keeping the last timestep. traj = traj[-1:]
def main(unused_arg): env = catch.Catch(seed=FLAGS.seed) agent = OnlineQ(env.observation_spec(), env.action_spec(), FLAGS.learning_rate, FLAGS.epsilon) accumulator = TransitionAccumulator() experiment.run_loop(agent, env, accumulator, FLAGS.seed, 1, FLAGS.train_episodes, FLAGS.evaluate_every, FLAGS.eval_episodes)
def setUp(self): super(CatchTest, self).setUp() self.env = catch.Catch() self.action_spec = self.env.action_spec() self.num_actions = self.action_spec.num_values self.obs_spec = self.env.observation_spec() self.agent = agent_lib.Agent( num_actions=self.num_actions, obs_spec=self.obs_spec, net_factory=haiku_nets.CatchNet, ) self.key = jax.random.PRNGKey(42) self.key, subkey = jax.random.split(self.key) self.initial_params = self.agent.initial_params(subkey)
def test_simulator_fidelity(self): """Tests whether the simulator match the ground truth.""" env = catch.Catch() num_actions = env.action_spec().num_values model = simulator.Simulator(env) for _ in range(10): true_timestep = env.reset() model_timestep = model.reset() self._check_equal(true_timestep, model_timestep) while not true_timestep.last(): action = np.random.randint(num_actions) true_timestep = env.step(action) model_timestep = model.step(action) self._check_equal(true_timestep, model_timestep)
def run(*, trajectories_per_actor, num_actors, unroll_len): """Runs the example.""" # Construct the agent network. We need a sample environment for its spec. env = catch.Catch() num_actions = env.action_spec().num_values net = hk.without_apply_rng( hk.transform(lambda ts: SimpleNet(num_actions)(ts))) # pylint: disable=unnecessary-lambda # Construct the agent and learner. agent = Agent(net.apply) opt = optax.rmsprop(5e-3, decay=0.99, eps=1e-7) learner = Learner(agent, opt.update) # Initialize the optimizer state. sample_ts = env.reset() sample_ts = preprocess_step(sample_ts) ts_with_batch = jax.tree_map(lambda t: np.expand_dims(t, 0), sample_ts) params = jax.jit(net.init)(jax.random.PRNGKey(428), ts_with_batch) opt_state = opt.init(params) # Create accessor and queueing functions. current_params = lambda: params batch_size = 2 q = queue.Queue(maxsize=batch_size) def dequeue(): batch = [] for _ in range(batch_size): batch.append(q.get()) batch = jax.tree_map(lambda *ts: np.stack(ts, axis=1), *batch) return jax.device_put(batch) # Start the actors. for i in range(num_actors): key = jax.random.PRNGKey(i) args = (agent, key, current_params, q.put, unroll_len, trajectories_per_actor) threading.Thread(target=run_actor, args=args).start() # Run the learner. num_steps = num_actors * trajectories_per_actor // batch_size for i in range(num_steps): traj = dequeue() params, opt_state = learner.update(params, opt_state, traj)
def test_checkpointing(self): env = catch.Catch() num_actions = env.action_spec().num_values model = simulator.Simulator(env) env.reset() model.reset() env.step(2) timestep = model.step(2) model.save_checkpoint() while not timestep.last(): timestep = model.step(np.random.randint(num_actions)) model.load_checkpoint() true_timestep = env.step(2) model_timestep = model.step(2) self._check_equal(true_timestep, model_timestep)
def main(unused_arg): env = catch.Catch(seed=FLAGS.seed) agent = OnlineQLambda(observation_spec=env.observation_spec(), action_spec=env.action_spec(), num_hidden_units=FLAGS.num_hidden_units, epsilon=FLAGS.epsilon, lambda_=FLAGS.lambda_, learning_rate=FLAGS.learning_rate) accumulator = SequenceAccumulator(length=FLAGS.sequence_length) experiment.run_loop( agent=agent, environment=env, accumulator=accumulator, seed=FLAGS.seed, batch_size=1, train_episodes=FLAGS.train_episodes, evaluate_every=FLAGS.evaluate_every, eval_episodes=FLAGS.eval_episodes, )
def main(unused_arg): env = catch.Catch(seed=FLAGS.seed) env = wrappers.RewardScale(env, reward_scale=FLAGS.reward_scale) agent = PopArtAgent( observation_spec=env.observation_spec(), action_spec=env.action_spec(), num_hidden_units=FLAGS.num_hidden_units, epsilon=FLAGS.epsilon, learning_rate=FLAGS.learning_rate, pop_art_step_size=FLAGS.pop_art_step_size, ) accumulator = TransitionAccumulator() experiment.run_loop( agent=agent, environment=env, accumulator=accumulator, seed=FLAGS.seed, batch_size=1, train_episodes=FLAGS.train_episodes, evaluate_every=FLAGS.evaluate_every, eval_episodes=FLAGS.eval_episodes, )
def test_simulator_fidelity(self): """Tests whether the simulator match the ground truth.""" # Given an environment. env = catch.Catch() # If we instantiate a simulator 'model' of this environment. model = simulator.Simulator(env) # Then the model and environment should always agree as we step them. num_actions = env.action_spec().num_values for _ in range(10): true_timestep = env.reset() self.assertTrue(model.needs_reset) model_timestep = model.reset() self.assertFalse(model.needs_reset) self._check_equal(true_timestep, model_timestep) while not true_timestep.last(): action = np.random.randint(num_actions) true_timestep = env.step(action) model_timestep = model.step(action) self._check_equal(true_timestep, model_timestep)
def main_loop(unused_arg): env = catch.Catch(seed=FLAGS.seed) rng = hk.PRNGSequence(jax.random.PRNGKey(FLAGS.seed)) # Build and initialize Q-network. num_actions = env.action_spec().num_values network = build_network(num_actions) sample_input = env.observation_spec().generate_value() net_params = network.init(next(rng), sample_input) # Build and initialize optimizer. optimizer = optix.adam(FLAGS.learning_rate) opt_state = optimizer.init(net_params) @jax.jit def policy(net_params, key, obs): """Sample action from epsilon-greedy policy.""" q = network.apply(net_params, obs) a = rlax.epsilon_greedy(epsilon=FLAGS.epsilon).sample(key, q) return q, a @jax.jit def eval_policy(net_params, key, obs): """Sample action from greedy policy.""" q = network.apply(net_params, obs) return rlax.greedy().sample(key, q) @jax.jit def update(net_params, opt_state, obs_tm1, a_tm1, r_t, discount_t, q_t): """Update network weights wrt Q-learning loss.""" def q_learning_loss(net_params, obs_tm1, a_tm1, r_t, discount_t, q_t): q_tm1 = network.apply(net_params, obs_tm1) td_error = rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t) return rlax.l2_loss(td_error) dloss_dtheta = jax.grad(q_learning_loss)(net_params, obs_tm1, a_tm1, r_t, discount_t, q_t) updates, opt_state = optimizer.update(dloss_dtheta, opt_state) net_params = optix.apply_updates(net_params, updates) return net_params, opt_state print(f"Training agent for {FLAGS.train_episodes} episodes") print("Returns range [-1.0, 1.0]") for episode in range(FLAGS.train_episodes): timestep = env.reset() obs_tm1 = timestep.observation _, a_tm1 = policy(net_params, next(rng), obs_tm1) while not timestep.last(): new_timestep = env.step(int(a_tm1)) obs_t = new_timestep.observation # Sample action from stochastic policy. q_t, a_t = policy(net_params, next(rng), obs_t) # Update Q-values. r_t = new_timestep.reward discount_t = FLAGS.discount_factor * new_timestep.discount net_params, opt_state = update(net_params, opt_state, obs_tm1, a_tm1, r_t, discount_t, q_t) timestep = new_timestep obs_tm1 = obs_t a_tm1 = a_t if not episode % FLAGS.evaluate_every: # Evaluate agent with deterministic policy. returns = 0. for _ in range(FLAGS.eval_episodes): timestep = env.reset() obs = timestep.observation while not timestep.last(): action = eval_policy(net_params, next(rng), obs) timestep = env.step(int(action)) obs = timestep.observation returns += timestep.reward avg_returns = returns / FLAGS.eval_episodes print(f"Episode {episode:4d}: Average returns: {avg_returns:.2f}")
def make_object_under_test(self): return catch.Catch(rows=10, columns=5)
def make_object_under_test(self): env = catch.Catch() return wrappers.ImageObservation(env, (84, 84, 4))