Beispiel #1
0
    def step(self, timestep: dm_env.TimeStep) -> parts.Action:
        """Selects action given timestep and potentially learns."""
        self._frame_t += 1

        timestep = self._preprocessor(timestep)

        if timestep is None:  # Repeat action.
            action = self._action
        else:
            action = self._action = self._act(timestep)

            for transition in self._transition_accumulator.step(
                    timestep, action):
                mask = self._get_random_mask(self._rng_key)
                masked_transition = replay_lib.MaskedTransition(
                    s_tm1=transition.s_tm1,
                    a_tm1=transition.a_tm1,
                    r_t=transition.r_t,
                    discount_t=transition.discount_t,
                    s_t=transition.s_t,
                    mask_t=mask)
                self._replay.add(masked_transition)

        if self._replay.size < self._min_replay_capacity:
            return action

        if self._frame_t % self._learn_period == 0:
            self._learn()

        if self._frame_t % self._target_network_update_period == 0:
            self._target_params = self._online_params

        return action
Beispiel #2
0
def main(argv):
    """Trains DQN agent on Atari."""
    del argv
    logging.info("Boostrapped DQN on Key-Door on %s.",
                 jax.lib.xla_bridge.get_backend().platform)
    random_state = np.random.RandomState(FLAGS.seed)
    rng_key = jax.random.PRNGKey(
        random_state.randint(-sys.maxsize - 1, sys.maxsize + 1))

    if FLAGS.results_csv_path:
        writer = parts.CsvWriter(FLAGS.results_csv_path)
    else:
        writer = parts.NullWriter()

    def environment_builder():
        """Creates Key-Door environment."""
        env = gym_key_door.GymKeyDoor(
            env_args={
                constants.MAP_ASCII_PATH: FLAGS.map_ascii_path,
                constants.MAP_YAML_PATH: FLAGS.map_yaml_path,
                constants.REPRESENTATION: constants.PIXEL,
                constants.SCALING: FLAGS.env_scaling,
                constants.EPISODE_TIMEOUT: FLAGS.max_frames_per_episode,
                constants.GRAYSCALE: False,
                constants.BATCH_DIMENSION: False,
                constants.TORCH_AXES: False,
            },
            env_shape=FLAGS.env_shape,
        )
        return gym_atari.RandomNoopsEnvironmentWrapper(
            env,
            min_noop_steps=1,
            max_noop_steps=30,
            seed=random_state.randint(1, 2**32),
        )

    env = environment_builder()

    logging.info("Environment: %s", FLAGS.environment_name)
    logging.info("Action spec: %s", env.action_spec())
    logging.info("Observation spec: %s", env.observation_spec())
    num_actions = env.action_spec().num_values
    network_fn = networks.bootstrapped_dqn_multi_head_network(
        num_actions,
        num_heads=FLAGS.num_heads,
        mask_probability=FLAGS.mask_probability)
    network = hk.transform(network_fn)

    def preprocessor_builder():
        return processors.atari(
            additional_discount=FLAGS.additional_discount,
            max_abs_reward=FLAGS.max_abs_reward,
            resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
            num_action_repeats=FLAGS.num_action_repeats,
            num_pooled_frames=2,
            zero_discount_on_life_loss=True,
            num_stacked_frames=FLAGS.num_stacked_frames,
            grayscaling=True,
        )

    # Create sample network input from sample preprocessor output.
    sample_processed_timestep = preprocessor_builder()(env.reset())
    sample_processed_timestep = typing.cast(dm_env.TimeStep,
                                            sample_processed_timestep)
    sample_network_input = sample_processed_timestep.observation
    assert sample_network_input.shape == (
        FLAGS.environment_height,
        FLAGS.environment_width,
        FLAGS.num_stacked_frames,
    )

    exploration_epsilon_schedule = parts.LinearSchedule(
        begin_t=int(FLAGS.min_replay_capacity_fraction *
                    FLAGS.replay_capacity * FLAGS.num_action_repeats),
        decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction *
                        FLAGS.num_iterations * FLAGS.num_train_frames),
        begin_value=FLAGS.exploration_epsilon_begin_value,
        end_value=FLAGS.exploration_epsilon_end_value,
    )

    if FLAGS.compress_state:

        def encoder(transition):
            return transition._replace(
                s_tm1=replay_lib.compress_array(transition.s_tm1),
                s_t=replay_lib.compress_array(transition.s_t),
            )

        def decoder(transition):
            return transition._replace(
                s_tm1=replay_lib.uncompress_array(transition.s_tm1),
                s_t=replay_lib.uncompress_array(transition.s_t),
            )

    else:
        encoder = None
        decoder = None

    replay_structure = replay_lib.MaskedTransition(
        s_tm1=None,
        a_tm1=None,
        r_t=None,
        discount_t=None,
        s_t=None,
        mask_t=None,
    )

    replay = replay_lib.TransitionReplay(FLAGS.replay_capacity,
                                         replay_structure, random_state,
                                         encoder, decoder)

    optimizer = optax.rmsprop(
        learning_rate=FLAGS.learning_rate,
        decay=0.95,
        eps=FLAGS.optimizer_epsilon,
        centered=True,
    )

    if FLAGS.shaping_function_type == constants.NO_PENALTY:
        shaping_function = shaping.NoPenalty()
    elif FLAGS.shaping_function_type == constants.HARD_CODED_PENALTY:
        shaping_function = shaping.HardCodedPenalty(
            penalty=FLAGS.shaping_multiplicative_factor)
    elif FLAGS.shaping_function_type == constants.UNCERTAINTY_PENALTY:
        shaping_function = shaping.UncertaintyPenalty(
            multiplicative_factor=FLAGS.shaping_multiplicative_factor)
    elif FLAGS.shaping_function_type == constants.POLICY_ENTROPY_PENALTY:
        shaping_function = shaping.PolicyEntropyPenalty(
            multiplicative_factor=FLAGS.shaping_multiplicative_factor,
            num_actions=num_actions,
        )
    elif FLAGS.shaping_function_type == constants.MUNCHAUSEN_PENALTY:
        shaping_function = shaping.MunchausenPenalty(
            multiplicative_factor=FLAGS.shaping_multiplicative_factor,
            num_actions=num_actions,
        )

    train_rng_key, eval_rng_key = jax.random.split(rng_key)

    train_agent = agent.BootstrappedDqn(
        preprocessor=preprocessor_builder(),
        sample_network_input=sample_network_input,
        network=network,
        optimizer=optimizer,
        transition_accumulator=replay_lib.TransitionAccumulator(),
        replay=replay,
        shaping_function=shaping_function,
        mask_probability=FLAGS.mask_probability,
        num_heads=FLAGS.num_heads,
        batch_size=FLAGS.batch_size,
        exploration_epsilon=exploration_epsilon_schedule,
        min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
        learn_period=FLAGS.learn_period,
        target_network_update_period=FLAGS.target_network_update_period,
        grad_error_bound=FLAGS.grad_error_bound,
        rng_key=train_rng_key,
    )
    eval_agent = parts.EpsilonGreedyActor(
        preprocessor=preprocessor_builder(),
        network=network,
        exploration_epsilon=FLAGS.eval_exploration_epsilon,
        rng_key=eval_rng_key,
    )

    # Set up checkpointing.
    # checkpoint = parts.NullCheckpoint()
    checkpoint = parts.ImplementedCheckpoint(
        checkpoint_path=FLAGS.checkpoint_path)

    if checkpoint.can_be_restored():
        checkpoint.restore()
        iteration = checkpoint.state.iteration
        random_state = checkpoint.state.random_state
        train_agent.set_state(state=checkpoint.state.train_agent)
        eval_agent.set_state(state=checkpoint.state.eval_agent)
        writer.set_state(state=checkpoint.state.writer)
    else:
        iteration = 0

    while iteration <= FLAGS.num_iterations:
        # New environment for each iteration to allow for determinism if preempted.
        env = environment_builder()

        logging.info("Training iteration %d.", iteration)
        train_seq = parts.run_loop(train_agent, env,
                                   FLAGS.max_frames_per_episode)
        num_train_frames = 0 if iteration == 0 else FLAGS.num_train_frames
        train_seq_truncated = itertools.islice(train_seq, num_train_frames)
        train_stats = parts.generate_statistics(train_seq_truncated)

        logging.info("Evaluation iteration %d.", iteration)
        eval_agent.network_params = train_agent.online_params
        eval_seq = parts.run_loop(eval_agent, env,
                                  FLAGS.max_frames_per_episode)
        eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames)
        eval_stats = parts.generate_statistics(eval_seq_truncated)

        # Logging and checkpointing.
        human_normalized_score = atari_data.get_human_normalized_score(
            FLAGS.environment_name, eval_stats["episode_return"])
        capped_human_normalized_score = np.amin([1.0, human_normalized_score])
        log_output = [
            ("iteration", iteration, "%3d"),
            ("frame", iteration * FLAGS.num_train_frames, "%5d"),
            ("eval_episode_return", eval_stats["episode_return"], "% 2.2f"),
            ("train_episode_return", train_stats["episode_return"], "% 2.2f"),
            ("eval_num_episodes", eval_stats["num_episodes"], "%3d"),
            ("train_num_episodes", train_stats["num_episodes"], "%3d"),
            ("eval_frame_rate", eval_stats["step_rate"], "%4.0f"),
            ("train_frame_rate", train_stats["step_rate"], "%4.0f"),
            ("train_exploration_epsilon", train_agent.exploration_epsilon,
             "%.3f"),
            ("normalized_return", human_normalized_score, "%.3f"),
            ("capped_normalized_return", capped_human_normalized_score,
             "%.3f"),
            ("human_gap", 1.0 - capped_human_normalized_score, "%.3f"),
            ("train_loss", train_stats["train_loss"], "% 2.2f"),
            ("shaped_reward", train_stats["shaped_reward"], "% 2.2f"),
            ("penalties", train_stats["penalties"], "% 2.2f"),
        ]
        log_output_str = ", ".join(
            ("%s: " + f) % (n, v) for n, v, f in log_output)
        logging.info(log_output_str)
        writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))

        iteration += 1

        # update state before checkpointing
        checkpoint.state.iteration = iteration
        checkpoint.state.train_agent = train_agent.get_state()
        checkpoint.state.eval_agent = eval_agent.get_state()
        checkpoint.state.random_state = random_state
        checkpoint.state.writer = writer.get_state()
        checkpoint.save()

    writer.close()
