Beispiel #1
0
    def make_learner(
        self,
        random_key: networks_lib.PRNGKey,
        networks: d4pg_networks.D4PGNetworks,
        dataset: Iterator[reverb.ReplaySample],
        logger_fn: loggers.LoggerFactory,
        environment_spec: specs.EnvironmentSpec,
        replay_client: Optional[reverb.Client] = None,
        counter: Optional[counting.Counter] = None,
    ) -> core.Learner:
        del environment_spec, replay_client

        policy_optimizer = optax.adam(self._config.learning_rate)
        critic_optimizer = optax.adam(self._config.learning_rate)

        if self._config.clipping:
            policy_optimizer = optax.chain(optax.clip_by_global_norm(40.),
                                           policy_optimizer)
            critic_optimizer = optax.chain(optax.clip_by_global_norm(40.),
                                           critic_optimizer)

        # The learner updates the parameters (and initializes them).
        return learning.D4PGLearner(
            policy_network=networks.policy_network,
            critic_network=networks.critic_network,
            random_key=random_key,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            clipping=self._config.clipping,
            discount=self._config.discount,
            target_update_period=self._config.target_update_period,
            iterator=dataset,
            counter=counter,
            logger=logger_fn('learner'),
            num_sgd_steps_per_step=self._config.num_sgd_steps_per_step)
Beispiel #2
0
  def make_learner(
      self,
      random_key: networks_lib.PRNGKey,
      networks: impala_networks.IMPALANetworks,
      dataset: Iterator[reverb.ReplaySample],
      logger_fn: loggers.LoggerFactory,
      environment_spec: specs.EnvironmentSpec,
      replay_client: Optional[reverb.Client] = None,
      counter: Optional[counting.Counter] = None,
  ) -> core.Learner:
    del environment_spec, replay_client

    optimizer = optax.chain(
        optax.clip_by_global_norm(self._config.max_gradient_norm),
        optax.adam(
            self._config.learning_rate,
            b1=self._config.adam_momentum_decay,
            b2=self._config.adam_variance_decay,
            eps=self._config.adam_eps,
            eps_root=self._config.adam_eps_root))

    return learning.IMPALALearner(
        networks=networks,
        iterator=dataset,
        optimizer=optimizer,
        random_key=random_key,
        discount=self._config.discount,
        entropy_cost=self._config.entropy_cost,
        baseline_cost=self._config.baseline_cost,
        max_abs_reward=self._config.max_abs_reward,
        counter=counter,
        logger=logger_fn('learner'),
    )
Beispiel #3
0
def main():
    # Create the dataset.
    train_dataset, vocab_size = load(batch_size, sequence_length)
    # Set up the model, loss, and updater.
    forward_fn = build_forward_fn(vocab_size, d_model, num_heads, num_layers,
                                  dropout_rate)
    forward_fn = hk.transform(forward_fn)
    loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

    optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value),
                            optax.adam(learning_rate, b1=0.9, b2=0.99))

    updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)

    # Initialize parameters.
    logging.info('Initializing parameters...')
    rng = jax.random.PRNGKey(428)
    data = next(train_dataset)
    state = updater.init(rng, data)

    logging.info('Starting train loop...')
    prev_time = time.time()
    for step in range(MAX_STEPS):
        data = next(train_dataset)
        state, metrics = updater.update(state, data)
Beispiel #4
0
def test_batch_overfit(train_dataset):
    vocab_size, d_model, num_heads, num_layers = 100, 32, 8, 1
    dropout_rate, grad_clip_value, learning_rate = 0.01, 0.25, 2e-2
    max_iter = 100

    # Set up the model, loss, and updater.
    forward_fn = build_forward_fn(vocab_size, d_model, num_heads, num_layers,
                                  dropout_rate, max_iter)

    forward_fn = hk.transform(forward_fn)
    loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

    optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value),
                            optax.adam(learning_rate, b1=0.9, b2=0.99))

    updater = Updater(forward_fn.init, loss_fn, optimizer)

    # Initialize parameters.
    rng = jax.random.PRNGKey(428)
    data = next(train_dataset)
    state = updater.init(rng, data)

    for step in range(100):
        data = next(train_dataset)
        state, metrics = updater.update(state, data)

    assert metrics['loss'] < 0.1
    assert metrics['step'] == 99
Beispiel #5
0
def main(_):
    # Create the dataset.
    train_dataset, vocab_size = dataset.load(FLAGS.batch_size,
                                             FLAGS.sequence_length)
    # Set up the model, loss, and updater.
    forward_fn = build_forward_fn(vocab_size, FLAGS.d_model, FLAGS.num_heads,
                                  FLAGS.num_layers, FLAGS.dropout_rate)
    forward_fn = hk.transform(forward_fn)
    loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

    optimizer = optax.chain(optax.clip_by_global_norm(FLAGS.grad_clip_value),
                            optax.adam(FLAGS.learning_rate, b1=0.9, b2=0.99))

    updater = Updater(forward_fn.init, loss_fn, optimizer)
    updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)

    # Initialize parameters.
    logging.info('Initializing parameters...')
    rng = jax.random.PRNGKey(428)
    data = next(train_dataset)
    state = updater.init(rng, data)

    logging.info('Starting train loop...')
    prev_time = time.time()
    for step in range(MAX_STEPS):
        data = next(train_dataset)
        state, metrics = updater.update(state, data)
        # We use JAX runahead to mask data preprocessing and JAX dispatch overheads.
        # Using values from state/metrics too often will block the runahead and can
        # cause these overheads to become more prominent.
        if step % LOG_EVERY == 0:
            steps_per_sec = LOG_EVERY / (time.time() - prev_time)
            prev_time = time.time()
            metrics.update({'steps_per_sec': steps_per_sec})
            logging.info({k: float(v) for k, v in metrics.items()})
Beispiel #6
0
    def make_learner(
        self,
        random_key: networks_lib.PRNGKey,
        networks: td3_networks.TD3Networks,
        dataset: Iterator[reverb.ReplaySample],
        replay_client: Optional[reverb.Client] = None,
        counter: Optional[counting.Counter] = None,
    ) -> core.Learner:

        critic_optimizer = optax.adam(self._config.critic_learning_rate)
        twin_critic_optimizer = optax.adam(self._config.critic_learning_rate)
        policy_optimizer = optax.adam(self._config.policy_learning_rate)

        if self._config.policy_gradient_clipping is not None:
            policy_optimizer = optax.chain(
                optax.clip_by_global_norm(
                    self._config.policy_gradient_clipping), policy_optimizer)

        return learning.TD3Learner(
            networks=networks,
            random_key=random_key,
            discount=self._config.discount,
            target_sigma=self._config.target_sigma,
            noise_clip=self._config.noise_clip,
            policy_optimizer=policy_optimizer,
            critic_optimizer=critic_optimizer,
            twin_critic_optimizer=twin_critic_optimizer,
            num_sgd_steps_per_step=self._config.num_sgd_steps_per_step,
            bc_alpha=self._config.bc_alpha,
            iterator=dataset,
            logger=self._logger_fn(),
            counter=counter)
Beispiel #7
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: networks_lib.FeedForwardNetwork,
        config: DQNConfig,
    ):
        """Initialize the agent."""
        # Data is communicated via reverb replay.
        reverb_replay = replay.make_reverb_prioritized_nstep_replay(
            environment_spec=environment_spec,
            n_step=config.n_step,
            batch_size=config.batch_size,
            max_replay_size=config.max_replay_size,
            min_replay_size=config.min_replay_size,
            priority_exponent=config.priority_exponent,
            discount=config.discount,
        )
        self._server = reverb_replay.server

        optimizer = optax.chain(
            optax.clip_by_global_norm(config.max_gradient_norm),
            optax.adam(config.learning_rate),
        )
        key_learner, key_actor = jax.random.split(
            jax.random.PRNGKey(config.seed))
        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            obs_spec=environment_spec.observations,
            random_key=key_learner,
            optimizer=optimizer,
            discount=config.discount,
            importance_sampling_exponent=config.importance_sampling_exponent,
            target_update_period=config.target_update_period,
            iterator=reverb_replay.data_iterator,
            replay_client=reverb_replay.client,
        )

        # The actor selects actions according to the policy.
        def policy(params: networks_lib.Params, key: jnp.ndarray,
                   observation: jnp.ndarray) -> jnp.ndarray:
            action_values = network.apply(params, observation)
            return rlax.epsilon_greedy(config.epsilon).sample(
                key, action_values)

        actor = actors.FeedForwardActor(
            policy=policy,
            rng=hk.PRNGSequence(key_actor),
            variable_client=variable_utils.VariableClient(learner, ''),
            adder=reverb_replay.adder)

        super().__init__(
            actor=actor,
            learner=learner,
            min_observations=max(config.batch_size, config.min_replay_size),
            observations_per_step=config.batch_size /
            config.samples_per_insert,
        )
Beispiel #8
0
def train(*, data_folder, batch_size, epochs, learning_rate, weight_decay,
          seed, max_norm, text_vocab, text_dim, text_depth, text_heads,
          audio_dim, audio_depth, audio_heads):
    # rng

    rng_key = random.PRNGKey(seed)

    # data

    dataset = PairTextSpectrogramDataset(data_folder)
    dl = DataLoader(dataset,
                    batch_size=batch_size,
                    collate_fn=pair_text_spectrogram_dataset_collate_fn,
                    drop_last=True,
                    shuffle=True)

    # model

    model = CLAP(text_vocab=text_vocab,
                 text_dim=text_dim,
                 text_depth=text_depth,
                 text_heads=text_heads,
                 audio_dim=audio_dim,
                 audio_depth=audio_depth,
                 audio_heads=audio_heads)

    # optimizer

    exclude_bias = lambda params: tree_util.tree_map(lambda x: x.ndim != 1,
                                                     params)

    optim = chain(clip_by_global_norm(max_norm), scale_by_adam(eps=1e-4),
                  add_decayed_weights(weight_decay, exclude_bias),
                  scale(-learning_rate))

    # init

    audio, audio_mask, text, text_mask = next(iter(dl))

    params = model.init(rng_key, text, audio, text_mask, audio_mask)
    optim_state = optim.init(params)

    # loss function, for use with value_and_grad

    @jit
    @value_and_grad
    def loss_fn(params, text, audio, text_mask, audio_mask):
        return model.apply(params, text, audio, text_mask, audio_mask)

    # train loop

    for _ in range(epochs):
        for audio, audio_mask, text, text_mask in dl:
            loss, grads = loss_fn(params, text, audio, text_mask, audio_mask)
            updates, optim_state = optim.update(grads, optim_state, params)
            params = apply_updates(params, updates)
            print(f'loss: {loss}')
