Beispiel #1
0
def fit(params: optax.Params,
        opt: optax.GradientTransformation) -> optax.Params:
    state = TrainState.create(
        apply_fn=net.apply,
        params=params,
        tx=opt,
        # opt_state=opt.init(params)
    )

    @jax.jit
    def step(state, batch, labels):
        (loss_val,
         accuracy), grads = jax.value_and_grad(loss,
                                               has_aux=True)(state.params,
                                                             batch, labels)
        state = state.apply_gradients(grads=grads)
        return state, loss_val, accuracy

    for i, (batch, labels) in enumerate(zip(train_data, train_labels)):
        state, loss_val, accuracy = step(state, batch, labels)
        if i % 100 == 0:
            print(
                f"step {i}/{nb_steps} | loss: {loss_val:.5f} | accuracy: {accuracy*100:.2f}%"
            )

    return params
    def train_step(
            state: train_state.TrainState, batch: Dict[str, Array],
            dropout_rng: PRNGKey) -> Tuple[train_state.TrainState, float]:
        """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
        targets = batch.pop("labels")

        def loss_fn(params):
            logits = state.apply_fn(**batch,
                                    params=params,
                                    dropout_rng=dropout_rng,
                                    train=True)[0]
            loss = state.loss_fn(logits, targets)
            return loss

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)
        metrics = jax.lax.pmean(
            {
                "loss": loss,
                "learning_rate": learning_rate_fn(state.step)
            },
            axis_name="batch")
        return new_state, metrics, new_dropout_rng
Beispiel #3
0
def fit(params: optax.Params,
        opt: optax.GradientTransformation) -> optax.Params:
    # bundle together everything in the `TrainState class`
    state = TrainState.create(
        apply_fn=net.apply,
        params=params,
        tx=opt,
    )

    # jit compile the step function
    @jax.jit
    def train_step(state, batch, labels):
        batch = jnp.transpose(batch, axes=(0, 2, 3, 1))
        labels = jax.nn.one_hot(labels, nb_classes)
        (loss_val, accuracy), grads = jax.value_and_grad(loss, has_aux=True)(
            state.params, batch, labels)  # return accuracy as aux
        state = state.apply_gradients(
            grads=grads
        )  # apply gradients to training state (calls other things internally)
        return state, loss_val, accuracy

    @jax.jit
    def eval_step(params, batch, labels):
        batch = jnp.transpose(batch, axes=(0, 2, 3, 1))
        labels = jax.nn.one_hot(labels, nb_classes)
        loss_val, accuracy = loss(params, batch, labels)
        return loss_val, accuracy

    for i in range(nb_epochs):
        train_loss, train_accuracy = 0.0, 0.0
        for batch, labels in train_loader:
            batch, labels = jnp.array(batch), jnp.array(labels)
            state, loss_val, accuracy = train_step(state, batch, labels)

            train_loss += loss_val
            train_accuracy += accuracy

        test_loss, test_accuracy = 0.0, 0.0
        for batch, labels in test_loader:
            batch, labels = jnp.array(batch), jnp.array(labels)
            loss_val, accuracy = eval_step(state.params, batch, labels)

            test_loss += loss_val
            test_accuracy += accuracy

        train_loss /= len(train_loader)
        train_accuracy /= len(train_loader)

        test_loss /= len(test_loader)
        test_accuracy /= len(test_loader)

        print(
            f"epoch {i+1}/{nb_epochs} | train: {train_loss:.5f} [{train_accuracy*100:.2f}%] | eval: {test_loss:.5f} [{test_accuracy*100:.2f}%]"
        )

    return params
Beispiel #4
0
def train_step(state: TrainState, batch):
    def compute_loss(params: Dict[str, Any]):
        inputs, labels = batch
        logits = state.apply_fn({"params": params}, inputs)
        return loss_fn(logits, labels)

    grad_fn = jax.value_and_grad(compute_loss)
    loss, grad = grad_fn(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)

    metrics = {"loss": loss}
    metrics = jax.lax.pmean(metrics, axis_name="batch")

    return new_state, metrics
Beispiel #5
0
def step(x, y, state: TrainState, training: bool):
    def loss_fn(params):
        y_pred = model.apply({"params": params}, x)
        y_one_hot = jax.nn.one_hot(y, 10)
        loss = optax.softmax_cross_entropy(y_pred, y_one_hot).mean()
        return loss, y_pred

    x = x.reshape(-1, 28 * 28)
    if training:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, y_pred), grads = grad_fn(state.params)
        state = state.apply_gradients(grads=grads)
    else:
        loss, y_pred = loss_fn(state.params)
    return loss, y_pred, state
def create_train_state(rng, model, img_size, lr_schedule_fn, weight_decay,
                       max_norm):

    tx = optax.chain(optax.clip_by_global_norm(max_norm),
                     optax.scale_by_adam(),
                     optax.additive_weight_decay(weight_decay),
                     optax.scale_by_schedule(lr_schedule_fn))

    params = model.init(rng,
                        jax.numpy.ones((1, img_size, img_size, 3)),
                        is_training=False)

    train_state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx,
    )
    return train_state
Beispiel #7
0
def train_step(
    state: train_state.TrainState,
    trajectories: Tuple,
    batch_size: int,
    *,
    clip_param: float,
    vf_coeff: float,
    entropy_coeff: float):
  """Compilable train step.

  Runs an entire epoch of training (i.e. the loop over minibatches within
  an epoch is included here for performance reasons).

  Args:
    state: the train state
    trajectories: Tuple of the following five elements forming the experience:
                  states: shape (steps_per_agent*num_agents, 84, 84, 4)
                  actions: shape (steps_per_agent*num_agents, 84, 84, 4)
                  old_log_probs: shape (steps_per_agent*num_agents, )
                  returns: shape (steps_per_agent*num_agents, )
                  advantages: (steps_per_agent*num_agents, )
    batch_size: the minibatch size, static argument
    clip_param: the PPO clipping parameter used to clamp ratios in loss function
    vf_coeff: weighs value function loss in total loss
    entropy_coeff: weighs entropy bonus in the total loss

  Returns:
    optimizer: new optimizer after the parameters update
    loss: loss summed over training steps
  """
  iterations = trajectories[0].shape[0] // batch_size
  trajectories = jax.tree_map(
      lambda x: x.reshape((iterations, batch_size) + x.shape[1:]), trajectories)
  loss = 0.
  for batch in zip(*trajectories):
    grad_fn = jax.value_and_grad(loss_fn)
    l, grads = grad_fn(state.params, state.apply_fn, batch, clip_param, vf_coeff,
                      entropy_coeff)
    loss += l
    state = state.apply_gradients(grads=grads)
  return state, loss
Beispiel #8
0
def train_step(
    train_state: ts.TrainState,
    model_vars: Dict[str, Any],
    batch: Dict[str, Any],
    dropout_rng: jnp.ndarray,
    model_config: ml_collections.FrozenConfigDict,
) -> Tuple[ts.TrainState, Dict[str, Any]]:
  """Perform a single training step.

  Args:
    train_state: contains model params, loss fn, grad update fn.
    model_vars: model variables that are not optimized.
    batch: input to model.
    dropout_rng: seed for dropout rng in model.
    model_config: contains model hyperparameters.

  Returns:
    Train state with updated parameters and dictionary of metrics.
  """

  dropout_rng = jax.random.fold_in(dropout_rng, train_state.step)

  def loss_fn_partial(model_params):
    loss, metrics, _ = train_state.apply_fn(
        model_config,
        model_params,
        model_vars,
        batch,
        deterministic=False,
        dropout_rng={'dropout': dropout_rng},
    )
    return loss, metrics

  grad_fn = jax.value_and_grad(loss_fn_partial, has_aux=True)
  (_, metrics), grad = grad_fn(train_state.params)
  grad = jax.lax.pmean(grad, 'batch')
  metrics = jax.lax.psum(metrics, axis_name='batch')
  metrics = metric_utils.update_metrics_dtype(metrics)
  new_train_state = train_state.apply_gradients(grads=grad)
  return new_train_state, metrics
Beispiel #9
0
def collate_fn(batch):
    inputs = np.stack([x[0] for x in batch], axis=0)
    labels = np.array([x[1] for x in batch])
    return {"inputs": inputs, "labels": labels}


if __name__ == "__main__":
    num_epochs = 3
    rng = random.PRNGKey(42)

    model = Model()
    dummy_inputs = np.ones((1, 28, 28, 1), np.float32)
    params = model.init(rng, dummy_inputs)["params"]
    tx = optax.adam(learning_rate=1e-3)
    train_state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

    train_loader = DataLoader(MNIST("data/mnist", download=True),
                              batch_size=64,
                              shuffle=True)
    eval_loader = DataLoader(MNIST("data/mnist", train=False), batch_size=64)

    p_train_state = replicate(train_state)
    p_train_step = jax.pmap(train_step, "batch")
    p_eval_step = jax.pmap(eval_step, "batch")

    for epoch in range(num_epochs):
        print(f"\nEpoch: {epoch}")
        rng, input_rng = random.split(rng)
        train_metrics = []
Beispiel #10
0
class Net(flax.linen.Module):
    @flax.linen.compact
    def __call__(self, x):
        x = flax.linen.Dense(128)(x)
        x = flax.linen.relu(x)
        x = flax.linen.Dense(32)(x)
        x = flax.linen.relu(x)
        x = flax.linen.Dense(10)(x)
        x = flax.linen.log_softmax(x)
        return x


model = Net()
params = model.init(jax.random.PRNGKey(42), numpy.ones((1, 28 * 28)))["params"]
optimizer = optax.adam(0.001)
state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)


@functools.partial(jax.jit, static_argnums=(3, ))
def step(x, y, state: TrainState, training: bool):
    def loss_fn(params):
        y_pred = model.apply({"params": params}, x)
        y_one_hot = jax.nn.one_hot(y, 10)
        loss = optax.softmax_cross_entropy(y_pred, y_one_hot).mean()
        return loss, y_pred

    x = x.reshape(-1, 28 * 28)
    if training:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, y_pred), grads = grad_fn(state.params)
        state = state.apply_gradients(grads=grads)
Beispiel #11
0
def train(base_dir, config):
    """Train function."""
    print(config)
    chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'train'))

    writer = create_default_writer()

    # Initialize dataset
    key = jax.random.PRNGKey(config.seed)
    key, subkey = jax.random.split(key)
    ds = dataset.get_dataset(config, subkey, num_tasks=config.num_tasks)
    ds_iter = iter(ds)

    key, subkey = jax.random.split(key)
    encoder = MLPEncoder(**config.encoder)

    train_config = config.train.to_dict()
    train_method = train_config.pop('method')

    module_config = train_config.pop('module')
    module_class = module_config.pop('name')

    module = globals().get(module_class)(encoder, **module_config)
    train_step = globals().get(f'train_step_{train_method}')
    train_step = functools.partial(train_step, **train_config)

    params = module.init(subkey, next(ds_iter)[0])
    lr = optax.cosine_decay_schedule(config.learning_rate,
                                     config.num_train_steps)
    optim = optax.chain(optax.adam(lr),
                        # optax.adaptive_grad_clip(0.15)
                        )

    state = TrainState.create(apply_fn=module.apply, params=params, tx=optim)
    state = chkpt_manager.restore_or_initialize(state)

    # Hooks
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    hooks = [
        report_progress,
        periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir))
    ]

    def handle_preemption(signal_number, _):
        logging.info('Received signal %d, saving checkpoint.', signal_number)
        with report_progress.timed('checkpointing'):
            chkpt_manager.save(state)
        logging.info('Finished saving checkpoint.')

    signal.signal(signal.SIGTERM, handle_preemption)

    metrics = TrainMetrics.empty()
    with metric_writers.ensure_flushes(writer):
        for step in tqdm.tqdm(range(state.step, config.num_train_steps)):
            with jax.profiler.StepTraceAnnotation('train', step_num=step):
                states, targets = next(ds_iter)
                state, metrics = train_step(state, metrics, states, targets)

            logging.log_first_n(logging.INFO, 'Finished training step %d', 5,
                                step)

            if step % config.log_metrics_every == 0:
                writer.write_scalars(step, metrics.compute())
                metrics = TrainMetrics.empty()

            # if step % config.log_eval_metrics_every == 0 and isinstance(
            #     ds, dataset.MDPDataset):
            #   eval_metrics = evaluate_mdp(state, ds.aux_task_matrix, config)
            #   writer.write_scalars(step, eval_metrics.compute())

            for hook in hooks:
                hook(step)

    chkpt_manager.save(state)
    return state
Beispiel #12
0
def evaluate(base_dir, config, *, train_state):
    """Eval function."""
    chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'eval'))

    writer = create_default_writer()

    key = jax.random.PRNGKey(config.eval.seed)
    model_init_key, ds_key = jax.random.split(key)

    linear_module = LinearModule(config.eval.num_tasks)
    params = linear_module.init(model_init_key,
                                jnp.zeros((config.encoder.embedding_dim, )))
    lr = optax.cosine_decay_schedule(config.eval.learning_rate,
                                     config.num_eval_steps)
    optim = optax.adam(lr)

    ds = dataset.get_dataset(config, ds_key, num_tasks=config.eval.num_tasks)
    ds_iter = iter(ds)

    state = TrainState.create(apply_fn=linear_module.apply,
                              params=params,
                              tx=optim)
    state = chkpt_manager.restore_or_initialize(state)

    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_eval_steps, writer=writer)
    hooks = [
        report_progress,
        periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir))
    ]

    def handle_preemption(signal_number, _):
        logging.info('Received signal %d, saving checkpoint.', signal_number)
        with report_progress.timed('checkpointing'):
            chkpt_manager.save(state)
        logging.info('Finished saving checkpoint.')

    signal.signal(signal.SIGTERM, handle_preemption)

    metrics = EvalMetrics.empty()
    with metric_writers.ensure_flushes(writer):
        for step in tqdm.tqdm(range(state.step, config.num_eval_steps)):
            with jax.profiler.StepTraceAnnotation('eval', step_num=step):
                states, targets = next(ds_iter)
                state, metrics = evaluate_step(train_state, state, metrics,
                                               states, targets)

            if step % config.log_metrics_every == 0:
                writer.write_scalars(step, metrics.compute())
                metrics = EvalMetrics.empty()

            for hook in hooks:
                hook(step)

        # Finally, evaluate on the true(ish) test aux task matrix.
        states, targets = dataset.EvalDataset(config, ds_key).get_batch()

        @jax.jit
        def loss_fn():
            outputs = train_state.apply_fn(train_state.params, states)
            phis = outputs.phi
            predictions = jax.vmap(state.apply_fn,
                                   in_axes=(None, 0))(state.params, phis)
            return jnp.mean(optax.l2_loss(predictions, targets))

        test_loss = loss_fn()
        writer.write_scalars(config.num_eval_steps + 1,
                             {'test_loss': test_loss})