Exemplo n.º 1
0
def create_optax_optim(name,
                       learning_rate=None,
                       momentum=0.9,
                       weight_decay=0,
                       **kwargs):
    """ Optimizer Factory

    Args:
        learning_rate (float): specify learning rate or leave up to scheduler / optim if None
        weight_decay (float): weight decay to apply to all params, not applied if 0
        **kwargs: optional / optimizer specific params that override defaults

    With regards to the kwargs, I've tried to keep the param naming incoming via kwargs from
    config file more consistent so there is less variation. Names of common args such as eps,
    beta1, beta2 etc will be remapped where possible (even if optimizer impl uses a diff name)
    and removed when not needed. A list of some common params to use in config files as named:
        eps (float): default stability / regularization epsilon value
        beta1 (float): moving average / momentum coefficient for gradient
        beta2 (float): moving average / momentum coefficient for gradient magnitude (squared grad)
    """
    name = name.lower()
    opt_args = dict(learning_rate=learning_rate, **kwargs)
    _rename(opt_args, ('beta1', 'beta2'), ('b1', 'b2'))
    if name == 'sgd' or name == 'momentum' or name == 'nesterov':
        _erase(opt_args, ('eps', ))
        if name == 'momentum':
            optimizer = optax.sgd(momentum=momentum, **opt_args)
        elif name == 'nesterov':
            optimizer = optax.sgd(momentum=momentum, nesterov=True)
        else:
            assert name == 'sgd'
            optimizer = optax.sgd(momentum=0, **opt_args)
    elif name == 'adabelief':
        optimizer = optax.adabelief(**opt_args)
    elif name == 'adam' or name == 'adamw':
        if name == 'adamw':
            optimizer = optax.adamw(weight_decay=weight_decay, **opt_args)
        else:
            optimizer = optax.adam(**opt_args)
    elif name == 'lamb':
        optimizer = optax.lamb(weight_decay=weight_decay, **opt_args)
    elif name == 'lars':
        optimizer = lars(weight_decay=weight_decay, **opt_args)
    elif name == 'rmsprop':
        optimizer = optax.rmsprop(momentum=momentum, **opt_args)
    elif name == 'rmsproptf':
        optimizer = optax.rmsprop(momentum=momentum,
                                  initial_scale=1.0,
                                  **opt_args)
    else:
        assert False, f"Invalid optimizer name specified ({name})"

    return optimizer
Exemplo n.º 2
0
 def __init__(
     self,
     learning_rate: float = 0.001,
     decay: float = 0.9,
     eps: float = 1e-8,
     initial_scale: float = 0,
     centered: bool = False,
     momentum: float or None = None,
     nesterov: bool = False,
 ):
     super(RMSProp, self).__init__(learning_rate=learning_rate)
     self._decay = decay
     self._eps = eps
     self._initial_scale = initial_scale
     self._centered = centered
     self._momentum = momentum
     self._nesterov = nesterov
     self._optimizer = optax.rmsprop(
         learning_rate=learning_rate,
         decay=decay,
         eps=eps,
         initial_scale=initial_scale,
         centered=centered,
         momentum=momentum,
         nesterov=nesterov,
     )
     self._optimizer_update = jit(self._optimizer.update)
Exemplo n.º 3
0
    def test_example_restore(self):
        class MLP(eg.Module):
            @eg.compact
            def __call__(self, x):
                x = eg.Linear(10)(x)
                x = jax.lax.stop_gradient(x)
                return x

        # This callback will stop the training when there is no improvement in
        # the for three consecutive epochs.
        model = eg.Model(
            module=MLP(),
            loss=eg.losses.MeanSquaredError(),
            optimizer=optax.rmsprop(0.01),
        )
        history = model.fit(
            inputs=np.ones((5, 20)),
            labels=np.zeros((5, 10)),
            epochs=10,
            batch_size=1,
            callbacks=[
                eg.callbacks.EarlyStopping(monitor="loss",
                                           patience=3,
                                           restore_best_weights=True)
            ],
            verbose=0,
        )
        assert len(history.history["loss"]) == 4  # Only 4 epochs are run.
Exemplo n.º 4
0
def main(_):
    # A thunk that builds a new environment.
    # Substitute your environment here!
    build_env = catch.Catch

    # Construct the agent. We need a sample environment for its spec.
    env_for_spec = build_env()
    num_actions = env_for_spec.action_spec().num_values
    agent = agent_lib.Agent(num_actions, env_for_spec.observation_spec(),
                            haiku_nets.CatchNet)

    # Construct the optimizer.
    max_updates = MAX_ENV_FRAMES / FRAMES_PER_ITER
    opt = optax.rmsprop(5e-3, decay=0.99, eps=1e-7)

    # Construct the learner.
    learner = learner_lib.Learner(
        agent,
        jax.random.PRNGKey(428),
        opt,
        BATCH_SIZE,
        DISCOUNT_FACTOR,
        FRAMES_PER_ITER,
        max_abs_reward=1.,
        logger=util.AbslLogger(),  # Provide your own logger here.
    )

    # Construct the actors on different threads.
    # stop_signal in a list so the reference is shared.
    actor_threads = []
    stop_signal = [False]
    for i in range(NUM_ACTORS):
        actor = actor_lib.Actor(
            agent,
            build_env(),
            UNROLL_LENGTH,
            learner,
            rng_seed=i,
            logger=util.AbslLogger(),  # Provide your own logger here.
        )
        args = (actor, stop_signal)
        actor_threads.append(threading.Thread(target=run_actor, args=args))

    # Start the actors and learner.
    for t in actor_threads:
        t.start()
    learner.run(int(max_updates))

    # Stop.
    stop_signal[0] = True
    for t in actor_threads:
        t.join()