Beispiel #9
0
    def make_learner(
        self,
        random_key: networks_lib.PRNGKey,
        networks: ppo_networks.PPONetworks,
        dataset: Iterator[reverb.ReplaySample],
        logger_fn: loggers.LoggerFactory,
        environment_spec: specs.EnvironmentSpec,
        replay_client: Optional[reverb.Client] = None,
        counter: Optional[counting.Counter] = None,
    ) -> core.Learner:
        del environment_spec, replay_client

        if callable(self._config.learning_rate):
            optimizer = optax.chain(
                optax.clip_by_global_norm(self._config.max_gradient_norm),
                optax.scale_by_adam(eps=self._config.adam_epsilon),
                optax.scale_by_schedule(self._config.learning_rate),
                optax.scale(-1))
        else:
            optimizer = optax.chain(
                optax.clip_by_global_norm(self._config.max_gradient_norm),
                optax.scale_by_adam(eps=self._config.adam_epsilon),
                optax.scale(-self._config.learning_rate))

        return learning.PPOLearner(
            ppo_networks=networks,
            iterator=dataset,
            discount=self._config.discount,
            entropy_cost=self._config.entropy_cost,
            value_cost=self._config.value_cost,
            max_abs_reward=self._config.max_abs_reward,
            ppo_clipping_epsilon=self._config.ppo_clipping_epsilon,
            clip_value=self._config.clip_value,
            gae_lambda=self._config.gae_lambda,
            counter=counter,
            random_key=random_key,
            optimizer=optimizer,
            num_epochs=self._config.num_epochs,
            num_minibatches=self._config.num_minibatches,
            logger=logger_fn('learner'),
        )
Beispiel #10
0
    def optimizer(self, learning_rate):
        """Construct optimizer."""
        clip = optax.clip_by_global_norm(
            self.config.optimizer.gradient_clip_norm)
        optimizer = getattr(optax, self.config.optimizer.name)(
            learning_rate,
            **self.config.optimizer.args,
        )
        optim_step = optax.chain(clip, optimizer)
        optim_step = optimizers.maybe_skip_gradient_update(
            optim_step,
            self.config.optimizer.gradient_skip_norm,
        )

        return optim_step
def create_train_state(rng, model, img_size, lr_schedule_fn, weight_decay,
                       max_norm):

    tx = optax.chain(optax.clip_by_global_norm(max_norm),
                     optax.scale_by_adam(),
                     optax.additive_weight_decay(weight_decay),
                     optax.scale_by_schedule(lr_schedule_fn))

    params = model.init(rng,
                        jax.numpy.ones((1, img_size, img_size, 3)),
                        is_training=False)

    train_state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx,
    )
    return train_state
Beispiel #12
0
def make_optimizer(optimizer_config, lr_schedule):
  """Construct the optax optimizer with given LR schedule."""
  if (optimizer_config.get('decay_pos_embs') is None or
      optimizer_config.decay_pos_embs):
    # Decay learned position embeddings by default.
    weight_decay_exclude_names = ['b']
  else:
    weight_decay_exclude_names = ['pos_embs', 'b']

  optax_chain = []
  if optimizer_config.max_norm > 0:
    optax_chain.append(
        optax.clip_by_global_norm(optimizer_config.max_norm))

  if optimizer_config.optimizer == 'adam':
    # See: https://arxiv.org/abs/1412.6980
    optax_chain.extend([
        optax.scale_by_adam(**optimizer_config.adam_kwargs),
        add_weight_decay(
            optimizer_config.weight_decay,
            exclude_names=weight_decay_exclude_names)
    ])
  elif optimizer_config.optimizer == 'lamb':
    # See: https://arxiv.org/abs/1904.00962
    optax_chain.extend([
        optax.scale_by_adam(**optimizer_config.lamb_kwargs),
        add_weight_decay(
            optimizer_config.weight_decay,
            exclude_names=weight_decay_exclude_names),
        optax.scale_by_trust_ratio()
    ])
  else:
    raise ValueError(f'Undefined optimizer {optimizer_config.optimizer}')

  # Scale by the (negative) learning rate.
  optax_chain.extend([
      optax.scale_by_schedule(lr_schedule),
      optax.scale(-1),
  ])

  return optax.chain(*optax_chain)
Beispiel #13
0
def main():
  parser = ArgumentParser()
  parser.add_argument('-b', '--batch-size', default=32, type=int)
  parser.add_argument('-d', '--rnn-hidden-size', default=256, type=int)
  parser.add_argument('-f', '--data-file', default='/tmp/convex_hull.dat', type=str)
  parser.add_argument('-l', '--lr', default=1e-3, type=float)
  parser.add_argument('-r', '--resume-training', default=False, action='store_true')
  parser.add_argument('-t', '--training-steps', default=100_000, type=int)
  parser.add_argument('-w', '--use-wandb', default=False, action='store_true')
  parser.add_argument('-wd', '--wd', default=1e-2, type=float)
  
  hparams = parser.parse_args()
  if hparams.use_wandb:
    wandb.init(project='pointer-networks', dir='/tmp')
    wandb.config.update(hparams)

  print(hparams)

  dataloader = ConvexHullDataLoader(data_filepath=hparams.data_file)
  loss_fn = partial(_loss_fn, hparams=hparams)
  optimizer = optax.chain(optax.adamw(hparams.lr, weight_decay=hparams.wd), optax.clip_by_global_norm(10.))
  train_iter = dataloader.data_iter(hparams.batch_size, 'train')
  val_iter = dataloader.data_iter(hparams.batch_size, 'val')
  wandb_obj = wandb if hparams.use_wandb else None

  trainer = Trainer(
      train_loss_fn=loss_fn,
      train_data_iter=train_iter,
      val_loss_fn=loss_fn,
      val_data_iter=val_iter,
      optimizer=optimizer,
      wandb=wandb_obj,
      resume=hparams.resume_training
  )

  plot_att_fn = partial(_plot_attention, hparams=hparams, wandb=wandb_obj)
  trainer.register_callback(1000, plot_att_fn)

  trainer.fit(total_steps=hparams.training_steps)
Beispiel #14
0
  def test_graph_network_learning(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    R_key, dr0_key, params_key = random.split(key, 3)

    d, _ = space.free()

    R = random.uniform(R_key, (6, 3, spatial_dimension), dtype=dtype)
    dr0 = random.uniform(dr0_key, (6, 3, 3), dtype=dtype)
    E_gt = vmap(
      lambda R, dr0: \
      np.sum((space.distance(space.map_product(d)(R, R)) - dr0) ** 2))

    cutoff = 0.2

    init_fn, energy_fn = energy.graph_network(d, cutoff)
    params = init_fn(params_key, R[0])

    @jit
    def loss(params, R):
      return np.mean((vmap(energy_fn, (None, 0))(params, R) - E_gt(R, dr0)) ** 2)

    opt = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-4))


    @jit
    def update(params, opt_state, R):
      updates, opt_state = opt.update(grad(loss)(params, R),
                                      opt_state)
      return optax.apply_updates(params, updates), opt_state

    opt_state = opt.init(params)

    l0 = loss(params, R)
    for i in range(4):
      params, opt_state = update(params, opt_state, R)

    assert loss(params, R) < l0 * 0.95
Beispiel #15
0
def configure_update_step(learning_rate: float, loss: Callable):
    """Configure an optax training update step."""

    opt = optax.chain(optax.clip_by_global_norm(1.0),
                      optax.adam(learning_rate))

    @jit
    def _update_step(params, opt_state, positions, labels):
        updates, opt_state = opt.update(
            grad(loss)(params, positions, labels), opt_state)
        return optax.apply_updates(params, updates), opt_state

    @jit
    def update_step(params_and_opt_state, batches):
        def inner_update(params_and_opt_state, batch):
            params, opt_state = params_and_opt_state
            b_xs, b_labels = batch

            return _update_step(params, opt_state, b_xs, b_labels), 0

        return lax.scan(inner_update, params_and_opt_state, batches)[0]

    return update_step, opt
Beispiel #16
0
def create_train_state(config, rng, learning_rate_fn, example_batch):
    """Create and initialize the model.

  Args:
    config: Configuration for model.
    rng: JAX PRNG Key.
    learning_rate_fn: learning rate function
    example_batch: for model intialization

  Returns:
    The initialized TrainState with the optimizer.
  """
    model, variables = create_model(config, rng, example_batch)
    params = variables['params']
    parameter_overview.log_parameter_overview(params)

    optimizer = optax.adamw(learning_rate=learning_rate_fn,
                            b1=0.9,
                            b2=.98,
                            eps=1e-9,
                            weight_decay=config.train.weight_decay)

    if config.train.grad_max_norm > 0:
        tx = optax.chain(optax.clip_by_global_norm(config.train.grad_max_norm),
                         optimizer)
    elif config.train.grad_max_val > 1:
        tx = optax.chain(optax.clip(config.train.grad_max_val), optimizer)
    else:
        tx = optimizer

    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=variables['params'],
        tx=tx,
    )
    return model, state
