Exemplo n.º 1
0
def train_for_one_epoch(
    dataset_source: dataset_source_lib.DatasetSource,
    optimizer: flax.optim.Optimizer, state: flax.nn.Collection,
    prng_key: jnp.ndarray, pmapped_train_step: _TrainStep,
    pmapped_update_ema: Optional[_EMAUpdateStep],
    moving_averages: Optional[efficientnet_optim.ExponentialMovingAverage],
    summary_writer: tensorboard.SummaryWriter
) -> Tuple[flax.optim.Optimizer, flax.nn.Collection,
           Optional[efficientnet_optim.ExponentialMovingAverage]]:
  """Trains the model for one epoch.

  Args:
    dataset_source: Container for the training dataset.
    optimizer: The optimizer targeting the model to train.
    state: Current state associated with the model (contains the batch norm MA).
    prng_key: A PRNG key to use for stochasticity (e.g. for sampling an eventual
      dropout mask). Is not used for shuffling the dataset.
    pmapped_train_step: A pmapped version of the `train_step` function (see its
      documentation for more details).
    pmapped_update_ema: Function to update the parameter moving average. Can be
      None if we don't use EMA.
    moving_averages: Parameters moving average if used.
    summary_writer: A Tensorboard SummaryWriter to use to log metrics.

  Returns:
    The updated optimizer (with the associated updated model), state and PRNG
      key.
  """
  start_time = time.time()
  cnt = 0
  train_metrics = []
  for batch in dataset_source.get_train(use_augmentations=True):
    # Generate a PRNG key that will be rolled into the batch.
    step_key = jax.random.fold_in(prng_key, optimizer.state.step[0])
    # Load and shard the TF batch.
    batch = tensorflow_to_numpy(batch)
    batch = shard_batch(batch)
    # Shard the step PRNG key.
    sharded_keys = common_utils.shard_prng_key(step_key)

    optimizer, state, metrics, lr = pmapped_train_step(
        optimizer, state, batch, sharded_keys)
    cnt += 1

    if moving_averages is not None:
      moving_averages = pmapped_update_ema(optimizer, state, moving_averages)

    train_metrics.append(metrics)
  train_metrics = common_utils.get_metrics(train_metrics)
  # Get training epoch summary for logging.
  train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
  train_summary['learning_rate'] = lr[0]
  current_step = int(optimizer.state.step[0])
  info = 'Whole training step done in {} ({} steps)'.format(
      time.time()-start_time, cnt)
  logging.info(info)
  for metric_name, metric_value in train_summary.items():
    summary_writer.scalar(metric_name, metric_value, current_step)
  summary_writer.flush()
  return optimizer, state, moving_averages
Exemplo n.º 2
0
def local_train_loop(key,
                     init_params,
                     loss_fn,
                     summarize_fn=default_summarize,
                     lr=1e-4,
                     num_steps=int(1e5),
                     summarize_every=100,
                     checkpoint_every=5000,
                     clobber_checkpoint=False,
                     logdir="/tmp/lda_inference"):

    optimizer_def = optim.Adam()
    optimizer = optimizer_def.create(init_params)
    optimizer = util.maybe_load_checkpoint(
        logdir, optimizer, clobber_checkpoint=clobber_checkpoint)
    lr_fn = util.create_learning_rate_scheduler(base_learning_rate=lr)

    def train_step(optimizer, key):
        loss_val, loss_grad = jax.value_and_grad(loss_fn,
                                                 argnums=0)(optimizer.target,
                                                            key)
        new_optimizer = optimizer.apply_gradient(loss_grad,
                                                 learning_rate=lr_fn(
                                                     optimizer.state.step))
        return loss_val, new_optimizer

    train_step = jit(train_step)

    sw = SummaryWriter(logdir)

    start = timeit.default_timer()
    first_step = optimizer.state.step
    for t in range(optimizer.state.step, num_steps):
        if t % checkpoint_every == 0 and t != first_step:
            checkpoints.save_checkpoint(logdir,
                                        optimizer,
                                        optimizer.state.step,
                                        keep=3)
            print("Checkpoint saved for step %d" % optimizer.state.step)
        key, subkey = jax.random.split(key)
        try:
            loss_val, new_optimizer = train_step(optimizer, subkey)
        except FloatingPointError as e:
            print("Exception on step %d" % t)
            print(e)
            traceback.print_exc()
            checkpoints.save_checkpoint(logdir,
                                        optimizer,
                                        optimizer.state.step,
                                        keep=3)
            print("Checkpoint saved for step %d" % optimizer.state.step)
            print("key ", subkey)
            sys.stdout.flush()
            sys.exit(1)
        optimizer = new_optimizer
        if t % summarize_every == 0:
            key, subkey = jax.random.split(key)
            print("Step %d loss: %0.4f" % (t, loss_val))
            sw.scalar("loss", loss_val, step=t)
            summarize_fn(sw, t, optimizer.target, subkey)
            end = timeit.default_timer()
            if t == 0:
                steps_per_sec = 1. / (end - start)
            else:
                steps_per_sec = summarize_every / (end - start)
            print("Steps/sec: %0.2f" % steps_per_sec)
            sw.scalar("steps_per_sec", steps_per_sec, step=t)
            start = end
            sw.flush()
            sys.stdout.flush()
