def main(argv): """Trains Prioritized DQN agent on Atari.""" del argv logging.info('Prioritized DQN on Atari on %s.', jax.lib.xla_bridge.get_backend().platform) random_state = np.random.RandomState(FLAGS.seed) rng_key = jax.random.PRNGKey( random_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64)) if FLAGS.results_csv_path: writer = parts.CsvWriter(FLAGS.results_csv_path) else: writer = parts.NullWriter() def environment_builder(): """Creates Atari environment.""" env = gym_atari.GymAtari( FLAGS.environment_name, seed=random_state.randint(1, 2**32)) return gym_atari.RandomNoopsEnvironmentWrapper( env, min_noop_steps=1, max_noop_steps=30, seed=random_state.randint(1, 2**32), ) env = environment_builder() logging.info('Environment: %s', FLAGS.environment_name) logging.info('Action spec: %s', env.action_spec()) logging.info('Observation spec: %s', env.observation_spec()) num_actions = env.action_spec().num_values network_fn = networks.double_dqn_atari_network(num_actions) network = hk.transform(network_fn) def preprocessor_builder(): return processors.atari( additional_discount=FLAGS.additional_discount, max_abs_reward=FLAGS.max_abs_reward, resize_shape=(FLAGS.environment_height, FLAGS.environment_width), num_action_repeats=FLAGS.num_action_repeats, num_pooled_frames=2, zero_discount_on_life_loss=True, num_stacked_frames=FLAGS.num_stacked_frames, grayscaling=True, ) # Create sample network input from sample preprocessor output. sample_processed_timestep = preprocessor_builder()(env.reset()) sample_processed_timestep = typing.cast(dm_env.TimeStep, sample_processed_timestep) sample_network_input = sample_processed_timestep.observation chex.assert_shape(sample_network_input, (FLAGS.environment_height, FLAGS.environment_width, FLAGS.num_stacked_frames)) exploration_epsilon_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity * FLAGS.num_action_repeats), decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction * FLAGS.num_iterations * FLAGS.num_train_frames), begin_value=FLAGS.exploration_epsilon_begin_value, end_value=FLAGS.exploration_epsilon_end_value) # Note the t in the replay is not exactly aligned with the agent t. importance_sampling_exponent_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity), end_t=(FLAGS.num_iterations * int(FLAGS.num_train_frames / FLAGS.num_action_repeats)), begin_value=FLAGS.importance_sampling_exponent_begin_value, end_value=FLAGS.importance_sampling_exponent_end_value) if FLAGS.compress_state: def encoder(transition): return transition._replace( s_tm1=replay_lib.compress_array(transition.s_tm1), s_t=replay_lib.compress_array(transition.s_t)) def decoder(transition): return transition._replace( s_tm1=replay_lib.uncompress_array(transition.s_tm1), s_t=replay_lib.uncompress_array(transition.s_t)) else: encoder = None decoder = None replay_structure = replay_lib.Transition( s_tm1=None, a_tm1=None, r_t=None, discount_t=None, s_t=None, ) replay = replay_lib.PrioritizedTransitionReplay( FLAGS.replay_capacity, replay_structure, FLAGS.priority_exponent, importance_sampling_exponent_schedule, FLAGS.uniform_sample_probability, FLAGS.normalize_weights, random_state, encoder, decoder) optimizer = optax.rmsprop( learning_rate=FLAGS.learning_rate, decay=0.95, eps=FLAGS.optimizer_epsilon, centered=True, ) train_rng_key, eval_rng_key = jax.random.split(rng_key) train_agent = agent.PrioritizedDqn( preprocessor=preprocessor_builder(), sample_network_input=sample_network_input, network=network, optimizer=optimizer, transition_accumulator=replay_lib.TransitionAccumulator(), replay=replay, batch_size=FLAGS.batch_size, exploration_epsilon=exploration_epsilon_schedule, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, grad_error_bound=FLAGS.grad_error_bound, rng_key=train_rng_key, ) eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=network, exploration_epsilon=FLAGS.eval_exploration_epsilon, rng_key=eval_rng_key, ) # Set up checkpointing. checkpoint = parts.NullCheckpoint() state = checkpoint.state state.iteration = 0 state.train_agent = train_agent state.eval_agent = eval_agent state.random_state = random_state state.writer = writer if checkpoint.can_be_restored(): checkpoint.restore() while state.iteration <= FLAGS.num_iterations: # New environment for each iteration to allow for determinism if preempted. env = environment_builder() logging.info('Training iteration %d.', state.iteration) train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode) num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames train_seq_truncated = itertools.islice(train_seq, num_train_frames) train_trackers = parts.make_default_trackers(train_agent) train_stats = parts.generate_statistics(train_trackers, train_seq_truncated) logging.info('Evaluation iteration %d.', state.iteration) eval_agent.network_params = train_agent.online_params eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode) eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames) eval_trackers = parts.make_default_trackers(eval_agent) eval_stats = parts.generate_statistics(eval_trackers, eval_seq_truncated) # Logging and checkpointing. human_normalized_score = atari_data.get_human_normalized_score( FLAGS.environment_name, eval_stats['episode_return']) capped_human_normalized_score = np.amin([1., human_normalized_score]) log_output = [ ('iteration', state.iteration, '%3d'), ('frame', state.iteration * FLAGS.num_train_frames, '%5d'), ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'), ('train_episode_return', train_stats['episode_return'], '% 2.2f'), ('eval_num_episodes', eval_stats['num_episodes'], '%3d'), ('train_num_episodes', train_stats['num_episodes'], '%3d'), ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'), ('train_frame_rate', train_stats['step_rate'], '%4.0f'), ('train_exploration_epsilon', train_agent.exploration_epsilon, '%.3f'), ('train_state_value', train_stats['state_value'], '%.3f'), ('importance_sampling_exponent', train_agent.importance_sampling_exponent, '%.3f'), ('max_seen_priority', train_agent.max_seen_priority, '%.3f'), ('normalized_return', human_normalized_score, '%.3f'), ('capped_normalized_return', capped_human_normalized_score, '%.3f'), ('human_gap', 1. - capped_human_normalized_score, '%.3f'), ] log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output) logging.info(log_output_str) writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output)) state.iteration += 1 checkpoint.save() writer.close()
def test_basic(self): """Tests sequence of agent and environment interactions in typical usage.""" tape = [] agent = test_utils.DummyAgent(tape) environment = test_utils.DummyEnvironment(tape, episode_length=4) episode_index = 0 t = 0 # steps = t + 1 max_steps = 14 loop_outputs = parts.run_loop(agent, environment, max_steps_per_episode=100, yield_before_reset=True) for timestep_t, unused_a_t in loop_outputs: tape.append((episode_index, t, timestep_t is None)) if timestep_t is None: tape.append('Episode begin') continue if timestep_t.last(): tape.append('Episode end') episode_index += 1 if t + 1 >= max_steps: tape.append('Maximum number of steps reached') break t += 1 expected_tape = [ (0, 0, True), 'Episode begin', 'Agent reset', 'Environment reset', 'Agent step', (0, 0, False), 'Environment step (0)', 'Agent step', (0, 1, False), 'Environment step (0)', 'Agent step', (0, 2, False), 'Environment step (0)', 'Agent step', (0, 3, False), 'Environment step (0)', 'Agent step', (0, 4, False), 'Episode end', (1, 5, True), 'Episode begin', 'Agent reset', 'Environment reset', 'Agent step', (1, 5, False), 'Environment step (0)', 'Agent step', (1, 6, False), 'Environment step (0)', 'Agent step', (1, 7, False), 'Environment step (0)', 'Agent step', (1, 8, False), 'Environment step (0)', 'Agent step', (1, 9, False), 'Episode end', (2, 10, True), 'Episode begin', 'Agent reset', 'Environment reset', 'Agent step', (2, 10, False), 'Environment step (0)', 'Agent step', (2, 11, False), 'Environment step (0)', 'Agent step', (2, 12, False), 'Environment step (0)', 'Agent step', (2, 13, False), 'Maximum number of steps reached', ] self.assertEqual(expected_tape, tape)
def main(argv): """Trains DQN agent on Atari.""" del argv logging.info("DQN on Atari on %s.", jax.lib.xla_bridge.get_backend().platform) random_state = np.random.RandomState(FLAGS.seed) rng_key = jax.random.PRNGKey( random_state.randint(-sys.maxsize - 1, sys.maxsize + 1)) if FLAGS.results_csv_path: writer = parts.CsvWriter(FLAGS.results_csv_path) else: writer = parts.NullWriter() def environment_builder(): """Creates Key-Door environment.""" env = gym_key_door.GymKeyDoor( env_args={ constants.MAP_ASCII_PATH: FLAGS.map_ascii_path, constants.MAP_YAML_PATH: FLAGS.map_yaml_path, constants.REPRESENTATION: constants.PIXEL, constants.SCALING: FLAGS.env_scaling, constants.EPISODE_TIMEOUT: FLAGS.max_frames_per_episode, constants.GRAYSCALE: False, constants.BATCH_DIMENSION: False, constants.TORCH_AXES: False, }, env_shape=FLAGS.env_shape, ) return gym_atari.RandomNoopsEnvironmentWrapper( env, min_noop_steps=1, max_noop_steps=30, seed=random_state.randint(1, 2**32), ) env = environment_builder() logging.info("Environment: %s", FLAGS.environment_name) logging.info("Action spec: %s", env.action_spec()) logging.info("Observation spec: %s", env.observation_spec()) num_actions = env.action_spec().num_values network_fn = networks.dqn_atari_network(num_actions) network = hk.transform(network_fn) def preprocessor_builder(): return processors.atari( additional_discount=FLAGS.additional_discount, max_abs_reward=FLAGS.max_abs_reward, resize_shape=(FLAGS.environment_height, FLAGS.environment_width), num_action_repeats=FLAGS.num_action_repeats, num_pooled_frames=2, zero_discount_on_life_loss=True, num_stacked_frames=FLAGS.num_stacked_frames, grayscaling=True, ) # Create sample network input from sample preprocessor output. sample_processed_timestep = preprocessor_builder()(env.reset()) sample_processed_timestep = typing.cast(dm_env.TimeStep, sample_processed_timestep) sample_network_input = sample_processed_timestep.observation assert sample_network_input.shape == ( FLAGS.environment_height, FLAGS.environment_width, FLAGS.num_stacked_frames, ) exploration_epsilon_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity * FLAGS.num_action_repeats), decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction * FLAGS.num_iterations * FLAGS.num_train_frames), begin_value=FLAGS.exploration_epsilon_begin_value, end_value=FLAGS.exploration_epsilon_end_value, ) if FLAGS.compress_state: def encoder(transition): return transition._replace( s_tm1=replay_lib.compress_array(transition.s_tm1), s_t=replay_lib.compress_array(transition.s_t), ) def decoder(transition): return transition._replace( s_tm1=replay_lib.uncompress_array(transition.s_tm1), s_t=replay_lib.uncompress_array(transition.s_t), ) else: encoder = None decoder = None replay_structure = replay_lib.Transition( s_tm1=None, a_tm1=None, r_t=None, discount_t=None, s_t=None, ) replay = replay_lib.TransitionReplay(FLAGS.replay_capacity, replay_structure, random_state, encoder, decoder) optimizer = optax.rmsprop( learning_rate=FLAGS.learning_rate, decay=0.95, eps=FLAGS.optimizer_epsilon, centered=True, ) if FLAGS.shaping_function_type == constants.NO_PENALTY: shaping_function = shaping.NoPenalty() if FLAGS.shaping_function_type == constants.HARD_CODED_PENALTY: shaping_function = shaping.HardCodedPenalty( penalty=FLAGS.shaping_multiplicative_factor) train_rng_key, eval_rng_key = jax.random.split(rng_key) train_agent = agent.Dqn( preprocessor=preprocessor_builder(), sample_network_input=sample_network_input, network=network, optimizer=optimizer, transition_accumulator=replay_lib.TransitionAccumulator(), replay=replay, shaping_function=shaping_function, batch_size=FLAGS.batch_size, exploration_epsilon=exploration_epsilon_schedule, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, grad_error_bound=FLAGS.grad_error_bound, rng_key=train_rng_key, ) eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=network, exploration_epsilon=FLAGS.eval_exploration_epsilon, rng_key=eval_rng_key, ) # Set up checkpointing. # checkpoint = parts.NullCheckpoint() checkpoint = parts.ImplementedCheckpoint( checkpoint_path=FLAGS.checkpoint_path) if checkpoint.can_be_restored(): checkpoint.restore() train_agent.set_state(state=checkpoint.state.train_agent) eval_agent.set_state(state=checkpoint.state.eval_agent) writer.set_state(state=checkpoint.state.writer) state = checkpoint.state state.iteration = 0 state.train_agent = train_agent.get_state() state.eval_agent = eval_agent.get_state() state.random_state = random_state state.writer = writer.get_state() while state.iteration <= FLAGS.num_iterations: # New environment for each iteration to allow for determinism if preempted. env = environment_builder() logging.info("Training iteration %d.", state.iteration) train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode) num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames train_seq_truncated = itertools.islice(train_seq, num_train_frames) train_stats = parts.generate_statistics(train_seq_truncated) logging.info("Evaluation iteration %d.", state.iteration) eval_agent.network_params = train_agent.online_params eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode) eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames) eval_stats = parts.generate_statistics(eval_seq_truncated) # Logging and checkpointing. human_normalized_score = atari_data.get_human_normalized_score( FLAGS.environment_name, eval_stats["episode_return"]) capped_human_normalized_score = np.amin([1.0, human_normalized_score]) log_output = [ ("iteration", state.iteration, "%3d"), ("frame", state.iteration * FLAGS.num_train_frames, "%5d"), ("eval_episode_return", eval_stats["episode_return"], "% 2.2f"), ("train_episode_return", train_stats["episode_return"], "% 2.2f"), ("eval_num_episodes", eval_stats["num_episodes"], "%3d"), ("train_num_episodes", train_stats["num_episodes"], "%3d"), ("eval_frame_rate", eval_stats["step_rate"], "%4.0f"), ("train_frame_rate", train_stats["step_rate"], "%4.0f"), ("train_exploration_epsilon", train_agent.exploration_epsilon, "%.3f"), ("normalized_return", human_normalized_score, "%.3f"), ("capped_normalized_return", capped_human_normalized_score, "%.3f"), ("human_gap", 1.0 - capped_human_normalized_score, "%.3f"), ] log_output_str = ", ".join( ("%s: " + f) % (n, v) for n, v, f in log_output) logging.info(log_output_str) writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output)) state.iteration += 1 checkpoint.save() writer.close()
def main(argv): """Trains DQN agent on Atari.""" del argv logging.info('Boostrapped DQN on Atari on %s.', jax.lib.xla_bridge.get_backend().platform) random_state = np.random.RandomState(FLAGS.seed) rng_key = jax.random.PRNGKey( random_state.randint(-sys.maxsize - 1, sys.maxsize + 1)) if FLAGS.results_csv_path: writer = parts.CsvWriter(FLAGS.results_csv_path) else: writer = parts.NullWriter() def environment_builder(): """Creates Atari environment.""" env = gym_atari.GymAtari(FLAGS.environment_name, seed=random_state.randint(1, 2**32)) return gym_atari.RandomNoopsEnvironmentWrapper( env, min_noop_steps=1, max_noop_steps=30, seed=random_state.randint(1, 2**32), ) env = environment_builder() logging.info('Environment: %s', FLAGS.environment_name) logging.info('Action spec: %s', env.action_spec()) logging.info('Observation spec: %s', env.observation_spec()) num_actions = env.action_spec().num_values network_fn = networks.bootstrapped_dqn_multi_head_network( num_actions, num_heads=FLAGS.num_heads, mask_probability=FLAGS.mask_probability) network = hk.transform(network_fn) def preprocessor_builder(): return processors.atari( additional_discount=FLAGS.additional_discount, max_abs_reward=FLAGS.max_abs_reward, resize_shape=(FLAGS.environment_height, FLAGS.environment_width), num_action_repeats=FLAGS.num_action_repeats, num_pooled_frames=2, zero_discount_on_life_loss=True, num_stacked_frames=FLAGS.num_stacked_frames, grayscaling=True, ) # Create sample network input from sample preprocessor output. sample_processed_timestep = preprocessor_builder()(env.reset()) sample_processed_timestep = typing.cast(dm_env.TimeStep, sample_processed_timestep) sample_network_input = sample_processed_timestep.observation assert sample_network_input.shape == (FLAGS.environment_height, FLAGS.environment_width, FLAGS.num_stacked_frames) exploration_epsilon_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity * FLAGS.num_action_repeats), decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction * FLAGS.num_iterations * FLAGS.num_train_frames), begin_value=FLAGS.exploration_epsilon_begin_value, end_value=FLAGS.exploration_epsilon_end_value) if FLAGS.compress_state: def encoder(transition): return transition._replace( s_tm1=replay_lib.compress_array(transition.s_tm1), s_t=replay_lib.compress_array(transition.s_t)) def decoder(transition): return transition._replace( s_tm1=replay_lib.uncompress_array(transition.s_tm1), s_t=replay_lib.uncompress_array(transition.s_t)) else: encoder = None decoder = None replay_structure = replay_lib.MaskedTransition( s_tm1=None, a_tm1=None, r_t=None, discount_t=None, s_t=None, mask_t=None, ) replay = replay_lib.TransitionReplay(FLAGS.replay_capacity, replay_structure, random_state, encoder, decoder) optimizer = optax.rmsprop( learning_rate=FLAGS.learning_rate, decay=0.95, eps=FLAGS.optimizer_epsilon, centered=True, ) if FLAGS.shaping_function_type == constants.NO_PENALTY: shaping_function = shaping.NoPenalty() elif FLAGS.shaping_function_type == constants.HARD_CODED_PENALTY: shaping_function = shaping.HardCodedPenalty( penalty=FLAGS.shaping_multiplicative_factor) elif FLAGS.shaping_function_type == constants.UNCERTAINTY_PENALTY: shaping_function = shaping.UncertaintyPenalty( multiplicative_factor=FLAGS.shaping_multiplicative_factor) elif FLAGS.shaping_function_type == constants.POLICY_ENTROPY_PENALTY: shaping_function = shaping.PolicyEntropyPenalty( multiplicative_factor=FLAGS.shaping_multiplicative_factor, num_actions=num_actions) elif FLAGS.shaping_function_type == constants.MUNCHAUSEN_PENALTY: shaping_function = shaping.MunchausenPenalty( multiplicative_factor=FLAGS.shaping_multiplicative_factor, num_actions=num_actions) train_rng_key, eval_rng_key = jax.random.split(rng_key) train_agent = agent.BootstrappedDqn( preprocessor=preprocessor_builder(), sample_network_input=sample_network_input, network=network, optimizer=optimizer, transition_accumulator=replay_lib.TransitionAccumulator(), replay=replay, shaping_function=shaping_function, mask_probability=FLAGS.mask_probability, num_heads=FLAGS.num_heads, batch_size=FLAGS.batch_size, exploration_epsilon=exploration_epsilon_schedule, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, grad_error_bound=FLAGS.grad_error_bound, rng_key=train_rng_key, ) eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=network, exploration_epsilon=FLAGS.eval_exploration_epsilon, rng_key=eval_rng_key, ) # Set up checkpointing. # checkpoint = parts.NullCheckpoint() checkpoint = parts.ImplementedCheckpoint( checkpoint_path=FLAGS.checkpoint_path) if checkpoint.can_be_restored(): checkpoint.restore() iteration = checkpoint.state.iteration random_state = checkpoint.state.random_state train_agent.set_state(state=checkpoint.state.train_agent) eval_agent.set_state(state=checkpoint.state.eval_agent) writer.set_state(state=checkpoint.state.writer) else: iteration = 0 while iteration <= FLAGS.num_iterations: # New environment for each iteration to allow for determinism if preempted. env = environment_builder() logging.info('Training iteration %d.', iteration) train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode) num_train_frames = 0 if iteration == 0 else FLAGS.num_train_frames train_seq_truncated = itertools.islice(train_seq, num_train_frames) train_stats = parts.generate_statistics(train_seq_truncated) logging.info('Evaluation iteration %d.', iteration) eval_agent.network_params = train_agent.online_params eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode) eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames) eval_stats = parts.generate_statistics(eval_seq_truncated) # Logging and checkpointing. human_normalized_score = atari_data.get_human_normalized_score( FLAGS.environment_name, eval_stats['episode_return']) capped_human_normalized_score = np.amin([1., human_normalized_score]) log_output = [ ('iteration', iteration, '%3d'), ('frame', iteration * FLAGS.num_train_frames, '%5d'), ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'), ('train_episode_return', train_stats['episode_return'], '% 2.2f'), ('eval_num_episodes', eval_stats['num_episodes'], '%3d'), ('train_num_episodes', train_stats['num_episodes'], '%3d'), ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'), ('train_frame_rate', train_stats['step_rate'], '%4.0f'), ('train_exploration_epsilon', train_agent.exploration_epsilon, '%.3f'), ('normalized_return', human_normalized_score, '%.3f'), ('capped_normalized_return', capped_human_normalized_score, '%.3f'), ('human_gap', 1. - capped_human_normalized_score, '%.3f'), ('train_loss', train_stats['train_loss'], '% 2.2f'), ('shaped_reward', train_stats['shaped_reward'], '% 2.2f'), ('penalties', train_stats['penalties'], '% 2.2f') ] log_output_str = ', '.join( ('%s: ' + f) % (n, v) for n, v, f in log_output) logging.info(log_output_str) writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output)) iteration += 1 # update state before checkpointing checkpoint.state.iteration = iteration checkpoint.state.train_agent = train_agent.get_state() checkpoint.state.eval_agent = eval_agent.get_state() checkpoint.state.random_state = random_state checkpoint.state.writer = writer.get_state() checkpoint.save() writer.close()