Beispiel #17
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        forward_fn: impala_types.PolicyValueFn,
        unroll_init_fn: impala_types.PolicyValueInitFn,
        unroll_fn: impala_types.PolicyValueFn,
        initial_state_init_fn: impala_types.RecurrentStateInitFn,
        initial_state_fn: impala_types.RecurrentStateFn,
        config: impala_config.IMPALAConfig,
        counter: Optional[counting.Counter] = None,
        logger: Optional[loggers.Logger] = None,
    ):
        networks = impala_networks.IMPALANetworks(
            forward_fn=forward_fn,
            unroll_init_fn=unroll_init_fn,
            unroll_fn=unroll_fn,
            initial_state_init_fn=initial_state_init_fn,
            initial_state_fn=initial_state_fn,
        )

        self._config = config

        # Data is handled by the reverb replay queue.
        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')

        key, key_initial_state = jax.random.split(
            jax.random.PRNGKey(self._config.seed))
        params = initial_state_init_fn(key_initial_state)
        extra_spec = {
            'core_state': initial_state_fn(params),
            'logits': np.ones(shape=(num_actions, ), dtype=np.float32)
        }

        reverb_queue = replay.make_reverb_online_queue(
            environment_spec=environment_spec,
            extra_spec=extra_spec,
            max_queue_size=self._config.max_queue_size,
            sequence_length=self._config.sequence_length,
            sequence_period=self._config.sequence_period,
            batch_size=self._config.batch_size,
        )
        self._server = reverb_queue.server
        self._can_sample = reverb_queue.can_sample

        # Make the learner.
        optimizer = optax.chain(
            optax.clip_by_global_norm(self._config.max_gradient_norm),
            optax.adam(self._config.learning_rate),
        )
        key_learner, key_actor = jax.random.split(key)
        self._learner = learning.IMPALALearner(
            networks=networks,
            iterator=reverb_queue.data_iterator,
            random_key=key_learner,
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=self._config.discount,
            entropy_cost=self._config.entropy_cost,
            baseline_cost=self._config.baseline_cost,
            max_abs_reward=self._config.max_abs_reward,
        )

        # Make the actor.
        variable_client = variable_utils.VariableClient(self._learner,
                                                        key='policy')
        self._actor = acting.IMPALAActor(
            forward_fn=jax.jit(forward_fn, backend='cpu'),
            initial_state_init_fn=initial_state_init_fn,
            initial_state_fn=initial_state_fn,
            rng=hk.PRNGSequence(key_actor),
            adder=reverb_queue.adder,
            variable_client=variable_client,
        )
def train_and_evaluate(config, workdir, vocab_filepath):
    """Runs a training and evaluation loop.

  Args:
    config: Model and training configuration.
    workdir: Working directory for checkpoints and TensorBoard summaries. If
      this contains a checkpoint, training will be resumed from the latest
      checkpoint.
    vocab_filepath: Absolute path to SentencePiece vocab model.

  Raises:
    ValueError: If training or eval batch sizes won't fit number of hosts and
      devices, or config is underspecified.
  """
    # Update config before config validation.
    with config.unlocked():
        # Numeric floating point type to use for model computations.
        config.dtype = jnp.float32

    train_utils.validate_config(config)

    if jax.process_index() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "train"))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "eval"))
    else:
        train_summary_writer = None
        eval_summary_writer = None

    tokenizer = spm.SentencePieceProcessor()
    tokenizer.Load(vocab_filepath)
    tokenizer.SetEncodeExtraOptions("")
    # Note: [CLS] and [SEP] will be added by the data pipeline, not the tokenizer.

    with config.unlocked():
        config.vocab_size = tokenizer.GetPieceSize()
        config.pad_id = tokenizer.pad_id()

    config = ml_collections.FrozenConfigDict(config)

    model = models.PreTrainingModel(config=config)
    rng = random.PRNGKey(config.seed)
    rng, init_rng = random.split(rng)
    params = _init_params(model, init_rng, config)

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors="constant * linear_warmup * linear_decay",
        base_learning_rate=config.learning_rate,
        warmup_steps=config.num_warmup_steps,
        decay_steps=config.num_train_steps - config.num_warmup_steps,
    )

    tx = optax.adamw(learning_rate_fn,
                     b1=0.9,
                     b2=0.999,
                     eps=1e-6,
                     weight_decay=0.01)
    if config.clipped_grad_norm:
        tx = optax.chain(optax.clip_by_global_norm(config.clipped_grad_norm),
                         tx)

    # jit state creation to ensure arrays are created on same device as input
    # (i.e. CPU).
    state_cpu = jax.jit(
        functools.partial(FlaxTrainState.create,
                          apply_fn=model.apply,
                          params=params,
                          tx=tx))()

    # We access model params only via state.params
    del params

    if config.num_experts > 1:
        sharded_match_fn = core_utils.match_fn(r".*expert.*")
        not_sharded_match_fn = lambda name: not sharded_match_fn(name)
    else:
        sharded_match_fn = None
        not_sharded_match_fn = lambda name: True

    state, start_step = _restore_state_from_checkpoint(workdir, state_cpu,
                                                       sharded_match_fn,
                                                       not_sharded_match_fn,
                                                       config)
    train_ds, eval_ds = _init_train_and_eval_ds(tokenizer, config)
    train_iter = iter(train_ds)

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    rngs = random.split(rng, jax.local_device_count())

    loss_and_metrics_fn = functools.partial(
        _compute_loss_and_metrics,
        model=model,
        is_experts_model=config.num_experts > 1,
        auxiliary_loss_factor=config.auxiliary_loss_factor,
        router_z_loss_factor=config.router_z_loss_factor)
    train_step = functools.partial(
        train_utils.pmap_train_step,
        loss_and_metrics_fn=loss_and_metrics_fn,
        axis_name="batch",
        sharded_match_fn=sharded_match_fn,
        gradient_accum_steps=config.gradient_accum_steps)
    p_train_step = jax.pmap(train_step, axis_name="batch")

    eval_step = functools.partial(_compute_eval_stats, model=model)
    p_eval_step = jax.pmap(eval_step, axis_name="batch")

    seconds = 0.
    train_stats = []
    logging.info("Starting training loop.")
    logging.info("====================")

    for step in range(start_step, config.num_train_steps):
        with jax.profiler.StepTraceContext("train", step_num=step):
            train_batch = next(train_iter)
            train_batch = common_utils.shard(train_batch)

            tick = time.time()
            state, train_step_stats, rngs = p_train_step(state,
                                                         train_batch,
                                                         rng=rngs)
            if config.measure_step_speed:
                jax.tree_map(lambda opt: opt.block_until_ready(), state)
                tock = time.time()
                seconds += tock - tick

            train_stats.append(train_step_stats)

        if (step > 0 and config.save_checkpoints_steps
                and step % config.save_checkpoints_steps == 0):
            # We allow all hosts to potentially save checkpoints because some model
            # parameters are sharded across devices. Parameters replicated across
            # devices (i.e. not sharded) will only be checkpointed by host 0.
            unreplicated_state = jax.tree_map(
                np.array,
                core_utils.tree_unreplicate_by_name(state,
                                                    not_sharded_match_fn))
            checkpoints.save_checkpoint(workdir,
                                        unreplicated_state,
                                        sharded_match_fn,
                                        step,
                                        keep=config.checkpoints_to_keep)
            del unreplicated_state  # Only used for checkpointing.

        # Periodic metric handling.
        if step % config.eval_frequency != 0 and step > 0:
            continue

        logging.info("Gathering training metrics at step: %d", step)
        train_metrics = train_utils.collect_metrics(train_stats)
        train_summary = train_utils.compute_pretraining_metrics(train_metrics)
        train_summary["learning_rate"] = learning_rate_fn(step)
        if config.measure_step_speed:
            train_summary["steps_per_sec"] = (step - start_step + 1) / seconds

        if jax.process_index() == 0:
            assert train_summary_writer
            for key, val in train_summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        # Reset metric accumulation for next training evaluation cycle.
        train_stats = []

        logging.info("Gathering evaluation metrics at step: %d", step)

        eval_stats = []
        for _, eval_batch in zip(range(config.max_num_eval_steps), eval_ds):
            eval_batch = common_utils.shard(eval_batch)
            eval_stats.append(p_eval_step(state.params, eval_batch))
        eval_metrics = train_utils.collect_metrics(eval_stats)
        eval_summary = train_utils.compute_pretraining_metrics(
            eval_metrics, record_grad_norm=False)

        if jax.process_index() == 0:
            assert eval_summary_writer
            for key, val in eval_summary.items():
                eval_summary_writer.scalar(key, val, step)
            eval_summary_writer.flush()
Beispiel #19
0
def create_optimizer(config):
    """Creates the optimizer associated to a config."""
    ops = []

    # Gradient clipping either by norm `gradient_norm_clip` or by absolute value
    # `gradient_value_clip`.
    if "gradient_clip" in config:
        raise ValueError("'gradient_clip' is deprecated, please use "
                         "'gradient_norm_clip'.")
    assert not ("gradient_norm_clip" in config
                and "gradient_value_clip" in config), (
                    "Gradient clipping by norm and by value are exclusive.")

    if "gradient_norm_clip" in config:
        ops.append(optax.clip_by_global_norm(config.gradient_norm_clip))
    if "gradient_value_clip" in config:
        ops.append(optax.clip(config.gradient_value_clip))

    # Define the learning rate schedule.
    schedule_fn = utils.get_optax_schedule_fn(
        warmup_ratio=config.get("warmup_ratio", 0.),
        num_train_steps=config.num_train_steps,
        decay=config.get("learning_rate_step_decay", 1.0),
        decay_at_steps=config.get("learning_rate_decay_at_steps", []),
        cosine_decay_schedule=config.get("cosine_decay", False))

    schedule_ops = [optax.scale_by_schedule(schedule_fn)]

    # Scale some parameters matching a regex by a multiplier. Config field
    # `scaling_by_regex` is a list of pairs (regex: str, multiplier: float).
    scaling_by_regex = config.get("scaling_learning_rate_by_regex", [])
    for regex, multiplier in scaling_by_regex:
        logging.info(
            "Learning rate is scaled by %f for parameters matching '%s'",
            multiplier, regex)
        schedule_ops.append(utils.scale_selected_parameters(regex, multiplier))
    schedule_optimizer = optax.chain(*schedule_ops)

    if config.optimizer.lower() == "adam":
        optimizer = optax.adam(config.learning_rate)
        ops.append(optimizer)
        ops.append(schedule_optimizer)
    elif config.optimizer.lower() == "sgd":
        ops.append(schedule_optimizer)
        optimizer = optax.sgd(config.learning_rate, momentum=config.momentum)
        ops.append(optimizer)
    else:
        raise NotImplementedError("Invalid optimizer: {}".format(
            config.optimizer))

    if "weight_decay" in config and config.weight_decay > 0.:
        ops.append(
            utils.decoupled_weight_decay(decay=config.weight_decay,
                                         step_size_fn=schedule_fn))

    # Freeze parameters that match the given regexes (if any).
    freeze_weights_regexes = config.get("freeze_weights_regex", []) or []
    if isinstance(freeze_weights_regexes, str):
        freeze_weights_regexes = [freeze_weights_regexes]
    for reg in freeze_weights_regexes:
        ops.append(utils.freeze(reg))

    return optax.chain(*ops)