Exemplo n.º 3
0
def parallel_train_loop(key,
                        init_params,
                        loss_fn,
                        summarize_fn=default_summarize,
                        lr=1e-4,
                        num_steps=int(1e5),
                        summarize_every=100,
                        checkpoint_every=5000,
                        clobber_checkpoint=False,
                        logdir="/tmp/lda_inference"):

    loss_fn = jax.jit(loss_fn)

    optimizer_def = optim.Adam()
    local_optimizer = optimizer_def.create(init_params)
    local_optimizer = util.maybe_load_checkpoint(
        logdir, local_optimizer, clobber_checkpoint=clobber_checkpoint)
    first_step = local_optimizer.state.step
    repl_optimizer = jax_utils.replicate(local_optimizer)

    lr_fn = util.create_learning_rate_scheduler(base_learning_rate=lr)

    @functools.partial(jax.pmap, axis_name="batch")
    def train_step(optimizer, key):
        key, subkey = jax.random.split(key)
        loss_grad = jax.grad(loss_fn, argnums=0)(optimizer.target, key)
        loss_grad = jax.lax.pmean(loss_grad, "batch")
        new_optimizer = optimizer.apply_gradient(loss_grad,
                                                 learning_rate=lr_fn(
                                                     optimizer.state.step))
        return new_optimizer, subkey

    sw = SummaryWriter(logdir)

    repl_key = jax.pmap(jax.random.PRNGKey)(jnp.arange(
        jax.local_device_count()))
    start = timeit.default_timer()
    for t in range(first_step, num_steps):
        if t % checkpoint_every == 0 and t != first_step:
            optimizer = jax_utils.unreplicate(repl_optimizer)
            checkpoints.save_checkpoint(logdir,
                                        optimizer,
                                        optimizer.state.step,
                                        keep=3)
            print("Checkpoint saved for step %d" % optimizer.state.step)

        repl_optimizer, repl_key = train_step(repl_optimizer, repl_key)

        if t % summarize_every == 0:
            key, subkey = jax.random.split(jax_utils.unreplicate(repl_key))
            optimizer = jax_utils.unreplicate(repl_optimizer)
            loss_val = loss_fn(optimizer.target, key)
            print("Step %d loss: %0.4f" % (t, loss_val))
            sw.scalar("loss", loss_val, step=t)
            summarize_fn(sw, t, optimizer.target, subkey)
            end = timeit.default_timer()
            if t == 0:
                steps_per_sec = 1. / (end - start)
            else:
                steps_per_sec = summarize_every / (end - start)
            print("Steps/sec: %0.2f" % steps_per_sec)
            sw.scalar("steps_per_sec", steps_per_sec, step=t)
            start = end
            sw.flush()
            sys.stdout.flush()
Exemplo n.º 4
0
 def test_summarywriter_flush_after_close(self):
   log_dir = tempfile.mkdtemp()
   summary_writer = SummaryWriter(log_dir=log_dir)
   summary_writer.close()
   with self.assertRaises(AttributeError):
     summary_writer.flush()