Exemple #1
0
 def test_deserialize_before_header(self):
     """Tests that header is written after deserialization if not written yet."""
     writer1 = parts.CsvWriter('test.csv')
     self.fake_file.write.assert_not_called()
     writer2 = parts.CsvWriter('test.csv')
     writer2.set_state(writer1.get_state())
     writer2.write(collections.OrderedDict([('a', 1), ('b', 2)]))
     self.assertSequenceEqual(
         [mock.call('a,b\r\n'), mock.call('1,2\r\n')],
         self.fake_file.write.call_args_list)
Exemple #2
0
 def test_deserialize_after_header(self):
     """Tests that no header is written unnecessarily after deserialization."""
     writer1 = parts.CsvWriter('test.csv')
     writer1.write(collections.OrderedDict([('a', 1), ('b', 2)]))
     self.assertSequenceEqual(
         [mock.call('a,b\r\n'), mock.call('1,2\r\n')],
         self.fake_file.write.call_args_list)
     writer2 = parts.CsvWriter('test.csv')
     writer2.set_state(writer1.get_state())
     writer2.write(collections.OrderedDict([('a', 3), ('b', 4)]))
     self.assertSequenceEqual(
         [mock.call('a,b\r\n'),
          mock.call('1,2\r\n'),
          mock.call('3,4\r\n')], self.fake_file.write.call_args_list)
Exemple #3
0
 def test_error_new_keys(self):
     """Tests that an error is thrown when an unexpected key occurs."""
     writer = parts.CsvWriter('test.csv')
     writer.write(collections.OrderedDict([('a', 1), ('b', 2)]))
     with self.assertRaisesRegex(ValueError, 'fields not in fieldnames'):
         writer.write(
             collections.OrderedDict([('a', 3), ('b', 4), ('c', 5)]))
Exemple #4
0
 def test_file_close(self):
     """Tests that file is closed on writer.close()."""
     writer = parts.CsvWriter('test.csv')
     writer.write(collections.OrderedDict([('a', 1), ('b', 2)]))
     self.fake_file.close.assert_not_called()
     writer.close()
     self.fake_file.close.assert_called_once_with()
Exemple #5
0
 def test_create_dir(self):
     """Tests that a csv file dir is created if it doesn't exist yet."""
     with mock.patch('os.path.exists') as fake_exists, \
          mock.patch('os.makedirs') as fake_makedirs:
         fake_exists.return_value = False
         dirname = '/some/sub/dir'
         _ = parts.CsvWriter(dirname + '/test.csv')
         fake_exists.assert_called_once_with(dirname)
         fake_makedirs.assert_called_once_with(dirname)
Exemple #6
0
 def test_missing_keys(self):
   """Tests that when a key is missing, an empty value is used."""
   writer = parts.CsvWriter('test.csv')
   writer.write(collections.OrderedDict([('a', 1), ('b', 2), ('c', 3)]))
   writer.write(collections.OrderedDict([('a', 4), ('c', 6)]))
   self.assertSequenceEqual(
       [mock.call('a,b,c\r\n'),
        mock.call('1,2,3\r\n'),
        mock.call('4,,6\r\n')], self.fake_file.write.call_args_list)
Exemple #7
0
 def test_insertion_order_of_fields_preserved(self):
     """Tests that when a key is missing, an empty value is used."""
     writer = parts.CsvWriter('test.csv')
     writer.write(collections.OrderedDict([('c', 3), ('a', 1), ('b', 2)]))
     writer.write(collections.OrderedDict([('b', 5), ('c', 6), ('a', 4)]))
     self.assertSequenceEqual([
         mock.call('c,a,b\r\n'),
         mock.call('3,1,2\r\n'),
         mock.call('6,4,5\r\n')
     ], self.fake_file.write.call_args_list)
Exemple #8
0
 def test_file_writes(self):
     """Tests that file is written correctly."""
     writer = parts.CsvWriter('test.csv')
     self.fake_file.write.assert_not_called()
     writer.write(collections.OrderedDict([('a', 1), ('b', 2)]))
     self.assertSequenceEqual(
         [mock.call('a,b\r\n'), mock.call('1,2\r\n')],
         self.fake_file.write.call_args_list)
     writer.write(collections.OrderedDict([('a', 3), ('b', 4)]))
     self.assertSequenceEqual(
         [mock.call('a,b\r\n'),
          mock.call('1,2\r\n'),
          mock.call('3,4\r\n')], self.fake_file.write.call_args_list)