Beispiel #20
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        forward_fn: networks.PolicyValueRNN,
        unroll_fn: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], hk.LSTMState],
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        seed: int = 0,
        max_abs_reward: float = np.inf,
        max_gradient_norm: float = np.inf,
    ):

        # Data is handled by the reverb replay queue.
        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')
        extra_spec = {
            'core_state':
            hk.without_apply_rng(hk.transform(initial_state_fn)).apply(None),
            'logits':
            np.ones(shape=(num_actions, ), dtype=np.float32)
        }
        reverb_queue = replay.make_reverb_online_queue(
            environment_spec=environment_spec,
            extra_spec=extra_spec,
            max_queue_size=max_queue_size,
            sequence_length=sequence_length,
            sequence_period=sequence_period,
            batch_size=batch_size,
        )
        self._server = reverb_queue.server
        self._can_sample = reverb_queue.can_sample

        # Make the learner.
        optimizer = optax.chain(
            optax.clip_by_global_norm(max_gradient_norm),
            optax.adam(learning_rate),
        )
        key_learner, key_actor = jax.random.split(jax.random.PRNGKey(seed))
        self._learner = learning.IMPALALearner(
            obs_spec=environment_spec.observations,
            unroll_fn=unroll_fn,
            initial_state_fn=initial_state_fn,
            iterator=reverb_queue.data_iterator,
            random_key=key_learner,
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=discount,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_abs_reward=max_abs_reward,
        )

        # Make the actor.
        variable_client = variable_utils.VariableClient(self._learner,
                                                        key='policy')
        transformed = hk.without_apply_rng(hk.transform(forward_fn))
        self._actor = acting.IMPALAActor(
            forward_fn=jax.jit(transformed.apply, backend='cpu'),
            initial_state_fn=initial_state_fn,
            rng=hk.PRNGSequence(key_actor),
            adder=reverb_queue.adder,
            variable_client=variable_client,
        )
Beispiel #21
0
def main(argv):
  """Trains Rainbow agent on Atari."""
  del argv
  logging.info('Rainbow on Atari on %s.',
               jax.lib.xla_bridge.get_backend().platform)
  random_state = np.random.RandomState(FLAGS.seed)
  rng_key = jax.random.PRNGKey(
      random_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64))

  if FLAGS.results_csv_path:
    writer = parts.CsvWriter(FLAGS.results_csv_path)
  else:
    writer = parts.NullWriter()

  def environment_builder():
    """Creates Atari environment."""
    env = gym_atari.GymAtari(
        FLAGS.environment_name, seed=random_state.randint(1, 2**32))
    return gym_atari.RandomNoopsEnvironmentWrapper(
        env,
        min_noop_steps=1,
        max_noop_steps=30,
        seed=random_state.randint(1, 2**32),
    )

  env = environment_builder()

  logging.info('Environment: %s', FLAGS.environment_name)
  logging.info('Action spec: %s', env.action_spec())
  logging.info('Observation spec: %s', env.observation_spec())
  num_actions = env.action_spec().num_values
  support = jnp.linspace(-FLAGS.vmax, FLAGS.vmax, FLAGS.num_atoms)
  network_fn = networks.rainbow_atari_network(num_actions, support,
                                              FLAGS.noisy_weight_init)
  network = hk.transform(network_fn)

  def preprocessor_builder():
    return processors.atari(
        additional_discount=FLAGS.additional_discount,
        max_abs_reward=FLAGS.max_abs_reward,
        resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
        num_action_repeats=FLAGS.num_action_repeats,
        num_pooled_frames=2,
        zero_discount_on_life_loss=True,
        num_stacked_frames=FLAGS.num_stacked_frames,
        grayscaling=True,
    )

  # Create sample network input from sample preprocessor output.
  sample_processed_timestep = preprocessor_builder()(env.reset())
  sample_processed_timestep = typing.cast(dm_env.TimeStep,
                                          sample_processed_timestep)
  sample_network_input = sample_processed_timestep.observation
  chex.assert_shape(sample_network_input,
                    (FLAGS.environment_height, FLAGS.environment_width,
                     FLAGS.num_stacked_frames))

  # Note the t in the replay is not exactly aligned with the agent t.
  importance_sampling_exponent_schedule = parts.LinearSchedule(
      begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity),
      end_t=(FLAGS.num_iterations *
             int(FLAGS.num_train_frames / FLAGS.num_action_repeats)),
      begin_value=FLAGS.importance_sampling_exponent_begin_value,
      end_value=FLAGS.importance_sampling_exponent_end_value)

  if FLAGS.compress_state:

    def encoder(transition):
      return transition._replace(
          s_tm1=replay_lib.compress_array(transition.s_tm1),
          s_t=replay_lib.compress_array(transition.s_t))

    def decoder(transition):
      return transition._replace(
          s_tm1=replay_lib.uncompress_array(transition.s_tm1),
          s_t=replay_lib.uncompress_array(transition.s_t))
  else:
    encoder = None
    decoder = None

  replay_structure = replay_lib.Transition(
      s_tm1=None,
      a_tm1=None,
      r_t=None,
      discount_t=None,
      s_t=None,
  )

  transition_accumulator = replay_lib.NStepTransitionAccumulator(FLAGS.n_steps)
  replay = replay_lib.PrioritizedTransitionReplay(
      FLAGS.replay_capacity, replay_structure, FLAGS.priority_exponent,
      importance_sampling_exponent_schedule, FLAGS.uniform_sample_probability,
      FLAGS.normalize_weights, random_state, encoder, decoder)

  optimizer = optax.adam(
      learning_rate=FLAGS.learning_rate, eps=FLAGS.optimizer_epsilon)
  if FLAGS.max_global_grad_norm > 0:
    optimizer = optax.chain(
        optax.clip_by_global_norm(FLAGS.max_global_grad_norm), optimizer)

  train_rng_key, eval_rng_key = jax.random.split(rng_key)

  train_agent = agent.Rainbow(
      preprocessor=preprocessor_builder(),
      sample_network_input=sample_network_input,
      network=network,
      support=support,
      optimizer=optimizer,
      transition_accumulator=transition_accumulator,
      replay=replay,
      batch_size=FLAGS.batch_size,
      min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
      learn_period=FLAGS.learn_period,
      target_network_update_period=FLAGS.target_network_update_period,
      rng_key=train_rng_key,
  )
  eval_agent = parts.EpsilonGreedyActor(
      preprocessor=preprocessor_builder(),
      network=network,
      exploration_epsilon=0,
      rng_key=eval_rng_key,
  )

  # Set up checkpointing.
  checkpoint = parts.NullCheckpoint()

  state = checkpoint.state
  state.iteration = 0
  state.train_agent = train_agent
  state.eval_agent = eval_agent
  state.random_state = random_state
  state.writer = writer
  if checkpoint.can_be_restored():
    checkpoint.restore()

  while state.iteration <= FLAGS.num_iterations:
    # New environment for each iteration to allow for determinism if preempted.
    env = environment_builder()

    logging.info('Training iteration %d.', state.iteration)
    train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode)
    num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames
    train_seq_truncated = itertools.islice(train_seq, num_train_frames)
    train_trackers = parts.make_default_trackers(train_agent)
    train_stats = parts.generate_statistics(train_trackers, train_seq_truncated)

    logging.info('Evaluation iteration %d.', state.iteration)
    eval_agent.network_params = train_agent.online_params
    eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode)
    eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames)
    eval_trackers = parts.make_default_trackers(eval_agent)
    eval_stats = parts.generate_statistics(eval_trackers, eval_seq_truncated)

    # Logging and checkpointing.
    human_normalized_score = atari_data.get_human_normalized_score(
        FLAGS.environment_name, eval_stats['episode_return'])
    capped_human_normalized_score = np.amin([1., human_normalized_score])
    log_output = [
        ('iteration', state.iteration, '%3d'),
        ('frame', state.iteration * FLAGS.num_train_frames, '%5d'),
        ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'),
        ('train_episode_return', train_stats['episode_return'], '% 2.2f'),
        ('eval_num_episodes', eval_stats['num_episodes'], '%3d'),
        ('train_num_episodes', train_stats['num_episodes'], '%3d'),
        ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'),
        ('train_frame_rate', train_stats['step_rate'], '%4.0f'),
        ('train_state_value', train_stats['state_value'], '%.3f'),
        ('importance_sampling_exponent',
         train_agent.importance_sampling_exponent, '%.3f'),
        ('max_seen_priority', train_agent.max_seen_priority, '%.3f'),
        ('normalized_return', human_normalized_score, '%.3f'),
        ('capped_normalized_return', capped_human_normalized_score, '%.3f'),
        ('human_gap', 1. - capped_human_normalized_score, '%.3f'),
    ]
    log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output)
    logging.info(log_output_str)
    writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))
    state.iteration += 1
    checkpoint.save()

  writer.close()
