Exemplo n.º 1
0
def main(argv):
  """Trains Prioritized DQN agent on Atari."""
  del argv
  logging.info('Prioritized DQN on Atari on %s.',
               jax.lib.xla_bridge.get_backend().platform)
  random_state = np.random.RandomState(FLAGS.seed)
  rng_key = jax.random.PRNGKey(
      random_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64))

  if FLAGS.results_csv_path:
    writer = parts.CsvWriter(FLAGS.results_csv_path)
  else:
    writer = parts.NullWriter()

  def environment_builder():
    """Creates Atari environment."""
    env = gym_atari.GymAtari(
        FLAGS.environment_name, seed=random_state.randint(1, 2**32))
    return gym_atari.RandomNoopsEnvironmentWrapper(
        env,
        min_noop_steps=1,
        max_noop_steps=30,
        seed=random_state.randint(1, 2**32),
    )

  env = environment_builder()

  logging.info('Environment: %s', FLAGS.environment_name)
  logging.info('Action spec: %s', env.action_spec())
  logging.info('Observation spec: %s', env.observation_spec())
  num_actions = env.action_spec().num_values
  network_fn = networks.double_dqn_atari_network(num_actions)
  network = hk.transform(network_fn)

  def preprocessor_builder():
    return processors.atari(
        additional_discount=FLAGS.additional_discount,
        max_abs_reward=FLAGS.max_abs_reward,
        resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
        num_action_repeats=FLAGS.num_action_repeats,
        num_pooled_frames=2,
        zero_discount_on_life_loss=True,
        num_stacked_frames=FLAGS.num_stacked_frames,
        grayscaling=True,
    )

  # Create sample network input from sample preprocessor output.
  sample_processed_timestep = preprocessor_builder()(env.reset())
  sample_processed_timestep = typing.cast(dm_env.TimeStep,
                                          sample_processed_timestep)
  sample_network_input = sample_processed_timestep.observation
  chex.assert_shape(sample_network_input,
                    (FLAGS.environment_height, FLAGS.environment_width,
                     FLAGS.num_stacked_frames))

  exploration_epsilon_schedule = parts.LinearSchedule(
      begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity *
                  FLAGS.num_action_repeats),
      decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction *
                      FLAGS.num_iterations * FLAGS.num_train_frames),
      begin_value=FLAGS.exploration_epsilon_begin_value,
      end_value=FLAGS.exploration_epsilon_end_value)

  # Note the t in the replay is not exactly aligned with the agent t.
  importance_sampling_exponent_schedule = parts.LinearSchedule(
      begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity),
      end_t=(FLAGS.num_iterations *
             int(FLAGS.num_train_frames / FLAGS.num_action_repeats)),
      begin_value=FLAGS.importance_sampling_exponent_begin_value,
      end_value=FLAGS.importance_sampling_exponent_end_value)

  if FLAGS.compress_state:

    def encoder(transition):
      return transition._replace(
          s_tm1=replay_lib.compress_array(transition.s_tm1),
          s_t=replay_lib.compress_array(transition.s_t))

    def decoder(transition):
      return transition._replace(
          s_tm1=replay_lib.uncompress_array(transition.s_tm1),
          s_t=replay_lib.uncompress_array(transition.s_t))
  else:
    encoder = None
    decoder = None

  replay_structure = replay_lib.Transition(
      s_tm1=None,
      a_tm1=None,
      r_t=None,
      discount_t=None,
      s_t=None,
  )

  replay = replay_lib.PrioritizedTransitionReplay(
      FLAGS.replay_capacity, replay_structure, FLAGS.priority_exponent,
      importance_sampling_exponent_schedule, FLAGS.uniform_sample_probability,
      FLAGS.normalize_weights, random_state, encoder, decoder)

  optimizer = optax.rmsprop(
      learning_rate=FLAGS.learning_rate,
      decay=0.95,
      eps=FLAGS.optimizer_epsilon,
      centered=True,
  )

  train_rng_key, eval_rng_key = jax.random.split(rng_key)

  train_agent = agent.PrioritizedDqn(
      preprocessor=preprocessor_builder(),
      sample_network_input=sample_network_input,
      network=network,
      optimizer=optimizer,
      transition_accumulator=replay_lib.TransitionAccumulator(),
      replay=replay,
      batch_size=FLAGS.batch_size,
      exploration_epsilon=exploration_epsilon_schedule,
      min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
      learn_period=FLAGS.learn_period,
      target_network_update_period=FLAGS.target_network_update_period,
      grad_error_bound=FLAGS.grad_error_bound,
      rng_key=train_rng_key,
  )
  eval_agent = parts.EpsilonGreedyActor(
      preprocessor=preprocessor_builder(),
      network=network,
      exploration_epsilon=FLAGS.eval_exploration_epsilon,
      rng_key=eval_rng_key,
  )

  # Set up checkpointing.
  checkpoint = parts.NullCheckpoint()

  state = checkpoint.state
  state.iteration = 0
  state.train_agent = train_agent
  state.eval_agent = eval_agent
  state.random_state = random_state
  state.writer = writer
  if checkpoint.can_be_restored():
    checkpoint.restore()

  while state.iteration <= FLAGS.num_iterations:
    # New environment for each iteration to allow for determinism if preempted.
    env = environment_builder()

    logging.info('Training iteration %d.', state.iteration)
    train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode)
    num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames
    train_seq_truncated = itertools.islice(train_seq, num_train_frames)
    train_trackers = parts.make_default_trackers(train_agent)
    train_stats = parts.generate_statistics(train_trackers, train_seq_truncated)

    logging.info('Evaluation iteration %d.', state.iteration)
    eval_agent.network_params = train_agent.online_params
    eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode)
    eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames)
    eval_trackers = parts.make_default_trackers(eval_agent)
    eval_stats = parts.generate_statistics(eval_trackers, eval_seq_truncated)

    # Logging and checkpointing.
    human_normalized_score = atari_data.get_human_normalized_score(
        FLAGS.environment_name, eval_stats['episode_return'])
    capped_human_normalized_score = np.amin([1., human_normalized_score])
    log_output = [
        ('iteration', state.iteration, '%3d'),
        ('frame', state.iteration * FLAGS.num_train_frames, '%5d'),
        ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'),
        ('train_episode_return', train_stats['episode_return'], '% 2.2f'),
        ('eval_num_episodes', eval_stats['num_episodes'], '%3d'),
        ('train_num_episodes', train_stats['num_episodes'], '%3d'),
        ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'),
        ('train_frame_rate', train_stats['step_rate'], '%4.0f'),
        ('train_exploration_epsilon', train_agent.exploration_epsilon, '%.3f'),
        ('train_state_value', train_stats['state_value'], '%.3f'),
        ('importance_sampling_exponent',
         train_agent.importance_sampling_exponent, '%.3f'),
        ('max_seen_priority', train_agent.max_seen_priority, '%.3f'),
        ('normalized_return', human_normalized_score, '%.3f'),
        ('capped_normalized_return', capped_human_normalized_score, '%.3f'),
        ('human_gap', 1. - capped_human_normalized_score, '%.3f'),
    ]
    log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output)
    logging.info(log_output_str)
    writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))
    state.iteration += 1
    checkpoint.save()

  writer.close()