Exemple #9
0
def main(argv):
  """Trains Prioritized DQN agent on Atari."""
  del argv
  logging.info('Prioritized DQN on Atari on %s.',
               jax.lib.xla_bridge.get_backend().platform)
  random_state = np.random.RandomState(FLAGS.seed)
  rng_key = jax.random.PRNGKey(
      random_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64))

  if FLAGS.results_csv_path:
    writer = parts.CsvWriter(FLAGS.results_csv_path)
  else:
    writer = parts.NullWriter()

  def environment_builder():
    """Creates Atari environment."""
    env = gym_atari.GymAtari(
        FLAGS.environment_name, seed=random_state.randint(1, 2**32))
    return gym_atari.RandomNoopsEnvironmentWrapper(
        env,
        min_noop_steps=1,
        max_noop_steps=30,
        seed=random_state.randint(1, 2**32),
    )

  env = environment_builder()

  logging.info('Environment: %s', FLAGS.environment_name)
  logging.info('Action spec: %s', env.action_spec())
  logging.info('Observation spec: %s', env.observation_spec())
  num_actions = env.action_spec().num_values
  network_fn = networks.double_dqn_atari_network(num_actions)
  network = hk.transform(network_fn)

  def preprocessor_builder():
    return processors.atari(
        additional_discount=FLAGS.additional_discount,
        max_abs_reward=FLAGS.max_abs_reward,
        resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
        num_action_repeats=FLAGS.num_action_repeats,
        num_pooled_frames=2,
        zero_discount_on_life_loss=True,
        num_stacked_frames=FLAGS.num_stacked_frames,
        grayscaling=True,
    )

  # Create sample network input from sample preprocessor output.
  sample_processed_timestep = preprocessor_builder()(env.reset())
  sample_processed_timestep = typing.cast(dm_env.TimeStep,
                                          sample_processed_timestep)
  sample_network_input = sample_processed_timestep.observation
  chex.assert_shape(sample_network_input,
                    (FLAGS.environment_height, FLAGS.environment_width,
                     FLAGS.num_stacked_frames))

  exploration_epsilon_schedule = parts.LinearSchedule(
      begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity *
                  FLAGS.num_action_repeats),
      decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction *
                      FLAGS.num_iterations * FLAGS.num_train_frames),
      begin_value=FLAGS.exploration_epsilon_begin_value,
      end_value=FLAGS.exploration_epsilon_end_value)

  # Note the t in the replay is not exactly aligned with the agent t.
  importance_sampling_exponent_schedule = parts.LinearSchedule(
      begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity),
      end_t=(FLAGS.num_iterations *
             int(FLAGS.num_train_frames / FLAGS.num_action_repeats)),
      begin_value=FLAGS.importance_sampling_exponent_begin_value,
      end_value=FLAGS.importance_sampling_exponent_end_value)

  if FLAGS.compress_state:

    def encoder(transition):
      return transition._replace(
          s_tm1=replay_lib.compress_array(transition.s_tm1),
          s_t=replay_lib.compress_array(transition.s_t))

    def decoder(transition):
      return transition._replace(
          s_tm1=replay_lib.uncompress_array(transition.s_tm1),
          s_t=replay_lib.uncompress_array(transition.s_t))
  else:
    encoder = None
    decoder = None

  replay_structure = replay_lib.Transition(
      s_tm1=None,
      a_tm1=None,
      r_t=None,
      discount_t=None,
      s_t=None,
  )

  replay = replay_lib.PrioritizedTransitionReplay(
      FLAGS.replay_capacity, replay_structure, FLAGS.priority_exponent,
      importance_sampling_exponent_schedule, FLAGS.uniform_sample_probability,
      FLAGS.normalize_weights, random_state, encoder, decoder)

  optimizer = optax.rmsprop(
      learning_rate=FLAGS.learning_rate,
      decay=0.95,
      eps=FLAGS.optimizer_epsilon,
      centered=True,
  )

  train_rng_key, eval_rng_key = jax.random.split(rng_key)

  train_agent = agent.PrioritizedDqn(
      preprocessor=preprocessor_builder(),
      sample_network_input=sample_network_input,
      network=network,
      optimizer=optimizer,
      transition_accumulator=replay_lib.TransitionAccumulator(),
      replay=replay,
      batch_size=FLAGS.batch_size,
      exploration_epsilon=exploration_epsilon_schedule,
      min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
      learn_period=FLAGS.learn_period,
      target_network_update_period=FLAGS.target_network_update_period,
      grad_error_bound=FLAGS.grad_error_bound,
      rng_key=train_rng_key,
  )
  eval_agent = parts.EpsilonGreedyActor(
      preprocessor=preprocessor_builder(),
      network=network,
      exploration_epsilon=FLAGS.eval_exploration_epsilon,
      rng_key=eval_rng_key,
  )

  # Set up checkpointing.
  checkpoint = parts.NullCheckpoint()

  state = checkpoint.state
  state.iteration = 0
  state.train_agent = train_agent
  state.eval_agent = eval_agent
  state.random_state = random_state
  state.writer = writer
  if checkpoint.can_be_restored():
    checkpoint.restore()

  while state.iteration <= FLAGS.num_iterations:
    # New environment for each iteration to allow for determinism if preempted.
    env = environment_builder()

    logging.info('Training iteration %d.', state.iteration)
    train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode)
    num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames
    train_seq_truncated = itertools.islice(train_seq, num_train_frames)
    train_trackers = parts.make_default_trackers(train_agent)
    train_stats = parts.generate_statistics(train_trackers, train_seq_truncated)

    logging.info('Evaluation iteration %d.', state.iteration)
    eval_agent.network_params = train_agent.online_params
    eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode)
    eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames)
    eval_trackers = parts.make_default_trackers(eval_agent)
    eval_stats = parts.generate_statistics(eval_trackers, eval_seq_truncated)

    # Logging and checkpointing.
    human_normalized_score = atari_data.get_human_normalized_score(
        FLAGS.environment_name, eval_stats['episode_return'])
    capped_human_normalized_score = np.amin([1., human_normalized_score])
    log_output = [
        ('iteration', state.iteration, '%3d'),
        ('frame', state.iteration * FLAGS.num_train_frames, '%5d'),
        ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'),
        ('train_episode_return', train_stats['episode_return'], '% 2.2f'),
        ('eval_num_episodes', eval_stats['num_episodes'], '%3d'),
        ('train_num_episodes', train_stats['num_episodes'], '%3d'),
        ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'),
        ('train_frame_rate', train_stats['step_rate'], '%4.0f'),
        ('train_exploration_epsilon', train_agent.exploration_epsilon, '%.3f'),
        ('train_state_value', train_stats['state_value'], '%.3f'),
        ('importance_sampling_exponent',
         train_agent.importance_sampling_exponent, '%.3f'),
        ('max_seen_priority', train_agent.max_seen_priority, '%.3f'),
        ('normalized_return', human_normalized_score, '%.3f'),
        ('capped_normalized_return', capped_human_normalized_score, '%.3f'),
        ('human_gap', 1. - capped_human_normalized_score, '%.3f'),
    ]
    log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output)
    logging.info(log_output_str)
    writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))
    state.iteration += 1
    checkpoint.save()

  writer.close()