Exemplo n.º 5
0
def get_optimizer(optimizer_name: OptimizerName,
                  learning_rate: float,
                  momentum: float = 0.0,
                  adam_beta1: float = 0.9,
                  adam_beta2: float = 0.999,
                  adam_epsilon: float = 1e-8,
                  rmsprop_decay: float = 0.9,
                  rmsprop_epsilon: float = 1e-8,
                  adagrad_init_accumulator: float = 0.1,
                  adagrad_epsilon: float = 1e-6) -> Optimizer:
  """Given parameters, returns the corresponding optimizer.

  Args:
    optimizer_name: One of SGD, MOMENTUM, ADAM, RMSPROP.
    learning_rate: Learning rate for all optimizers.
    momentum: Momentum parameter for MOMENTUM.
    adam_beta1: beta1 parameter for ADAM.
    adam_beta2: beta2 parameter for ADAM.
    adam_epsilon: epsilon parameter for ADAM.
    rmsprop_decay: decay parameter for RMSPROP.
    rmsprop_epsilon: epsilon parameter for RMSPROP.
    adagrad_init_accumulator: initial accumulator for ADAGRAD.
    adagrad_epsilon: epsilon parameter for ADAGRAD.

  Returns:
    Returns the Optimizer with the specified properties.

  Raises:
    ValueError: iff the optimizer names is not one of SGD, MOMENTUM, ADAM,
  RMSPROP, or Adagrad, raises errors.
  """
  if optimizer_name == OptimizerName.SGD:
    return Optimizer(*optax.sgd(learning_rate))
  elif optimizer_name == OptimizerName.MOMENTUM:
    return Optimizer(*optax.sgd(learning_rate, momentum))
  elif optimizer_name == OptimizerName.ADAM:
    return Optimizer(*optax.adam(
        learning_rate, b1=adam_beta1, b2=adam_beta2, eps=adam_epsilon))
  elif optimizer_name == OptimizerName.RMSPROP:
    return Optimizer(
        *optax.rmsprop(learning_rate, decay=rmsprop_decay, eps=rmsprop_epsilon))
  elif optimizer_name == OptimizerName.ADAGRAD:
    return Optimizer(*optax.adagrad(
        learning_rate,
        initial_accumulator_value=adagrad_init_accumulator,
        eps=adagrad_epsilon))
  else:
    raise ValueError(f'Unsupported optimizer_name {optimizer_name}.')
Exemplo n.º 6
0
def make_optimizer(optimizer_type):
    """Constructs optimizer."""
    if optimizer_type == 'rmsprop':
        learning_rate = 0.00025
        epsilon = 0.01 / (32**2)
        optimizer = optax.rmsprop(learning_rate=learning_rate,
                                  decay=0.95,
                                  eps=epsilon,
                                  centered=True)
    elif optimizer_type == 'adam':
        learning_rate = 0.00005
        epsilon = 0.01 / 32
        optimizer = optax.adam(learning_rate=learning_rate, eps=epsilon)
    else:
        raise ValueError('Unknown optimizer "{}"'.format(optimizer_type))
    return optimizer
Exemplo n.º 7
0
def run(*, trajectories_per_actor, num_actors, unroll_len):
    """Runs the example."""

    # Construct the agent network. We need a sample environment for its spec.
    env = catch.Catch()
    num_actions = env.action_spec().num_values
    net = hk.without_apply_rng(
        hk.transform(lambda ts: SimpleNet(num_actions)(ts)))  # pylint: disable=unnecessary-lambda

    # Construct the agent and learner.
    agent = Agent(net.apply)
    opt = optax.rmsprop(5e-3, decay=0.99, eps=1e-7)
    learner = Learner(agent, opt.update)

    # Initialize the optimizer state.
    sample_ts = env.reset()
    sample_ts = preprocess_step(sample_ts)
    ts_with_batch = jax.tree_map(lambda t: np.expand_dims(t, 0), sample_ts)
    params = jax.jit(net.init)(jax.random.PRNGKey(428), ts_with_batch)
    opt_state = opt.init(params)

    # Create accessor and queueing functions.
    current_params = lambda: params
    batch_size = 2
    q = queue.Queue(maxsize=batch_size)

    def dequeue():
        batch = []
        for _ in range(batch_size):
            batch.append(q.get())
        batch = jax.tree_map(lambda *ts: np.stack(ts, axis=1), *batch)
        return jax.device_put(batch)

    # Start the actors.
    for i in range(num_actors):
        key = jax.random.PRNGKey(i)
        args = (agent, key, current_params, q.put, unroll_len,
                trajectories_per_actor)
        threading.Thread(target=run_actor, args=args).start()

    # Run the learner.
    num_steps = num_actors * trajectories_per_actor // batch_size
    for i in range(num_steps):
        traj = dequeue()
        params, opt_state = learner.update(params, opt_state, traj)
Exemplo n.º 8
0
def make(observation_spec: specs.Array,
         action_spec: specs.DiscreteArray,
         rnn_hidden_size: int = 32,
         encoding_hidden_size: List[int] = [256, 128, 64],
         buffer_length: int = 120,
         discount: float = .5,
         td_lambda: float = .9,
         entropy_cost: float = 1.,
         critic_cost: float = 1.,
         seed: int = 0):
    """Creates a default agent."""
    initial_rnn_state = jnp.zeros((1, rnn_hidden_size), dtype=jnp.float32)

    def network(inputs: List[jnp.ndarray], state) -> ModelOutput:
        observation = hk.Flatten()(inputs[0]).reshape((1, -1))
        previous_reward = inputs[1].reshape((1, 1))
        previous_action = inputs[2].reshape((1, -1))

        torso = hk.nets.MLP(encoding_hidden_size)
        gru = hk.GRU(rnn_hidden_size)
        policy_head = hk.Linear(action_spec.num_values)
        value_head = hk.Linear(1)

        input_embedding = jnp.concatenate(
            [observation, previous_reward, previous_action], -1)
        input_embedding = torso(input_embedding)
        embedding, state = gru(input_embedding, state)
        logits = policy_head(embedding)
        value = value_head(embedding)

        return (logits, jnp.squeeze(value, axis=-1), embedding, embedding,
                embedding), state

    return ActorCriticRNN(observation_spec=observation_spec,
                          action_spec=action_spec,
                          network=network,
                          initial_rnn_state=initial_rnn_state,
                          optimizer=optax.rmsprop(1e-3),
                          rng=hk.PRNGSequence(seed),
                          buffer_length=buffer_length,
                          discount=discount,
                          td_lambda=td_lambda,
                          entropy_cost=entropy_cost,
                          critic_cost=critic_cost)
Exemplo n.º 9
0
def rmsprop(learning_rate: ScalarOrSchedule,
            decay: float = 0.9,
            eps: float = 1e-8,
            initial_scale: float = 0.,
            centered: bool = False,
            momentum: Optional[float] = None,
            nesterov: bool = False) -> Optimizer:
  """A flexible RMSProp optimiser.

  RMSProp is an SGD variant with learning rate adaptation. The `learning_rate`
  used for each weight is scaled by a suitable estimate of the magnitude of the
  gradients on previous steps. Several variants of RMSProp can be found
  in the literature. This alias provides an easy to configure RMSProp
  optimiser that can be used to switch between several of these variants.

  References:
    [Tieleman and Hinton, 2012](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
    [Graves, 2013](https://arxiv.org/abs/1308.0850)

  Args:
    learning_rate: This is a fixed global scaling factor.
    decay: The decay used to track the magnitude of previous gradients.
    eps: A small numerical constant to avoid dividing by zero when rescaling.
    initial_scale: Initialisation of accumulators tracking the magnitude of
      previous updates. PyTorch uses `0`, TF1 uses `1`. When reproducing results
      from a paper, verify the value used by the authors.
    centered: Whether the second moment or the variance of the past gradients is
      used to rescale the latest gradients.
    momentum: The `decay` rate used by the momentum term, when it is set to
      `None`, then momentum is not used at all.
    nesterov: Whether nesterov momentum is used.

  Returns:
    The corresponding `Optimizer`.
  """
  return create_optimizer_from_optax(
      optax.rmsprop(
          learning_rate=learning_rate,
          decay=decay,
          eps=eps,
          initial_scale=initial_scale,
          centered=centered,
          momentum=momentum,
          nesterov=nesterov))