def train_model(train_file_pattern,
                test_file_pattern,
                max_files_to_load=None,
                n_epochs=1000,
                time_index=9,
                learning_rate=1e-4,
                grad_clip=1.0,
                measurement_store_interval=1000,
                checkpoint_path=None):
    """Trains GraphModel using tensorflow.

  Args:
    train_file_pattern: pattern matching the files with the training data.
    test_file_pattern: pattern matching the files with the test data.
    max_files_to_load: the maximum number of train and test files to load.
      If None, all files will be loaded.
    n_epochs: the number of passes through the training dataset (epochs).
    time_index: the time index (0-9) of the target mobilities.
    learning_rate: the learning rate used by the optimizer.
    grad_clip: all gradients are clipped to the given value.
    measurement_store_interval: number of steps between storing objective values
      (loss and correlation).
    checkpoint_path: ignored by this implementation.
  """
    if checkpoint_path:
        logging.warning('The checkpoint_path argument is ignored.')
    random.seed(42)
    np.random.seed(42)
    # Loads train and test dataset.
    dataset_kwargs = dict(time_index=time_index,
                          max_files_to_load=max_files_to_load)
    logging.info('Load training data')
    training_data = load_data(train_file_pattern, **dataset_kwargs)
    logging.info('Load test data')
    test_data = load_data(test_file_pattern, **dataset_kwargs)
    logging.info('Finished loading data')

    network = hk.without_apply_rng(hk.transform(network_definition))
    params = network.init(jax.random.PRNGKey(42), training_data[0][0])

    opt_init, opt_update = optax.chain(optax.clip_by_global_norm(grad_clip),
                                       optax.scale_by_adam(0.9, 0.999, 1e-8),
                                       optax.scale(-learning_rate))
    opt_state = opt_init(params)

    network_apply = jax.jit(network.apply)

    @jax.jit
    def loss_fn(params, graph, targets, mask):
        decoded_nodes = network_apply(params, graph) * mask
        return (jnp.sum((decoded_nodes - targets)**2 * mask) / jnp.sum(mask))

    @jax.jit
    def update(params, opt_state, graph, targets, mask):
        loss, grads = jax.value_and_grad(loss_fn)(params, graph, targets, mask)
        updates, opt_state = opt_update(grads, opt_state)
        return optax.apply_updates(params, updates), opt_state, loss

    train_stats = []
    i = 0
    logging.info('Start training')
    for epoch in range(n_epochs):
        logging.info('Start epoch %r', epoch)
        random.shuffle(training_data)
        for graph, targets, mask in training_data:
            graph = apply_random_rotation(graph)
            params, opt_state, loss = update(params, opt_state, graph, targets,
                                             mask)
            train_stats.append(loss)

            if (i + 1) % measurement_store_interval == 0:
                logging.info('Start evaluation run')
                test_stats = []
                for test_graph, test_targets, test_mask in test_data:
                    predictions = network_apply(params, test_graph)
                    test_stats.append(
                        np.corrcoef(predictions[test_mask == 1],
                                    test_targets[test_mask == 1])[0, 1])
                logging.info('Train loss %r', np.mean(train_stats))
                logging.info('Test correlation %r', np.mean(test_stats))
                train_stats = []
            i += 1
Beispiel #23
0
  (mel1_hat, mel2_hat), new_aux = (net if is_training else val_net).apply(params, aux, rng, inputs)
  loss1 = (jnp.square(mel1_hat - mels) + jnp.square(mel2_hat - mels)) / 2
  loss2 = (jnp.abs(mel1_hat - mels) + jnp.abs(mel2_hat - mels)) / 2
  loss = jnp.mean((loss1 + loss2)/2, axis=-1)
  mask = (jnp.arange(0, L)[None, :] - 10) < (inputs.wav_lengths // (FLAGS.n_fft // 4))[:, None]
  loss = jnp.sum(loss * mask) / jnp.sum(mask)
  return (loss, new_aux) if is_training else (loss, new_aux, mel2_hat, mels)


train_loss_fn = partial(loss_fn, is_training=True)
val_loss_fn = jax.jit(partial(loss_fn, is_training=False))

loss_vag = jax.value_and_grad(train_loss_fn, has_aux=True)

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(FLAGS.learning_rate)
)


@jax.jit
def update(params, aux, rng, optim_state, inputs):
  rng, new_rng = jax.random.split(rng)
  (loss, new_aux), grads = loss_vag(params, aux, rng, inputs)
  updates, new_optim_state = optimizer.update(grads, optim_state, params)
  new_params = optax.apply_updates(updates, params)
  return loss, (new_params, new_aux, new_rng, new_optim_state)


def initial_state(batch):
  rng = jax.random.PRNGKey(42)
Beispiel #24
0
forward_fn = jax.jit(
    hk.transform_with_state(lambda x: DurationModel(is_training=False)
                            (x)).apply)


def predict_duration(params, aux, rng, x: DurationInput):
    d, _ = forward_fn(params, aux, rng, x)
    return d, x.durations


val_loss_fn = jax.jit(partial(loss_fn, is_training=False))

loss_vag = jax.value_and_grad(loss_fn, has_aux=True)

optimizer = optax.chain(optax.clip_by_global_norm(FLAGS.max_grad_norm),
                        optax.adam(FLAGS.learning_rate))


@jax.jit
def update(params, aux, rng, optim_state, inputs: DurationInput):
    rng, new_rng = jax.random.split(rng)
    (loss, new_aux), grads = loss_vag(params, aux, rng, inputs)
    updates, new_optim_state = optimizer.update(grads, optim_state, params)
    new_params = optax.apply_updates(params, updates)
    return loss, (new_params, new_aux, new_rng, new_optim_state)