Exemple #10
0
def main(argv):
    """Trains DQN agent on Atari."""
    del argv
    logging.info("DQN on Atari on %s.",
                 jax.lib.xla_bridge.get_backend().platform)
    random_state = np.random.RandomState(FLAGS.seed)
    rng_key = jax.random.PRNGKey(
        random_state.randint(-sys.maxsize - 1, sys.maxsize + 1))

    if FLAGS.results_csv_path:
        writer = parts.CsvWriter(FLAGS.results_csv_path)
    else:
        writer = parts.NullWriter()

    def environment_builder():
        """Creates Key-Door environment."""
        env = gym_key_door.GymKeyDoor(
            env_args={
                constants.MAP_ASCII_PATH: FLAGS.map_ascii_path,
                constants.MAP_YAML_PATH: FLAGS.map_yaml_path,
                constants.REPRESENTATION: constants.PIXEL,
                constants.SCALING: FLAGS.env_scaling,
                constants.EPISODE_TIMEOUT: FLAGS.max_frames_per_episode,
                constants.GRAYSCALE: False,
                constants.BATCH_DIMENSION: False,
                constants.TORCH_AXES: False,
            },
            env_shape=FLAGS.env_shape,
        )
        return gym_atari.RandomNoopsEnvironmentWrapper(
            env,
            min_noop_steps=1,
            max_noop_steps=30,
            seed=random_state.randint(1, 2**32),
        )

    env = environment_builder()

    logging.info("Environment: %s", FLAGS.environment_name)
    logging.info("Action spec: %s", env.action_spec())
    logging.info("Observation spec: %s", env.observation_spec())
    num_actions = env.action_spec().num_values
    network_fn = networks.dqn_atari_network(num_actions)
    network = hk.transform(network_fn)

    def preprocessor_builder():
        return processors.atari(
            additional_discount=FLAGS.additional_discount,
            max_abs_reward=FLAGS.max_abs_reward,
            resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
            num_action_repeats=FLAGS.num_action_repeats,
            num_pooled_frames=2,
            zero_discount_on_life_loss=True,
            num_stacked_frames=FLAGS.num_stacked_frames,
            grayscaling=True,
        )

    # Create sample network input from sample preprocessor output.
    sample_processed_timestep = preprocessor_builder()(env.reset())
    sample_processed_timestep = typing.cast(dm_env.TimeStep,
                                            sample_processed_timestep)
    sample_network_input = sample_processed_timestep.observation
    assert sample_network_input.shape == (
        FLAGS.environment_height,
        FLAGS.environment_width,
        FLAGS.num_stacked_frames,
    )

    exploration_epsilon_schedule = parts.LinearSchedule(
        begin_t=int(FLAGS.min_replay_capacity_fraction *
                    FLAGS.replay_capacity * FLAGS.num_action_repeats),
        decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction *
                        FLAGS.num_iterations * FLAGS.num_train_frames),
        begin_value=FLAGS.exploration_epsilon_begin_value,
        end_value=FLAGS.exploration_epsilon_end_value,
    )

    if FLAGS.compress_state:

        def encoder(transition):
            return transition._replace(
                s_tm1=replay_lib.compress_array(transition.s_tm1),
                s_t=replay_lib.compress_array(transition.s_t),
            )

        def decoder(transition):
            return transition._replace(
                s_tm1=replay_lib.uncompress_array(transition.s_tm1),
                s_t=replay_lib.uncompress_array(transition.s_t),
            )

    else:
        encoder = None
        decoder = None

    replay_structure = replay_lib.Transition(
        s_tm1=None,
        a_tm1=None,
        r_t=None,
        discount_t=None,
        s_t=None,
    )

    replay = replay_lib.TransitionReplay(FLAGS.replay_capacity,
                                         replay_structure, random_state,
                                         encoder, decoder)

    optimizer = optax.rmsprop(
        learning_rate=FLAGS.learning_rate,
        decay=0.95,
        eps=FLAGS.optimizer_epsilon,
        centered=True,
    )

    if FLAGS.shaping_function_type == constants.NO_PENALTY:
        shaping_function = shaping.NoPenalty()
    if FLAGS.shaping_function_type == constants.HARD_CODED_PENALTY:
        shaping_function = shaping.HardCodedPenalty(
            penalty=FLAGS.shaping_multiplicative_factor)

    train_rng_key, eval_rng_key = jax.random.split(rng_key)

    train_agent = agent.Dqn(
        preprocessor=preprocessor_builder(),
        sample_network_input=sample_network_input,
        network=network,
        optimizer=optimizer,
        transition_accumulator=replay_lib.TransitionAccumulator(),
        replay=replay,
        shaping_function=shaping_function,
        batch_size=FLAGS.batch_size,
        exploration_epsilon=exploration_epsilon_schedule,
        min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
        learn_period=FLAGS.learn_period,
        target_network_update_period=FLAGS.target_network_update_period,
        grad_error_bound=FLAGS.grad_error_bound,
        rng_key=train_rng_key,
    )
    eval_agent = parts.EpsilonGreedyActor(
        preprocessor=preprocessor_builder(),
        network=network,
        exploration_epsilon=FLAGS.eval_exploration_epsilon,
        rng_key=eval_rng_key,
    )

    # Set up checkpointing.
    # checkpoint = parts.NullCheckpoint()
    checkpoint = parts.ImplementedCheckpoint(
        checkpoint_path=FLAGS.checkpoint_path)

    if checkpoint.can_be_restored():
        checkpoint.restore()
        train_agent.set_state(state=checkpoint.state.train_agent)
        eval_agent.set_state(state=checkpoint.state.eval_agent)
        writer.set_state(state=checkpoint.state.writer)

    state = checkpoint.state
    state.iteration = 0
    state.train_agent = train_agent.get_state()
    state.eval_agent = eval_agent.get_state()
    state.random_state = random_state
    state.writer = writer.get_state()

    while state.iteration <= FLAGS.num_iterations:
        # New environment for each iteration to allow for determinism if preempted.
        env = environment_builder()

        logging.info("Training iteration %d.", state.iteration)
        train_seq = parts.run_loop(train_agent, env,
                                   FLAGS.max_frames_per_episode)
        num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames
        train_seq_truncated = itertools.islice(train_seq, num_train_frames)
        train_stats = parts.generate_statistics(train_seq_truncated)

        logging.info("Evaluation iteration %d.", state.iteration)
        eval_agent.network_params = train_agent.online_params
        eval_seq = parts.run_loop(eval_agent, env,
                                  FLAGS.max_frames_per_episode)
        eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames)
        eval_stats = parts.generate_statistics(eval_seq_truncated)

        # Logging and checkpointing.
        human_normalized_score = atari_data.get_human_normalized_score(
            FLAGS.environment_name, eval_stats["episode_return"])
        capped_human_normalized_score = np.amin([1.0, human_normalized_score])
        log_output = [
            ("iteration", state.iteration, "%3d"),
            ("frame", state.iteration * FLAGS.num_train_frames, "%5d"),
            ("eval_episode_return", eval_stats["episode_return"], "% 2.2f"),
            ("train_episode_return", train_stats["episode_return"], "% 2.2f"),
            ("eval_num_episodes", eval_stats["num_episodes"], "%3d"),
            ("train_num_episodes", train_stats["num_episodes"], "%3d"),
            ("eval_frame_rate", eval_stats["step_rate"], "%4.0f"),
            ("train_frame_rate", train_stats["step_rate"], "%4.0f"),
            ("train_exploration_epsilon", train_agent.exploration_epsilon,
             "%.3f"),
            ("normalized_return", human_normalized_score, "%.3f"),
            ("capped_normalized_return", capped_human_normalized_score,
             "%.3f"),
            ("human_gap", 1.0 - capped_human_normalized_score, "%.3f"),
        ]
        log_output_str = ", ".join(
            ("%s: " + f) % (n, v) for n, v, f in log_output)
        logging.info(log_output_str)
        writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))
        state.iteration += 1
        checkpoint.save()

    writer.close()
