def fit(params: optax.Params, opt: optax.GradientTransformation) -> optax.Params: state = TrainState.create( apply_fn=net.apply, params=params, tx=opt, # opt_state=opt.init(params) ) @jax.jit def step(state, batch, labels): (loss_val, accuracy), grads = jax.value_and_grad(loss, has_aux=True)(state.params, batch, labels) state = state.apply_gradients(grads=grads) return state, loss_val, accuracy for i, (batch, labels) in enumerate(zip(train_data, train_labels)): state, loss_val, accuracy = step(state, batch, labels) if i % 100 == 0: print( f"step {i}/{nb_steps} | loss: {loss_val:.5f} | accuracy: {accuracy*100:.2f}%" ) return params
def fit(params: optax.Params, opt: optax.GradientTransformation) -> optax.Params: # bundle together everything in the `TrainState class` state = TrainState.create( apply_fn=net.apply, params=params, tx=opt, ) # jit compile the step function @jax.jit def train_step(state, batch, labels): batch = jnp.transpose(batch, axes=(0, 2, 3, 1)) labels = jax.nn.one_hot(labels, nb_classes) (loss_val, accuracy), grads = jax.value_and_grad(loss, has_aux=True)( state.params, batch, labels) # return accuracy as aux state = state.apply_gradients( grads=grads ) # apply gradients to training state (calls other things internally) return state, loss_val, accuracy @jax.jit def eval_step(params, batch, labels): batch = jnp.transpose(batch, axes=(0, 2, 3, 1)) labels = jax.nn.one_hot(labels, nb_classes) loss_val, accuracy = loss(params, batch, labels) return loss_val, accuracy for i in range(nb_epochs): train_loss, train_accuracy = 0.0, 0.0 for batch, labels in train_loader: batch, labels = jnp.array(batch), jnp.array(labels) state, loss_val, accuracy = train_step(state, batch, labels) train_loss += loss_val train_accuracy += accuracy test_loss, test_accuracy = 0.0, 0.0 for batch, labels in test_loader: batch, labels = jnp.array(batch), jnp.array(labels) loss_val, accuracy = eval_step(state.params, batch, labels) test_loss += loss_val test_accuracy += accuracy train_loss /= len(train_loader) train_accuracy /= len(train_loader) test_loss /= len(test_loader) test_accuracy /= len(test_loader) print( f"epoch {i+1}/{nb_epochs} | train: {train_loss:.5f} [{train_accuracy*100:.2f}%] | eval: {test_loss:.5f} [{test_accuracy*100:.2f}%]" ) return params
def create_train_state(rng, model, img_size, lr_schedule_fn, weight_decay, max_norm): tx = optax.chain(optax.clip_by_global_norm(max_norm), optax.scale_by_adam(), optax.additive_weight_decay(weight_decay), optax.scale_by_schedule(lr_schedule_fn)) params = model.init(rng, jax.numpy.ones((1, img_size, img_size, 3)), is_training=False) train_state = TrainState.create( apply_fn=model.apply, params=params, tx=tx, ) return train_state
def collate_fn(batch): inputs = np.stack([x[0] for x in batch], axis=0) labels = np.array([x[1] for x in batch]) return {"inputs": inputs, "labels": labels} if __name__ == "__main__": num_epochs = 3 rng = random.PRNGKey(42) model = Model() dummy_inputs = np.ones((1, 28, 28, 1), np.float32) params = model.init(rng, dummy_inputs)["params"] tx = optax.adam(learning_rate=1e-3) train_state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) train_loader = DataLoader(MNIST("data/mnist", download=True), batch_size=64, shuffle=True) eval_loader = DataLoader(MNIST("data/mnist", train=False), batch_size=64) p_train_state = replicate(train_state) p_train_step = jax.pmap(train_step, "batch") p_eval_step = jax.pmap(eval_step, "batch") for epoch in range(num_epochs): print(f"\nEpoch: {epoch}") rng, input_rng = random.split(rng) train_metrics = []
class Net(flax.linen.Module): @flax.linen.compact def __call__(self, x): x = flax.linen.Dense(128)(x) x = flax.linen.relu(x) x = flax.linen.Dense(32)(x) x = flax.linen.relu(x) x = flax.linen.Dense(10)(x) x = flax.linen.log_softmax(x) return x model = Net() params = model.init(jax.random.PRNGKey(42), numpy.ones((1, 28 * 28)))["params"] optimizer = optax.adam(0.001) state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer) @functools.partial(jax.jit, static_argnums=(3, )) def step(x, y, state: TrainState, training: bool): def loss_fn(params): y_pred = model.apply({"params": params}, x) y_one_hot = jax.nn.one_hot(y, 10) loss = optax.softmax_cross_entropy(y_pred, y_one_hot).mean() return loss, y_pred x = x.reshape(-1, 28 * 28) if training: grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, y_pred), grads = grad_fn(state.params) state = state.apply_gradients(grads=grads)
def train(base_dir, config): """Train function.""" print(config) chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'train')) writer = create_default_writer() # Initialize dataset key = jax.random.PRNGKey(config.seed) key, subkey = jax.random.split(key) ds = dataset.get_dataset(config, subkey, num_tasks=config.num_tasks) ds_iter = iter(ds) key, subkey = jax.random.split(key) encoder = MLPEncoder(**config.encoder) train_config = config.train.to_dict() train_method = train_config.pop('method') module_config = train_config.pop('module') module_class = module_config.pop('name') module = globals().get(module_class)(encoder, **module_config) train_step = globals().get(f'train_step_{train_method}') train_step = functools.partial(train_step, **train_config) params = module.init(subkey, next(ds_iter)[0]) lr = optax.cosine_decay_schedule(config.learning_rate, config.num_train_steps) optim = optax.chain(optax.adam(lr), # optax.adaptive_grad_clip(0.15) ) state = TrainState.create(apply_fn=module.apply, params=params, tx=optim) state = chkpt_manager.restore_or_initialize(state) # Hooks report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) hooks = [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir)) ] def handle_preemption(signal_number, _): logging.info('Received signal %d, saving checkpoint.', signal_number) with report_progress.timed('checkpointing'): chkpt_manager.save(state) logging.info('Finished saving checkpoint.') signal.signal(signal.SIGTERM, handle_preemption) metrics = TrainMetrics.empty() with metric_writers.ensure_flushes(writer): for step in tqdm.tqdm(range(state.step, config.num_train_steps)): with jax.profiler.StepTraceAnnotation('train', step_num=step): states, targets = next(ds_iter) state, metrics = train_step(state, metrics, states, targets) logging.log_first_n(logging.INFO, 'Finished training step %d', 5, step) if step % config.log_metrics_every == 0: writer.write_scalars(step, metrics.compute()) metrics = TrainMetrics.empty() # if step % config.log_eval_metrics_every == 0 and isinstance( # ds, dataset.MDPDataset): # eval_metrics = evaluate_mdp(state, ds.aux_task_matrix, config) # writer.write_scalars(step, eval_metrics.compute()) for hook in hooks: hook(step) chkpt_manager.save(state) return state
def evaluate(base_dir, config, *, train_state): """Eval function.""" chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'eval')) writer = create_default_writer() key = jax.random.PRNGKey(config.eval.seed) model_init_key, ds_key = jax.random.split(key) linear_module = LinearModule(config.eval.num_tasks) params = linear_module.init(model_init_key, jnp.zeros((config.encoder.embedding_dim, ))) lr = optax.cosine_decay_schedule(config.eval.learning_rate, config.num_eval_steps) optim = optax.adam(lr) ds = dataset.get_dataset(config, ds_key, num_tasks=config.eval.num_tasks) ds_iter = iter(ds) state = TrainState.create(apply_fn=linear_module.apply, params=params, tx=optim) state = chkpt_manager.restore_or_initialize(state) report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_eval_steps, writer=writer) hooks = [ report_progress, periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir)) ] def handle_preemption(signal_number, _): logging.info('Received signal %d, saving checkpoint.', signal_number) with report_progress.timed('checkpointing'): chkpt_manager.save(state) logging.info('Finished saving checkpoint.') signal.signal(signal.SIGTERM, handle_preemption) metrics = EvalMetrics.empty() with metric_writers.ensure_flushes(writer): for step in tqdm.tqdm(range(state.step, config.num_eval_steps)): with jax.profiler.StepTraceAnnotation('eval', step_num=step): states, targets = next(ds_iter) state, metrics = evaluate_step(train_state, state, metrics, states, targets) if step % config.log_metrics_every == 0: writer.write_scalars(step, metrics.compute()) metrics = EvalMetrics.empty() for hook in hooks: hook(step) # Finally, evaluate on the true(ish) test aux task matrix. states, targets = dataset.EvalDataset(config, ds_key).get_batch() @jax.jit def loss_fn(): outputs = train_state.apply_fn(train_state.params, states) phis = outputs.phi predictions = jax.vmap(state.apply_fn, in_axes=(None, 0))(state.params, phis) return jnp.mean(optax.l2_loss(predictions, targets)) test_loss = loss_fn() writer.write_scalars(config.num_eval_steps + 1, {'test_loss': test_loss})