def train_step(
            state: train_state.TrainState, batch: Dict[str, Array],
            dropout_rng: PRNGKey) -> Tuple[train_state.TrainState, float]:
        """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
        targets = batch.pop("labels")

        def loss_fn(params):
            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)[0]
            loss = state.loss_fn(logits, targets)
            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)
        metrics = jax.lax.pmean(
            {
                "loss": loss,
                "learning_rate": learning_rate_fn(state.step)
            },
            axis_name="batch")
        return new_state, metrics, new_dropout_rng
Exemple #2
0
def step(x, y, state: TrainState, training: bool):
    def loss_fn(params):
        y_pred = model.apply({"params": params}, x)
        y_one_hot = jax.nn.one_hot(y, 10)
        loss = optax.softmax_cross_entropy(y_pred, y_one_hot).mean()
        return loss, y_pred

    x = x.reshape(-1, 28 * 28)
    if training:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, y_pred), grads = grad_fn(state.params)
        state = state.apply_gradients(grads=grads)
    else:
        loss, y_pred = loss_fn(state.params)
    return loss, y_pred, state
Exemple #3
0
def train_step(state: TrainState, batch):
    def compute_loss(params: Dict[str, Any]):
        inputs, labels = batch
        logits = state.apply_fn({"params": params}, inputs)
        return loss_fn(logits, labels)

    grad_fn = jax.value_and_grad(compute_loss)
    loss, grad = grad_fn(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)

    metrics = {"loss": loss}
    metrics = jax.lax.pmean(metrics, axis_name="batch")

    return new_state, metrics
Exemple #4
0
def train_step(
    state: train_state.TrainState,
    trajectories: Tuple,
    batch_size: int,
    *,
    clip_param: float,
    vf_coeff: float,
    entropy_coeff: float):
  """Compilable train step.

  Runs an entire epoch of training (i.e. the loop over minibatches within
  an epoch is included here for performance reasons).

  Args:
    state: the train state
    trajectories: Tuple of the following five elements forming the experience:
                  states: shape (steps_per_agent*num_agents, 84, 84, 4)
                  actions: shape (steps_per_agent*num_agents, 84, 84, 4)
                  old_log_probs: shape (steps_per_agent*num_agents, )
                  returns: shape (steps_per_agent*num_agents, )
                  advantages: (steps_per_agent*num_agents, )
    batch_size: the minibatch size, static argument
    clip_param: the PPO clipping parameter used to clamp ratios in loss function
    vf_coeff: weighs value function loss in total loss
    entropy_coeff: weighs entropy bonus in the total loss

  Returns:
    optimizer: new optimizer after the parameters update
    loss: loss summed over training steps
  """
  iterations = trajectories[0].shape[0] // batch_size
  trajectories = jax.tree_map(
      lambda x: x.reshape((iterations, batch_size) + x.shape[1:]), trajectories)
  loss = 0.
  for batch in zip(*trajectories):
    grad_fn = jax.value_and_grad(loss_fn)
    l, grads = grad_fn(state.params, state.apply_fn, batch, clip_param, vf_coeff,
                      entropy_coeff)
    loss += l
    state = state.apply_gradients(grads=grads)
  return state, loss
Exemple #5
0
def train_step(
    train_state: ts.TrainState,
    model_vars: Dict[str, Any],
    batch: Dict[str, Any],
    dropout_rng: jnp.ndarray,
    model_config: ml_collections.FrozenConfigDict,
) -> Tuple[ts.TrainState, Dict[str, Any]]:
  """Perform a single training step.

  Args:
    train_state: contains model params, loss fn, grad update fn.
    model_vars: model variables that are not optimized.
    batch: input to model.
    dropout_rng: seed for dropout rng in model.
    model_config: contains model hyperparameters.

  Returns:
    Train state with updated parameters and dictionary of metrics.
  """

  dropout_rng = jax.random.fold_in(dropout_rng, train_state.step)

  def loss_fn_partial(model_params):
    loss, metrics, _ = train_state.apply_fn(
        model_config,
        model_params,
        model_vars,
        batch,
        deterministic=False,
        dropout_rng={'dropout': dropout_rng},
    )
    return loss, metrics

  grad_fn = jax.value_and_grad(loss_fn_partial, has_aux=True)
  (_, metrics), grad = grad_fn(train_state.params)
  grad = jax.lax.pmean(grad, 'batch')
  metrics = jax.lax.psum(metrics, axis_name='batch')
  metrics = metric_utils.update_metrics_dtype(metrics)
  new_train_state = train_state.apply_gradients(grads=grad)
  return new_train_state, metrics