Exemple #11
0
 def test_file_close_on_delete(self):
     """Tests that file is closed if writer gets deleted."""
     writer = parts.CsvWriter('test.csv')
     writer.write(collections.OrderedDict([('a', 1), ('b', 2)]))
     del writer
     self.fake_file.close.assert_called_once_with()
Exemple #12
0
 def test_file_open(self):
     """Tests that a file with correct name is opened."""
     _ = parts.CsvWriter('testabc.csv')
     self.mock_open.assert_called_once_with('testabc.csv', mock.ANY)
Exemple #13
0
def main(argv):
    """
    Train pick-up and drop-off Rainbow agents on ODySSEUS.
    """
    del argv # Unused arguments

    # Metadata configuration
    parent_dir = pathlib.Path(__file__).parent.absolute()

    sim_input_conf_dir = parent_dir / 'configs' / DEFAULT_sim_scenario_name

    # Load configuration
    sim_conf = importlib.import_module('esbdqn.configs.{}.{}'
                                       .format(DEFAULT_sim_scenario_name,
                                               FLAGS.conf_filename))

    # Extract a single conf pair
    sim_general_conf  = EFFCS_SimConfGrid(sim_conf.General)       \
                                          .conf_list[0]
    sim_scenario_conf = EFFCS_SimConfGrid(sim_conf.Multiple_runs) \
                                          .conf_list[0]

    experiment_dir = parent_dir                     \
                        / 'experiments'             \
                        / DEFAULT_sim_scenario_name \
                        / FLAGS.exp_name            \
                        / sim_general_conf['city']

    if pathlib.Path.exists(experiment_dir):
        # Ensure configuration has not changed
        if not filecmp.cmp(str(sim_input_conf_dir
                               / FLAGS.conf_filename)   + '.py',
                           str(experiment_dir
                               / DEFAULT_conf_filename) + ".py",
                           shallow=False):
            raise IOError('Configuration changed at: {}'
                          .format(str(experiment_dir)))
    else:
        pathlib.Path.mkdir(experiment_dir, parents=True,
                           exist_ok=True)

        # Copy configuration files
        shutil.rmtree(experiment_dir)
        shutil.copytree(sim_input_conf_dir, experiment_dir)

        # Rename to the default name
        conf_filepath = experiment_dir / (FLAGS.conf_filename + ".py")
        conf_filepath.rename(experiment_dir
                             / (DEFAULT_conf_filename + ".py"))

        # Delete all other potential conf files
        for filename in experiment_dir.glob(
                DEFAULT_conf_filename + "_*.py"):
            filename.unlink()

    # Create results files
    results_dir = experiment_dir / 'results'

    pathlib.Path.mkdir(results_dir, parents=True,
                       exist_ok=True)

    results_filepath = results_dir / DEFAULT_resu_filename

    logging.info('Rainbow agents on ODySSEUS running on %s.',
                 jax.lib.xla_bridge.get_backend().platform.upper())

    if FLAGS.checkpoint:
        checkpoint = PickleCheckpoint(
            experiment_dir / 'models',
            'ODySSEUS-' + sim_general_conf['city'])
    else:
        checkpoint = parts.NullCheckpoint()

    checkpoint_restored = False

    if FLAGS.checkpoint:
        if checkpoint.can_be_restored():
            logging.info('Restoring checkpoint...')

            checkpoint.restore()
            checkpoint_restored = True

    # Generate RNG key
    rng_state = np.random.RandomState(FLAGS.seed)

    if checkpoint_restored:
        rng_state.set_state(checkpoint.state
                                      .rng_state)

    rng_key   = jax.random.PRNGKey(
        rng_state.randint(-sys.maxsize - 1,
                          sys.maxsize + 1,
                          dtype=np.int64))

    # Generate results file writer
    if sim_general_conf['save_history']:
        writer = parts.CsvWriter(str(results_filepath))

        if checkpoint_restored:
            writer.set_state(checkpoint.state
                                       .writer)
    else:
        writer = parts.NullWriter()

    def environment_builder() -> ConstrainedEnvironment:
        """
        Create the ODySSEUS environment.
        """
        return EscooterSimulator(
                        (sim_general_conf,
                         sim_scenario_conf),
                    FLAGS.n_lives)

    def preprocessor_builder():
        """
        Create the ODySSEUS input preprocessor.
        """
        return processor(
            max_abs_reward=FLAGS.max_abs_reward,
            zero_discount_on_life_loss=True
        )

    env = environment_builder()

    logging.info('Environment: %s', FLAGS.exp_name)
    logging.info('Action spec: %s', env.action_spec())
    logging.info('Observation spec: %s', env.observation_spec())

    # Take [0] as both Rainbow have
    # the same number of actions
    num_actions = env.action_spec()[0].num_values
    support = jnp.linspace(-FLAGS.vmax, FLAGS.vmax,
                           FLAGS.num_atoms)

    network = hk.transform(rainbow_odysseus_network(
                           num_actions, support,
                           FLAGS.noisy_weight_init))

    # Create sample network input from reset.
    sample_processed_timestep = preprocessor_builder()(env.reset())
    sample_processed_timestep = t.cast(dm_env.TimeStep,
                                       sample_processed_timestep)

    sample_processed_network_input = sample_processed_timestep.observation

    # Note the t in the replay is not exactly
    # aligned with the Rainbow agents t.
    importance_sampling_exponent_schedule = parts.LinearSchedule(
        begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity),
        end_t=(FLAGS.num_iterations * FLAGS.num_train_frames),
        begin_value=FLAGS.importance_sampling_exponent_begin_value,
        end_value=FLAGS.importance_sampling_exponent_end_value)

    if FLAGS.compress_state:
        def encoder(transition):
            return transition._replace(
                s_tm1=replay.compress_array(transition.s_tm1),
                s_t=replay.compress_array(transition.s_t))

        def decoder(transition):
            return transition._replace(
                s_tm1=replay.uncompress_array(transition.s_tm1),
                s_t=replay.uncompress_array(transition.s_t))
    else:
        encoder = None
        decoder = None

    replay_struct = replay.Transition(
        s_tm1=None,
        a_tm1=None,
        r_t=None,
        discount_t=None,
        s_t=None,
    )

    transition_accumulator = replay.NStepTransitionAccumulator(FLAGS.n_steps)

    transition_replay = replay.PrioritizedTransitionReplay(
        FLAGS.replay_capacity, replay_struct,
        FLAGS.priority_exponent,
        importance_sampling_exponent_schedule,
        FLAGS.uniform_sample_probability,
        FLAGS.normalize_weights,
        rng_state, encoder, decoder)

    optimizer = optax.adam(
        learning_rate=FLAGS.learning_rate,
        eps=FLAGS.optimizer_epsilon)

    if FLAGS.max_global_grad_norm > 0:
        optimizer = optax.chain(
            optax.clip_by_global_norm(
                FLAGS.max_global_grad_norm),
            optimizer)

    train_rng_key, eval_rng_key = jax.random.split(rng_key)

    # Create pick-up/drop-off agents
    P_train_agent = agent.Rainbow(
        preprocessor=preprocessor_builder(),
        sample_network_input=copy.deepcopy(sample_processed_network_input),
        network=copy.deepcopy(network),
        support=copy.deepcopy(support),
        optimizer=copy.deepcopy(optimizer),
        transition_accumulator=copy.deepcopy(transition_accumulator),
        replay=copy.deepcopy(transition_replay),
        batch_size=FLAGS.batch_size,
        min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
        learn_period=FLAGS.learn_period,
        target_network_update_period=FLAGS.target_network_update_period,
        rng_key=train_rng_key,
    )

    D_train_agent = agent.Rainbow(
        preprocessor=preprocessor_builder(),
        sample_network_input=copy.deepcopy(sample_processed_network_input),
        network=copy.deepcopy(network),
        support=copy.deepcopy(support),
        optimizer=copy.deepcopy(optimizer),
        transition_accumulator=copy.deepcopy(transition_accumulator),
        replay=copy.deepcopy(transition_replay),
        batch_size=FLAGS.batch_size,
        min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
        learn_period=FLAGS.learn_period,
        target_network_update_period=FLAGS.target_network_update_period,
        rng_key=train_rng_key,
    )

    P_eval_agent = parts.EpsilonGreedyActor(
        preprocessor=preprocessor_builder(),
        network=copy.deepcopy(network),
        exploration_epsilon=0,
        rng_key=eval_rng_key,
    )

    D_eval_agent = parts.EpsilonGreedyActor(
        preprocessor=preprocessor_builder(),
        network=copy.deepcopy(network),
        exploration_epsilon=0,
        rng_key=eval_rng_key,
    )

    if checkpoint_restored:
        P_train_agent.set_state(checkpoint.state.P_agent['train'])
        D_train_agent.set_state(checkpoint.state.D_agent['train'])

        P_eval_agent.set_state(checkpoint.state.P_agent['eval'])
        D_eval_agent.set_state(checkpoint.state.D_agent['eval'])

    state = checkpoint.state

    if not checkpoint_restored:
        state.iteration = 0

    state.P_agent = {}
    state.D_agent = {}

    state.rng_state = rng_state
    state.writer = writer

    state.P_agent['train'] = P_train_agent
    state.D_agent['train'] = D_train_agent

    state.P_agent['eval'] = P_eval_agent
    state.D_agent['eval'] = D_eval_agent

    while state.iteration < FLAGS.num_iterations:
        # Create a new environment at each new iteration
        # to allow for determinism if preempted.
        env = environment_builder()

        # Leave some spacing
        print('\n')

        logging.info('Training iteration: %d', state.iteration)

        train_trackers = make_odysseus_trackers(FLAGS.max_abs_reward)
        eval_trackers  = make_odysseus_trackers(FLAGS.max_abs_reward)

        train_seq = run_loop(P_train_agent, D_train_agent,
                             env, FLAGS.max_steps_per_episode)

        num_train_frames = 0        \
            if state.iteration == 0 \
            else FLAGS.num_train_frames

        train_seq_truncated = it.islice(train_seq, num_train_frames)

        train_stats = generate_statistics(train_trackers,
                                          train_seq_truncated)

        logging.info('Evaluation iteration: %d', state.iteration)

        # Synchronize network parameters
        P_eval_agent.network_params = P_train_agent.online_params
        D_eval_agent.network_params = P_train_agent.online_params

        eval_seq = run_loop(P_eval_agent, D_eval_agent,
                            env, FLAGS.max_steps_per_episode)

        eval_seq_truncated = it.islice(eval_seq, FLAGS.num_eval_frames)

        eval_stats = generate_statistics(eval_trackers,
                                         eval_seq_truncated)

        # Logging and checkpointing
        L = [
            # Simulation metadata
            ('iteration', state.iteration, '%3d'),

            # ODySSEUS metadata
            ('n_charging_workers', sim_scenario_conf['n_workers'], '%3d'),
            ('n_relocation_workers', sim_scenario_conf['n_relocation_workers'], '%3d'),
            ('n_vehicles', sim_scenario_conf['n_vehicles'], '%3d'),
            ('pct_incentive_willingness', sim_scenario_conf['incentive_willingness'], '%2.2f'),
            ('zone_side_m', sim_general_conf['bin_side_length'], '%3d'),

            # Validation agents
            ('eval_num_episodes', eval_stats['num_episodes'], '%3d'),

            ('eval_P_episode_return', eval_stats['episode_return'][0], '%2.2f'),
            ('eval_D_episode_return', eval_stats['episode_return'][1], '%2.2f'),

            ('eval_min_n_accepted_incentives',
             np.min(eval_stats['episodes_n_accepted_incentives']), '%2.2f'),
            ('eval_avg_n_accepted_incentives',
             np.mean(eval_stats['episodes_n_accepted_incentives']), '%2.2f'),
            ('eval_max_n_accepted_incentives',
             np.max(eval_stats['episodes_n_accepted_incentives']), '%2.2f'),

            ('eval_min_n_lives',
             np.min(eval_stats['episodes_n_lives']), '%2.2f'),
            ('eval_avg_n_lives',
             np.mean(eval_stats['episodes_n_lives']), '%2.2f'),
            ('eval_max_n_lives',
             np.max(eval_stats['episodes_n_lives']), '%2.2f'),

            ('eval_min_pct_satisfied_demand',
             np.min(eval_stats['pct_satisfied_demands']), '%2.2f'),
            ('eval_avg_pct_satisfied_demand',
             np.mean(eval_stats['pct_satisfied_demands']), '%2.2f'),
            ('eval_max_pct_satisfied_demand',
             np.max(eval_stats['pct_satisfied_demands']), '%2.2f'),

            # Training agents
            ('train_num_episodes', train_stats['num_episodes'], '%3d'),

            ('train_P_episode_return', train_stats['episode_return'][0], '%2.2f'),
            ('train_D_episode_return', train_stats['episode_return'][1], '%2.2f'),

            ('train_min_n_accepted_incentives',
             np.min(train_stats['episodes_n_accepted_incentives']), '%2.2f'),
            ('train_avg_n_accepted_incentives',
             np.mean(train_stats['episodes_n_accepted_incentives']), '%2.2f'),
            ('train_max_n_accepted_incentives',
             np.max(train_stats['episodes_n_accepted_incentives']), '%2.2f'),

            ('train_min_n_lives',
             np.min(train_stats['episodes_n_lives']), '%2.2f'),
            ('train_avg_n_lives',
             np.mean(train_stats['episodes_n_lives']), '%2.2f'),
            ('train_mac_n_lives',
             np.max(train_stats['episodes_n_lives']), '%2.2f'),

            ('train_min_pct_satisfied_demand',
             np.min(train_stats['pct_satisfied_demands']), '%2.2f'),
            ('train_avg_pct_satisfied_demand',
             np.mean(train_stats['pct_satisfied_demands']), '%2.2f'),
            ('train_max_pct_satisfied_demand',
             np.max(train_stats['pct_satisfied_demands']), '%2.2f'),

            ('P_importance_sampling_exponent',
             P_train_agent.importance_sampling_exponent, '%.3f'),
            ('D_importance_sampling_exponent',
             D_train_agent.importance_sampling_exponent, '%.3f'),

            ('P_max_seen_priority', P_train_agent.max_seen_priority, '%.3f'),
            ('D_max_seen_priority', D_train_agent.max_seen_priority, '%.3f'),
        ]

        L_str = '\n'.join(('%s: ' + f) % (n, v) for n, v, f in L)

        logging.info(L_str)

        if state.iteration == \
                FLAGS.num_iterations - 1:
            print('\n')

        writer.write(collections.OrderedDict(
            (n, v) for n, v, _ in L))

        state.iteration += 1

        if state.iteration \
                % FLAGS.checkpoint_period == 0:
            checkpoint.save()

    writer.close()
