def train_step( state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey) -> Tuple[train_state.TrainState, float]: """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) targets = batch.pop("labels") def loss_fn(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = state.loss_fn(logits, targets) return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": learning_rate_fn(state.step) }, axis_name="batch") return new_state, metrics, new_dropout_rng
def step(x, y, state: TrainState, training: bool): def loss_fn(params): y_pred = model.apply({"params": params}, x) y_one_hot = jax.nn.one_hot(y, 10) loss = optax.softmax_cross_entropy(y_pred, y_one_hot).mean() return loss, y_pred x = x.reshape(-1, 28 * 28) if training: grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, y_pred), grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) else: loss, y_pred = loss_fn(state.params) return loss, y_pred, state
def train_step(state: TrainState, batch): def compute_loss(params: Dict[str, Any]): inputs, labels = batch logits = state.apply_fn({"params": params}, inputs) return loss_fn(logits, labels) grad_fn = jax.value_and_grad(compute_loss) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = {"loss": loss} metrics = jax.lax.pmean(metrics, axis_name="batch") return new_state, metrics
def train_step( state: train_state.TrainState, trajectories: Tuple, batch_size: int, *, clip_param: float, vf_coeff: float, entropy_coeff: float): """Compilable train step. Runs an entire epoch of training (i.e. the loop over minibatches within an epoch is included here for performance reasons). Args: state: the train state trajectories: Tuple of the following five elements forming the experience: states: shape (steps_per_agent*num_agents, 84, 84, 4) actions: shape (steps_per_agent*num_agents, 84, 84, 4) old_log_probs: shape (steps_per_agent*num_agents, ) returns: shape (steps_per_agent*num_agents, ) advantages: (steps_per_agent*num_agents, ) batch_size: the minibatch size, static argument clip_param: the PPO clipping parameter used to clamp ratios in loss function vf_coeff: weighs value function loss in total loss entropy_coeff: weighs entropy bonus in the total loss Returns: optimizer: new optimizer after the parameters update loss: loss summed over training steps """ iterations = trajectories[0].shape[0] // batch_size trajectories = jax.tree_map( lambda x: x.reshape((iterations, batch_size) + x.shape[1:]), trajectories) loss = 0. for batch in zip(*trajectories): grad_fn = jax.value_and_grad(loss_fn) l, grads = grad_fn(state.params, state.apply_fn, batch, clip_param, vf_coeff, entropy_coeff) loss += l state = state.apply_gradients(grads=grads) return state, loss
def train_step( train_state: ts.TrainState, model_vars: Dict[str, Any], batch: Dict[str, Any], dropout_rng: jnp.ndarray, model_config: ml_collections.FrozenConfigDict, ) -> Tuple[ts.TrainState, Dict[str, Any]]: """Perform a single training step. Args: train_state: contains model params, loss fn, grad update fn. model_vars: model variables that are not optimized. batch: input to model. dropout_rng: seed for dropout rng in model. model_config: contains model hyperparameters. Returns: Train state with updated parameters and dictionary of metrics. """ dropout_rng = jax.random.fold_in(dropout_rng, train_state.step) def loss_fn_partial(model_params): loss, metrics, _ = train_state.apply_fn( model_config, model_params, model_vars, batch, deterministic=False, dropout_rng={'dropout': dropout_rng}, ) return loss, metrics grad_fn = jax.value_and_grad(loss_fn_partial, has_aux=True) (_, metrics), grad = grad_fn(train_state.params) grad = jax.lax.pmean(grad, 'batch') metrics = jax.lax.psum(metrics, axis_name='batch') metrics = metric_utils.update_metrics_dtype(metrics) new_train_state = train_state.apply_gradients(grads=grad) return new_train_state, metrics