Exemplo n.º 2
0
    def test_basic(self):
        """Tests sequence of agent and environment interactions in typical usage."""
        tape = []
        agent = test_utils.DummyAgent(tape)
        environment = test_utils.DummyEnvironment(tape, episode_length=4)

        episode_index = 0
        t = 0  # steps = t + 1
        max_steps = 14
        loop_outputs = parts.run_loop(agent,
                                      environment,
                                      max_steps_per_episode=100,
                                      yield_before_reset=True)

        for timestep_t, unused_a_t in loop_outputs:
            tape.append((episode_index, t, timestep_t is None))

            if timestep_t is None:
                tape.append('Episode begin')
                continue

            if timestep_t.last():
                tape.append('Episode end')
                episode_index += 1

            if t + 1 >= max_steps:
                tape.append('Maximum number of steps reached')
                break

            t += 1

        expected_tape = [
            (0, 0, True),
            'Episode begin',
            'Agent reset',
            'Environment reset',
            'Agent step',
            (0, 0, False),
            'Environment step (0)',
            'Agent step',
            (0, 1, False),
            'Environment step (0)',
            'Agent step',
            (0, 2, False),
            'Environment step (0)',
            'Agent step',
            (0, 3, False),
            'Environment step (0)',
            'Agent step',
            (0, 4, False),
            'Episode end',
            (1, 5, True),
            'Episode begin',
            'Agent reset',
            'Environment reset',
            'Agent step',
            (1, 5, False),
            'Environment step (0)',
            'Agent step',
            (1, 6, False),
            'Environment step (0)',
            'Agent step',
            (1, 7, False),
            'Environment step (0)',
            'Agent step',
            (1, 8, False),
            'Environment step (0)',
            'Agent step',
            (1, 9, False),
            'Episode end',
            (2, 10, True),
            'Episode begin',
            'Agent reset',
            'Environment reset',
            'Agent step',
            (2, 10, False),
            'Environment step (0)',
            'Agent step',
            (2, 11, False),
            'Environment step (0)',
            'Agent step',
            (2, 12, False),
            'Environment step (0)',
            'Agent step',
            (2, 13, False),
            'Maximum number of steps reached',
        ]
        self.assertEqual(expected_tape, tape)
