def make_replay( capacity=10, structure=None, priority_exponent=0.8, importance_sampling_exponent=lambda t: 0.6, uniform_sample_probability=0.1, normalize_weights=True, seed=None, ): return replay_lib.PrioritizedTransitionReplay( capacity=capacity, structure=structure or ReplayStructure(value=None), priority_exponent=priority_exponent, importance_sampling_exponent=importance_sampling_exponent, uniform_sample_probability=uniform_sample_probability, normalize_weights=normalize_weights, random_state=np.random.RandomState(seed), )
def main(argv): """Trains Prioritized DQN agent on Atari.""" del argv logging.info('Prioritized DQN on Atari on %s.', jax.lib.xla_bridge.get_backend().platform) random_state = np.random.RandomState(FLAGS.seed) rng_key = jax.random.PRNGKey( random_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64)) if FLAGS.results_csv_path: writer = parts.CsvWriter(FLAGS.results_csv_path) else: writer = parts.NullWriter() def environment_builder(): """Creates Atari environment.""" env = gym_atari.GymAtari( FLAGS.environment_name, seed=random_state.randint(1, 2**32)) return gym_atari.RandomNoopsEnvironmentWrapper( env, min_noop_steps=1, max_noop_steps=30, seed=random_state.randint(1, 2**32), ) env = environment_builder() logging.info('Environment: %s', FLAGS.environment_name) logging.info('Action spec: %s', env.action_spec()) logging.info('Observation spec: %s', env.observation_spec()) num_actions = env.action_spec().num_values network_fn = networks.double_dqn_atari_network(num_actions) network = hk.transform(network_fn) def preprocessor_builder(): return processors.atari( additional_discount=FLAGS.additional_discount, max_abs_reward=FLAGS.max_abs_reward, resize_shape=(FLAGS.environment_height, FLAGS.environment_width), num_action_repeats=FLAGS.num_action_repeats, num_pooled_frames=2, zero_discount_on_life_loss=True, num_stacked_frames=FLAGS.num_stacked_frames, grayscaling=True, ) # Create sample network input from sample preprocessor output. sample_processed_timestep = preprocessor_builder()(env.reset()) sample_processed_timestep = typing.cast(dm_env.TimeStep, sample_processed_timestep) sample_network_input = sample_processed_timestep.observation chex.assert_shape(sample_network_input, (FLAGS.environment_height, FLAGS.environment_width, FLAGS.num_stacked_frames)) exploration_epsilon_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity * FLAGS.num_action_repeats), decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction * FLAGS.num_iterations * FLAGS.num_train_frames), begin_value=FLAGS.exploration_epsilon_begin_value, end_value=FLAGS.exploration_epsilon_end_value) # Note the t in the replay is not exactly aligned with the agent t. importance_sampling_exponent_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity), end_t=(FLAGS.num_iterations * int(FLAGS.num_train_frames / FLAGS.num_action_repeats)), begin_value=FLAGS.importance_sampling_exponent_begin_value, end_value=FLAGS.importance_sampling_exponent_end_value) if FLAGS.compress_state: def encoder(transition): return transition._replace( s_tm1=replay_lib.compress_array(transition.s_tm1), s_t=replay_lib.compress_array(transition.s_t)) def decoder(transition): return transition._replace( s_tm1=replay_lib.uncompress_array(transition.s_tm1), s_t=replay_lib.uncompress_array(transition.s_t)) else: encoder = None decoder = None replay_structure = replay_lib.Transition( s_tm1=None, a_tm1=None, r_t=None, discount_t=None, s_t=None, ) replay = replay_lib.PrioritizedTransitionReplay( FLAGS.replay_capacity, replay_structure, FLAGS.priority_exponent, importance_sampling_exponent_schedule, FLAGS.uniform_sample_probability, FLAGS.normalize_weights, random_state, encoder, decoder) optimizer = optax.rmsprop( learning_rate=FLAGS.learning_rate, decay=0.95, eps=FLAGS.optimizer_epsilon, centered=True, ) train_rng_key, eval_rng_key = jax.random.split(rng_key) train_agent = agent.PrioritizedDqn( preprocessor=preprocessor_builder(), sample_network_input=sample_network_input, network=network, optimizer=optimizer, transition_accumulator=replay_lib.TransitionAccumulator(), replay=replay, batch_size=FLAGS.batch_size, exploration_epsilon=exploration_epsilon_schedule, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, grad_error_bound=FLAGS.grad_error_bound, rng_key=train_rng_key, ) eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=network, exploration_epsilon=FLAGS.eval_exploration_epsilon, rng_key=eval_rng_key, ) # Set up checkpointing. checkpoint = parts.NullCheckpoint() state = checkpoint.state state.iteration = 0 state.train_agent = train_agent state.eval_agent = eval_agent state.random_state = random_state state.writer = writer if checkpoint.can_be_restored(): checkpoint.restore() while state.iteration <= FLAGS.num_iterations: # New environment for each iteration to allow for determinism if preempted. env = environment_builder() logging.info('Training iteration %d.', state.iteration) train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode) num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames train_seq_truncated = itertools.islice(train_seq, num_train_frames) train_trackers = parts.make_default_trackers(train_agent) train_stats = parts.generate_statistics(train_trackers, train_seq_truncated) logging.info('Evaluation iteration %d.', state.iteration) eval_agent.network_params = train_agent.online_params eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode) eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames) eval_trackers = parts.make_default_trackers(eval_agent) eval_stats = parts.generate_statistics(eval_trackers, eval_seq_truncated) # Logging and checkpointing. human_normalized_score = atari_data.get_human_normalized_score( FLAGS.environment_name, eval_stats['episode_return']) capped_human_normalized_score = np.amin([1., human_normalized_score]) log_output = [ ('iteration', state.iteration, '%3d'), ('frame', state.iteration * FLAGS.num_train_frames, '%5d'), ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'), ('train_episode_return', train_stats['episode_return'], '% 2.2f'), ('eval_num_episodes', eval_stats['num_episodes'], '%3d'), ('train_num_episodes', train_stats['num_episodes'], '%3d'), ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'), ('train_frame_rate', train_stats['step_rate'], '%4.0f'), ('train_exploration_epsilon', train_agent.exploration_epsilon, '%.3f'), ('train_state_value', train_stats['state_value'], '%.3f'), ('importance_sampling_exponent', train_agent.importance_sampling_exponent, '%.3f'), ('max_seen_priority', train_agent.max_seen_priority, '%.3f'), ('normalized_return', human_normalized_score, '%.3f'), ('capped_normalized_return', capped_human_normalized_score, '%.3f'), ('human_gap', 1. - capped_human_normalized_score, '%.3f'), ] log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output) logging.info(log_output_str) writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output)) state.iteration += 1 checkpoint.save() writer.close()
def main(argv): """ Train pick-up and drop-off Rainbow agents on ODySSEUS. """ del argv # Unused arguments # Metadata configuration parent_dir = pathlib.Path(__file__).parent.absolute() sim_input_conf_dir = parent_dir / 'configs' / DEFAULT_sim_scenario_name # Load configuration sim_conf = importlib.import_module('esbdqn.configs.{}.{}' .format(DEFAULT_sim_scenario_name, FLAGS.conf_filename)) # Extract a single conf pair sim_general_conf = EFFCS_SimConfGrid(sim_conf.General) \ .conf_list[0] sim_scenario_conf = EFFCS_SimConfGrid(sim_conf.Multiple_runs) \ .conf_list[0] experiment_dir = parent_dir \ / 'experiments' \ / DEFAULT_sim_scenario_name \ / FLAGS.exp_name \ / sim_general_conf['city'] if pathlib.Path.exists(experiment_dir): # Ensure configuration has not changed if not filecmp.cmp(str(sim_input_conf_dir / FLAGS.conf_filename) + '.py', str(experiment_dir / DEFAULT_conf_filename) + ".py", shallow=False): raise IOError('Configuration changed at: {}' .format(str(experiment_dir))) else: pathlib.Path.mkdir(experiment_dir, parents=True, exist_ok=True) # Copy configuration files shutil.rmtree(experiment_dir) shutil.copytree(sim_input_conf_dir, experiment_dir) # Rename to the default name conf_filepath = experiment_dir / (FLAGS.conf_filename + ".py") conf_filepath.rename(experiment_dir / (DEFAULT_conf_filename + ".py")) # Delete all other potential conf files for filename in experiment_dir.glob( DEFAULT_conf_filename + "_*.py"): filename.unlink() # Create results files results_dir = experiment_dir / 'results' pathlib.Path.mkdir(results_dir, parents=True, exist_ok=True) results_filepath = results_dir / DEFAULT_resu_filename logging.info('Rainbow agents on ODySSEUS running on %s.', jax.lib.xla_bridge.get_backend().platform.upper()) if FLAGS.checkpoint: checkpoint = PickleCheckpoint( experiment_dir / 'models', 'ODySSEUS-' + sim_general_conf['city']) else: checkpoint = parts.NullCheckpoint() checkpoint_restored = False if FLAGS.checkpoint: if checkpoint.can_be_restored(): logging.info('Restoring checkpoint...') checkpoint.restore() checkpoint_restored = True # Generate RNG key rng_state = np.random.RandomState(FLAGS.seed) if checkpoint_restored: rng_state.set_state(checkpoint.state .rng_state) rng_key = jax.random.PRNGKey( rng_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64)) # Generate results file writer if sim_general_conf['save_history']: writer = parts.CsvWriter(str(results_filepath)) if checkpoint_restored: writer.set_state(checkpoint.state .writer) else: writer = parts.NullWriter() def environment_builder() -> ConstrainedEnvironment: """ Create the ODySSEUS environment. """ return EscooterSimulator( (sim_general_conf, sim_scenario_conf), FLAGS.n_lives) def preprocessor_builder(): """ Create the ODySSEUS input preprocessor. """ return processor( max_abs_reward=FLAGS.max_abs_reward, zero_discount_on_life_loss=True ) env = environment_builder() logging.info('Environment: %s', FLAGS.exp_name) logging.info('Action spec: %s', env.action_spec()) logging.info('Observation spec: %s', env.observation_spec()) # Take [0] as both Rainbow have # the same number of actions num_actions = env.action_spec()[0].num_values support = jnp.linspace(-FLAGS.vmax, FLAGS.vmax, FLAGS.num_atoms) network = hk.transform(rainbow_odysseus_network( num_actions, support, FLAGS.noisy_weight_init)) # Create sample network input from reset. sample_processed_timestep = preprocessor_builder()(env.reset()) sample_processed_timestep = t.cast(dm_env.TimeStep, sample_processed_timestep) sample_processed_network_input = sample_processed_timestep.observation # Note the t in the replay is not exactly # aligned with the Rainbow agents t. importance_sampling_exponent_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity), end_t=(FLAGS.num_iterations * FLAGS.num_train_frames), begin_value=FLAGS.importance_sampling_exponent_begin_value, end_value=FLAGS.importance_sampling_exponent_end_value) if FLAGS.compress_state: def encoder(transition): return transition._replace( s_tm1=replay.compress_array(transition.s_tm1), s_t=replay.compress_array(transition.s_t)) def decoder(transition): return transition._replace( s_tm1=replay.uncompress_array(transition.s_tm1), s_t=replay.uncompress_array(transition.s_t)) else: encoder = None decoder = None replay_struct = replay.Transition( s_tm1=None, a_tm1=None, r_t=None, discount_t=None, s_t=None, ) transition_accumulator = replay.NStepTransitionAccumulator(FLAGS.n_steps) transition_replay = replay.PrioritizedTransitionReplay( FLAGS.replay_capacity, replay_struct, FLAGS.priority_exponent, importance_sampling_exponent_schedule, FLAGS.uniform_sample_probability, FLAGS.normalize_weights, rng_state, encoder, decoder) optimizer = optax.adam( learning_rate=FLAGS.learning_rate, eps=FLAGS.optimizer_epsilon) if FLAGS.max_global_grad_norm > 0: optimizer = optax.chain( optax.clip_by_global_norm( FLAGS.max_global_grad_norm), optimizer) train_rng_key, eval_rng_key = jax.random.split(rng_key) # Create pick-up/drop-off agents P_train_agent = agent.Rainbow( preprocessor=preprocessor_builder(), sample_network_input=copy.deepcopy(sample_processed_network_input), network=copy.deepcopy(network), support=copy.deepcopy(support), optimizer=copy.deepcopy(optimizer), transition_accumulator=copy.deepcopy(transition_accumulator), replay=copy.deepcopy(transition_replay), batch_size=FLAGS.batch_size, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, rng_key=train_rng_key, ) D_train_agent = agent.Rainbow( preprocessor=preprocessor_builder(), sample_network_input=copy.deepcopy(sample_processed_network_input), network=copy.deepcopy(network), support=copy.deepcopy(support), optimizer=copy.deepcopy(optimizer), transition_accumulator=copy.deepcopy(transition_accumulator), replay=copy.deepcopy(transition_replay), batch_size=FLAGS.batch_size, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, rng_key=train_rng_key, ) P_eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=copy.deepcopy(network), exploration_epsilon=0, rng_key=eval_rng_key, ) D_eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=copy.deepcopy(network), exploration_epsilon=0, rng_key=eval_rng_key, ) if checkpoint_restored: P_train_agent.set_state(checkpoint.state.P_agent['train']) D_train_agent.set_state(checkpoint.state.D_agent['train']) P_eval_agent.set_state(checkpoint.state.P_agent['eval']) D_eval_agent.set_state(checkpoint.state.D_agent['eval']) state = checkpoint.state if not checkpoint_restored: state.iteration = 0 state.P_agent = {} state.D_agent = {} state.rng_state = rng_state state.writer = writer state.P_agent['train'] = P_train_agent state.D_agent['train'] = D_train_agent state.P_agent['eval'] = P_eval_agent state.D_agent['eval'] = D_eval_agent while state.iteration < FLAGS.num_iterations: # Create a new environment at each new iteration # to allow for determinism if preempted. env = environment_builder() # Leave some spacing print('\n') logging.info('Training iteration: %d', state.iteration) train_trackers = make_odysseus_trackers(FLAGS.max_abs_reward) eval_trackers = make_odysseus_trackers(FLAGS.max_abs_reward) train_seq = run_loop(P_train_agent, D_train_agent, env, FLAGS.max_steps_per_episode) num_train_frames = 0 \ if state.iteration == 0 \ else FLAGS.num_train_frames train_seq_truncated = it.islice(train_seq, num_train_frames) train_stats = generate_statistics(train_trackers, train_seq_truncated) logging.info('Evaluation iteration: %d', state.iteration) # Synchronize network parameters P_eval_agent.network_params = P_train_agent.online_params D_eval_agent.network_params = P_train_agent.online_params eval_seq = run_loop(P_eval_agent, D_eval_agent, env, FLAGS.max_steps_per_episode) eval_seq_truncated = it.islice(eval_seq, FLAGS.num_eval_frames) eval_stats = generate_statistics(eval_trackers, eval_seq_truncated) # Logging and checkpointing L = [ # Simulation metadata ('iteration', state.iteration, '%3d'), # ODySSEUS metadata ('n_charging_workers', sim_scenario_conf['n_workers'], '%3d'), ('n_relocation_workers', sim_scenario_conf['n_relocation_workers'], '%3d'), ('n_vehicles', sim_scenario_conf['n_vehicles'], '%3d'), ('pct_incentive_willingness', sim_scenario_conf['incentive_willingness'], '%2.2f'), ('zone_side_m', sim_general_conf['bin_side_length'], '%3d'), # Validation agents ('eval_num_episodes', eval_stats['num_episodes'], '%3d'), ('eval_P_episode_return', eval_stats['episode_return'][0], '%2.2f'), ('eval_D_episode_return', eval_stats['episode_return'][1], '%2.2f'), ('eval_min_n_accepted_incentives', np.min(eval_stats['episodes_n_accepted_incentives']), '%2.2f'), ('eval_avg_n_accepted_incentives', np.mean(eval_stats['episodes_n_accepted_incentives']), '%2.2f'), ('eval_max_n_accepted_incentives', np.max(eval_stats['episodes_n_accepted_incentives']), '%2.2f'), ('eval_min_n_lives', np.min(eval_stats['episodes_n_lives']), '%2.2f'), ('eval_avg_n_lives', np.mean(eval_stats['episodes_n_lives']), '%2.2f'), ('eval_max_n_lives', np.max(eval_stats['episodes_n_lives']), '%2.2f'), ('eval_min_pct_satisfied_demand', np.min(eval_stats['pct_satisfied_demands']), '%2.2f'), ('eval_avg_pct_satisfied_demand', np.mean(eval_stats['pct_satisfied_demands']), '%2.2f'), ('eval_max_pct_satisfied_demand', np.max(eval_stats['pct_satisfied_demands']), '%2.2f'), # Training agents ('train_num_episodes', train_stats['num_episodes'], '%3d'), ('train_P_episode_return', train_stats['episode_return'][0], '%2.2f'), ('train_D_episode_return', train_stats['episode_return'][1], '%2.2f'), ('train_min_n_accepted_incentives', np.min(train_stats['episodes_n_accepted_incentives']), '%2.2f'), ('train_avg_n_accepted_incentives', np.mean(train_stats['episodes_n_accepted_incentives']), '%2.2f'), ('train_max_n_accepted_incentives', np.max(train_stats['episodes_n_accepted_incentives']), '%2.2f'), ('train_min_n_lives', np.min(train_stats['episodes_n_lives']), '%2.2f'), ('train_avg_n_lives', np.mean(train_stats['episodes_n_lives']), '%2.2f'), ('train_mac_n_lives', np.max(train_stats['episodes_n_lives']), '%2.2f'), ('train_min_pct_satisfied_demand', np.min(train_stats['pct_satisfied_demands']), '%2.2f'), ('train_avg_pct_satisfied_demand', np.mean(train_stats['pct_satisfied_demands']), '%2.2f'), ('train_max_pct_satisfied_demand', np.max(train_stats['pct_satisfied_demands']), '%2.2f'), ('P_importance_sampling_exponent', P_train_agent.importance_sampling_exponent, '%.3f'), ('D_importance_sampling_exponent', D_train_agent.importance_sampling_exponent, '%.3f'), ('P_max_seen_priority', P_train_agent.max_seen_priority, '%.3f'), ('D_max_seen_priority', D_train_agent.max_seen_priority, '%.3f'), ] L_str = '\n'.join(('%s: ' + f) % (n, v) for n, v, f in L) logging.info(L_str) if state.iteration == \ FLAGS.num_iterations - 1: print('\n') writer.write(collections.OrderedDict( (n, v) for n, v, _ in L)) state.iteration += 1 if state.iteration \ % FLAGS.checkpoint_period == 0: checkpoint.save() writer.close()