Exemplo n.º 10
0
def create_opt(name='adamw',
               learning_rate=6.25e-5,
               beta1=0.9,
               beta2=0.999,
               eps=1.5e-4,
               weight_decay=0.0,
               centered=False):
    """Create an optimizer for training.

  Currently, only the Adam and RMSProp optimizers are supported.

  Args:
    name: str, name of the optimizer to create.
    learning_rate: float, learning rate to use in the optimizer.
    beta1: float, beta1 parameter for the optimizer.
    beta2: float, beta2 parameter for the optimizer.
    eps: float, epsilon parameter for the optimizer.
    centered: bool, centered parameter for RMSProp.

  Returns:
    A flax optimizer.
  """
    if name == 'adam':
        logging.info(
            'Creating AdamW optimizer with settings lr=%f, beta1=%f, '
            'beta2=%f, eps=%f, weight decay=%f', learning_rate, beta1, beta2,
            eps, weight_decay)
        return optax.adam(learning_rate, b1=beta1, b2=beta2, eps=eps)
    elif name == 'rmsprop':
        logging.info(
            'Creating RMSProp optimizer with settings lr=%f, beta2=%f, '
            'eps=%f', learning_rate, beta2, eps)
        return optax.rmsprop(learning_rate,
                             decay=beta2,
                             eps=eps,
                             centered=centered)
    else:
        raise ValueError('Unsupported optimizer {}'.format(name))
Exemplo n.º 11
0
def make(observation_spec: specs.Array,
         action_spec: specs.DiscreteArray,
         hidden_size: List[int] = [256, 128, 64],
         buffer_length: int = 120,
         discount: float = .5,
         td_lambda: float = .9,
         entropy_cost: float = 1.,
         critic_cost: float = 1.,
         seed: int = 0):
    """Creates a default agent."""
    def network(observation: Observation, previous_reward: PreviousReward,
                previous_action: PreviousAction) -> ModelOutput:
        observation = hk.Flatten()(observation)
        previous_reward = hk.Flatten()(previous_reward)
        previous_action = hk.Flatten()(previous_action)

        torso = hk.nets.MLP(hidden_size)
        policy_head = hk.Linear(action_spec.num_values)
        value_head = hk.Linear(1)

        embedding = torso(
            jnp.concatenate([observation, previous_reward, previous_action],
                            -1))
        logits = policy_head(embedding)
        value = value_head(embedding)
        return logits, jnp.squeeze(value,
                                   axis=-1), embedding, embedding, embedding

    return ActorCritic(observation_spec=observation_spec,
                       action_spec=action_spec,
                       network=network,
                       optimizer=optax.rmsprop(1e-3),
                       rng=hk.PRNGSequence(seed),
                       buffer_length=buffer_length,
                       discount=discount,
                       td_lambda=td_lambda,
                       entropy_cost=entropy_cost,
                       critic_cost=critic_cost)
Exemplo n.º 12
0
def RmsProp(
    learning_rate: float = 0.001,
    beta: float = 0.9,
    epscut: float = 1.0e-7,
    centered: bool = False,
):
    r"""RMSProp optimizer.

    RMSProp is a well-known update algorithm proposed by Geoff Hinton
    in his Neural Networks course notes `Neural Networks course notes
    <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
    It corrects the problem with AdaGrad by using an exponentially weighted
    moving average over past squared gradients instead of a cumulative sum.
    After initializing the vector :math:`\mathbf{s}` to zero, :math:`s_k` and t
    he parameters :math:`p_k` are updated as

    .. math:: s^\prime_k = \beta s_k + (1-\beta) G_k(\mathbf{p})^2 \\
              p^\prime_k = p_k - \frac{\eta}{\sqrt{s_k}+\epsilon} G_k(\mathbf{p})


    Constructs a new ``RmsProp`` optimizer.

    Args:
       learning_rate: The learning rate :math:`\eta`
       beta: Exponential decay rate.
       epscut: Small cutoff value.
       centered: whever to center the moving average.

    Examples:
       RmsProp optimizer.

       >>> from netket.optimizer import RmsProp
       >>> op = RmsProp(learning_rate=0.02)
    """
    from optax import rmsprop

    return rmsprop(learning_rate, decay=beta, eps=epscut, centered=centered)
Exemplo n.º 13
0
    def test_example(self):
        class MLP(elegy.Module):
            def call(self, x):
                x = elegy.nn.Linear(10)(x)
                x = jax.lax.stop_gradient(x)
                return x

        callback = elegy.callbacks.EarlyStopping(monitor="loss", patience=3)
        # This callback will stop the training when there is no improvement in
        # the for three consecutive epochs.
        model = elegy.Model(
            module=MLP(),
            loss=elegy.losses.MeanSquaredError(),
            optimizer=optax.rmsprop(0.01),
        )
        history = model.fit(
            x=np.ones((5, 20)),
            y=np.zeros((5, 10)),
            epochs=10,
            batch_size=1,
            callbacks=[callback],
            verbose=0,
        )
        assert len(history.history["loss"]) == 4  # Only 4 epochs are run.