def initial_state(batch):
    rng = jax.random.PRNGKey(42)
    params, aux = hk.transform_with_state(lambda x: DurationModel(True)
Beispiel #25
0
def main(argv):
    """
    Train pick-up and drop-off Rainbow agents on ODySSEUS.
    """
    del argv # Unused arguments

    # Metadata configuration
    parent_dir = pathlib.Path(__file__).parent.absolute()

    sim_input_conf_dir = parent_dir / 'configs' / DEFAULT_sim_scenario_name

    # Load configuration
    sim_conf = importlib.import_module('esbdqn.configs.{}.{}'
                                       .format(DEFAULT_sim_scenario_name,
                                               FLAGS.conf_filename))

    # Extract a single conf pair
    sim_general_conf  = EFFCS_SimConfGrid(sim_conf.General)       \
                                          .conf_list[0]
    sim_scenario_conf = EFFCS_SimConfGrid(sim_conf.Multiple_runs) \
                                          .conf_list[0]

    experiment_dir = parent_dir                     \
                        / 'experiments'             \
                        / DEFAULT_sim_scenario_name \
                        / FLAGS.exp_name            \
                        / sim_general_conf['city']

    if pathlib.Path.exists(experiment_dir):
        # Ensure configuration has not changed
        if not filecmp.cmp(str(sim_input_conf_dir
                               / FLAGS.conf_filename)   + '.py',
                           str(experiment_dir
                               / DEFAULT_conf_filename) + ".py",
                           shallow=False):
            raise IOError('Configuration changed at: {}'
                          .format(str(experiment_dir)))
    else:
        pathlib.Path.mkdir(experiment_dir, parents=True,
                           exist_ok=True)

        # Copy configuration files
        shutil.rmtree(experiment_dir)
        shutil.copytree(sim_input_conf_dir, experiment_dir)

        # Rename to the default name
        conf_filepath = experiment_dir / (FLAGS.conf_filename + ".py")
        conf_filepath.rename(experiment_dir
                             / (DEFAULT_conf_filename + ".py"))

        # Delete all other potential conf files
        for filename in experiment_dir.glob(
                DEFAULT_conf_filename + "_*.py"):
            filename.unlink()

    # Create results files
    results_dir = experiment_dir / 'results'

    pathlib.Path.mkdir(results_dir, parents=True,
                       exist_ok=True)

    results_filepath = results_dir / DEFAULT_resu_filename

    logging.info('Rainbow agents on ODySSEUS running on %s.',
                 jax.lib.xla_bridge.get_backend().platform.upper())

    if FLAGS.checkpoint:
        checkpoint = PickleCheckpoint(
            experiment_dir / 'models',
            'ODySSEUS-' + sim_general_conf['city'])
    else:
        checkpoint = parts.NullCheckpoint()

    checkpoint_restored = False

    if FLAGS.checkpoint:
        if checkpoint.can_be_restored():
            logging.info('Restoring checkpoint...')

            checkpoint.restore()
            checkpoint_restored = True

    # Generate RNG key
    rng_state = np.random.RandomState(FLAGS.seed)

    if checkpoint_restored:
        rng_state.set_state(checkpoint.state
                                      .rng_state)

    rng_key   = jax.random.PRNGKey(
        rng_state.randint(-sys.maxsize - 1,
                          sys.maxsize + 1,
                          dtype=np.int64))

    # Generate results file writer
    if sim_general_conf['save_history']:
        writer = parts.CsvWriter(str(results_filepath))

        if checkpoint_restored:
            writer.set_state(checkpoint.state
                                       .writer)
    else:
        writer = parts.NullWriter()

    def environment_builder() -> ConstrainedEnvironment:
        """
        Create the ODySSEUS environment.
        """
        return EscooterSimulator(
                        (sim_general_conf,
                         sim_scenario_conf),
                    FLAGS.n_lives)

    def preprocessor_builder():
        """
        Create the ODySSEUS input preprocessor.
        """
        return processor(
            max_abs_reward=FLAGS.max_abs_reward,
            zero_discount_on_life_loss=True
        )

    env = environment_builder()

    logging.info('Environment: %s', FLAGS.exp_name)
    logging.info('Action spec: %s', env.action_spec())
    logging.info('Observation spec: %s', env.observation_spec())

    # Take [0] as both Rainbow have
    # the same number of actions
    num_actions = env.action_spec()[0].num_values
    support = jnp.linspace(-FLAGS.vmax, FLAGS.vmax,
                           FLAGS.num_atoms)

    network = hk.transform(rainbow_odysseus_network(
                           num_actions, support,
                           FLAGS.noisy_weight_init))

    # Create sample network input from reset.
    sample_processed_timestep = preprocessor_builder()(env.reset())
    sample_processed_timestep = t.cast(dm_env.TimeStep,
                                       sample_processed_timestep)

    sample_processed_network_input = sample_processed_timestep.observation

    # Note the t in the replay is not exactly
    # aligned with the Rainbow agents t.
    importance_sampling_exponent_schedule = parts.LinearSchedule(
        begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity),
        end_t=(FLAGS.num_iterations * FLAGS.num_train_frames),
        begin_value=FLAGS.importance_sampling_exponent_begin_value,
        end_value=FLAGS.importance_sampling_exponent_end_value)

    if FLAGS.compress_state:
        def encoder(transition):
            return transition._replace(
                s_tm1=replay.compress_array(transition.s_tm1),
                s_t=replay.compress_array(transition.s_t))

        def decoder(transition):
            return transition._replace(
                s_tm1=replay.uncompress_array(transition.s_tm1),
                s_t=replay.uncompress_array(transition.s_t))
    else:
        encoder = None
        decoder = None

    replay_struct = replay.Transition(
        s_tm1=None,
        a_tm1=None,
        r_t=None,
        discount_t=None,
        s_t=None,
    )

    transition_accumulator = replay.NStepTransitionAccumulator(FLAGS.n_steps)

    transition_replay = replay.PrioritizedTransitionReplay(
        FLAGS.replay_capacity, replay_struct,
        FLAGS.priority_exponent,
        importance_sampling_exponent_schedule,
        FLAGS.uniform_sample_probability,
        FLAGS.normalize_weights,
        rng_state, encoder, decoder)

    optimizer = optax.adam(
        learning_rate=FLAGS.learning_rate,
        eps=FLAGS.optimizer_epsilon)

    if FLAGS.max_global_grad_norm > 0:
        optimizer = optax.chain(
            optax.clip_by_global_norm(
                FLAGS.max_global_grad_norm),
            optimizer)

    train_rng_key, eval_rng_key = jax.random.split(rng_key)

    # Create pick-up/drop-off agents
    P_train_agent = agent.Rainbow(
        preprocessor=preprocessor_builder(),
        sample_network_input=copy.deepcopy(sample_processed_network_input),
        network=copy.deepcopy(network),
        support=copy.deepcopy(support),
        optimizer=copy.deepcopy(optimizer),
        transition_accumulator=copy.deepcopy(transition_accumulator),
        replay=copy.deepcopy(transition_replay),
        batch_size=FLAGS.batch_size,
        min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
        learn_period=FLAGS.learn_period,
        target_network_update_period=FLAGS.target_network_update_period,
        rng_key=train_rng_key,
    )

    D_train_agent = agent.Rainbow(
        preprocessor=preprocessor_builder(),
        sample_network_input=copy.deepcopy(sample_processed_network_input),
        network=copy.deepcopy(network),
        support=copy.deepcopy(support),
        optimizer=copy.deepcopy(optimizer),
        transition_accumulator=copy.deepcopy(transition_accumulator),
        replay=copy.deepcopy(transition_replay),
        batch_size=FLAGS.batch_size,
        min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
        learn_period=FLAGS.learn_period,
        target_network_update_period=FLAGS.target_network_update_period,
        rng_key=train_rng_key,
    )

    P_eval_agent = parts.EpsilonGreedyActor(
        preprocessor=preprocessor_builder(),
        network=copy.deepcopy(network),
        exploration_epsilon=0,
        rng_key=eval_rng_key,
    )

    D_eval_agent = parts.EpsilonGreedyActor(
        preprocessor=preprocessor_builder(),
        network=copy.deepcopy(network),
        exploration_epsilon=0,
        rng_key=eval_rng_key,
    )

    if checkpoint_restored:
        P_train_agent.set_state(checkpoint.state.P_agent['train'])
        D_train_agent.set_state(checkpoint.state.D_agent['train'])

        P_eval_agent.set_state(checkpoint.state.P_agent['eval'])
        D_eval_agent.set_state(checkpoint.state.D_agent['eval'])

    state = checkpoint.state

    if not checkpoint_restored:
        state.iteration = 0

    state.P_agent = {}
    state.D_agent = {}

    state.rng_state = rng_state
    state.writer = writer

    state.P_agent['train'] = P_train_agent
    state.D_agent['train'] = D_train_agent

    state.P_agent['eval'] = P_eval_agent
    state.D_agent['eval'] = D_eval_agent

    while state.iteration < FLAGS.num_iterations:
        # Create a new environment at each new iteration
        # to allow for determinism if preempted.
        env = environment_builder()

        # Leave some spacing
        print('\n')

        logging.info('Training iteration: %d', state.iteration)

        train_trackers = make_odysseus_trackers(FLAGS.max_abs_reward)
        eval_trackers  = make_odysseus_trackers(FLAGS.max_abs_reward)

        train_seq = run_loop(P_train_agent, D_train_agent,
                             env, FLAGS.max_steps_per_episode)

        num_train_frames = 0        \
            if state.iteration == 0 \
            else FLAGS.num_train_frames

        train_seq_truncated = it.islice(train_seq, num_train_frames)

        train_stats = generate_statistics(train_trackers,
                                          train_seq_truncated)

        logging.info('Evaluation iteration: %d', state.iteration)

        # Synchronize network parameters
        P_eval_agent.network_params = P_train_agent.online_params
        D_eval_agent.network_params = P_train_agent.online_params

        eval_seq = run_loop(P_eval_agent, D_eval_agent,
                            env, FLAGS.max_steps_per_episode)

        eval_seq_truncated = it.islice(eval_seq, FLAGS.num_eval_frames)

        eval_stats = generate_statistics(eval_trackers,
                                         eval_seq_truncated)

        # Logging and checkpointing
        L = [
            # Simulation metadata
            ('iteration', state.iteration, '%3d'),

            # ODySSEUS metadata
            ('n_charging_workers', sim_scenario_conf['n_workers'], '%3d'),
            ('n_relocation_workers', sim_scenario_conf['n_relocation_workers'], '%3d'),
            ('n_vehicles', sim_scenario_conf['n_vehicles'], '%3d'),
            ('pct_incentive_willingness', sim_scenario_conf['incentive_willingness'], '%2.2f'),
            ('zone_side_m', sim_general_conf['bin_side_length'], '%3d'),

            # Validation agents
            ('eval_num_episodes', eval_stats['num_episodes'], '%3d'),

            ('eval_P_episode_return', eval_stats['episode_return'][0], '%2.2f'),
            ('eval_D_episode_return', eval_stats['episode_return'][1], '%2.2f'),

            ('eval_min_n_accepted_incentives',
             np.min(eval_stats['episodes_n_accepted_incentives']), '%2.2f'),
            ('eval_avg_n_accepted_incentives',
             np.mean(eval_stats['episodes_n_accepted_incentives']), '%2.2f'),
            ('eval_max_n_accepted_incentives',
             np.max(eval_stats['episodes_n_accepted_incentives']), '%2.2f'),

            ('eval_min_n_lives',
             np.min(eval_stats['episodes_n_lives']), '%2.2f'),
            ('eval_avg_n_lives',
             np.mean(eval_stats['episodes_n_lives']), '%2.2f'),
            ('eval_max_n_lives',
             np.max(eval_stats['episodes_n_lives']), '%2.2f'),

            ('eval_min_pct_satisfied_demand',
             np.min(eval_stats['pct_satisfied_demands']), '%2.2f'),
            ('eval_avg_pct_satisfied_demand',
             np.mean(eval_stats['pct_satisfied_demands']), '%2.2f'),
            ('eval_max_pct_satisfied_demand',
             np.max(eval_stats['pct_satisfied_demands']), '%2.2f'),

            # Training agents
            ('train_num_episodes', train_stats['num_episodes'], '%3d'),

            ('train_P_episode_return', train_stats['episode_return'][0], '%2.2f'),
            ('train_D_episode_return', train_stats['episode_return'][1], '%2.2f'),

            ('train_min_n_accepted_incentives',
             np.min(train_stats['episodes_n_accepted_incentives']), '%2.2f'),
            ('train_avg_n_accepted_incentives',
             np.mean(train_stats['episodes_n_accepted_incentives']), '%2.2f'),
            ('train_max_n_accepted_incentives',
             np.max(train_stats['episodes_n_accepted_incentives']), '%2.2f'),

            ('train_min_n_lives',
             np.min(train_stats['episodes_n_lives']), '%2.2f'),
            ('train_avg_n_lives',
             np.mean(train_stats['episodes_n_lives']), '%2.2f'),
            ('train_mac_n_lives',
             np.max(train_stats['episodes_n_lives']), '%2.2f'),

            ('train_min_pct_satisfied_demand',
             np.min(train_stats['pct_satisfied_demands']), '%2.2f'),
            ('train_avg_pct_satisfied_demand',
             np.mean(train_stats['pct_satisfied_demands']), '%2.2f'),
            ('train_max_pct_satisfied_demand',
             np.max(train_stats['pct_satisfied_demands']), '%2.2f'),

            ('P_importance_sampling_exponent',
             P_train_agent.importance_sampling_exponent, '%.3f'),
            ('D_importance_sampling_exponent',
             D_train_agent.importance_sampling_exponent, '%.3f'),

            ('P_max_seen_priority', P_train_agent.max_seen_priority, '%.3f'),
            ('D_max_seen_priority', D_train_agent.max_seen_priority, '%.3f'),
        ]

        L_str = '\n'.join(('%s: ' + f) % (n, v) for n, v, f in L)

        logging.info(L_str)

        if state.iteration == \
                FLAGS.num_iterations - 1:
            print('\n')

        writer.write(collections.OrderedDict(
            (n, v) for n, v, _ in L))

        state.iteration += 1

        if state.iteration \
                % FLAGS.checkpoint_period == 0:
            checkpoint.save()

    writer.close()
Beispiel #26
0
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["JAX_DEBUG_NANS"] = "True"

from swarm_jax.swarm_layer import NetworkPrecision

from loader import TextLoader
from swarm_jax.model import SwarmCharTransformer
from swarm_jax.swarm import Swarm

import ray
import optax