Exemple #14
0
def main(argv):
    """Trains DQN agent on Atari."""
    del argv
    logging.info('Boostrapped DQN on Atari on %s.',
                 jax.lib.xla_bridge.get_backend().platform)
    random_state = np.random.RandomState(FLAGS.seed)
    rng_key = jax.random.PRNGKey(
        random_state.randint(-sys.maxsize - 1, sys.maxsize + 1))

    if FLAGS.results_csv_path:
        writer = parts.CsvWriter(FLAGS.results_csv_path)
    else:
        writer = parts.NullWriter()

    def environment_builder():
        """Creates Atari environment."""
        env = gym_atari.GymAtari(FLAGS.environment_name,
                                 seed=random_state.randint(1, 2**32))
        return gym_atari.RandomNoopsEnvironmentWrapper(
            env,
            min_noop_steps=1,
            max_noop_steps=30,
            seed=random_state.randint(1, 2**32),
        )

    env = environment_builder()

    logging.info('Environment: %s', FLAGS.environment_name)
    logging.info('Action spec: %s', env.action_spec())
    logging.info('Observation spec: %s', env.observation_spec())
    num_actions = env.action_spec().num_values
    network_fn = networks.bootstrapped_dqn_multi_head_network(
        num_actions,
        num_heads=FLAGS.num_heads,
        mask_probability=FLAGS.mask_probability)
    network = hk.transform(network_fn)

    def preprocessor_builder():
        return processors.atari(
            additional_discount=FLAGS.additional_discount,
            max_abs_reward=FLAGS.max_abs_reward,
            resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
            num_action_repeats=FLAGS.num_action_repeats,
            num_pooled_frames=2,
            zero_discount_on_life_loss=True,
            num_stacked_frames=FLAGS.num_stacked_frames,
            grayscaling=True,
        )

    # Create sample network input from sample preprocessor output.
    sample_processed_timestep = preprocessor_builder()(env.reset())
    sample_processed_timestep = typing.cast(dm_env.TimeStep,
                                            sample_processed_timestep)
    sample_network_input = sample_processed_timestep.observation
    assert sample_network_input.shape == (FLAGS.environment_height,
                                          FLAGS.environment_width,
                                          FLAGS.num_stacked_frames)

    exploration_epsilon_schedule = parts.LinearSchedule(
        begin_t=int(FLAGS.min_replay_capacity_fraction *
                    FLAGS.replay_capacity * FLAGS.num_action_repeats),
        decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction *
                        FLAGS.num_iterations * FLAGS.num_train_frames),
        begin_value=FLAGS.exploration_epsilon_begin_value,
        end_value=FLAGS.exploration_epsilon_end_value)

    if FLAGS.compress_state:

        def encoder(transition):
            return transition._replace(
                s_tm1=replay_lib.compress_array(transition.s_tm1),
                s_t=replay_lib.compress_array(transition.s_t))

        def decoder(transition):
            return transition._replace(
                s_tm1=replay_lib.uncompress_array(transition.s_tm1),
                s_t=replay_lib.uncompress_array(transition.s_t))
    else:
        encoder = None
        decoder = None

    replay_structure = replay_lib.MaskedTransition(
        s_tm1=None,
        a_tm1=None,
        r_t=None,
        discount_t=None,
        s_t=None,
        mask_t=None,
    )

    replay = replay_lib.TransitionReplay(FLAGS.replay_capacity,
                                         replay_structure, random_state,
                                         encoder, decoder)

    optimizer = optax.rmsprop(
        learning_rate=FLAGS.learning_rate,
        decay=0.95,
        eps=FLAGS.optimizer_epsilon,
        centered=True,
    )

    if FLAGS.shaping_function_type == constants.NO_PENALTY:
        shaping_function = shaping.NoPenalty()
    elif FLAGS.shaping_function_type == constants.HARD_CODED_PENALTY:
        shaping_function = shaping.HardCodedPenalty(
            penalty=FLAGS.shaping_multiplicative_factor)
    elif FLAGS.shaping_function_type == constants.UNCERTAINTY_PENALTY:
        shaping_function = shaping.UncertaintyPenalty(
            multiplicative_factor=FLAGS.shaping_multiplicative_factor)
    elif FLAGS.shaping_function_type == constants.POLICY_ENTROPY_PENALTY:
        shaping_function = shaping.PolicyEntropyPenalty(
            multiplicative_factor=FLAGS.shaping_multiplicative_factor,
            num_actions=num_actions)
    elif FLAGS.shaping_function_type == constants.MUNCHAUSEN_PENALTY:
        shaping_function = shaping.MunchausenPenalty(
            multiplicative_factor=FLAGS.shaping_multiplicative_factor,
            num_actions=num_actions)

    train_rng_key, eval_rng_key = jax.random.split(rng_key)

    train_agent = agent.BootstrappedDqn(
        preprocessor=preprocessor_builder(),
        sample_network_input=sample_network_input,
        network=network,
        optimizer=optimizer,
        transition_accumulator=replay_lib.TransitionAccumulator(),
        replay=replay,
        shaping_function=shaping_function,
        mask_probability=FLAGS.mask_probability,
        num_heads=FLAGS.num_heads,
        batch_size=FLAGS.batch_size,
        exploration_epsilon=exploration_epsilon_schedule,
        min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
        learn_period=FLAGS.learn_period,
        target_network_update_period=FLAGS.target_network_update_period,
        grad_error_bound=FLAGS.grad_error_bound,
        rng_key=train_rng_key,
    )
    eval_agent = parts.EpsilonGreedyActor(
        preprocessor=preprocessor_builder(),
        network=network,
        exploration_epsilon=FLAGS.eval_exploration_epsilon,
        rng_key=eval_rng_key,
    )

    # Set up checkpointing.
    # checkpoint = parts.NullCheckpoint()
    checkpoint = parts.ImplementedCheckpoint(
        checkpoint_path=FLAGS.checkpoint_path)

    if checkpoint.can_be_restored():
        checkpoint.restore()
        iteration = checkpoint.state.iteration
        random_state = checkpoint.state.random_state
        train_agent.set_state(state=checkpoint.state.train_agent)
        eval_agent.set_state(state=checkpoint.state.eval_agent)
        writer.set_state(state=checkpoint.state.writer)
    else:
        iteration = 0

    while iteration <= FLAGS.num_iterations:
        # New environment for each iteration to allow for determinism if preempted.
        env = environment_builder()

        logging.info('Training iteration %d.', iteration)
        train_seq = parts.run_loop(train_agent, env,
                                   FLAGS.max_frames_per_episode)
        num_train_frames = 0 if iteration == 0 else FLAGS.num_train_frames
        train_seq_truncated = itertools.islice(train_seq, num_train_frames)
        train_stats = parts.generate_statistics(train_seq_truncated)

        logging.info('Evaluation iteration %d.', iteration)
        eval_agent.network_params = train_agent.online_params
        eval_seq = parts.run_loop(eval_agent, env,
                                  FLAGS.max_frames_per_episode)
        eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames)
        eval_stats = parts.generate_statistics(eval_seq_truncated)

        # Logging and checkpointing.
        human_normalized_score = atari_data.get_human_normalized_score(
            FLAGS.environment_name, eval_stats['episode_return'])
        capped_human_normalized_score = np.amin([1., human_normalized_score])
        log_output = [
            ('iteration', iteration, '%3d'),
            ('frame', iteration * FLAGS.num_train_frames, '%5d'),
            ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'),
            ('train_episode_return', train_stats['episode_return'], '% 2.2f'),
            ('eval_num_episodes', eval_stats['num_episodes'], '%3d'),
            ('train_num_episodes', train_stats['num_episodes'], '%3d'),
            ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'),
            ('train_frame_rate', train_stats['step_rate'], '%4.0f'),
            ('train_exploration_epsilon', train_agent.exploration_epsilon,
             '%.3f'), ('normalized_return', human_normalized_score, '%.3f'),
            ('capped_normalized_return',
             capped_human_normalized_score, '%.3f'),
            ('human_gap', 1. - capped_human_normalized_score, '%.3f'),
            ('train_loss', train_stats['train_loss'], '% 2.2f'),
            ('shaped_reward', train_stats['shaped_reward'], '% 2.2f'),
            ('penalties', train_stats['penalties'], '% 2.2f')
        ]
        log_output_str = ', '.join(
            ('%s: ' + f) % (n, v) for n, v, f in log_output)
        logging.info(log_output_str)
        writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))

        iteration += 1

        # update state before checkpointing
        checkpoint.state.iteration = iteration
        checkpoint.state.train_agent = train_agent.get_state()
        checkpoint.state.eval_agent = eval_agent.get_state()
        checkpoint.state.random_state = random_state
        checkpoint.state.writer = writer.get_state()
        checkpoint.save()

    writer.close()