Exemplo n.º 14
0
def main(argv):
  """Trains Prioritized DQN agent on Atari."""
  del argv
  logging.info('Prioritized DQN on Atari on %s.',
               jax.lib.xla_bridge.get_backend().platform)
  random_state = np.random.RandomState(FLAGS.seed)
  rng_key = jax.random.PRNGKey(
      random_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64))

  if FLAGS.results_csv_path:
    writer = parts.CsvWriter(FLAGS.results_csv_path)
  else:
    writer = parts.NullWriter()

  def environment_builder():
    """Creates Atari environment."""
    env = gym_atari.GymAtari(
        FLAGS.environment_name, seed=random_state.randint(1, 2**32))
    return gym_atari.RandomNoopsEnvironmentWrapper(
        env,
        min_noop_steps=1,
        max_noop_steps=30,
        seed=random_state.randint(1, 2**32),
    )

  env = environment_builder()

  logging.info('Environment: %s', FLAGS.environment_name)
  logging.info('Action spec: %s', env.action_spec())
  logging.info('Observation spec: %s', env.observation_spec())
  num_actions = env.action_spec().num_values
  network_fn = networks.double_dqn_atari_network(num_actions)
  network = hk.transform(network_fn)

  def preprocessor_builder():
    return processors.atari(
        additional_discount=FLAGS.additional_discount,
        max_abs_reward=FLAGS.max_abs_reward,
        resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
        num_action_repeats=FLAGS.num_action_repeats,
        num_pooled_frames=2,
        zero_discount_on_life_loss=True,
        num_stacked_frames=FLAGS.num_stacked_frames,
        grayscaling=True,
    )

  # Create sample network input from sample preprocessor output.
  sample_processed_timestep = preprocessor_builder()(env.reset())
  sample_processed_timestep = typing.cast(dm_env.TimeStep,
                                          sample_processed_timestep)
  sample_network_input = sample_processed_timestep.observation
  chex.assert_shape(sample_network_input,
                    (FLAGS.environment_height, FLAGS.environment_width,
                     FLAGS.num_stacked_frames))

  exploration_epsilon_schedule = parts.LinearSchedule(
      begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity *
                  FLAGS.num_action_repeats),
      decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction *
                      FLAGS.num_iterations * FLAGS.num_train_frames),
      begin_value=FLAGS.exploration_epsilon_begin_value,
      end_value=FLAGS.exploration_epsilon_end_value)

  # Note the t in the replay is not exactly aligned with the agent t.
  importance_sampling_exponent_schedule = parts.LinearSchedule(
      begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity),
      end_t=(FLAGS.num_iterations *
             int(FLAGS.num_train_frames / FLAGS.num_action_repeats)),
      begin_value=FLAGS.importance_sampling_exponent_begin_value,
      end_value=FLAGS.importance_sampling_exponent_end_value)

  if FLAGS.compress_state:

    def encoder(transition):
      return transition._replace(
          s_tm1=replay_lib.compress_array(transition.s_tm1),
          s_t=replay_lib.compress_array(transition.s_t))

    def decoder(transition):
      return transition._replace(
          s_tm1=replay_lib.uncompress_array(transition.s_tm1),
          s_t=replay_lib.uncompress_array(transition.s_t))
  else:
    encoder = None
    decoder = None

  replay_structure = replay_lib.Transition(
      s_tm1=None,
      a_tm1=None,
      r_t=None,
      discount_t=None,
      s_t=None,
  )

  replay = replay_lib.PrioritizedTransitionReplay(
      FLAGS.replay_capacity, replay_structure, FLAGS.priority_exponent,
      importance_sampling_exponent_schedule, FLAGS.uniform_sample_probability,
      FLAGS.normalize_weights, random_state, encoder, decoder)

  optimizer = optax.rmsprop(
      learning_rate=FLAGS.learning_rate,
      decay=0.95,
      eps=FLAGS.optimizer_epsilon,
      centered=True,
  )

  train_rng_key, eval_rng_key = jax.random.split(rng_key)

  train_agent = agent.PrioritizedDqn(
      preprocessor=preprocessor_builder(),
      sample_network_input=sample_network_input,
      network=network,
      optimizer=optimizer,
      transition_accumulator=replay_lib.TransitionAccumulator(),
      replay=replay,
      batch_size=FLAGS.batch_size,
      exploration_epsilon=exploration_epsilon_schedule,
      min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
      learn_period=FLAGS.learn_period,
      target_network_update_period=FLAGS.target_network_update_period,
      grad_error_bound=FLAGS.grad_error_bound,
      rng_key=train_rng_key,
  )
  eval_agent = parts.EpsilonGreedyActor(
      preprocessor=preprocessor_builder(),
      network=network,
      exploration_epsilon=FLAGS.eval_exploration_epsilon,
      rng_key=eval_rng_key,
  )

  # Set up checkpointing.
  checkpoint = parts.NullCheckpoint()

  state = checkpoint.state
  state.iteration = 0
  state.train_agent = train_agent
  state.eval_agent = eval_agent
  state.random_state = random_state
  state.writer = writer
  if checkpoint.can_be_restored():
    checkpoint.restore()

  while state.iteration <= FLAGS.num_iterations:
    # New environment for each iteration to allow for determinism if preempted.
    env = environment_builder()

    logging.info('Training iteration %d.', state.iteration)
    train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode)
    num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames
    train_seq_truncated = itertools.islice(train_seq, num_train_frames)
    train_trackers = parts.make_default_trackers(train_agent)
    train_stats = parts.generate_statistics(train_trackers, train_seq_truncated)

    logging.info('Evaluation iteration %d.', state.iteration)
    eval_agent.network_params = train_agent.online_params
    eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode)
    eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames)
    eval_trackers = parts.make_default_trackers(eval_agent)
    eval_stats = parts.generate_statistics(eval_trackers, eval_seq_truncated)

    # Logging and checkpointing.
    human_normalized_score = atari_data.get_human_normalized_score(
        FLAGS.environment_name, eval_stats['episode_return'])
    capped_human_normalized_score = np.amin([1., human_normalized_score])
    log_output = [
        ('iteration', state.iteration, '%3d'),
        ('frame', state.iteration * FLAGS.num_train_frames, '%5d'),
        ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'),
        ('train_episode_return', train_stats['episode_return'], '% 2.2f'),
        ('eval_num_episodes', eval_stats['num_episodes'], '%3d'),
        ('train_num_episodes', train_stats['num_episodes'], '%3d'),
        ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'),
        ('train_frame_rate', train_stats['step_rate'], '%4.0f'),
        ('train_exploration_epsilon', train_agent.exploration_epsilon, '%.3f'),
        ('train_state_value', train_stats['state_value'], '%.3f'),
        ('importance_sampling_exponent',
         train_agent.importance_sampling_exponent, '%.3f'),
        ('max_seen_priority', train_agent.max_seen_priority, '%.3f'),
        ('normalized_return', human_normalized_score, '%.3f'),
        ('capped_normalized_return', capped_human_normalized_score, '%.3f'),
        ('human_gap', 1. - capped_human_normalized_score, '%.3f'),
    ]
    log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output)
    logging.info(log_output_str)
    writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))
    state.iteration += 1
    checkpoint.save()

  writer.close()