ray.init(resources={"tpu": 999})  # pretend we have infinite tpus lol

train_dataset = TextLoader("data/enwik8",
                           batchsize=(1, 16),
                           sample_size=128,
                           length=90000000)

optimizer = optax.chain(optax.clip_by_global_norm(0.25),
                        optax.adam(2e-4, b1=0.9, b2=0.99, eps=1e-5))

prec = NetworkPrecision(fwd_act="uint16", rev_act="uint16", grad="uint16")

model = SwarmCharTransformer
swarm = Swarm(model, optimizer, 2**16, train_dataset.get_samples, prec)
swarm.run(100000, "runs/512_30L", "ckpt/512_30L")

ray.shutdown()
Beispiel #27
0
def train(config: ml_collections.ConfigDict):
  """Run training."""

  # Establish host information
  local_device_count = jax.local_device_count()
  host_count = jax.process_count()
  host_id = jax.process_index()

  task = task_registry.get_registered_task(config.task_name)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)

  model_config = ml_collections.FrozenConfigDict(config.model_config)
  model = task.build_model(model_config)

  # Initialization needs to be pmapped because models use collective ops.
  # Create dummy input
  dummy_input = {
      key: jnp.tile(value, (local_device_count,) + (1,) * value.ndim)
      for key, value in task.dummy_input(config).items()
  }

  rng, init_rng = jax.random.split(rng)
  init_rng = jax.random.split(init_rng, local_device_count)

  logging.info('Initializing model.')
  initial_variables = jax.pmap(
      model.init, 'batch', static_broadcasted_argnums=2)(init_rng, dummy_input,
                                                         True)
  logging.info('Finished initializing model.')
  initial_variables = initial_variables.unfreeze()

  if config.load_weights is not None:
    logging.info('Loading model weights from file')
    loaded_variables = task.load_weights(config)
    unexpected, missing = checkpoint_utils.merge_nested_dicts(
        initial_variables, loaded_variables)
    logging.info('*** Unexpected features: ***')
    for feature_name in unexpected:
      logging.info('\t%s', feature_name)
    logging.info('*** Missing features: ***')
    for feature_name in missing:
      logging.info('\t%s', feature_name)

  model_vars = {
      key: value for key, value in initial_variables.items() if key != 'params'
  }

  learning_rate_fn = optim_utils.create_learning_rate_scheduler(
      learning_rate=config.learning_rate,
      warmup=config.warmup,
      warmup_steps=config.get('warmup_steps', None),
      linear_decay=config.linear_decay,
      max_steps=config.num_train_steps,
      decay_minimum_factor=config.get('decay_minimum_factor', None),
  )

  if config.weight_decay_exclude is not None:
    decay_mask = optim_utils.create_dict_mask(initial_variables['params'],
                                              config.weight_decay_exclude)
  else:
    decay_mask = None
  tx = optax.adamw(
      learning_rate=learning_rate_fn,
      weight_decay=config.weight_decay,
      b1=0.9,
      b2=0.999,
      eps=1e-6,
      mask=decay_mask)
  if config.grad_clip is not None:
    tx = optax.chain(tx, optax.clip_by_global_norm(config.grad_clip))

  ignore_k_nans = config.get('ignore_k_nans')
  if ignore_k_nans is not None:
    tx = optax.apply_if_finite(tx, max_consecutive_errors=ignore_k_nans)

  loss_fn = task.make_loss_fn(config)
  train_state = ts.TrainState.create(
      apply_fn=loss_fn,
      params=jax_utils.unreplicate(initial_variables['params']),
      tx=tx,
  )

  # We access model params only from train state.
  del initial_variables

  # Restore unreplicated train state from last checkpoint
  train_state = checkpoints.restore_checkpoint(config.model_dir, train_state)
  # Grab last step.
  start_step = int(train_state.step)

  writer = metric_writers.create_default_writer(
      config.model_dir, just_logging=jax.process_index() > 0)
  if start_step == 0:
    writer.write_hparams(config.to_dict())

  dropout_rngs = jax.random.split(rng, local_device_count)

  del rng

  # Load datasets
  logging.info('Loading dataset.')

  # Make sure we don't re-use same data if we load weights or checkpoint
  seed = config.seed + start_step
  if config.load_weights:
    seed = seed + hash(config.load_weights)

  name_to_features = task.get_name_to_features(config)
  preprocess_fn = task.make_preprocess_fn(config)
  collater_fn = task.make_collater_fn(config)

  train_data = data_utils.load_multi_dataset(
      datasets_config=config.train_data,
      name_to_features=name_to_features,
      preprocess_fn=preprocess_fn,
      collater_fn=collater_fn,
      is_training=True,
      per_device_batch_size=config.per_device_batch_size,
      local_device_count=local_device_count,
      host_count=host_count,
      host_id=host_id,
      seed=config.seed,
  )
  train_iter = iter(train_data)

  pad_eval = config.get('pad_eval', False)
  if pad_eval:
    logging.info('Eval data is padded such that none of samples are dropped.')
  else:
    logging.warn('Eval data is NOT padded -- some samples might be dropped.')

  eval_data = data_utils.load_multi_dataset(
      datasets_config=config.eval_data,
      name_to_features=name_to_features,
      preprocess_fn=preprocess_fn,
      collater_fn=collater_fn,
      is_training=False,
      per_device_batch_size=config.per_device_batch_size,
      local_device_count=local_device_count,
      host_count=host_count,
      host_id=host_id,
      seed=config.seed,
      pad_eval=pad_eval,
  )
  eval_data = list(eval_data)
  logging.info('Loaded %d samples for evaluation.', len(eval_data))

  # Setup postprocessing_fn for saving samples occasionally.
  if config.get('save_samples_every_steps') is not None:
    if config.get('save_samples_every_steps') % config.eval_every_steps != 0:
      raise ValueError(
          '`eval_every_steps` must divide `save_samples_every_steps`.')
    postprocessing_fn = task.make_output_postprocess_fn(config)

  # Training loop
  logging.info('Starting training.')

  # Replicate train state.
  train_state = jax_utils.replicate(train_state)

  # compile multidevice versions of train/eval/predict step
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          model_config=model_config,
      ),
      axis_name='batch',
      donate_argnums=(0,),
  )  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step,
          model_config=model_config,
      ),
      axis_name='batch')

  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)

  if jax.process_index() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=config.model_dir, num_profile_steps=5)
    ]
  train_metrics = []
  with metric_writers.ensure_flushes(writer):
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and perform a training step.
      with jax.profiler.StepTraceAnnotation('train', step_num=step):
        batch = jax.tree_map(jnp.asarray, train_iter.get_next())
        train_state, metrics = p_train_step(
            train_state,
            model_vars,
            batch,
            dropout_rngs,
        )
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, 'Finished training step %d', 5, step)
      for h in hooks:
        h(step)

        # Periodic metric handling.
      if step % config.eval_every_steps == 0 or is_last_step:
        with report_progress.timed('training_metrics'):
          logging.info('Gathering training metrics.')
          train_metrics = common_utils.get_metrics(train_metrics)
          metrics_sums = jax.tree_map(jnp.sum, train_metrics)
          summary = metric_utils.process_metrics(metrics_sums, prefix='train')
          summary['learning_rate'] = learning_rate_fn(step)

          writer.write_scalars(step, summary)
          train_metrics = []

          with report_progress.timed('eval'):
            eval_results, eval_auxiliary = evaluate(
                eval_step_fn=p_eval_step,
                train_state=train_state,
                model_vars=model_vars,
                eval_data=eval_data,
            )
            writer.write_scalars(step, eval_results)

            if config.get('save_samples_every_steps') is not None:
              with report_progress.timed('save_samples'):
                if config.get('save_first_batch_only', 'True'):
                  postprocessing_input = [eval_auxiliary[0]]
                eval_processed = [
                    postprocessing_fn(batch, auxiliary_output)
                    for batch, auxiliary_output in eval_auxiliary
                ]
                data_utils.save_samples_to_json(eval_processed, config, step)

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (
          step % config.checkpoint_every_steps == 0 or is_last_step)
      if (config.save_checkpoints and save_checkpoint and
          jax.process_index() == 0):
        with report_progress.timed('checkpoint'):
          logging.info('Saving checkpoint at step %s', step)
          checkpoints.save_checkpoint(
              config.model_dir,
              jax_utils.unreplicate(train_state),
              step,
              keep=config.get('keep_checkpoints', 1),
              keep_every_n_steps=config.get('keep_checkpoint_every_steps'),
          )

      save_model = (
          config.save_every_steps and
          (step % config.save_every_steps == 0 or is_last_step) and step != 0)
      if (save_model and jax.process_index() == 0):
        with report_progress.timed('checkpoint'):
          logging.info('Saving weights at step %s', step)
          save_path = os.path.join(config.model_dir, 'weights',
                                   'step' + str(step))
          # By default, save only encoder weights
          weights = jax_utils.unreplicate(train_state).params['encoder']
          checkpoint_utils.save_weights(save_path, weights)
