Exemple #1
0
def fit(params: optax.Params,
        opt: optax.GradientTransformation) -> optax.Params:
    state = TrainState.create(
        apply_fn=net.apply,
        params=params,
        tx=opt,
        # opt_state=opt.init(params)
    )

    @jax.jit
    def step(state, batch, labels):
        (loss_val,
         accuracy), grads = jax.value_and_grad(loss,
                                               has_aux=True)(state.params,
                                                             batch, labels)
        state = state.apply_gradients(grads=grads)
        return state, loss_val, accuracy

    for i, (batch, labels) in enumerate(zip(train_data, train_labels)):
        state, loss_val, accuracy = step(state, batch, labels)
        if i % 100 == 0:
            print(
                f"step {i}/{nb_steps} | loss: {loss_val:.5f} | accuracy: {accuracy*100:.2f}%"
            )

    return params
Exemple #2
0
def fit(params: optax.Params,
        opt: optax.GradientTransformation) -> optax.Params:
    # bundle together everything in the `TrainState class`
    state = TrainState.create(
        apply_fn=net.apply,
        params=params,
        tx=opt,
    )

    # jit compile the step function
    @jax.jit
    def train_step(state, batch, labels):
        batch = jnp.transpose(batch, axes=(0, 2, 3, 1))
        labels = jax.nn.one_hot(labels, nb_classes)
        (loss_val, accuracy), grads = jax.value_and_grad(loss, has_aux=True)(
            state.params, batch, labels)  # return accuracy as aux
        state = state.apply_gradients(
            grads=grads
        )  # apply gradients to training state (calls other things internally)
        return state, loss_val, accuracy

    @jax.jit
    def eval_step(params, batch, labels):
        batch = jnp.transpose(batch, axes=(0, 2, 3, 1))
        labels = jax.nn.one_hot(labels, nb_classes)
        loss_val, accuracy = loss(params, batch, labels)
        return loss_val, accuracy

    for i in range(nb_epochs):
        train_loss, train_accuracy = 0.0, 0.0
        for batch, labels in train_loader:
            batch, labels = jnp.array(batch), jnp.array(labels)
            state, loss_val, accuracy = train_step(state, batch, labels)

            train_loss += loss_val
            train_accuracy += accuracy

        test_loss, test_accuracy = 0.0, 0.0
        for batch, labels in test_loader:
            batch, labels = jnp.array(batch), jnp.array(labels)
            loss_val, accuracy = eval_step(state.params, batch, labels)

            test_loss += loss_val
            test_accuracy += accuracy

        train_loss /= len(train_loader)
        train_accuracy /= len(train_loader)

        test_loss /= len(test_loader)
        test_accuracy /= len(test_loader)

        print(
            f"epoch {i+1}/{nb_epochs} | train: {train_loss:.5f} [{train_accuracy*100:.2f}%] | eval: {test_loss:.5f} [{test_accuracy*100:.2f}%]"
        )

    return params
def create_train_state(rng, model, img_size, lr_schedule_fn, weight_decay,
                       max_norm):

    tx = optax.chain(optax.clip_by_global_norm(max_norm),
                     optax.scale_by_adam(),
                     optax.additive_weight_decay(weight_decay),
                     optax.scale_by_schedule(lr_schedule_fn))

    params = model.init(rng,
                        jax.numpy.ones((1, img_size, img_size, 3)),
                        is_training=False)

    train_state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx,
    )
    return train_state
Exemple #4
0
def collate_fn(batch):
    inputs = np.stack([x[0] for x in batch], axis=0)
    labels = np.array([x[1] for x in batch])
    return {"inputs": inputs, "labels": labels}


if __name__ == "__main__":
    num_epochs = 3
    rng = random.PRNGKey(42)

    model = Model()
    dummy_inputs = np.ones((1, 28, 28, 1), np.float32)
    params = model.init(rng, dummy_inputs)["params"]
    tx = optax.adam(learning_rate=1e-3)
    train_state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

    train_loader = DataLoader(MNIST("data/mnist", download=True),
                              batch_size=64,
                              shuffle=True)
    eval_loader = DataLoader(MNIST("data/mnist", train=False), batch_size=64)

    p_train_state = replicate(train_state)
    p_train_step = jax.pmap(train_step, "batch")
    p_eval_step = jax.pmap(eval_step, "batch")

    for epoch in range(num_epochs):
        print(f"\nEpoch: {epoch}")
        rng, input_rng = random.split(rng)
        train_metrics = []
Exemple #5
0
class Net(flax.linen.Module):
    @flax.linen.compact
    def __call__(self, x):
        x = flax.linen.Dense(128)(x)
        x = flax.linen.relu(x)
        x = flax.linen.Dense(32)(x)
        x = flax.linen.relu(x)
        x = flax.linen.Dense(10)(x)
        x = flax.linen.log_softmax(x)
        return x


model = Net()
params = model.init(jax.random.PRNGKey(42), numpy.ones((1, 28 * 28)))["params"]
optimizer = optax.adam(0.001)
state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)