Exemplo n.º 15
0
def main(argv):
    """Trains DQN agent on Atari."""
    del argv
    logging.info("DQN on Atari on %s.",
                 jax.lib.xla_bridge.get_backend().platform)
    random_state = np.random.RandomState(FLAGS.seed)
    rng_key = jax.random.PRNGKey(
        random_state.randint(-sys.maxsize - 1, sys.maxsize + 1))

    if FLAGS.results_csv_path:
        writer = parts.CsvWriter(FLAGS.results_csv_path)
    else:
        writer = parts.NullWriter()

    def environment_builder():
        """Creates Key-Door environment."""
        env = gym_key_door.GymKeyDoor(
            env_args={
                constants.MAP_ASCII_PATH: FLAGS.map_ascii_path,
                constants.MAP_YAML_PATH: FLAGS.map_yaml_path,
                constants.REPRESENTATION: constants.PIXEL,
                constants.SCALING: FLAGS.env_scaling,
                constants.EPISODE_TIMEOUT: FLAGS.max_frames_per_episode,
                constants.GRAYSCALE: False,
                constants.BATCH_DIMENSION: False,
                constants.TORCH_AXES: False,
            },
            env_shape=FLAGS.env_shape,
        )
        return gym_atari.RandomNoopsEnvironmentWrapper(
            env,
            min_noop_steps=1,
            max_noop_steps=30,
            seed=random_state.randint(1, 2**32),
        )

    env = environment_builder()

    logging.info("Environment: %s", FLAGS.environment_name)
    logging.info("Action spec: %s", env.action_spec())
    logging.info("Observation spec: %s", env.observation_spec())
    num_actions = env.action_spec().num_values
    network_fn = networks.dqn_atari_network(num_actions)
    network = hk.transform(network_fn)

    def preprocessor_builder():
        return processors.atari(
            additional_discount=FLAGS.additional_discount,
            max_abs_reward=FLAGS.max_abs_reward,
            resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
            num_action_repeats=FLAGS.num_action_repeats,
            num_pooled_frames=2,
            zero_discount_on_life_loss=True,
            num_stacked_frames=FLAGS.num_stacked_frames,
            grayscaling=True,
        )

    # Create sample network input from sample preprocessor output.
    sample_processed_timestep = preprocessor_builder()(env.reset())
    sample_processed_timestep = typing.cast(dm_env.TimeStep,
                                            sample_processed_timestep)
    sample_network_input = sample_processed_timestep.observation
    assert sample_network_input.shape == (
        FLAGS.environment_height,
        FLAGS.environment_width,
        FLAGS.num_stacked_frames,
    )

    exploration_epsilon_schedule = parts.LinearSchedule(
        begin_t=int(FLAGS.min_replay_capacity_fraction *
                    FLAGS.replay_capacity * FLAGS.num_action_repeats),
        decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction *
                        FLAGS.num_iterations * FLAGS.num_train_frames),
        begin_value=FLAGS.exploration_epsilon_begin_value,
        end_value=FLAGS.exploration_epsilon_end_value,
    )

    if FLAGS.compress_state:

        def encoder(transition):
            return transition._replace(
                s_tm1=replay_lib.compress_array(transition.s_tm1),
                s_t=replay_lib.compress_array(transition.s_t),
            )

        def decoder(transition):
            return transition._replace(
                s_tm1=replay_lib.uncompress_array(transition.s_tm1),
                s_t=replay_lib.uncompress_array(transition.s_t),
            )

    else:
        encoder = None
        decoder = None

    replay_structure = replay_lib.Transition(
        s_tm1=None,
        a_tm1=None,
        r_t=None,
        discount_t=None,
        s_t=None,
    )

    replay = replay_lib.TransitionReplay(FLAGS.replay_capacity,
                                         replay_structure, random_state,
                                         encoder, decoder)

    optimizer = optax.rmsprop(
        learning_rate=FLAGS.learning_rate,
        decay=0.95,
        eps=FLAGS.optimizer_epsilon,
        centered=True,
    )

    if FLAGS.shaping_function_type == constants.NO_PENALTY:
        shaping_function = shaping.NoPenalty()
    if FLAGS.shaping_function_type == constants.HARD_CODED_PENALTY:
        shaping_function = shaping.HardCodedPenalty(
            penalty=FLAGS.shaping_multiplicative_factor)

    train_rng_key, eval_rng_key = jax.random.split(rng_key)

    train_agent = agent.Dqn(
        preprocessor=preprocessor_builder(),
        sample_network_input=sample_network_input,
        network=network,
        optimizer=optimizer,
        transition_accumulator=replay_lib.TransitionAccumulator(),
        replay=replay,
        shaping_function=shaping_function,
        batch_size=FLAGS.batch_size,
        exploration_epsilon=exploration_epsilon_schedule,
        min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
        learn_period=FLAGS.learn_period,
        target_network_update_period=FLAGS.target_network_update_period,
        grad_error_bound=FLAGS.grad_error_bound,
        rng_key=train_rng_key,
    )
    eval_agent = parts.EpsilonGreedyActor(
        preprocessor=preprocessor_builder(),
        network=network,
        exploration_epsilon=FLAGS.eval_exploration_epsilon,
        rng_key=eval_rng_key,
    )

    # Set up checkpointing.
    # checkpoint = parts.NullCheckpoint()
    checkpoint = parts.ImplementedCheckpoint(
        checkpoint_path=FLAGS.checkpoint_path)

    if checkpoint.can_be_restored():
        checkpoint.restore()
        train_agent.set_state(state=checkpoint.state.train_agent)
        eval_agent.set_state(state=checkpoint.state.eval_agent)
        writer.set_state(state=checkpoint.state.writer)

    state = checkpoint.state
    state.iteration = 0
    state.train_agent = train_agent.get_state()
    state.eval_agent = eval_agent.get_state()
    state.random_state = random_state
    state.writer = writer.get_state()

    while state.iteration <= FLAGS.num_iterations:
        # New environment for each iteration to allow for determinism if preempted.
        env = environment_builder()

        logging.info("Training iteration %d.", state.iteration)
        train_seq = parts.run_loop(train_agent, env,
                                   FLAGS.max_frames_per_episode)
        num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames
        train_seq_truncated = itertools.islice(train_seq, num_train_frames)
        train_stats = parts.generate_statistics(train_seq_truncated)

        logging.info("Evaluation iteration %d.", state.iteration)
        eval_agent.network_params = train_agent.online_params
        eval_seq = parts.run_loop(eval_agent, env,
                                  FLAGS.max_frames_per_episode)
        eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames)
        eval_stats = parts.generate_statistics(eval_seq_truncated)

        # Logging and checkpointing.
        human_normalized_score = atari_data.get_human_normalized_score(
            FLAGS.environment_name, eval_stats["episode_return"])
        capped_human_normalized_score = np.amin([1.0, human_normalized_score])
        log_output = [
            ("iteration", state.iteration, "%3d"),
            ("frame", state.iteration * FLAGS.num_train_frames, "%5d"),
            ("eval_episode_return", eval_stats["episode_return"], "% 2.2f"),
            ("train_episode_return", train_stats["episode_return"], "% 2.2f"),
            ("eval_num_episodes", eval_stats["num_episodes"], "%3d"),
            ("train_num_episodes", train_stats["num_episodes"], "%3d"),
            ("eval_frame_rate", eval_stats["step_rate"], "%4.0f"),
            ("train_frame_rate", train_stats["step_rate"], "%4.0f"),
            ("train_exploration_epsilon", train_agent.exploration_epsilon,
             "%.3f"),
            ("normalized_return", human_normalized_score, "%.3f"),
            ("capped_normalized_return", capped_human_normalized_score,
             "%.3f"),
            ("human_gap", 1.0 - capped_human_normalized_score, "%.3f"),
        ]
        log_output_str = ", ".join(
            ("%s: " + f) % (n, v) for n, v, f in log_output)
        logging.info(log_output_str)
        writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))
        state.iteration += 1
        checkpoint.save()

    writer.close()
