def fit(params: optax.Params, opt: optax.GradientTransformation) -> optax.Params: state = TrainState.create( apply_fn=net.apply, params=params, tx=opt, # opt_state=opt.init(params) ) @jax.jit def step(state, batch, labels): (loss_val, accuracy), grads = jax.value_and_grad(loss, has_aux=True)(state.params, batch, labels) state = state.apply_gradients(grads=grads) return state, loss_val, accuracy for i, (batch, labels) in enumerate(zip(train_data, train_labels)): state, loss_val, accuracy = step(state, batch, labels) if i % 100 == 0: print( f"step {i}/{nb_steps} | loss: {loss_val:.5f} | accuracy: {accuracy*100:.2f}%" ) return params
def train_step( state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey) -> Tuple[train_state.TrainState, float]: """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) targets = batch.pop("labels") def loss_fn(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = state.loss_fn(logits, targets) return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": learning_rate_fn(state.step) }, axis_name="batch") return new_state, metrics, new_dropout_rng
def fit(params: optax.Params, opt: optax.GradientTransformation) -> optax.Params: # bundle together everything in the `TrainState class` state = TrainState.create( apply_fn=net.apply, params=params, tx=opt, ) # jit compile the step function @jax.jit def train_step(state, batch, labels): batch = jnp.transpose(batch, axes=(0, 2, 3, 1)) labels = jax.nn.one_hot(labels, nb_classes) (loss_val, accuracy), grads = jax.value_and_grad(loss, has_aux=True)( state.params, batch, labels) # return accuracy as aux state = state.apply_gradients( grads=grads ) # apply gradients to training state (calls other things internally) return state, loss_val, accuracy @jax.jit def eval_step(params, batch, labels): batch = jnp.transpose(batch, axes=(0, 2, 3, 1)) labels = jax.nn.one_hot(labels, nb_classes) loss_val, accuracy = loss(params, batch, labels) return loss_val, accuracy for i in range(nb_epochs): train_loss, train_accuracy = 0.0, 0.0 for batch, labels in train_loader: batch, labels = jnp.array(batch), jnp.array(labels) state, loss_val, accuracy = train_step(state, batch, labels) train_loss += loss_val train_accuracy += accuracy test_loss, test_accuracy = 0.0, 0.0 for batch, labels in test_loader: batch, labels = jnp.array(batch), jnp.array(labels) loss_val, accuracy = eval_step(state.params, batch, labels) test_loss += loss_val test_accuracy += accuracy train_loss /= len(train_loader) train_accuracy /= len(train_loader) test_loss /= len(test_loader) test_accuracy /= len(test_loader) print( f"epoch {i+1}/{nb_epochs} | train: {train_loss:.5f} [{train_accuracy*100:.2f}%] | eval: {test_loss:.5f} [{test_accuracy*100:.2f}%]" ) return params
def train_step(state: TrainState, batch): def compute_loss(params: Dict[str, Any]): inputs, labels = batch logits = state.apply_fn({"params": params}, inputs) return loss_fn(logits, labels) grad_fn = jax.value_and_grad(compute_loss) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = {"loss": loss} metrics = jax.lax.pmean(metrics, axis_name="batch") return new_state, metrics
def step(x, y, state: TrainState, training: bool): def loss_fn(params): y_pred = model.apply({"params": params}, x) y_one_hot = jax.nn.one_hot(y, 10) loss = optax.softmax_cross_entropy(y_pred, y_one_hot).mean() return loss, y_pred x = x.reshape(-1, 28 * 28) if training: grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, y_pred), grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) else: loss, y_pred = loss_fn(state.params) return loss, y_pred, state
def create_train_state(rng, model, img_size, lr_schedule_fn, weight_decay, max_norm): tx = optax.chain(optax.clip_by_global_norm(max_norm), optax.scale_by_adam(), optax.additive_weight_decay(weight_decay), optax.scale_by_schedule(lr_schedule_fn)) params = model.init(rng, jax.numpy.ones((1, img_size, img_size, 3)), is_training=False) train_state = TrainState.create( apply_fn=model.apply, params=params, tx=tx, ) return train_state
def train_step( state: train_state.TrainState, trajectories: Tuple, batch_size: int, *, clip_param: float, vf_coeff: float, entropy_coeff: float): """Compilable train step. Runs an entire epoch of training (i.e. the loop over minibatches within an epoch is included here for performance reasons). Args: state: the train state trajectories: Tuple of the following five elements forming the experience: states: shape (steps_per_agent*num_agents, 84, 84, 4) actions: shape (steps_per_agent*num_agents, 84, 84, 4) old_log_probs: shape (steps_per_agent*num_agents, ) returns: shape (steps_per_agent*num_agents, ) advantages: (steps_per_agent*num_agents, ) batch_size: the minibatch size, static argument clip_param: the PPO clipping parameter used to clamp ratios in loss function vf_coeff: weighs value function loss in total loss entropy_coeff: weighs entropy bonus in the total loss Returns: optimizer: new optimizer after the parameters update loss: loss summed over training steps """ iterations = trajectories[0].shape[0] // batch_size trajectories = jax.tree_map( lambda x: x.reshape((iterations, batch_size) + x.shape[1:]), trajectories) loss = 0. for batch in zip(*trajectories): grad_fn = jax.value_and_grad(loss_fn) l, grads = grad_fn(state.params, state.apply_fn, batch, clip_param, vf_coeff, entropy_coeff) loss += l state = state.apply_gradients(grads=grads) return state, loss
def train_step( train_state: ts.TrainState, model_vars: Dict[str, Any], batch: Dict[str, Any], dropout_rng: jnp.ndarray, model_config: ml_collections.FrozenConfigDict, ) -> Tuple[ts.TrainState, Dict[str, Any]]: """Perform a single training step. Args: train_state: contains model params, loss fn, grad update fn. model_vars: model variables that are not optimized. batch: input to model. dropout_rng: seed for dropout rng in model. model_config: contains model hyperparameters. Returns: Train state with updated parameters and dictionary of metrics. """ dropout_rng = jax.random.fold_in(dropout_rng, train_state.step) def loss_fn_partial(model_params): loss, metrics, _ = train_state.apply_fn( model_config, model_params, model_vars, batch, deterministic=False, dropout_rng={'dropout': dropout_rng}, ) return loss, metrics grad_fn = jax.value_and_grad(loss_fn_partial, has_aux=True) (_, metrics), grad = grad_fn(train_state.params) grad = jax.lax.pmean(grad, 'batch') metrics = jax.lax.psum(metrics, axis_name='batch') metrics = metric_utils.update_metrics_dtype(metrics) new_train_state = train_state.apply_gradients(grads=grad) return new_train_state, metrics
def collate_fn(batch): inputs = np.stack([x[0] for x in batch], axis=0) labels = np.array([x[1] for x in batch]) return {"inputs": inputs, "labels": labels} if __name__ == "__main__": num_epochs = 3 rng = random.PRNGKey(42) model = Model() dummy_inputs = np.ones((1, 28, 28, 1), np.float32) params = model.init(rng, dummy_inputs)["params"] tx = optax.adam(learning_rate=1e-3) train_state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) train_loader = DataLoader(MNIST("data/mnist", download=True), batch_size=64, shuffle=True) eval_loader = DataLoader(MNIST("data/mnist", train=False), batch_size=64) p_train_state = replicate(train_state) p_train_step = jax.pmap(train_step, "batch") p_eval_step = jax.pmap(eval_step, "batch") for epoch in range(num_epochs): print(f"\nEpoch: {epoch}") rng, input_rng = random.split(rng) train_metrics = []
class Net(flax.linen.Module): @flax.linen.compact def __call__(self, x): x = flax.linen.Dense(128)(x) x = flax.linen.relu(x) x = flax.linen.Dense(32)(x) x = flax.linen.relu(x) x = flax.linen.Dense(10)(x) x = flax.linen.log_softmax(x) return x model = Net() params = model.init(jax.random.PRNGKey(42), numpy.ones((1, 28 * 28)))["params"] optimizer = optax.adam(0.001) state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer) @functools.partial(jax.jit, static_argnums=(3, )) def step(x, y, state: TrainState, training: bool): def loss_fn(params): y_pred = model.apply({"params": params}, x) y_one_hot = jax.nn.one_hot(y, 10) loss = optax.softmax_cross_entropy(y_pred, y_one_hot).mean() return loss, y_pred x = x.reshape(-1, 28 * 28) if training: grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, y_pred), grads = grad_fn(state.params) state = state.apply_gradients(grads=grads)
def train(base_dir, config): """Train function.""" print(config) chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'train')) writer = create_default_writer() # Initialize dataset key = jax.random.PRNGKey(config.seed) key, subkey = jax.random.split(key) ds = dataset.get_dataset(config, subkey, num_tasks=config.num_tasks) ds_iter = iter(ds) key, subkey = jax.random.split(key) encoder = MLPEncoder(**config.encoder) train_config = config.train.to_dict() train_method = train_config.pop('method') module_config = train_config.pop('module') module_class = module_config.pop('name') module = globals().get(module_class)(encoder, **module_config) train_step = globals().get(f'train_step_{train_method}') train_step = functools.partial(train_step, **train_config) params = module.init(subkey, next(ds_iter)[0]) lr = optax.cosine_decay_schedule(config.learning_rate, config.num_train_steps) optim = optax.chain(optax.adam(lr), # optax.adaptive_grad_clip(0.15) ) state = TrainState.create(apply_fn=module.apply, params=params, tx=optim) state = chkpt_manager.restore_or_initialize(state) # Hooks report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) hooks = [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir)) ] def handle_preemption(signal_number, _): logging.info('Received signal %d, saving checkpoint.', signal_number) with report_progress.timed('checkpointing'): chkpt_manager.save(state) logging.info('Finished saving checkpoint.') signal.signal(signal.SIGTERM, handle_preemption) metrics = TrainMetrics.empty() with metric_writers.ensure_flushes(writer): for step in tqdm.tqdm(range(state.step, config.num_train_steps)): with jax.profiler.StepTraceAnnotation('train', step_num=step): states, targets = next(ds_iter) state, metrics = train_step(state, metrics, states, targets) logging.log_first_n(logging.INFO, 'Finished training step %d', 5, step) if step % config.log_metrics_every == 0: writer.write_scalars(step, metrics.compute()) metrics = TrainMetrics.empty() # if step % config.log_eval_metrics_every == 0 and isinstance( # ds, dataset.MDPDataset): # eval_metrics = evaluate_mdp(state, ds.aux_task_matrix, config) # writer.write_scalars(step, eval_metrics.compute()) for hook in hooks: hook(step) chkpt_manager.save(state) return state
def evaluate(base_dir, config, *, train_state): """Eval function.""" chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'eval')) writer = create_default_writer() key = jax.random.PRNGKey(config.eval.seed) model_init_key, ds_key = jax.random.split(key) linear_module = LinearModule(config.eval.num_tasks) params = linear_module.init(model_init_key, jnp.zeros((config.encoder.embedding_dim, ))) lr = optax.cosine_decay_schedule(config.eval.learning_rate, config.num_eval_steps) optim = optax.adam(lr) ds = dataset.get_dataset(config, ds_key, num_tasks=config.eval.num_tasks) ds_iter = iter(ds) state = TrainState.create(apply_fn=linear_module.apply, params=params, tx=optim) state = chkpt_manager.restore_or_initialize(state) report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_eval_steps, writer=writer) hooks = [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir)) ] def handle_preemption(signal_number, _): logging.info('Received signal %d, saving checkpoint.', signal_number) with report_progress.timed('checkpointing'): chkpt_manager.save(state) logging.info('Finished saving checkpoint.') signal.signal(signal.SIGTERM, handle_preemption) metrics = EvalMetrics.empty() with metric_writers.ensure_flushes(writer): for step in tqdm.tqdm(range(state.step, config.num_eval_steps)): with jax.profiler.StepTraceAnnotation('eval', step_num=step): states, targets = next(ds_iter) state, metrics = evaluate_step(train_state, state, metrics, states, targets) if step % config.log_metrics_every == 0: writer.write_scalars(step, metrics.compute()) metrics = EvalMetrics.empty() for hook in hooks: hook(step) # Finally, evaluate on the true(ish) test aux task matrix. states, targets = dataset.EvalDataset(config, ds_key).get_batch() @jax.jit def loss_fn(): outputs = train_state.apply_fn(train_state.params, states) phis = outputs.phi predictions = jax.vmap(state.apply_fn, in_axes=(None, 0))(state.params, phis) return jnp.mean(optax.l2_loss(predictions, targets)) test_loss = loss_fn() writer.write_scalars(config.num_eval_steps + 1, {'test_loss': test_loss})