Beispiel #3
0
def main(argv):
  """Trains DQN agent on Atari."""
  del argv
  logging.info('Boostrapped DQN on Atari on %s.', jax.lib.xla_bridge.get_backend().platform)
  random_state = np.random.RandomState(FLAGS.seed)
  rng_key = jax.random.PRNGKey(
      random_state.randint(-sys.maxsize - 1, sys.maxsize + 1))

  if FLAGS.results_csv_path:
    writer = parts.CsvWriter(FLAGS.results_csv_path)
  else:
    writer = parts.NullWriter()

  def environment_builder():
    """Creates Atari environment."""
    env = gym_atari.GymAtari(
        FLAGS.environment_name, seed=random_state.randint(1, 2**32))
    return gym_atari.RandomNoopsEnvironmentWrapper(
        env,
        min_noop_steps=1,
        max_noop_steps=30,
        seed=random_state.randint(1, 2**32),
    )

  env = environment_builder()

  logging.info('Environment: %s', FLAGS.environment_name)
  logging.info('Action spec: %s', env.action_spec())
  logging.info('Observation spec: %s', env.observation_spec())
  num_actions = env.action_spec().num_values
  network_fn = networks.bootstrapped_dqn_multi_head_network(
    num_actions, 
    num_heads=FLAGS.num_heads, 
    mask_probability=FLAGS.mask_probability)
  network = hk.transform(network_fn)

  def preprocessor_builder():
    return processors.atari(
        additional_discount=FLAGS.additional_discount,
        max_abs_reward=FLAGS.max_abs_reward,
        resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
        num_action_repeats=FLAGS.num_action_repeats,
        num_pooled_frames=2,
        zero_discount_on_life_loss=True,
        num_stacked_frames=FLAGS.num_stacked_frames,
        grayscaling=True,
    )

  # Create sample network input from sample preprocessor output.
  sample_processed_timestep = preprocessor_builder()(env.reset())
  sample_processed_timestep = typing.cast(dm_env.TimeStep,
                                          sample_processed_timestep)
  sample_network_input = sample_processed_timestep.observation
  assert sample_network_input.shape == (FLAGS.environment_height,
                                        FLAGS.environment_width,
                                        FLAGS.num_stacked_frames)

  exploration_epsilon_schedule = parts.LinearSchedule(
      begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity *
                  FLAGS.num_action_repeats),
      decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction *
                      FLAGS.num_iterations * FLAGS.num_train_frames),
      begin_value=FLAGS.exploration_epsilon_begin_value,
      end_value=FLAGS.exploration_epsilon_end_value)

  if FLAGS.compress_state:

    def encoder(transition):
      return transition._replace(
          s_tm1=replay_lib.compress_array(transition.s_tm1),
          s_t=replay_lib.compress_array(transition.s_t))

    def decoder(transition):
      return transition._replace(
          s_tm1=replay_lib.uncompress_array(transition.s_tm1),
          s_t=replay_lib.uncompress_array(transition.s_t))
  else:
    encoder = None
    decoder = None

  replay_structure = replay_lib.MaskedTransition(
      s_tm1=None,
      a_tm1=None,
      r_t=None,
      discount_t=None,
      s_t=None,
      mask_t=None,
  )

  replay = replay_lib.TransitionReplay(FLAGS.replay_capacity, replay_structure,
                                       random_state, encoder, decoder)

  optimizer = optax.rmsprop(
      learning_rate=FLAGS.learning_rate,
      decay=0.95,
      eps=FLAGS.optimizer_epsilon,
      centered=True,
  )

  if FLAGS.shaping_function_type == constants.NO_PENALTY:
    shaping_function = shaping.NoPenalty()
  elif FLAGS.shaping_function_type == constants.HARD_CODED_PENALTY:
    shaping_function = shaping.HardCodedPenalty(penalty=FLAGS.shaping_multiplicative_factor)
  elif FLAGS.shaping_function_type == constants.UNCERTAINTY_PENALTY:
    shaping_function = shaping.UncertaintyPenalty(multiplicative_factor=FLAGS.shaping_multiplicative_factor)
  elif FLAGS.shaping_function_type == constants.POLICY_ENTROPY_PENALTY:
    shaping_function = shaping.PolicyEntropyPenalty(multiplicative_factor=FLAGS.shaping_multiplicative_factor, num_actions=num_actions)

  train_rng_key, eval_rng_key = jax.random.split(rng_key)

  train_agent = agent.BootstrappedDqn(
      preprocessor=preprocessor_builder(),
      sample_network_input=sample_network_input,
      network=network,
      optimizer=optimizer,
      transition_accumulator=replay_lib.TransitionAccumulator(),
      replay=replay,
      shaping_function=shaping_function,
      mask_probability=FLAGS.mask_probability,
      num_heads=FLAGS.num_heads,
      batch_size=FLAGS.batch_size,
      exploration_epsilon=exploration_epsilon_schedule,
      min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
      learn_period=FLAGS.learn_period,
      target_network_update_period=FLAGS.target_network_update_period,
      grad_error_bound=FLAGS.grad_error_bound,
      rng_key=train_rng_key,
  )
  eval_agent = parts.EpsilonGreedyActor(
      preprocessor=preprocessor_builder(),
      network=network,
      exploration_epsilon=FLAGS.eval_exploration_epsilon,
      rng_key=eval_rng_key,
  )

  # Set up checkpointing.
  # checkpoint = parts.NullCheckpoint()
  checkpoint = parts.ImplementedCheckpoint(checkpoint_path=FLAGS.checkpoint_path)

  if checkpoint.can_be_restored():
    checkpoint.restore()
    train_agent.set_state(state=checkpoint.state.train_agent)
    eval_agent.set_state(state=checkpoint.state.eval_agent)
    writer.set_state(state=checkpoint.state.writer)

  state = checkpoint.state
  state.iteration = 0
  state.train_agent = train_agent.get_state()
  state.eval_agent = eval_agent.get_state()
  state.random_state = random_state
  state.writer = writer.get_state()

  while state.iteration <= FLAGS.num_iterations:
    # New environment for each iteration to allow for determinism if preempted.
    env = environment_builder()

    logging.info('Training iteration %d.', state.iteration)
    train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode)
    num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames
    train_seq_truncated = itertools.islice(train_seq, num_train_frames)
    train_stats = parts.generate_statistics(train_seq_truncated)

    logging.info('Evaluation iteration %d.', state.iteration)
    eval_agent.network_params = train_agent.online_params
    eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode)
    eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames)
    eval_stats = parts.generate_statistics(eval_seq_truncated)

    # Logging and checkpointing.
    human_normalized_score = atari_data.get_human_normalized_score(
        FLAGS.environment_name, eval_stats['episode_return'])
    capped_human_normalized_score = np.amin([1., human_normalized_score])
    log_output = [
        ('iteration', state.iteration, '%3d'),
        ('frame', state.iteration * FLAGS.num_train_frames, '%5d'),
        ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'),
        ('train_episode_return', train_stats['episode_return'], '% 2.2f'),
        ('eval_num_episodes', eval_stats['num_episodes'], '%3d'),
        ('train_num_episodes', train_stats['num_episodes'], '%3d'),
        ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'),
        ('train_frame_rate', train_stats['step_rate'], '%4.0f'),
        ('train_exploration_epsilon', train_agent.exploration_epsilon, '%.3f'),
        ('normalized_return', human_normalized_score, '%.3f'),
        ('capped_normalized_return', capped_human_normalized_score, '%.3f'),
        ('human_gap', 1. - capped_human_normalized_score, '%.3f'),
    ]
    log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output)
    logging.info(log_output_str)
    writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))
    state.iteration += 1
    checkpoint.save()

  writer.close()