Exemplo n.º 16
0
def main(debug: bool = False, eager: bool = False, logdir: str = "runs"):

    if debug:
        import debugpy

        print("Waiting for debugger...")
        debugpy.listen(5678)
        debugpy.wait_for_client()

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    logdir = os.path.join(logdir, current_time)

    X_train, _1, X_test, _2 = dataget.image.mnist(global_cache=True).get()

    print("X_train:", X_train.shape, X_train.dtype)
    print("X_test:", X_test.shape, X_test.dtype)

    class MLP(elegy.Module):
        """Standard LeNet-300-100 MLP network."""

        def __init__(self, n1: int = 300, n2: int = 100, **kwargs):
            super().__init__(**kwargs)
            self.n1 = n1
            self.n2 = n2

        def call(self, image: jnp.ndarray):
            image = image.astype(jnp.float32) / 255.0
            x = elegy.nn.Flatten()(image)
            x = elegy.nn.sequential(
                elegy.nn.Linear(self.n1),
                jax.nn.relu,
                elegy.nn.Linear(self.n2),
                jax.nn.relu,
                elegy.nn.Linear(self.n1),
                jax.nn.relu,
                elegy.nn.Linear(x.shape[-1]),
                jax.nn.sigmoid,
            )(x)
            return x.reshape(image.shape) * 255

    class MeanSquaredError(elegy.losses.MeanSquaredError):
        # we request `x` instead of `y_true` since we are don't require labels in autoencoders
        def call(self, x, y_pred):
            return super().call(x, y_pred)

    model = elegy.Model(
        module=MLP(n1=256, n2=64),
        loss=MeanSquaredError(),
        optimizer=optax.rmsprop(0.001),
        run_eagerly=eager,
    )

    model.summary(X_train[:64])

    # Notice we are not passing `y`
    history = model.fit(
        x=X_train,
        epochs=20,
        batch_size=64,
        validation_data=(X_test,),
        shuffle=True,
        callbacks=[elegy.callbacks.TensorBoard(logdir=logdir, update_freq=300)],
    )

    plot_history(history)

    # get random samples
    idxs = np.random.randint(0, 10000, size=(5,))
    x_sample = X_test[idxs]

    # get predictions
    y_pred = model.predict(x=x_sample)

    # plot and save results
    with SummaryWriter(os.path.join(logdir, "val")) as tbwriter:

        figure = plt.figure(figsize=(12, 12))
        for i in range(5):
            plt.subplot(2, 5, i + 1)
            plt.imshow(x_sample[i], cmap="gray")
            plt.subplot(2, 5, 5 + i + 1)
            plt.imshow(y_pred[i], cmap="gray")

        # tbwriter.add_figure("AutoEncoder images", figure, 20)

    plt.show()

    print(
        "\n\n\nMetrics and images can be explored using tensorboard using:",
        f"\n \t\t\t tensorboard --logdir {logdir}",
    )
Exemplo n.º 17
0
 def update(
         self, gradient: Weights, state: GenericGradientState,
         parameters: Optional[Weights]
 ) -> Tuple[Weights, GenericGradientState]:
     return GenericGradientState.wrap(*rmsprop(
         **asdict(self)).update(gradient, state.data, parameters))
Exemplo n.º 18
0
 def init(self, parameters: Weights) -> GenericGradientState:
     return GenericGradientState(rmsprop(**asdict(self)).init(parameters))