Exemplo n.º 3
0
def main(argv):
    """Trains DQN agent on Atari."""
    del argv
    logging.info("DQN on Atari on %s.",
                 jax.lib.xla_bridge.get_backend().platform)
    random_state = np.random.RandomState(FLAGS.seed)
    rng_key = jax.random.PRNGKey(
        random_state.randint(-sys.maxsize - 1, sys.maxsize + 1))

    if FLAGS.results_csv_path:
        writer = parts.CsvWriter(FLAGS.results_csv_path)
    else:
        writer = parts.NullWriter()

    def environment_builder():
        """Creates Key-Door environment."""
        env = gym_key_door.GymKeyDoor(
            env_args={
                constants.MAP_ASCII_PATH: FLAGS.map_ascii_path,
                constants.MAP_YAML_PATH: FLAGS.map_yaml_path,
                constants.REPRESENTATION: constants.PIXEL,
                constants.SCALING: FLAGS.env_scaling,
                constants.EPISODE_TIMEOUT: FLAGS.max_frames_per_episode,
                constants.GRAYSCALE: False,
                constants.BATCH_DIMENSION: False,
                constants.TORCH_AXES: False,
            },
            env_shape=FLAGS.env_shape,
        )
        return gym_atari.RandomNoopsEnvironmentWrapper(
            env,
            min_noop_steps=1,
            max_noop_steps=30,
            seed=random_state.randint(1, 2**32),
        )

    env = environment_builder()

    logging.info("Environment: %s", FLAGS.environment_name)
    logging.info("Action spec: %s", env.action_spec())
    logging.info("Observation spec: %s", env.observation_spec())
    num_actions = env.action_spec().num_values
    network_fn = networks.dqn_atari_network(num_actions)
    network = hk.transform(network_fn)

    def preprocessor_builder():
        return processors.atari(
            additional_discount=FLAGS.additional_discount,
            max_abs_reward=FLAGS.max_abs_reward,
            resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
            num_action_repeats=FLAGS.num_action_repeats,
            num_pooled_frames=2,
            zero_discount_on_life_loss=True,
            num_stacked_frames=FLAGS.num_stacked_frames,
            grayscaling=True,
        )

    # Create sample network input from sample preprocessor output.
    sample_processed_timestep = preprocessor_builder()(env.reset())
    sample_processed_timestep = typing.cast(dm_env.TimeStep,
                                            sample_processed_timestep)
    sample_network_input = sample_processed_timestep.observation
    assert sample_network_input.shape == (
        FLAGS.environment_height,
        FLAGS.environment_width,
        FLAGS.num_stacked_frames,
    )

    exploration_epsilon_schedule = parts.LinearSchedule(
        begin_t=int(FLAGS.min_replay_capacity_fraction *
                    FLAGS.replay_capacity * FLAGS.num_action_repeats),
        decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction *
                        FLAGS.num_iterations * FLAGS.num_train_frames),
        begin_value=FLAGS.exploration_epsilon_begin_value,
        end_value=FLAGS.exploration_epsilon_end_value,
    )

    if FLAGS.compress_state:

        def encoder(transition):
            return transition._replace(
                s_tm1=replay_lib.compress_array(transition.s_tm1),
                s_t=replay_lib.compress_array(transition.s_t),
            )

        def decoder(transition):
            return transition._replace(
                s_tm1=replay_lib.uncompress_array(transition.s_tm1),
                s_t=replay_lib.uncompress_array(transition.s_t),
            )

    else:
        encoder = None
        decoder = None

    replay_structure = replay_lib.Transition(
        s_tm1=None,
        a_tm1=None,
        r_t=None,
        discount_t=None,
        s_t=None,
    )

    replay = replay_lib.TransitionReplay(FLAGS.replay_capacity,
                                         replay_structure, random_state,
                                         encoder, decoder)

    optimizer = optax.rmsprop(
        learning_rate=FLAGS.learning_rate,
        decay=0.95,
        eps=FLAGS.optimizer_epsilon,
        centered=True,
    )

    if FLAGS.shaping_function_type == constants.NO_PENALTY:
        shaping_function = shaping.NoPenalty()
    if FLAGS.shaping_function_type == constants.HARD_CODED_PENALTY:
        shaping_function = shaping.HardCodedPenalty(
            penalty=FLAGS.shaping_multiplicative_factor)

    train_rng_key, eval_rng_key = jax.random.split(rng_key)

    train_agent = agent.Dqn(
        preprocessor=preprocessor_builder(),
        sample_network_input=sample_network_input,
        network=network,
        optimizer=optimizer,
        transition_accumulator=replay_lib.TransitionAccumulator(),
        replay=replay,
        shaping_function=shaping_function,
        batch_size=FLAGS.batch_size,
        exploration_epsilon=exploration_epsilon_schedule,
        min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
        learn_period=FLAGS.learn_period,
        target_network_update_period=FLAGS.target_network_update_period,
        grad_error_bound=FLAGS.grad_error_bound,
        rng_key=train_rng_key,
    )
    eval_agent = parts.EpsilonGreedyActor(
        preprocessor=preprocessor_builder(),
        network=network,
        exploration_epsilon=FLAGS.eval_exploration_epsilon,
        rng_key=eval_rng_key,
    )

    # Set up checkpointing.
    # checkpoint = parts.NullCheckpoint()
    checkpoint = parts.ImplementedCheckpoint(
        checkpoint_path=FLAGS.checkpoint_path)

    if checkpoint.can_be_restored():
        checkpoint.restore()
        train_agent.set_state(state=checkpoint.state.train_agent)
        eval_agent.set_state(state=checkpoint.state.eval_agent)
        writer.set_state(state=checkpoint.state.writer)

    state = checkpoint.state
    state.iteration = 0
    state.train_agent = train_agent.get_state()
    state.eval_agent = eval_agent.get_state()
    state.random_state = random_state
    state.writer = writer.get_state()

    while state.iteration <= FLAGS.num_iterations:
        # New environment for each iteration to allow for determinism if preempted.
        env = environment_builder()

        logging.info("Training iteration %d.", state.iteration)
        train_seq = parts.run_loop(train_agent, env,
                                   FLAGS.max_frames_per_episode)
        num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames
        train_seq_truncated = itertools.islice(train_seq, num_train_frames)
        train_stats = parts.generate_statistics(train_seq_truncated)

        logging.info("Evaluation iteration %d.", state.iteration)
        eval_agent.network_params = train_agent.online_params
        eval_seq = parts.run_loop(eval_agent, env,
                                  FLAGS.max_frames_per_episode)
        eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames)
        eval_stats = parts.generate_statistics(eval_seq_truncated)

        # Logging and checkpointing.
        human_normalized_score = atari_data.get_human_normalized_score(
            FLAGS.environment_name, eval_stats["episode_return"])
        capped_human_normalized_score = np.amin([1.0, human_normalized_score])
        log_output = [
            ("iteration", state.iteration, "%3d"),
            ("frame", state.iteration * FLAGS.num_train_frames, "%5d"),
            ("eval_episode_return", eval_stats["episode_return"], "% 2.2f"),
            ("train_episode_return", train_stats["episode_return"], "% 2.2f"),
            ("eval_num_episodes", eval_stats["num_episodes"], "%3d"),
            ("train_num_episodes", train_stats["num_episodes"], "%3d"),
            ("eval_frame_rate", eval_stats["step_rate"], "%4.0f"),
            ("train_frame_rate", train_stats["step_rate"], "%4.0f"),
            ("train_exploration_epsilon", train_agent.exploration_epsilon,
             "%.3f"),
            ("normalized_return", human_normalized_score, "%.3f"),
            ("capped_normalized_return", capped_human_normalized_score,
             "%.3f"),
            ("human_gap", 1.0 - capped_human_normalized_score, "%.3f"),
        ]
        log_output_str = ", ".join(
            ("%s: " + f) % (n, v) for n, v, f in log_output)
        logging.info(log_output_str)
        writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))
        state.iteration += 1
        checkpoint.save()

    writer.close()