Beispiel #28
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        forward_fn: networks.PolicyValueRNN,
        unroll_fn: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], hk.LSTMState],
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        seed: int = 0,
        max_abs_reward: float = np.inf,
        max_gradient_norm: float = np.inf,
    ):

        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')

        extra_spec = {
            'core_state':
            hk.without_apply_rng(hk.transform(initial_state_fn,
                                              apply_rng=True)).apply(None),
            'logits':
            np.ones(shape=(num_actions, ), dtype=np.float32)
        }
        signature = adders.SequenceAdder.signature(environment_spec,
                                                   extra_spec)
        queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE,
                                   max_size=max_queue_size,
                                   signature=signature)
        self._server = reverb.Server([queue], port=None)
        self._can_sample = lambda: queue.can_sample(batch_size)
        address = f'localhost:{self._server.port}'

        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=sequence_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        # We don't use datasets.make_reverb_dataset() here to avoid interleaving
        # and prefetching, that doesn't work well with can_sample() check on update.
        dataset = reverb.ReplayDataset.from_table_signature(
            server_address=address,
            table=adders.DEFAULT_PRIORITY_TABLE,
            max_in_flight_samples_per_worker=1,
            sequence_length=sequence_length,
            emit_timesteps=False)
        dataset = dataset.batch(batch_size, drop_remainder=True)

        optimizer = optax.chain(
            optax.clip_by_global_norm(max_gradient_norm),
            optax.adam(learning_rate),
        )

        self._learner = learning.IMPALALearner(
            obs_spec=environment_spec.observations,
            unroll_fn=unroll_fn,
            initial_state_fn=initial_state_fn,
            iterator=dataset.as_numpy_iterator(),
            rng=hk.PRNGSequence(seed),
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=discount,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_abs_reward=max_abs_reward,
        )

        variable_client = variable_utils.VariableClient(self._learner,
                                                        key='policy')
        self._actor = acting.IMPALAActor(
            forward_fn=jax.jit(hk.without_apply_rng(
                hk.transform(forward_fn, apply_rng=True)).apply,
                               backend='cpu'),
            initial_state_fn=initial_state_fn,
            rng=hk.PRNGSequence(seed),
            adder=adder,
            variable_client=variable_client,
        )
Beispiel #29
0

forward_fn = jax.jit(hk.transform_with_state(lambda x: DurationModel(is_training=False)(x)).apply)


def predict_duration(params, aux, rng, x: DurationInput):
  d, _ = forward_fn(params, aux, rng, x)
  return d, x.durations


val_loss_fn = jax.jit(partial(loss_fn, is_training=False))

loss_vag = jax.value_and_grad(loss_fn, has_aux=True)

optimizer = optax.chain(
    optax.clip_by_global_norm(FLAGS.max_grad_norm),
    optax.adamw(FLAGS.duration_learning_rate, weight_decay=FLAGS.weight_decay)
)


@jax.jit
def update(params, aux, rng, optim_state, inputs: DurationInput):
  rng, new_rng = jax.random.split(rng)
  (loss, new_aux), grads = loss_vag(params, aux, rng, inputs)
  updates, new_optim_state = optimizer.update(grads, optim_state, params)
  new_params = optax.apply_updates(params, updates)
  return loss, (new_params, new_aux, new_rng, new_optim_state)


def initial_state(batch):
  rng = jax.random.PRNGKey(42)
Beispiel #30
0
def train_and_evaluate(config, workdir, vocab_filepath):
    """Runs a training and evaluation loop.

  Args:
    config: Model and training configuration.
    workdir: Working directory for checkpoints and TensorBoard summaries. If
      this contains a checkpoint, training will be resumed from the latest
      checkpoint.
    vocab_filepath: Absolute path to SentencePiece vocab model.

  Raises:
    ValueError: If training or eval batch sizes won't fit number of hosts and
      devices, or config is underspecified.
  """
    # Update config before config validation.
    with config.unlocked():
        # Numeric floating point type to use for model computations.
        config.dtype = jnp.float32

    train_utils.validate_config(config)

    per_host_train_batch_size = config.train_batch_size // jax.process_count()
    per_host_eval_batch_size = config.eval_batch_size // jax.process_count()

    if jax.process_index() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "train"))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "eval"))
    else:
        train_summary_writer = None
        eval_summary_writer = None

    tokenizer = spm.SentencePieceProcessor()
    tokenizer.Load(vocab_filepath)

    ds_info = tfds.builder(config.dataset_name).info
    num_train_examples = ds_info.splits[tfds.Split.TRAIN].num_examples

    num_train_steps = int(num_train_examples * config.num_train_epochs //
                          config.train_batch_size)
    num_warmup_steps = int(config.warmup_proportion * num_train_steps)
    # Round up evaluation frequency to power of 10.
    eval_frequency = int(
        math.ceil(config.eval_proportion * num_train_steps / 10)) * 10

    # STSB is a regression task. COPA and ReCoRD are treated as scalar/regression
    # tasks during training.
    is_regression_task = (config.dataset_name == "glue/stsb"
                          or config.dataset_name == "super_glue/copa"
                          or config.dataset_name == "super_glue/record")
    if is_regression_task:
        num_classes = 1
    else:
        num_classes = ds_info.features["label"].num_classes

    with config.unlocked():
        config.vocab_size = tokenizer.GetPieceSize()
        config.pad_id = tokenizer.pad_id()

    config = ml_collections.FrozenConfigDict(config)
    model = models.SequenceClassificationModel(config, num_classes)
    rng = random.PRNGKey(config.seed)
    rng, init_rng = random.split(rng)
    params = _init_params(model, init_rng, config)

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors="constant * linear_warmup * linear_decay",
        base_learning_rate=config.learning_rate,
        warmup_steps=num_warmup_steps,
        decay_steps=num_train_steps - num_warmup_steps,
    )

    tx = optax.adamw(learning_rate_fn,
                     b1=0.9,
                     b2=0.999,
                     eps=1e-6,
                     weight_decay=0.01)
    if config.clipped_grad_norm:
        tx = optax.chain(optax.clip_by_global_norm(config.clipped_grad_norm),
                         tx)

    # jit state creation to ensure arrays are created on same device as input
    # (i.e. CPU).
    state_cpu = jax.jit(
        functools.partial(FlaxTrainState.create,
                          apply_fn=model.apply,
                          params=params,
                          tx=tx))()

    # We access model params only via state.params
    del params

    if config.num_experts > 1:
        sharded_match_fn = core_utils.match_fn(r".*expert.*")
        not_sharded_match_fn = lambda name: not sharded_match_fn(name)
    else:
        sharded_match_fn = None
        not_sharded_match_fn = lambda name: True

    state, start_step = _restore_state_from_checkpoint(workdir, state_cpu,
                                                       sharded_match_fn,
                                                       not_sharded_match_fn,
                                                       config)

    if is_regression_task:
        scoring_fn = lambda y: y[Ellipsis, 0]
    else:
        scoring_fn = lambda y: y.argmax(-1)
    compute_stats = functools.partial(_compute_stats,
                                      model=model,
                                      scoring_fn=scoring_fn)

    classification_inputs = functools.partial(
        input_pipeline.classification_inputs,
        dataset_name=config.dataset_name,
        max_seq_length=config.max_seq_length,
        tokenizer=tokenizer)
    train_ds = classification_inputs(split=tfds.Split.TRAIN,
                                     batch_size=per_host_train_batch_size,
                                     training=True)
    train_iter = iter(train_ds)

    if config.dataset_name == "glue/mnli":
        # MNLI contains two validation and test datasets.
        split_suffixes = ["_matched", "_mismatched"]
    else:
        split_suffixes = [""]

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    rngs = random.split(rng, jax.local_device_count())

    loss_and_metrics_fn = functools.partial(
        _compute_loss_and_metrics,
        model=model,
        is_experts_model=config.num_experts > 1,
        auxiliary_loss_factor=config.auxiliary_loss_factor,
        router_z_loss_factor=config.router_z_loss_factor)
    train_step = functools.partial(
        train_utils.pmap_train_step,
        loss_and_metrics_fn=loss_and_metrics_fn,
        axis_name="batch",
        sharded_match_fn=sharded_match_fn,
        gradient_accum_steps=config.gradient_accum_steps)
    p_train_step = jax.pmap(train_step, axis_name="batch")
    p_eval_step = jax.pmap(compute_stats, axis_name="batch")
    eval_metrics_fn = _create_eval_metrics_fn(config.dataset_name)

    train_stats = []
    logging.info("Starting training loop.")
    logging.info("====================")

    for step in range(start_step, num_train_steps):
        with jax.profiler.StepTraceContext("train", step_num=step):
            train_batch = next(train_iter)
            train_batch = common_utils.shard(train_batch)

            state, train_step_stats, rngs = p_train_step(state,
                                                         train_batch,
                                                         rng=rngs)

            train_stats.append(train_step_stats)

        if ((step > 0 and config.save_checkpoints_steps
             and step % config.save_checkpoints_steps == 0)
                or step == num_train_steps - 1):
            # We allow all hosts to potentially save checkpoints because some model
            # parameters are sharded across devices. Parameters replicated across
            # devices (i.e. not sharded) will only be checkpointed by host 0.
            unreplicated_train_state = jax.tree_map(
                np.array,
                core_utils.tree_unreplicate_by_name(state,
                                                    not_sharded_match_fn))
            checkpoints.save_checkpoint(workdir,
                                        unreplicated_train_state,
                                        sharded_match_fn,
                                        step,
                                        keep=config.checkpoints_to_keep)
            del unreplicated_train_state  # Only used for checkpointing.

        # Periodic metric handling.
        if step % eval_frequency != 0 and step < num_train_steps - 1:
            continue

        logging.info("Gathering training metrics at step: %d", step)
        train_metrics = train_utils.collect_metrics(train_stats)
        train_summary = train_utils.compute_classification_metrics(
            train_metrics, is_regression_task)
        train_summary["learning_rate"] = learning_rate_fn(step)

        if jax.process_index() == 0:
            assert train_summary_writer
            for key, val in train_summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        # Reset metric accumulation for next training evaluation cycle.
        train_stats = []

        logging.info("Gathering validation metrics at step: %d", step)

        for split_suffix in split_suffixes:
            eval_ds = classification_inputs(
                split=tfds.Split.VALIDATION + split_suffix,
                batch_size=per_host_eval_batch_size,
                training=False)

            eval_stats = []
            for _, eval_batch in zip(range(config.max_num_eval_steps),
                                     eval_ds):
                eval_stats.append(
                    _evaluate(p_eval_step, state.params, eval_batch))
            eval_metrics = {}
            for k in eval_stats[
                    0]:  # All batches of output stats are the same size
                eval_metrics[k] = np.concatenate(
                    [stat[k] for stat in eval_stats], axis=0)
            eval_summary = eval_metrics_fn(eval_metrics)

            if jax.process_index() == 0:
                assert eval_summary_writer
                for key, val in eval_summary.items():
                    eval_summary_writer.scalar(f"{key}{split_suffix}", val,
                                               step)
                eval_summary_writer.flush()