def main():
    start_time = time.time()
    parser = argparse.ArgumentParser()
    add_args(parser)
    args = parser.parse_args()
    print(args)
    print("Is jax using @jit decorators?",
          not jax.config.read("jax_disable_jit"))
    rng_seq = hk.PRNGSequence(args.random_seed)
    p_log_prob = hk.transform(lambda x, z: Model(
        args.latent_size, args.hidden_size, MNIST_IMAGE_SHAPE)(x=x, z=z))
    if args.variational == "mean-field":
        variational = VariationalMeanField
    elif args.variational == "flow":
        variational = VariationalFlow
    q_sample_and_log_prob = hk.transform(lambda x, num_samples: variational(
        args.latent_size, args.hidden_size)(x, num_samples))
    p_params = p_log_prob.init(
        next(rng_seq),
        z=np.zeros((1, args.latent_size), dtype=np.float32),
        x=np.zeros((1, *MNIST_IMAGE_SHAPE), dtype=np.float32),
    )
    q_params = q_sample_and_log_prob.init(
        next(rng_seq),
        x=np.zeros((1, *MNIST_IMAGE_SHAPE), dtype=np.float32),
        num_samples=1,
    )
    optimizer = optax.rmsprop(args.learning_rate)
    params = (p_params, q_params)
    opt_state = optimizer.init(params)

    @jax.jit
    def objective_fn(params: hk.Params, rng_key: PRNGKey,
                     batch: Batch) -> jnp.ndarray:
        """Objective function is negative ELBO."""
        x = batch["image"]
        p_params, q_params = params
        z, log_q_z = q_sample_and_log_prob.apply(q_params,
                                                 rng_key,
                                                 x=x,
                                                 num_samples=1)
        log_p_x_z = p_log_prob.apply(p_params, rng_key, x=x, z=z)
        elbo = log_p_x_z - log_q_z
        # average elbo over number of samples
        elbo = elbo.mean(axis=0)
        # sum elbo over batch
        elbo = elbo.sum(axis=0)
        return -elbo

    @jax.jit
    def train_step(params: hk.Params, rng_key: PRNGKey,
                   opt_state: optax.OptState,
                   batch: Batch) -> Tuple[hk.Params, optax.OptState]:
        """Single update step to maximize the ELBO."""
        grads = jax.grad(objective_fn)(params, rng_key, batch)
        updates, new_opt_state = optimizer.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)
        return new_params, new_opt_state

    @jax.jit
    def importance_weighted_estimate(
            params: hk.Params, rng_key: PRNGKey,
            batch: Batch) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Estimate marginal log p(x) using importance sampling."""
        x = batch["image"]
        p_params, q_params = params
        z, log_q_z = q_sample_and_log_prob.apply(
            q_params, rng_key, x=x, num_samples=args.num_importance_samples)
        log_p_x_z = p_log_prob.apply(p_params, rng_key, x, z)
        elbo = log_p_x_z - log_q_z
        # importance sampling of approximate marginal likelihood with q(z)
        # as the proposal, and logsumexp in the sample dimension
        log_p_x = jax.nn.logsumexp(elbo, axis=0) - jnp.log(jnp.shape(elbo)[0])
        # sum over the elements of the minibatch
        log_p_x = log_p_x.sum(0)
        # average elbo over number of samples
        elbo = elbo.mean(axis=0)
        # sum elbo over batch
        elbo = elbo.sum(axis=0)
        return elbo, log_p_x

    def evaluate(
        dataset: Generator[Batch, None, None],
        params: hk.Params,
        rng_seq: hk.PRNGSequence,
    ) -> Tuple[float, float]:
        total_elbo = 0.0
        total_log_p_x = 0.0
        dataset_size = 0
        for batch in dataset:
            elbo, log_p_x = importance_weighted_estimate(
                params, next(rng_seq), batch)
            total_elbo += elbo
            total_log_p_x += log_p_x
            dataset_size += len(batch["image"])
        return total_elbo / dataset_size, total_log_p_x / dataset_size

    train_ds = load_dataset(tfds.Split.TRAIN,
                            args.batch_size,
                            args.random_seed,
                            repeat=True)
    test_ds = load_dataset(tfds.Split.TEST, args.batch_size, args.random_seed)

    def print_progress(step: int, examples_per_sec: float):
        valid_ds = load_dataset(tfds.Split.VALIDATION, args.batch_size,
                                args.random_seed)
        elbo, log_p_x = evaluate(valid_ds, params, rng_seq)
        train_elbo = (-objective_fn(params, next(rng_seq), next(train_ds)) /
                      args.batch_size)
        print(f"Step {step:<10d}\t"
              f"Train ELBO estimate: {train_elbo:<5.3f}\t"
              f"Validation ELBO estimate: {elbo:<5.3f}\t"
              f"Validation log p(x) estimate: {log_p_x:<5.3f}\t"
              f"Speed: {examples_per_sec:<5.2e} examples/s")

    t0 = time.time()
    for step in range(args.training_steps):
        if step % args.log_interval == 0:
            t1 = time.time()
            examples_per_sec = args.log_interval * args.batch_size / (t1 - t0)
            print_progress(step, examples_per_sec)
            t0 = t1
        params, opt_state = train_step(params, next(rng_seq), opt_state,
                                       next(train_ds))

    test_ds = load_dataset(tfds.Split.TEST, args.batch_size, args.random_seed)
    elbo, log_p_x = evaluate(test_ds, params, rng_seq)
    print(f"Step {step:<10d}\t"
          f"Test ELBO estimate: {elbo:<5.3f}\t"
          f"Test log p(x) estimate: {log_p_x:<5.3f}\t")
    print(f"Total time: {(time.time() - start_time) / 60:.3f} minutes")
Exemplo n.º 20
0
def main(argv):
    """Trains DQN agent on Atari."""
    del argv
    logging.info('Boostrapped DQN on Atari on %s.',
                 jax.lib.xla_bridge.get_backend().platform)
    random_state = np.random.RandomState(FLAGS.seed)
    rng_key = jax.random.PRNGKey(
        random_state.randint(-sys.maxsize - 1, sys.maxsize + 1))

    if FLAGS.results_csv_path:
        writer = parts.CsvWriter(FLAGS.results_csv_path)
    else:
        writer = parts.NullWriter()

    def environment_builder():
        """Creates Atari environment."""
        env = gym_atari.GymAtari(FLAGS.environment_name,
                                 seed=random_state.randint(1, 2**32))
        return gym_atari.RandomNoopsEnvironmentWrapper(
            env,
            min_noop_steps=1,
            max_noop_steps=30,
            seed=random_state.randint(1, 2**32),
        )

    env = environment_builder()

    logging.info('Environment: %s', FLAGS.environment_name)
    logging.info('Action spec: %s', env.action_spec())
    logging.info('Observation spec: %s', env.observation_spec())
    num_actions = env.action_spec().num_values
    network_fn = networks.bootstrapped_dqn_multi_head_network(
        num_actions,
        num_heads=FLAGS.num_heads,
        mask_probability=FLAGS.mask_probability)
    network = hk.transform(network_fn)

    def preprocessor_builder():
        return processors.atari(
            additional_discount=FLAGS.additional_discount,
            max_abs_reward=FLAGS.max_abs_reward,
            resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
            num_action_repeats=FLAGS.num_action_repeats,
            num_pooled_frames=2,
            zero_discount_on_life_loss=True,
            num_stacked_frames=FLAGS.num_stacked_frames,
            grayscaling=True,
        )

    # Create sample network input from sample preprocessor output.
    sample_processed_timestep = preprocessor_builder()(env.reset())
    sample_processed_timestep = typing.cast(dm_env.TimeStep,
                                            sample_processed_timestep)
    sample_network_input = sample_processed_timestep.observation
    assert sample_network_input.shape == (FLAGS.environment_height,
                                          FLAGS.environment_width,
                                          FLAGS.num_stacked_frames)

    exploration_epsilon_schedule = parts.LinearSchedule(
        begin_t=int(FLAGS.min_replay_capacity_fraction *
                    FLAGS.replay_capacity * FLAGS.num_action_repeats),
        decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction *
                        FLAGS.num_iterations * FLAGS.num_train_frames),
        begin_value=FLAGS.exploration_epsilon_begin_value,
        end_value=FLAGS.exploration_epsilon_end_value)

    if FLAGS.compress_state:

        def encoder(transition):
            return transition._replace(
                s_tm1=replay_lib.compress_array(transition.s_tm1),
                s_t=replay_lib.compress_array(transition.s_t))

        def decoder(transition):
            return transition._replace(
                s_tm1=replay_lib.uncompress_array(transition.s_tm1),
                s_t=replay_lib.uncompress_array(transition.s_t))
    else:
        encoder = None
        decoder = None

    replay_structure = replay_lib.MaskedTransition(
        s_tm1=None,
        a_tm1=None,
        r_t=None,
        discount_t=None,
        s_t=None,
        mask_t=None,
    )

    replay = replay_lib.TransitionReplay(FLAGS.replay_capacity,
                                         replay_structure, random_state,
                                         encoder, decoder)

    optimizer = optax.rmsprop(
        learning_rate=FLAGS.learning_rate,
        decay=0.95,
        eps=FLAGS.optimizer_epsilon,
        centered=True,
    )

    if FLAGS.shaping_function_type == constants.NO_PENALTY:
        shaping_function = shaping.NoPenalty()
    elif FLAGS.shaping_function_type == constants.HARD_CODED_PENALTY:
        shaping_function = shaping.HardCodedPenalty(
            penalty=FLAGS.shaping_multiplicative_factor)
    elif FLAGS.shaping_function_type == constants.UNCERTAINTY_PENALTY:
        shaping_function = shaping.UncertaintyPenalty(
            multiplicative_factor=FLAGS.shaping_multiplicative_factor)
    elif FLAGS.shaping_function_type == constants.POLICY_ENTROPY_PENALTY:
        shaping_function = shaping.PolicyEntropyPenalty(
            multiplicative_factor=FLAGS.shaping_multiplicative_factor,
            num_actions=num_actions)
    elif FLAGS.shaping_function_type == constants.MUNCHAUSEN_PENALTY:
        shaping_function = shaping.MunchausenPenalty(
            multiplicative_factor=FLAGS.shaping_multiplicative_factor,
            num_actions=num_actions)

    train_rng_key, eval_rng_key = jax.random.split(rng_key)

    train_agent = agent.BootstrappedDqn(
        preprocessor=preprocessor_builder(),
        sample_network_input=sample_network_input,
        network=network,
        optimizer=optimizer,
        transition_accumulator=replay_lib.TransitionAccumulator(),
        replay=replay,
        shaping_function=shaping_function,
        mask_probability=FLAGS.mask_probability,
        num_heads=FLAGS.num_heads,
        batch_size=FLAGS.batch_size,
        exploration_epsilon=exploration_epsilon_schedule,
        min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
        learn_period=FLAGS.learn_period,
        target_network_update_period=FLAGS.target_network_update_period,
        grad_error_bound=FLAGS.grad_error_bound,
        rng_key=train_rng_key,
    )
    eval_agent = parts.EpsilonGreedyActor(
        preprocessor=preprocessor_builder(),
        network=network,
        exploration_epsilon=FLAGS.eval_exploration_epsilon,
        rng_key=eval_rng_key,
    )

    # Set up checkpointing.
    # checkpoint = parts.NullCheckpoint()
    checkpoint = parts.ImplementedCheckpoint(
        checkpoint_path=FLAGS.checkpoint_path)

    if checkpoint.can_be_restored():
        checkpoint.restore()
        iteration = checkpoint.state.iteration
        random_state = checkpoint.state.random_state
        train_agent.set_state(state=checkpoint.state.train_agent)
        eval_agent.set_state(state=checkpoint.state.eval_agent)
        writer.set_state(state=checkpoint.state.writer)
    else:
        iteration = 0

    while iteration <= FLAGS.num_iterations:
        # New environment for each iteration to allow for determinism if preempted.
        env = environment_builder()

        logging.info('Training iteration %d.', iteration)
        train_seq = parts.run_loop(train_agent, env,
                                   FLAGS.max_frames_per_episode)
        num_train_frames = 0 if iteration == 0 else FLAGS.num_train_frames
        train_seq_truncated = itertools.islice(train_seq, num_train_frames)
        train_stats = parts.generate_statistics(train_seq_truncated)

        logging.info('Evaluation iteration %d.', iteration)
        eval_agent.network_params = train_agent.online_params
        eval_seq = parts.run_loop(eval_agent, env,
                                  FLAGS.max_frames_per_episode)
        eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames)
        eval_stats = parts.generate_statistics(eval_seq_truncated)

        # Logging and checkpointing.
        human_normalized_score = atari_data.get_human_normalized_score(
            FLAGS.environment_name, eval_stats['episode_return'])
        capped_human_normalized_score = np.amin([1., human_normalized_score])
        log_output = [
            ('iteration', iteration, '%3d'),
            ('frame', iteration * FLAGS.num_train_frames, '%5d'),
            ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'),
            ('train_episode_return', train_stats['episode_return'], '% 2.2f'),
            ('eval_num_episodes', eval_stats['num_episodes'], '%3d'),
            ('train_num_episodes', train_stats['num_episodes'], '%3d'),
            ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'),
            ('train_frame_rate', train_stats['step_rate'], '%4.0f'),
            ('train_exploration_epsilon', train_agent.exploration_epsilon,
             '%.3f'), ('normalized_return', human_normalized_score, '%.3f'),
            ('capped_normalized_return',
             capped_human_normalized_score, '%.3f'),
            ('human_gap', 1. - capped_human_normalized_score, '%.3f'),
            ('train_loss', train_stats['train_loss'], '% 2.2f'),
            ('shaped_reward', train_stats['shaped_reward'], '% 2.2f'),
            ('penalties', train_stats['penalties'], '% 2.2f')
        ]
        log_output_str = ', '.join(
            ('%s: ' + f) % (n, v) for n, v, f in log_output)
        logging.info(log_output_str)
        writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))

        iteration += 1

        # update state before checkpointing
        checkpoint.state.iteration = iteration
        checkpoint.state.train_agent = train_agent.get_state()
        checkpoint.state.eval_agent = eval_agent.get_state()
        checkpoint.state.random_state = random_state
        checkpoint.state.writer = writer.get_state()
        checkpoint.save()

    writer.close()
Exemplo n.º 21
0
def main(
    debug: bool = False,
    eager: bool = False,
    logdir: str = "runs",
    steps_per_epoch: int = 200,
    epochs: int = 100,
    batch_size: int = 64,
):

    if debug:
        import debugpy

        print("Waiting for debugger...")
        debugpy.listen(5678)
        debugpy.wait_for_client()

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    logdir = os.path.join(logdir, current_time)

    dataset = load_dataset("mnist")
    dataset.set_format("np")
    X_train = np.stack(dataset["train"]["image"])
    X_test = np.stack(dataset["test"]["image"])

    print("X_train:", X_train.shape, X_train.dtype)
    print("X_test:", X_test.shape, X_test.dtype)

    model = eg.Model(
        module=MLP(n1=256, n2=64),
        loss=MeanSquaredError(),
        optimizer=optax.rmsprop(0.001),
        eager=eager,
    )

    model.summary(X_train[:64])

    # Notice we are not passing `y`
    history = model.fit(
        inputs=X_train,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        batch_size=batch_size,
        validation_data=(X_test, ),
        shuffle=True,
        callbacks=[eg.callbacks.TensorBoard(logdir=logdir, update_freq=300)],
    )

    eg.utils.plot_history(history)

    # get random samples
    idxs = np.random.randint(0, 10000, size=(5, ))
    x_sample = X_test[idxs]

    # get predictions
    y_pred = model.predict(x=x_sample)

    # plot and save results
    with SummaryWriter(os.path.join(logdir, "val")) as tbwriter:

        figure = plt.figure(figsize=(12, 12))
        for i in range(5):
            plt.subplot(2, 5, i + 1)
            plt.imshow(x_sample[i], cmap="gray")
            plt.subplot(2, 5, 5 + i + 1)
            plt.imshow(y_pred[i], cmap="gray")

    plt.show()