def test_deserialize_before_header(self): """Tests that header is written after deserialization if not written yet.""" writer1 = parts.CsvWriter('test.csv') self.fake_file.write.assert_not_called() writer2 = parts.CsvWriter('test.csv') writer2.set_state(writer1.get_state()) writer2.write(collections.OrderedDict([('a', 1), ('b', 2)])) self.assertSequenceEqual( [mock.call('a,b\r\n'), mock.call('1,2\r\n')], self.fake_file.write.call_args_list)
def test_deserialize_after_header(self): """Tests that no header is written unnecessarily after deserialization.""" writer1 = parts.CsvWriter('test.csv') writer1.write(collections.OrderedDict([('a', 1), ('b', 2)])) self.assertSequenceEqual( [mock.call('a,b\r\n'), mock.call('1,2\r\n')], self.fake_file.write.call_args_list) writer2 = parts.CsvWriter('test.csv') writer2.set_state(writer1.get_state()) writer2.write(collections.OrderedDict([('a', 3), ('b', 4)])) self.assertSequenceEqual( [mock.call('a,b\r\n'), mock.call('1,2\r\n'), mock.call('3,4\r\n')], self.fake_file.write.call_args_list)
def test_error_new_keys(self): """Tests that an error is thrown when an unexpected key occurs.""" writer = parts.CsvWriter('test.csv') writer.write(collections.OrderedDict([('a', 1), ('b', 2)])) with self.assertRaisesRegex(ValueError, 'fields not in fieldnames'): writer.write( collections.OrderedDict([('a', 3), ('b', 4), ('c', 5)]))
def test_file_close(self): """Tests that file is closed on writer.close().""" writer = parts.CsvWriter('test.csv') writer.write(collections.OrderedDict([('a', 1), ('b', 2)])) self.fake_file.close.assert_not_called() writer.close() self.fake_file.close.assert_called_once_with()
def test_create_dir(self): """Tests that a csv file dir is created if it doesn't exist yet.""" with mock.patch('os.path.exists') as fake_exists, \ mock.patch('os.makedirs') as fake_makedirs: fake_exists.return_value = False dirname = '/some/sub/dir' _ = parts.CsvWriter(dirname + '/test.csv') fake_exists.assert_called_once_with(dirname) fake_makedirs.assert_called_once_with(dirname)
def test_missing_keys(self): """Tests that when a key is missing, an empty value is used.""" writer = parts.CsvWriter('test.csv') writer.write(collections.OrderedDict([('a', 1), ('b', 2), ('c', 3)])) writer.write(collections.OrderedDict([('a', 4), ('c', 6)])) self.assertSequenceEqual( [mock.call('a,b,c\r\n'), mock.call('1,2,3\r\n'), mock.call('4,,6\r\n')], self.fake_file.write.call_args_list)
def test_insertion_order_of_fields_preserved(self): """Tests that when a key is missing, an empty value is used.""" writer = parts.CsvWriter('test.csv') writer.write(collections.OrderedDict([('c', 3), ('a', 1), ('b', 2)])) writer.write(collections.OrderedDict([('b', 5), ('c', 6), ('a', 4)])) self.assertSequenceEqual([ mock.call('c,a,b\r\n'), mock.call('3,1,2\r\n'), mock.call('6,4,5\r\n') ], self.fake_file.write.call_args_list)
def test_file_writes(self): """Tests that file is written correctly.""" writer = parts.CsvWriter('test.csv') self.fake_file.write.assert_not_called() writer.write(collections.OrderedDict([('a', 1), ('b', 2)])) self.assertSequenceEqual( [mock.call('a,b\r\n'), mock.call('1,2\r\n')], self.fake_file.write.call_args_list) writer.write(collections.OrderedDict([('a', 3), ('b', 4)])) self.assertSequenceEqual( [mock.call('a,b\r\n'), mock.call('1,2\r\n'), mock.call('3,4\r\n')], self.fake_file.write.call_args_list)
def main(argv): """Trains Prioritized DQN agent on Atari.""" del argv logging.info('Prioritized DQN on Atari on %s.', jax.lib.xla_bridge.get_backend().platform) random_state = np.random.RandomState(FLAGS.seed) rng_key = jax.random.PRNGKey( random_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64)) if FLAGS.results_csv_path: writer = parts.CsvWriter(FLAGS.results_csv_path) else: writer = parts.NullWriter() def environment_builder(): """Creates Atari environment.""" env = gym_atari.GymAtari( FLAGS.environment_name, seed=random_state.randint(1, 2**32)) return gym_atari.RandomNoopsEnvironmentWrapper( env, min_noop_steps=1, max_noop_steps=30, seed=random_state.randint(1, 2**32), ) env = environment_builder() logging.info('Environment: %s', FLAGS.environment_name) logging.info('Action spec: %s', env.action_spec()) logging.info('Observation spec: %s', env.observation_spec()) num_actions = env.action_spec().num_values network_fn = networks.double_dqn_atari_network(num_actions) network = hk.transform(network_fn) def preprocessor_builder(): return processors.atari( additional_discount=FLAGS.additional_discount, max_abs_reward=FLAGS.max_abs_reward, resize_shape=(FLAGS.environment_height, FLAGS.environment_width), num_action_repeats=FLAGS.num_action_repeats, num_pooled_frames=2, zero_discount_on_life_loss=True, num_stacked_frames=FLAGS.num_stacked_frames, grayscaling=True, ) # Create sample network input from sample preprocessor output. sample_processed_timestep = preprocessor_builder()(env.reset()) sample_processed_timestep = typing.cast(dm_env.TimeStep, sample_processed_timestep) sample_network_input = sample_processed_timestep.observation chex.assert_shape(sample_network_input, (FLAGS.environment_height, FLAGS.environment_width, FLAGS.num_stacked_frames)) exploration_epsilon_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity * FLAGS.num_action_repeats), decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction * FLAGS.num_iterations * FLAGS.num_train_frames), begin_value=FLAGS.exploration_epsilon_begin_value, end_value=FLAGS.exploration_epsilon_end_value) # Note the t in the replay is not exactly aligned with the agent t. importance_sampling_exponent_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity), end_t=(FLAGS.num_iterations * int(FLAGS.num_train_frames / FLAGS.num_action_repeats)), begin_value=FLAGS.importance_sampling_exponent_begin_value, end_value=FLAGS.importance_sampling_exponent_end_value) if FLAGS.compress_state: def encoder(transition): return transition._replace( s_tm1=replay_lib.compress_array(transition.s_tm1), s_t=replay_lib.compress_array(transition.s_t)) def decoder(transition): return transition._replace( s_tm1=replay_lib.uncompress_array(transition.s_tm1), s_t=replay_lib.uncompress_array(transition.s_t)) else: encoder = None decoder = None replay_structure = replay_lib.Transition( s_tm1=None, a_tm1=None, r_t=None, discount_t=None, s_t=None, ) replay = replay_lib.PrioritizedTransitionReplay( FLAGS.replay_capacity, replay_structure, FLAGS.priority_exponent, importance_sampling_exponent_schedule, FLAGS.uniform_sample_probability, FLAGS.normalize_weights, random_state, encoder, decoder) optimizer = optax.rmsprop( learning_rate=FLAGS.learning_rate, decay=0.95, eps=FLAGS.optimizer_epsilon, centered=True, ) train_rng_key, eval_rng_key = jax.random.split(rng_key) train_agent = agent.PrioritizedDqn( preprocessor=preprocessor_builder(), sample_network_input=sample_network_input, network=network, optimizer=optimizer, transition_accumulator=replay_lib.TransitionAccumulator(), replay=replay, batch_size=FLAGS.batch_size, exploration_epsilon=exploration_epsilon_schedule, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, grad_error_bound=FLAGS.grad_error_bound, rng_key=train_rng_key, ) eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=network, exploration_epsilon=FLAGS.eval_exploration_epsilon, rng_key=eval_rng_key, ) # Set up checkpointing. checkpoint = parts.NullCheckpoint() state = checkpoint.state state.iteration = 0 state.train_agent = train_agent state.eval_agent = eval_agent state.random_state = random_state state.writer = writer if checkpoint.can_be_restored(): checkpoint.restore() while state.iteration <= FLAGS.num_iterations: # New environment for each iteration to allow for determinism if preempted. env = environment_builder() logging.info('Training iteration %d.', state.iteration) train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode) num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames train_seq_truncated = itertools.islice(train_seq, num_train_frames) train_trackers = parts.make_default_trackers(train_agent) train_stats = parts.generate_statistics(train_trackers, train_seq_truncated) logging.info('Evaluation iteration %d.', state.iteration) eval_agent.network_params = train_agent.online_params eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode) eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames) eval_trackers = parts.make_default_trackers(eval_agent) eval_stats = parts.generate_statistics(eval_trackers, eval_seq_truncated) # Logging and checkpointing. human_normalized_score = atari_data.get_human_normalized_score( FLAGS.environment_name, eval_stats['episode_return']) capped_human_normalized_score = np.amin([1., human_normalized_score]) log_output = [ ('iteration', state.iteration, '%3d'), ('frame', state.iteration * FLAGS.num_train_frames, '%5d'), ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'), ('train_episode_return', train_stats['episode_return'], '% 2.2f'), ('eval_num_episodes', eval_stats['num_episodes'], '%3d'), ('train_num_episodes', train_stats['num_episodes'], '%3d'), ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'), ('train_frame_rate', train_stats['step_rate'], '%4.0f'), ('train_exploration_epsilon', train_agent.exploration_epsilon, '%.3f'), ('train_state_value', train_stats['state_value'], '%.3f'), ('importance_sampling_exponent', train_agent.importance_sampling_exponent, '%.3f'), ('max_seen_priority', train_agent.max_seen_priority, '%.3f'), ('normalized_return', human_normalized_score, '%.3f'), ('capped_normalized_return', capped_human_normalized_score, '%.3f'), ('human_gap', 1. - capped_human_normalized_score, '%.3f'), ] log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output) logging.info(log_output_str) writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output)) state.iteration += 1 checkpoint.save() writer.close()
def main(argv): """Trains DQN agent on Atari.""" del argv logging.info("DQN on Atari on %s.", jax.lib.xla_bridge.get_backend().platform) random_state = np.random.RandomState(FLAGS.seed) rng_key = jax.random.PRNGKey( random_state.randint(-sys.maxsize - 1, sys.maxsize + 1)) if FLAGS.results_csv_path: writer = parts.CsvWriter(FLAGS.results_csv_path) else: writer = parts.NullWriter() def environment_builder(): """Creates Key-Door environment.""" env = gym_key_door.GymKeyDoor( env_args={ constants.MAP_ASCII_PATH: FLAGS.map_ascii_path, constants.MAP_YAML_PATH: FLAGS.map_yaml_path, constants.REPRESENTATION: constants.PIXEL, constants.SCALING: FLAGS.env_scaling, constants.EPISODE_TIMEOUT: FLAGS.max_frames_per_episode, constants.GRAYSCALE: False, constants.BATCH_DIMENSION: False, constants.TORCH_AXES: False, }, env_shape=FLAGS.env_shape, ) return gym_atari.RandomNoopsEnvironmentWrapper( env, min_noop_steps=1, max_noop_steps=30, seed=random_state.randint(1, 2**32), ) env = environment_builder() logging.info("Environment: %s", FLAGS.environment_name) logging.info("Action spec: %s", env.action_spec()) logging.info("Observation spec: %s", env.observation_spec()) num_actions = env.action_spec().num_values network_fn = networks.dqn_atari_network(num_actions) network = hk.transform(network_fn) def preprocessor_builder(): return processors.atari( additional_discount=FLAGS.additional_discount, max_abs_reward=FLAGS.max_abs_reward, resize_shape=(FLAGS.environment_height, FLAGS.environment_width), num_action_repeats=FLAGS.num_action_repeats, num_pooled_frames=2, zero_discount_on_life_loss=True, num_stacked_frames=FLAGS.num_stacked_frames, grayscaling=True, ) # Create sample network input from sample preprocessor output. sample_processed_timestep = preprocessor_builder()(env.reset()) sample_processed_timestep = typing.cast(dm_env.TimeStep, sample_processed_timestep) sample_network_input = sample_processed_timestep.observation assert sample_network_input.shape == ( FLAGS.environment_height, FLAGS.environment_width, FLAGS.num_stacked_frames, ) exploration_epsilon_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity * FLAGS.num_action_repeats), decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction * FLAGS.num_iterations * FLAGS.num_train_frames), begin_value=FLAGS.exploration_epsilon_begin_value, end_value=FLAGS.exploration_epsilon_end_value, ) if FLAGS.compress_state: def encoder(transition): return transition._replace( s_tm1=replay_lib.compress_array(transition.s_tm1), s_t=replay_lib.compress_array(transition.s_t), ) def decoder(transition): return transition._replace( s_tm1=replay_lib.uncompress_array(transition.s_tm1), s_t=replay_lib.uncompress_array(transition.s_t), ) else: encoder = None decoder = None replay_structure = replay_lib.Transition( s_tm1=None, a_tm1=None, r_t=None, discount_t=None, s_t=None, ) replay = replay_lib.TransitionReplay(FLAGS.replay_capacity, replay_structure, random_state, encoder, decoder) optimizer = optax.rmsprop( learning_rate=FLAGS.learning_rate, decay=0.95, eps=FLAGS.optimizer_epsilon, centered=True, ) if FLAGS.shaping_function_type == constants.NO_PENALTY: shaping_function = shaping.NoPenalty() if FLAGS.shaping_function_type == constants.HARD_CODED_PENALTY: shaping_function = shaping.HardCodedPenalty( penalty=FLAGS.shaping_multiplicative_factor) train_rng_key, eval_rng_key = jax.random.split(rng_key) train_agent = agent.Dqn( preprocessor=preprocessor_builder(), sample_network_input=sample_network_input, network=network, optimizer=optimizer, transition_accumulator=replay_lib.TransitionAccumulator(), replay=replay, shaping_function=shaping_function, batch_size=FLAGS.batch_size, exploration_epsilon=exploration_epsilon_schedule, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, grad_error_bound=FLAGS.grad_error_bound, rng_key=train_rng_key, ) eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=network, exploration_epsilon=FLAGS.eval_exploration_epsilon, rng_key=eval_rng_key, ) # Set up checkpointing. # checkpoint = parts.NullCheckpoint() checkpoint = parts.ImplementedCheckpoint( checkpoint_path=FLAGS.checkpoint_path) if checkpoint.can_be_restored(): checkpoint.restore() train_agent.set_state(state=checkpoint.state.train_agent) eval_agent.set_state(state=checkpoint.state.eval_agent) writer.set_state(state=checkpoint.state.writer) state = checkpoint.state state.iteration = 0 state.train_agent = train_agent.get_state() state.eval_agent = eval_agent.get_state() state.random_state = random_state state.writer = writer.get_state() while state.iteration <= FLAGS.num_iterations: # New environment for each iteration to allow for determinism if preempted. env = environment_builder() logging.info("Training iteration %d.", state.iteration) train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode) num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames train_seq_truncated = itertools.islice(train_seq, num_train_frames) train_stats = parts.generate_statistics(train_seq_truncated) logging.info("Evaluation iteration %d.", state.iteration) eval_agent.network_params = train_agent.online_params eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode) eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames) eval_stats = parts.generate_statistics(eval_seq_truncated) # Logging and checkpointing. human_normalized_score = atari_data.get_human_normalized_score( FLAGS.environment_name, eval_stats["episode_return"]) capped_human_normalized_score = np.amin([1.0, human_normalized_score]) log_output = [ ("iteration", state.iteration, "%3d"), ("frame", state.iteration * FLAGS.num_train_frames, "%5d"), ("eval_episode_return", eval_stats["episode_return"], "% 2.2f"), ("train_episode_return", train_stats["episode_return"], "% 2.2f"), ("eval_num_episodes", eval_stats["num_episodes"], "%3d"), ("train_num_episodes", train_stats["num_episodes"], "%3d"), ("eval_frame_rate", eval_stats["step_rate"], "%4.0f"), ("train_frame_rate", train_stats["step_rate"], "%4.0f"), ("train_exploration_epsilon", train_agent.exploration_epsilon, "%.3f"), ("normalized_return", human_normalized_score, "%.3f"), ("capped_normalized_return", capped_human_normalized_score, "%.3f"), ("human_gap", 1.0 - capped_human_normalized_score, "%.3f"), ] log_output_str = ", ".join( ("%s: " + f) % (n, v) for n, v, f in log_output) logging.info(log_output_str) writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output)) state.iteration += 1 checkpoint.save() writer.close()
def test_file_close_on_delete(self): """Tests that file is closed if writer gets deleted.""" writer = parts.CsvWriter('test.csv') writer.write(collections.OrderedDict([('a', 1), ('b', 2)])) del writer self.fake_file.close.assert_called_once_with()
def test_file_open(self): """Tests that a file with correct name is opened.""" _ = parts.CsvWriter('testabc.csv') self.mock_open.assert_called_once_with('testabc.csv', mock.ANY)
def main(argv): """ Train pick-up and drop-off Rainbow agents on ODySSEUS. """ del argv # Unused arguments # Metadata configuration parent_dir = pathlib.Path(__file__).parent.absolute() sim_input_conf_dir = parent_dir / 'configs' / DEFAULT_sim_scenario_name # Load configuration sim_conf = importlib.import_module('esbdqn.configs.{}.{}' .format(DEFAULT_sim_scenario_name, FLAGS.conf_filename)) # Extract a single conf pair sim_general_conf = EFFCS_SimConfGrid(sim_conf.General) \ .conf_list[0] sim_scenario_conf = EFFCS_SimConfGrid(sim_conf.Multiple_runs) \ .conf_list[0] experiment_dir = parent_dir \ / 'experiments' \ / DEFAULT_sim_scenario_name \ / FLAGS.exp_name \ / sim_general_conf['city'] if pathlib.Path.exists(experiment_dir): # Ensure configuration has not changed if not filecmp.cmp(str(sim_input_conf_dir / FLAGS.conf_filename) + '.py', str(experiment_dir / DEFAULT_conf_filename) + ".py", shallow=False): raise IOError('Configuration changed at: {}' .format(str(experiment_dir))) else: pathlib.Path.mkdir(experiment_dir, parents=True, exist_ok=True) # Copy configuration files shutil.rmtree(experiment_dir) shutil.copytree(sim_input_conf_dir, experiment_dir) # Rename to the default name conf_filepath = experiment_dir / (FLAGS.conf_filename + ".py") conf_filepath.rename(experiment_dir / (DEFAULT_conf_filename + ".py")) # Delete all other potential conf files for filename in experiment_dir.glob( DEFAULT_conf_filename + "_*.py"): filename.unlink() # Create results files results_dir = experiment_dir / 'results' pathlib.Path.mkdir(results_dir, parents=True, exist_ok=True) results_filepath = results_dir / DEFAULT_resu_filename logging.info('Rainbow agents on ODySSEUS running on %s.', jax.lib.xla_bridge.get_backend().platform.upper()) if FLAGS.checkpoint: checkpoint = PickleCheckpoint( experiment_dir / 'models', 'ODySSEUS-' + sim_general_conf['city']) else: checkpoint = parts.NullCheckpoint() checkpoint_restored = False if FLAGS.checkpoint: if checkpoint.can_be_restored(): logging.info('Restoring checkpoint...') checkpoint.restore() checkpoint_restored = True # Generate RNG key rng_state = np.random.RandomState(FLAGS.seed) if checkpoint_restored: rng_state.set_state(checkpoint.state .rng_state) rng_key = jax.random.PRNGKey( rng_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64)) # Generate results file writer if sim_general_conf['save_history']: writer = parts.CsvWriter(str(results_filepath)) if checkpoint_restored: writer.set_state(checkpoint.state .writer) else: writer = parts.NullWriter() def environment_builder() -> ConstrainedEnvironment: """ Create the ODySSEUS environment. """ return EscooterSimulator( (sim_general_conf, sim_scenario_conf), FLAGS.n_lives) def preprocessor_builder(): """ Create the ODySSEUS input preprocessor. """ return processor( max_abs_reward=FLAGS.max_abs_reward, zero_discount_on_life_loss=True ) env = environment_builder() logging.info('Environment: %s', FLAGS.exp_name) logging.info('Action spec: %s', env.action_spec()) logging.info('Observation spec: %s', env.observation_spec()) # Take [0] as both Rainbow have # the same number of actions num_actions = env.action_spec()[0].num_values support = jnp.linspace(-FLAGS.vmax, FLAGS.vmax, FLAGS.num_atoms) network = hk.transform(rainbow_odysseus_network( num_actions, support, FLAGS.noisy_weight_init)) # Create sample network input from reset. sample_processed_timestep = preprocessor_builder()(env.reset()) sample_processed_timestep = t.cast(dm_env.TimeStep, sample_processed_timestep) sample_processed_network_input = sample_processed_timestep.observation # Note the t in the replay is not exactly # aligned with the Rainbow agents t. importance_sampling_exponent_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity), end_t=(FLAGS.num_iterations * FLAGS.num_train_frames), begin_value=FLAGS.importance_sampling_exponent_begin_value, end_value=FLAGS.importance_sampling_exponent_end_value) if FLAGS.compress_state: def encoder(transition): return transition._replace( s_tm1=replay.compress_array(transition.s_tm1), s_t=replay.compress_array(transition.s_t)) def decoder(transition): return transition._replace( s_tm1=replay.uncompress_array(transition.s_tm1), s_t=replay.uncompress_array(transition.s_t)) else: encoder = None decoder = None replay_struct = replay.Transition( s_tm1=None, a_tm1=None, r_t=None, discount_t=None, s_t=None, ) transition_accumulator = replay.NStepTransitionAccumulator(FLAGS.n_steps) transition_replay = replay.PrioritizedTransitionReplay( FLAGS.replay_capacity, replay_struct, FLAGS.priority_exponent, importance_sampling_exponent_schedule, FLAGS.uniform_sample_probability, FLAGS.normalize_weights, rng_state, encoder, decoder) optimizer = optax.adam( learning_rate=FLAGS.learning_rate, eps=FLAGS.optimizer_epsilon) if FLAGS.max_global_grad_norm > 0: optimizer = optax.chain( optax.clip_by_global_norm( FLAGS.max_global_grad_norm), optimizer) train_rng_key, eval_rng_key = jax.random.split(rng_key) # Create pick-up/drop-off agents P_train_agent = agent.Rainbow( preprocessor=preprocessor_builder(), sample_network_input=copy.deepcopy(sample_processed_network_input), network=copy.deepcopy(network), support=copy.deepcopy(support), optimizer=copy.deepcopy(optimizer), transition_accumulator=copy.deepcopy(transition_accumulator), replay=copy.deepcopy(transition_replay), batch_size=FLAGS.batch_size, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, rng_key=train_rng_key, ) D_train_agent = agent.Rainbow( preprocessor=preprocessor_builder(), sample_network_input=copy.deepcopy(sample_processed_network_input), network=copy.deepcopy(network), support=copy.deepcopy(support), optimizer=copy.deepcopy(optimizer), transition_accumulator=copy.deepcopy(transition_accumulator), replay=copy.deepcopy(transition_replay), batch_size=FLAGS.batch_size, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, rng_key=train_rng_key, ) P_eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=copy.deepcopy(network), exploration_epsilon=0, rng_key=eval_rng_key, ) D_eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=copy.deepcopy(network), exploration_epsilon=0, rng_key=eval_rng_key, ) if checkpoint_restored: P_train_agent.set_state(checkpoint.state.P_agent['train']) D_train_agent.set_state(checkpoint.state.D_agent['train']) P_eval_agent.set_state(checkpoint.state.P_agent['eval']) D_eval_agent.set_state(checkpoint.state.D_agent['eval']) state = checkpoint.state if not checkpoint_restored: state.iteration = 0 state.P_agent = {} state.D_agent = {} state.rng_state = rng_state state.writer = writer state.P_agent['train'] = P_train_agent state.D_agent['train'] = D_train_agent state.P_agent['eval'] = P_eval_agent state.D_agent['eval'] = D_eval_agent while state.iteration < FLAGS.num_iterations: # Create a new environment at each new iteration # to allow for determinism if preempted. env = environment_builder() # Leave some spacing print('\n') logging.info('Training iteration: %d', state.iteration) train_trackers = make_odysseus_trackers(FLAGS.max_abs_reward) eval_trackers = make_odysseus_trackers(FLAGS.max_abs_reward) train_seq = run_loop(P_train_agent, D_train_agent, env, FLAGS.max_steps_per_episode) num_train_frames = 0 \ if state.iteration == 0 \ else FLAGS.num_train_frames train_seq_truncated = it.islice(train_seq, num_train_frames) train_stats = generate_statistics(train_trackers, train_seq_truncated) logging.info('Evaluation iteration: %d', state.iteration) # Synchronize network parameters P_eval_agent.network_params = P_train_agent.online_params D_eval_agent.network_params = P_train_agent.online_params eval_seq = run_loop(P_eval_agent, D_eval_agent, env, FLAGS.max_steps_per_episode) eval_seq_truncated = it.islice(eval_seq, FLAGS.num_eval_frames) eval_stats = generate_statistics(eval_trackers, eval_seq_truncated) # Logging and checkpointing L = [ # Simulation metadata ('iteration', state.iteration, '%3d'), # ODySSEUS metadata ('n_charging_workers', sim_scenario_conf['n_workers'], '%3d'), ('n_relocation_workers', sim_scenario_conf['n_relocation_workers'], '%3d'), ('n_vehicles', sim_scenario_conf['n_vehicles'], '%3d'), ('pct_incentive_willingness', sim_scenario_conf['incentive_willingness'], '%2.2f'), ('zone_side_m', sim_general_conf['bin_side_length'], '%3d'), # Validation agents ('eval_num_episodes', eval_stats['num_episodes'], '%3d'), ('eval_P_episode_return', eval_stats['episode_return'][0], '%2.2f'), ('eval_D_episode_return', eval_stats['episode_return'][1], '%2.2f'), ('eval_min_n_accepted_incentives', np.min(eval_stats['episodes_n_accepted_incentives']), '%2.2f'), ('eval_avg_n_accepted_incentives', np.mean(eval_stats['episodes_n_accepted_incentives']), '%2.2f'), ('eval_max_n_accepted_incentives', np.max(eval_stats['episodes_n_accepted_incentives']), '%2.2f'), ('eval_min_n_lives', np.min(eval_stats['episodes_n_lives']), '%2.2f'), ('eval_avg_n_lives', np.mean(eval_stats['episodes_n_lives']), '%2.2f'), ('eval_max_n_lives', np.max(eval_stats['episodes_n_lives']), '%2.2f'), ('eval_min_pct_satisfied_demand', np.min(eval_stats['pct_satisfied_demands']), '%2.2f'), ('eval_avg_pct_satisfied_demand', np.mean(eval_stats['pct_satisfied_demands']), '%2.2f'), ('eval_max_pct_satisfied_demand', np.max(eval_stats['pct_satisfied_demands']), '%2.2f'), # Training agents ('train_num_episodes', train_stats['num_episodes'], '%3d'), ('train_P_episode_return', train_stats['episode_return'][0], '%2.2f'), ('train_D_episode_return', train_stats['episode_return'][1], '%2.2f'), ('train_min_n_accepted_incentives', np.min(train_stats['episodes_n_accepted_incentives']), '%2.2f'), ('train_avg_n_accepted_incentives', np.mean(train_stats['episodes_n_accepted_incentives']), '%2.2f'), ('train_max_n_accepted_incentives', np.max(train_stats['episodes_n_accepted_incentives']), '%2.2f'), ('train_min_n_lives', np.min(train_stats['episodes_n_lives']), '%2.2f'), ('train_avg_n_lives', np.mean(train_stats['episodes_n_lives']), '%2.2f'), ('train_mac_n_lives', np.max(train_stats['episodes_n_lives']), '%2.2f'), ('train_min_pct_satisfied_demand', np.min(train_stats['pct_satisfied_demands']), '%2.2f'), ('train_avg_pct_satisfied_demand', np.mean(train_stats['pct_satisfied_demands']), '%2.2f'), ('train_max_pct_satisfied_demand', np.max(train_stats['pct_satisfied_demands']), '%2.2f'), ('P_importance_sampling_exponent', P_train_agent.importance_sampling_exponent, '%.3f'), ('D_importance_sampling_exponent', D_train_agent.importance_sampling_exponent, '%.3f'), ('P_max_seen_priority', P_train_agent.max_seen_priority, '%.3f'), ('D_max_seen_priority', D_train_agent.max_seen_priority, '%.3f'), ] L_str = '\n'.join(('%s: ' + f) % (n, v) for n, v, f in L) logging.info(L_str) if state.iteration == \ FLAGS.num_iterations - 1: print('\n') writer.write(collections.OrderedDict( (n, v) for n, v, _ in L)) state.iteration += 1 if state.iteration \ % FLAGS.checkpoint_period == 0: checkpoint.save() writer.close()
def main(argv): """Trains DQN agent on Atari.""" del argv logging.info('Boostrapped DQN on Atari on %s.', jax.lib.xla_bridge.get_backend().platform) random_state = np.random.RandomState(FLAGS.seed) rng_key = jax.random.PRNGKey( random_state.randint(-sys.maxsize - 1, sys.maxsize + 1)) if FLAGS.results_csv_path: writer = parts.CsvWriter(FLAGS.results_csv_path) else: writer = parts.NullWriter() def environment_builder(): """Creates Atari environment.""" env = gym_atari.GymAtari(FLAGS.environment_name, seed=random_state.randint(1, 2**32)) return gym_atari.RandomNoopsEnvironmentWrapper( env, min_noop_steps=1, max_noop_steps=30, seed=random_state.randint(1, 2**32), ) env = environment_builder() logging.info('Environment: %s', FLAGS.environment_name) logging.info('Action spec: %s', env.action_spec()) logging.info('Observation spec: %s', env.observation_spec()) num_actions = env.action_spec().num_values network_fn = networks.bootstrapped_dqn_multi_head_network( num_actions, num_heads=FLAGS.num_heads, mask_probability=FLAGS.mask_probability) network = hk.transform(network_fn) def preprocessor_builder(): return processors.atari( additional_discount=FLAGS.additional_discount, max_abs_reward=FLAGS.max_abs_reward, resize_shape=(FLAGS.environment_height, FLAGS.environment_width), num_action_repeats=FLAGS.num_action_repeats, num_pooled_frames=2, zero_discount_on_life_loss=True, num_stacked_frames=FLAGS.num_stacked_frames, grayscaling=True, ) # Create sample network input from sample preprocessor output. sample_processed_timestep = preprocessor_builder()(env.reset()) sample_processed_timestep = typing.cast(dm_env.TimeStep, sample_processed_timestep) sample_network_input = sample_processed_timestep.observation assert sample_network_input.shape == (FLAGS.environment_height, FLAGS.environment_width, FLAGS.num_stacked_frames) exploration_epsilon_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity * FLAGS.num_action_repeats), decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction * FLAGS.num_iterations * FLAGS.num_train_frames), begin_value=FLAGS.exploration_epsilon_begin_value, end_value=FLAGS.exploration_epsilon_end_value) if FLAGS.compress_state: def encoder(transition): return transition._replace( s_tm1=replay_lib.compress_array(transition.s_tm1), s_t=replay_lib.compress_array(transition.s_t)) def decoder(transition): return transition._replace( s_tm1=replay_lib.uncompress_array(transition.s_tm1), s_t=replay_lib.uncompress_array(transition.s_t)) else: encoder = None decoder = None replay_structure = replay_lib.MaskedTransition( s_tm1=None, a_tm1=None, r_t=None, discount_t=None, s_t=None, mask_t=None, ) replay = replay_lib.TransitionReplay(FLAGS.replay_capacity, replay_structure, random_state, encoder, decoder) optimizer = optax.rmsprop( learning_rate=FLAGS.learning_rate, decay=0.95, eps=FLAGS.optimizer_epsilon, centered=True, ) if FLAGS.shaping_function_type == constants.NO_PENALTY: shaping_function = shaping.NoPenalty() elif FLAGS.shaping_function_type == constants.HARD_CODED_PENALTY: shaping_function = shaping.HardCodedPenalty( penalty=FLAGS.shaping_multiplicative_factor) elif FLAGS.shaping_function_type == constants.UNCERTAINTY_PENALTY: shaping_function = shaping.UncertaintyPenalty( multiplicative_factor=FLAGS.shaping_multiplicative_factor) elif FLAGS.shaping_function_type == constants.POLICY_ENTROPY_PENALTY: shaping_function = shaping.PolicyEntropyPenalty( multiplicative_factor=FLAGS.shaping_multiplicative_factor, num_actions=num_actions) elif FLAGS.shaping_function_type == constants.MUNCHAUSEN_PENALTY: shaping_function = shaping.MunchausenPenalty( multiplicative_factor=FLAGS.shaping_multiplicative_factor, num_actions=num_actions) train_rng_key, eval_rng_key = jax.random.split(rng_key) train_agent = agent.BootstrappedDqn( preprocessor=preprocessor_builder(), sample_network_input=sample_network_input, network=network, optimizer=optimizer, transition_accumulator=replay_lib.TransitionAccumulator(), replay=replay, shaping_function=shaping_function, mask_probability=FLAGS.mask_probability, num_heads=FLAGS.num_heads, batch_size=FLAGS.batch_size, exploration_epsilon=exploration_epsilon_schedule, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, grad_error_bound=FLAGS.grad_error_bound, rng_key=train_rng_key, ) eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=network, exploration_epsilon=FLAGS.eval_exploration_epsilon, rng_key=eval_rng_key, ) # Set up checkpointing. # checkpoint = parts.NullCheckpoint() checkpoint = parts.ImplementedCheckpoint( checkpoint_path=FLAGS.checkpoint_path) if checkpoint.can_be_restored(): checkpoint.restore() iteration = checkpoint.state.iteration random_state = checkpoint.state.random_state train_agent.set_state(state=checkpoint.state.train_agent) eval_agent.set_state(state=checkpoint.state.eval_agent) writer.set_state(state=checkpoint.state.writer) else: iteration = 0 while iteration <= FLAGS.num_iterations: # New environment for each iteration to allow for determinism if preempted. env = environment_builder() logging.info('Training iteration %d.', iteration) train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode) num_train_frames = 0 if iteration == 0 else FLAGS.num_train_frames train_seq_truncated = itertools.islice(train_seq, num_train_frames) train_stats = parts.generate_statistics(train_seq_truncated) logging.info('Evaluation iteration %d.', iteration) eval_agent.network_params = train_agent.online_params eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode) eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames) eval_stats = parts.generate_statistics(eval_seq_truncated) # Logging and checkpointing. human_normalized_score = atari_data.get_human_normalized_score( FLAGS.environment_name, eval_stats['episode_return']) capped_human_normalized_score = np.amin([1., human_normalized_score]) log_output = [ ('iteration', iteration, '%3d'), ('frame', iteration * FLAGS.num_train_frames, '%5d'), ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'), ('train_episode_return', train_stats['episode_return'], '% 2.2f'), ('eval_num_episodes', eval_stats['num_episodes'], '%3d'), ('train_num_episodes', train_stats['num_episodes'], '%3d'), ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'), ('train_frame_rate', train_stats['step_rate'], '%4.0f'), ('train_exploration_epsilon', train_agent.exploration_epsilon, '%.3f'), ('normalized_return', human_normalized_score, '%.3f'), ('capped_normalized_return', capped_human_normalized_score, '%.3f'), ('human_gap', 1. - capped_human_normalized_score, '%.3f'), ('train_loss', train_stats['train_loss'], '% 2.2f'), ('shaped_reward', train_stats['shaped_reward'], '% 2.2f'), ('penalties', train_stats['penalties'], '% 2.2f') ] log_output_str = ', '.join( ('%s: ' + f) % (n, v) for n, v, f in log_output) logging.info(log_output_str) writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output)) iteration += 1 # update state before checkpointing checkpoint.state.iteration = iteration checkpoint.state.train_agent = train_agent.get_state() checkpoint.state.eval_agent = eval_agent.get_state() checkpoint.state.random_state = random_state checkpoint.state.writer = writer.get_state() checkpoint.save() writer.close()