Exemplo n.º 4
0
def main(argv):
    """Trains DQN agent on Atari."""
    del argv
    logging.info('Boostrapped DQN on Atari on %s.',
                 jax.lib.xla_bridge.get_backend().platform)
    random_state = np.random.RandomState(FLAGS.seed)
    rng_key = jax.random.PRNGKey(
        random_state.randint(-sys.maxsize - 1, sys.maxsize + 1))

    if FLAGS.results_csv_path:
        writer = parts.CsvWriter(FLAGS.results_csv_path)
    else:
        writer = parts.NullWriter()

    def environment_builder():
        """Creates Atari environment."""
        env = gym_atari.GymAtari(FLAGS.environment_name,
                                 seed=random_state.randint(1, 2**32))
        return gym_atari.RandomNoopsEnvironmentWrapper(
            env,
            min_noop_steps=1,
            max_noop_steps=30,
            seed=random_state.randint(1, 2**32),
        )

    env = environment_builder()

    logging.info('Environment: %s', FLAGS.environment_name)
    logging.info('Action spec: %s', env.action_spec())
    logging.info('Observation spec: %s', env.observation_spec())
    num_actions = env.action_spec().num_values
    network_fn = networks.bootstrapped_dqn_multi_head_network(
        num_actions,
        num_heads=FLAGS.num_heads,
        mask_probability=FLAGS.mask_probability)
    network = hk.transform(network_fn)

    def preprocessor_builder():
        return processors.atari(
            additional_discount=FLAGS.additional_discount,
            max_abs_reward=FLAGS.max_abs_reward,
            resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
            num_action_repeats=FLAGS.num_action_repeats,
            num_pooled_frames=2,
            zero_discount_on_life_loss=True,
            num_stacked_frames=FLAGS.num_stacked_frames,
            grayscaling=True,
        )

    # Create sample network input from sample preprocessor output.
    sample_processed_timestep = preprocessor_builder()(env.reset())
    sample_processed_timestep = typing.cast(dm_env.TimeStep,
                                            sample_processed_timestep)
    sample_network_input = sample_processed_timestep.observation
    assert sample_network_input.shape == (FLAGS.environment_height,
                                          FLAGS.environment_width,
                                          FLAGS.num_stacked_frames)

    exploration_epsilon_schedule = parts.LinearSchedule(
        begin_t=int(FLAGS.min_replay_capacity_fraction *
                    FLAGS.replay_capacity * FLAGS.num_action_repeats),
        decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction *
                        FLAGS.num_iterations * FLAGS.num_train_frames),
        begin_value=FLAGS.exploration_epsilon_begin_value,
        end_value=FLAGS.exploration_epsilon_end_value)

    if FLAGS.compress_state:

        def encoder(transition):
            return transition._replace(
                s_tm1=replay_lib.compress_array(transition.s_tm1),
                s_t=replay_lib.compress_array(transition.s_t))

        def decoder(transition):
            return transition._replace(
                s_tm1=replay_lib.uncompress_array(transition.s_tm1),
                s_t=replay_lib.uncompress_array(transition.s_t))
    else:
        encoder = None
        decoder = None

    replay_structure = replay_lib.MaskedTransition(
        s_tm1=None,
        a_tm1=None,
        r_t=None,
        discount_t=None,
        s_t=None,
        mask_t=None,
    )

    replay = replay_lib.TransitionReplay(FLAGS.replay_capacity,
                                         replay_structure, random_state,
                                         encoder, decoder)

    optimizer = optax.rmsprop(
        learning_rate=FLAGS.learning_rate,
        decay=0.95,
        eps=FLAGS.optimizer_epsilon,
        centered=True,
    )

    if FLAGS.shaping_function_type == constants.NO_PENALTY:
        shaping_function = shaping.NoPenalty()
    elif FLAGS.shaping_function_type == constants.HARD_CODED_PENALTY:
        shaping_function = shaping.HardCodedPenalty(
            penalty=FLAGS.shaping_multiplicative_factor)
    elif FLAGS.shaping_function_type == constants.UNCERTAINTY_PENALTY:
        shaping_function = shaping.UncertaintyPenalty(
            multiplicative_factor=FLAGS.shaping_multiplicative_factor)
    elif FLAGS.shaping_function_type == constants.POLICY_ENTROPY_PENALTY:
        shaping_function = shaping.PolicyEntropyPenalty(
            multiplicative_factor=FLAGS.shaping_multiplicative_factor,
            num_actions=num_actions)
    elif FLAGS.shaping_function_type == constants.MUNCHAUSEN_PENALTY:
        shaping_function = shaping.MunchausenPenalty(
            multiplicative_factor=FLAGS.shaping_multiplicative_factor,
            num_actions=num_actions)

    train_rng_key, eval_rng_key = jax.random.split(rng_key)

    train_agent = agent.BootstrappedDqn(
        preprocessor=preprocessor_builder(),
        sample_network_input=sample_network_input,
        network=network,
        optimizer=optimizer,
        transition_accumulator=replay_lib.TransitionAccumulator(),
        replay=replay,
        shaping_function=shaping_function,
        mask_probability=FLAGS.mask_probability,
        num_heads=FLAGS.num_heads,
        batch_size=FLAGS.batch_size,
        exploration_epsilon=exploration_epsilon_schedule,
        min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
        learn_period=FLAGS.learn_period,
        target_network_update_period=FLAGS.target_network_update_period,
        grad_error_bound=FLAGS.grad_error_bound,
        rng_key=train_rng_key,
    )
    eval_agent = parts.EpsilonGreedyActor(
        preprocessor=preprocessor_builder(),
        network=network,
        exploration_epsilon=FLAGS.eval_exploration_epsilon,
        rng_key=eval_rng_key,
    )

    # Set up checkpointing.
    # checkpoint = parts.NullCheckpoint()
    checkpoint = parts.ImplementedCheckpoint(
        checkpoint_path=FLAGS.checkpoint_path)

    if checkpoint.can_be_restored():
        checkpoint.restore()
        iteration = checkpoint.state.iteration
        random_state = checkpoint.state.random_state
        train_agent.set_state(state=checkpoint.state.train_agent)
        eval_agent.set_state(state=checkpoint.state.eval_agent)
        writer.set_state(state=checkpoint.state.writer)
    else:
        iteration = 0

    while iteration <= FLAGS.num_iterations:
        # New environment for each iteration to allow for determinism if preempted.
        env = environment_builder()

        logging.info('Training iteration %d.', iteration)
        train_seq = parts.run_loop(train_agent, env,
                                   FLAGS.max_frames_per_episode)
        num_train_frames = 0 if iteration == 0 else FLAGS.num_train_frames
        train_seq_truncated = itertools.islice(train_seq, num_train_frames)
        train_stats = parts.generate_statistics(train_seq_truncated)

        logging.info('Evaluation iteration %d.', iteration)
        eval_agent.network_params = train_agent.online_params
        eval_seq = parts.run_loop(eval_agent, env,
                                  FLAGS.max_frames_per_episode)
        eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames)
        eval_stats = parts.generate_statistics(eval_seq_truncated)

        # Logging and checkpointing.
        human_normalized_score = atari_data.get_human_normalized_score(
            FLAGS.environment_name, eval_stats['episode_return'])
        capped_human_normalized_score = np.amin([1., human_normalized_score])
        log_output = [
            ('iteration', iteration, '%3d'),
            ('frame', iteration * FLAGS.num_train_frames, '%5d'),
            ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'),
            ('train_episode_return', train_stats['episode_return'], '% 2.2f'),
            ('eval_num_episodes', eval_stats['num_episodes'], '%3d'),
            ('train_num_episodes', train_stats['num_episodes'], '%3d'),
            ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'),
            ('train_frame_rate', train_stats['step_rate'], '%4.0f'),
            ('train_exploration_epsilon', train_agent.exploration_epsilon,
             '%.3f'), ('normalized_return', human_normalized_score, '%.3f'),
            ('capped_normalized_return',
             capped_human_normalized_score, '%.3f'),
            ('human_gap', 1. - capped_human_normalized_score, '%.3f'),
            ('train_loss', train_stats['train_loss'], '% 2.2f'),
            ('shaped_reward', train_stats['shaped_reward'], '% 2.2f'),
            ('penalties', train_stats['penalties'], '% 2.2f')
        ]
        log_output_str = ', '.join(
            ('%s: ' + f) % (n, v) for n, v, f in log_output)
        logging.info(log_output_str)
        writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))

        iteration += 1

        # update state before checkpointing
        checkpoint.state.iteration = iteration
        checkpoint.state.train_agent = train_agent.get_state()
        checkpoint.state.eval_agent = eval_agent.get_state()
        checkpoint.state.random_state = random_state
        checkpoint.state.writer = writer.get_state()
        checkpoint.save()

    writer.close()