@functools.partial(jax.jit, static_argnums=(3, ))
def step(x, y, state: TrainState, training: bool):
    def loss_fn(params):
        y_pred = model.apply({"params": params}, x)
        y_one_hot = jax.nn.one_hot(y, 10)
        loss = optax.softmax_cross_entropy(y_pred, y_one_hot).mean()
        return loss, y_pred

    x = x.reshape(-1, 28 * 28)
    if training:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, y_pred), grads = grad_fn(state.params)
        state = state.apply_gradients(grads=grads)
Exemple #6
0
def train(base_dir, config):
    """Train function."""
    print(config)
    chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'train'))

    writer = create_default_writer()

    # Initialize dataset
    key = jax.random.PRNGKey(config.seed)
    key, subkey = jax.random.split(key)
    ds = dataset.get_dataset(config, subkey, num_tasks=config.num_tasks)
    ds_iter = iter(ds)

    key, subkey = jax.random.split(key)
    encoder = MLPEncoder(**config.encoder)

    train_config = config.train.to_dict()
    train_method = train_config.pop('method')

    module_config = train_config.pop('module')
    module_class = module_config.pop('name')

    module = globals().get(module_class)(encoder, **module_config)
    train_step = globals().get(f'train_step_{train_method}')
    train_step = functools.partial(train_step, **train_config)

    params = module.init(subkey, next(ds_iter)[0])
    lr = optax.cosine_decay_schedule(config.learning_rate,
                                     config.num_train_steps)
    optim = optax.chain(optax.adam(lr),
                        # optax.adaptive_grad_clip(0.15)
                        )

    state = TrainState.create(apply_fn=module.apply, params=params, tx=optim)
    state = chkpt_manager.restore_or_initialize(state)

    # Hooks
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    hooks = [
        report_progress,
        periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir))
    ]

    def handle_preemption(signal_number, _):
        logging.info('Received signal %d, saving checkpoint.', signal_number)
        with report_progress.timed('checkpointing'):
            chkpt_manager.save(state)
        logging.info('Finished saving checkpoint.')

    signal.signal(signal.SIGTERM, handle_preemption)

    metrics = TrainMetrics.empty()
    with metric_writers.ensure_flushes(writer):
        for step in tqdm.tqdm(range(state.step, config.num_train_steps)):
            with jax.profiler.StepTraceAnnotation('train', step_num=step):
                states, targets = next(ds_iter)
                state, metrics = train_step(state, metrics, states, targets)

            logging.log_first_n(logging.INFO, 'Finished training step %d', 5,
                                step)

            if step % config.log_metrics_every == 0:
                writer.write_scalars(step, metrics.compute())
                metrics = TrainMetrics.empty()

            # if step % config.log_eval_metrics_every == 0 and isinstance(
            #     ds, dataset.MDPDataset):
            #   eval_metrics = evaluate_mdp(state, ds.aux_task_matrix, config)
            #   writer.write_scalars(step, eval_metrics.compute())

            for hook in hooks:
                hook(step)

    chkpt_manager.save(state)
    return state
Exemple #7
0
def evaluate(base_dir, config, *, train_state):
    """Eval function."""
    chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'eval'))

    writer = create_default_writer()

    key = jax.random.PRNGKey(config.eval.seed)
    model_init_key, ds_key = jax.random.split(key)

    linear_module = LinearModule(config.eval.num_tasks)
    params = linear_module.init(model_init_key,
                                jnp.zeros((config.encoder.embedding_dim, )))
    lr = optax.cosine_decay_schedule(config.eval.learning_rate,
                                     config.num_eval_steps)
    optim = optax.adam(lr)

    ds = dataset.get_dataset(config, ds_key, num_tasks=config.eval.num_tasks)
    ds_iter = iter(ds)

    state = TrainState.create(apply_fn=linear_module.apply,
                              params=params,
                              tx=optim)
    state = chkpt_manager.restore_or_initialize(state)

    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_eval_steps, writer=writer)
    hooks = [
        report_progress,
        periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir))
    ]

    def handle_preemption(signal_number, _):
        logging.info('Received signal %d, saving checkpoint.', signal_number)
        with report_progress.timed('checkpointing'):
            chkpt_manager.save(state)
        logging.info('Finished saving checkpoint.')

    signal.signal(signal.SIGTERM, handle_preemption)

    metrics = EvalMetrics.empty()
    with metric_writers.ensure_flushes(writer):
        for step in tqdm.tqdm(range(state.step, config.num_eval_steps)):
            with jax.profiler.StepTraceAnnotation('eval', step_num=step):
                states, targets = next(ds_iter)
                state, metrics = evaluate_step(train_state, state, metrics,
                                               states, targets)

            if step % config.log_metrics_every == 0:
                writer.write_scalars(step, metrics.compute())
                metrics = EvalMetrics.empty()

            for hook in hooks:
                hook(step)

        # Finally, evaluate on the true(ish) test aux task matrix.
        states, targets = dataset.EvalDataset(config, ds_key).get_batch()

        @jax.jit
        def loss_fn():
            outputs = train_state.apply_fn(train_state.params, states)
            phis = outputs.phi
            predictions = jax.vmap(state.apply_fn,
                                   in_axes=(None, 0))(state.params, phis)
            return jnp.mean(optax.l2_loss(predictions, targets))

        test_loss = loss_fn()
        writer.write_scalars(config.num_eval_steps + 1,
                             {'test_loss': test_loss})