Beispiel #1
0
def train_for_one_epoch(
    dataset_source: dataset_source_lib.DatasetSource,
    optimizer: flax.optim.Optimizer, state: flax.nn.Collection,
    prng_key: jnp.ndarray, pmapped_train_step: _TrainStep,
    pmapped_update_ema: Optional[_EMAUpdateStep],
    moving_averages: Optional[efficientnet_optim.ExponentialMovingAverage],
    summary_writer: tensorboard.SummaryWriter
) -> Tuple[flax.optim.Optimizer, flax.nn.Collection,
           Optional[efficientnet_optim.ExponentialMovingAverage]]:
  """Trains the model for one epoch.

  Args:
    dataset_source: Container for the training dataset.
    optimizer: The optimizer targeting the model to train.
    state: Current state associated with the model (contains the batch norm MA).
    prng_key: A PRNG key to use for stochasticity (e.g. for sampling an eventual
      dropout mask). Is not used for shuffling the dataset.
    pmapped_train_step: A pmapped version of the `train_step` function (see its
      documentation for more details).
    pmapped_update_ema: Function to update the parameter moving average. Can be
      None if we don't use EMA.
    moving_averages: Parameters moving average if used.
    summary_writer: A Tensorboard SummaryWriter to use to log metrics.

  Returns:
    The updated optimizer (with the associated updated model), state and PRNG
      key.
  """
  start_time = time.time()
  cnt = 0
  train_metrics = []
  for batch in dataset_source.get_train(use_augmentations=True):
    # Generate a PRNG key that will be rolled into the batch.
    step_key = jax.random.fold_in(prng_key, optimizer.state.step[0])
    # Load and shard the TF batch.
    batch = tensorflow_to_numpy(batch)
    batch = shard_batch(batch)
    # Shard the step PRNG key.
    sharded_keys = common_utils.shard_prng_key(step_key)

    optimizer, state, metrics, lr = pmapped_train_step(
        optimizer, state, batch, sharded_keys)
    cnt += 1

    if moving_averages is not None:
      moving_averages = pmapped_update_ema(optimizer, state, moving_averages)

    train_metrics.append(metrics)
  train_metrics = common_utils.get_metrics(train_metrics)
  # Get training epoch summary for logging.
  train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
  train_summary['learning_rate'] = lr[0]
  current_step = int(optimizer.state.step[0])
  info = 'Whole training step done in {} ({} steps)'.format(
      time.time()-start_time, cnt)
  logging.info(info)
  for metric_name, metric_value in train_summary.items():
    summary_writer.scalar(metric_name, metric_value, current_step)
  summary_writer.flush()
  return optimizer, state, moving_averages
Beispiel #2
0
def process_iterator(tag: str,
                     item_ids: Sequence[str],
                     iterator,
                     rng: types.PRNGKey,
                     state: model_utils.TrainState,
                     step: int,
                     render_fn: Any,
                     summary_writer: tensorboard.SummaryWriter,
                     save_dir: Optional[gpath.GPath],
                     datasource: datasets.DataSource):
  """Process a dataset iterator and compute metrics."""
  save_dir = save_dir / f'{step:08d}' / tag if save_dir else None
  meters = collections.defaultdict(utils.ValueMeter)
  for i, (item_id, batch) in enumerate(zip(item_ids, iterator)):
    logging.info('[%s:%d/%d] Processing %s ', tag, i+1, len(item_ids), item_id)
    if tag == 'test':
      test_rng = random.PRNGKey(step)
      shape = batch['origins'][..., :1].shape
      metadata = {}
      if datasource.use_appearance_id:
        appearance_id = random.choice(
            test_rng, jnp.asarray(datasource.appearance_ids))
        logging.info('\tUsing appearance_id = %d', appearance_id)
        metadata['appearance'] = jnp.full(shape, fill_value=appearance_id,
                                          dtype=jnp.uint32)
      if datasource.use_warp_id:
        warp_id = random.choice(test_rng, jnp.asarray(datasource.warp_ids))
        logging.info('\tUsing warp_id = %d', warp_id)
        metadata['warp'] = jnp.full(shape, fill_value=warp_id, dtype=jnp.uint32)
      if datasource.use_camera_id:
        camera_id = random.choice(test_rng, jnp.asarray(datasource.camera_ids))
        logging.info('\tUsing camera_id = %d', camera_id)
        metadata['camera'] = jnp.full(shape, fill_value=camera_id,
                                      dtype=jnp.uint32)
      if datasource.use_time:
        timestamp = random.uniform(test_rng, minval=0.0, maxval=1.0)
        logging.info('\tUsing time = %d', timestamp)
        metadata['time'] = jnp.full(
            shape, fill_value=timestamp, dtype=jnp.uint32)
      batch['metadata'] = metadata

    stats = process_batch(batch=batch,
                          rng=rng,
                          state=state,
                          tag=tag,
                          item_id=item_id,
                          step=step,
                          render_fn=render_fn,
                          summary_writer=summary_writer,
                          save_dir=save_dir,
                          datasource=datasource)
    if jax.process_index() == 0:
      for k, v in stats.items():
        meters[k].update(v)

  if jax.process_index() == 0:
    for meter_name, meter in meters.items():
      summary_writer.scalar(tag=f'metrics-eval/{meter_name}/{tag}',
                            value=meter.reduce('mean'),
                            step=step)
Beispiel #3
0
  def test_summarywriter_scalar(self):
    log_dir = tempfile.mkdtemp()
    summary_writer = SummaryWriter(log_dir=log_dir)
    # Write the scalar and check if the event exists and check data.
    float_value = 99.1232
    summary_writer.scalar(tag='scalar_test', value=float_value, step=1)

    summary_value = self.parse_and_return_summary_value(path=log_dir)
    self.assertEqual(summary_value.tag, 'scalar_test')
    self.assertTrue(onp.allclose(
        tensor_util.make_ndarray(summary_value.tensor).item(),
        float_value))
Beispiel #4
0
def local_train_loop(key,
                     init_params,
                     loss_fn,
                     summarize_fn=default_summarize,
                     lr=1e-4,
                     num_steps=int(1e5),
                     summarize_every=100,
                     checkpoint_every=5000,
                     clobber_checkpoint=False,
                     logdir="/tmp/lda_inference"):

    optimizer_def = optim.Adam()
    optimizer = optimizer_def.create(init_params)
    optimizer = util.maybe_load_checkpoint(
        logdir, optimizer, clobber_checkpoint=clobber_checkpoint)
    lr_fn = util.create_learning_rate_scheduler(base_learning_rate=lr)

    def train_step(optimizer, key):
        loss_val, loss_grad = jax.value_and_grad(loss_fn,
                                                 argnums=0)(optimizer.target,
                                                            key)
        new_optimizer = optimizer.apply_gradient(loss_grad,
                                                 learning_rate=lr_fn(
                                                     optimizer.state.step))
        return loss_val, new_optimizer

    train_step = jit(train_step)

    sw = SummaryWriter(logdir)

    start = timeit.default_timer()
    first_step = optimizer.state.step
    for t in range(optimizer.state.step, num_steps):
        if t % checkpoint_every == 0 and t != first_step:
            checkpoints.save_checkpoint(logdir,
                                        optimizer,
                                        optimizer.state.step,
                                        keep=3)
            print("Checkpoint saved for step %d" % optimizer.state.step)
        key, subkey = jax.random.split(key)
        try:
            loss_val, new_optimizer = train_step(optimizer, subkey)
        except FloatingPointError as e:
            print("Exception on step %d" % t)
            print(e)
            traceback.print_exc()
            checkpoints.save_checkpoint(logdir,
                                        optimizer,
                                        optimizer.state.step,
                                        keep=3)
            print("Checkpoint saved for step %d" % optimizer.state.step)
            print("key ", subkey)
            sys.stdout.flush()
            sys.exit(1)
        optimizer = new_optimizer
        if t % summarize_every == 0:
            key, subkey = jax.random.split(key)
            print("Step %d loss: %0.4f" % (t, loss_val))
            sw.scalar("loss", loss_val, step=t)
            summarize_fn(sw, t, optimizer.target, subkey)
            end = timeit.default_timer()
            if t == 0:
                steps_per_sec = 1. / (end - start)
            else:
                steps_per_sec = summarize_every / (end - start)
            print("Steps/sec: %0.2f" % steps_per_sec)
            sw.scalar("steps_per_sec", steps_per_sec, step=t)
            start = end
            sw.flush()
            sys.stdout.flush()
Beispiel #5
0
def parallel_train_loop(key,
                        init_params,
                        loss_fn,
                        summarize_fn=default_summarize,
                        lr=1e-4,
                        num_steps=int(1e5),
                        summarize_every=100,
                        checkpoint_every=5000,
                        clobber_checkpoint=False,
                        logdir="/tmp/lda_inference"):

    loss_fn = jax.jit(loss_fn)

    optimizer_def = optim.Adam()
    local_optimizer = optimizer_def.create(init_params)
    local_optimizer = util.maybe_load_checkpoint(
        logdir, local_optimizer, clobber_checkpoint=clobber_checkpoint)
    first_step = local_optimizer.state.step
    repl_optimizer = jax_utils.replicate(local_optimizer)

    lr_fn = util.create_learning_rate_scheduler(base_learning_rate=lr)

    @functools.partial(jax.pmap, axis_name="batch")
    def train_step(optimizer, key):
        key, subkey = jax.random.split(key)
        loss_grad = jax.grad(loss_fn, argnums=0)(optimizer.target, key)
        loss_grad = jax.lax.pmean(loss_grad, "batch")
        new_optimizer = optimizer.apply_gradient(loss_grad,
                                                 learning_rate=lr_fn(
                                                     optimizer.state.step))
        return new_optimizer, subkey

    sw = SummaryWriter(logdir)

    repl_key = jax.pmap(jax.random.PRNGKey)(jnp.arange(
        jax.local_device_count()))
    start = timeit.default_timer()
    for t in range(first_step, num_steps):
        if t % checkpoint_every == 0 and t != first_step:
            optimizer = jax_utils.unreplicate(repl_optimizer)
            checkpoints.save_checkpoint(logdir,
                                        optimizer,
                                        optimizer.state.step,
                                        keep=3)
            print("Checkpoint saved for step %d" % optimizer.state.step)

        repl_optimizer, repl_key = train_step(repl_optimizer, repl_key)

        if t % summarize_every == 0:
            key, subkey = jax.random.split(jax_utils.unreplicate(repl_key))
            optimizer = jax_utils.unreplicate(repl_optimizer)
            loss_val = loss_fn(optimizer.target, key)
            print("Step %d loss: %0.4f" % (t, loss_val))
            sw.scalar("loss", loss_val, step=t)
            summarize_fn(sw, t, optimizer.target, subkey)
            end = timeit.default_timer()
            if t == 0:
                steps_per_sec = 1. / (end - start)
            else:
                steps_per_sec = summarize_every / (end - start)
            print("Steps/sec: %0.2f" % steps_per_sec)
            sw.scalar("steps_per_sec", steps_per_sec, step=t)
            start = end
            sw.flush()
            sys.stdout.flush()
        for i, batch_idx in enumerate(
                tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
            samples = [
                tokenized_datasets["validation"][int(idx)] for idx in batch_idx
            ]
            model_inputs = data_collator(samples, pad_to_multiple_of=16)

            # Model forward
            model_inputs = common_utils.shard(model_inputs.data)
            metrics = p_eval_step(optimizer.target, model_inputs)
            eval_metrics.append(metrics)

        eval_metrics_np = get_metrics(eval_metrics)
        eval_metrics_np = jax.tree_map(jnp.sum, eval_metrics_np)
        eval_normalizer = eval_metrics_np.pop("normalizer")
        eval_summary = jax.tree_map(lambda x: x / eval_normalizer,
                                    eval_metrics_np)

        # Update progress bar
        epochs.desc = (
            f"Epoch... ({epoch + 1}/{nb_epochs} | Loss: {eval_summary['loss']}, Acc: {eval_summary['accuracy']})"
        )

        if wandb_args.wandb_user_name is not None:
            wandb.log({"Eval loss": np.array(eval_summary["loss"]).mean()})

        # Save metrics
        if has_tensorboard and jax.host_id() == 0:
            for name, value in eval_summary.items():
                summary_writer.scalar(name, value, epoch)
Beispiel #7
0
def process_iterator(
    tag: str,
    item_ids: Sequence[str],
    iterator,
    rng: types.PRNGKey,
    state: model_utils.TrainState,
    step: int,
    render_fn: Any,
    summary_writer: tensorboard.SummaryWriter,
    save_dir: Optional[gpath.GPath],
    datasource: datasets.DataSource,
):
    """Process a dataset iterator and compute metrics."""
    save_dir = save_dir / f"{step:08d}" / tag if save_dir else None
    meters = collections.defaultdict(utils.ValueMeter)
    for i, (item_id, batch) in enumerate(zip(item_ids, iterator)):
        logging.info("[%s:%d/%d] Processing %s ", tag, i + 1, len(item_ids),
                     item_id)
        if tag == "test":
            test_rng = random.PRNGKey(step)
            shape = batch["origins"][..., :1].shape
            metadata = {}
            if datasource.use_appearance_id:
                appearance_id = random.choice(
                    test_rng, jnp.asarray(datasource.appearance_ids))
                logging.info("\tUsing appearance_id = %d", appearance_id)
                metadata["appearance"] = jnp.full(shape,
                                                  fill_value=appearance_id,
                                                  dtype=jnp.uint32)
            if datasource.use_warp_id:
                warp_id = random.choice(test_rng,
                                        jnp.asarray(datasource.warp_ids))
                logging.info("\tUsing warp_id = %d", warp_id)
                metadata["warp"] = jnp.full(shape,
                                            fill_value=warp_id,
                                            dtype=jnp.uint32)
            if datasource.use_camera_id:
                camera_id = random.choice(test_rng,
                                          jnp.asarray(datasource.camera_ids))
                logging.info("\tUsing camera_id = %d", camera_id)
                metadata["camera"] = jnp.full(shape,
                                              fill_value=camera_id,
                                              dtype=jnp.uint32)
            batch["metadata"] = metadata

        stats = process_batch(
            batch=batch,
            rng=rng,
            state=state,
            tag=tag,
            item_id=item_id,
            step=step,
            render_fn=render_fn,
            summary_writer=summary_writer,
            save_dir=save_dir,
            datasource=datasource,
        )
        if jax.host_id() == 0:
            for k, v in stats.items():
                meters[k].update(v)

    if jax.host_id() == 0:
        for meter_name, meter in meters.items():
            summary_writer.scalar(
                tag=f"metrics-eval/{meter_name}/{tag}",
                value=meter.reduce("mean"),
                step=step,
            )
Beispiel #8
0
def _log_to_tensorboard(writer: tensorboard.SummaryWriter,
                        state: model_utils.TrainState,
                        scalar_params: training.ScalarParams,
                        stats: Dict[str, Union[Dict[str, jnp.ndarray],
                                               jnp.ndarray]],
                        time_dict: Dict[str, jnp.ndarray]):
    """Log statistics to Tensorboard."""
    step = int(state.optimizer.state.step)
    writer.scalar('params/learning_rate', scalar_params.learning_rate, step)
    writer.scalar('params/warp_alpha', state.warp_alpha, step)
    writer.scalar('params/time_alpha', state.time_alpha, step)
    writer.scalar('params/elastic_loss/weight',
                  scalar_params.elastic_loss_weight, step)

    # pmean is applied in train_step so just take the item.
    for branch in {'coarse', 'fine'}:
        if branch not in stats:
            continue
        for stat_key, stat_value in stats[branch].items():
            writer.scalar(f'{stat_key}/{branch}', stat_value, step)

    if 'background_loss' in stats:
        writer.scalar('loss/background', stats['background_loss'], step)

    for k, v in time_dict.items():
        writer.scalar(f'time/{k}', v, step)
Beispiel #9
0
def _log_to_tensorboard(
    writer: tensorboard.SummaryWriter,
    state: model_utils.TrainState,
    scalar_params: training.ScalarParams,
    stats: Dict[str, Union[Dict[str, jnp.ndarray], jnp.ndarray]],
    time_dict: Dict[str, jnp.ndarray],
):
    """Log statistics to Tensorboard."""
    step = int(state.optimizer.state.step)
    writer.scalar("params/learning_rate", scalar_params.learning_rate, step)
    writer.scalar("params/warp_alpha", state.warp_alpha, step)
    writer.scalar("params/elastic_loss/weight",
                  scalar_params.elastic_loss_weight, step)

    # pmean is applied in train_step so just take the item.
    for branch in {"coarse", "fine"}:
        if branch not in stats:
            continue
        for stat_key, stat_value in stats[branch].items():
            writer.scalar(f"{stat_key}/{branch}", stat_value, step)

    if "background_loss" in stats:
        writer.scalar("losses/background", stats["background_loss"], step)

    params = state.optimizer.target["model"]
    if "appearance_encoder" in params:
        embeddings = params["appearance_encoder"]["embed"]["embedding"]
        writer.histogram("appearance_embedding", embeddings, step)
    if "camera_encoder" in params:
        embeddings = params["camera_encoder"]["embed"]["embedding"]
        writer.histogram("camera_embedding", embeddings, step)
    if "warp_field" in params:
        embeddings = params["warp_field"]["metadata_encoder"]["embed"][
            "embedding"]
        writer.histogram("warp_embedding", embeddings, step)

    for k, v in time_dict.items():
        writer.scalar(f"time/{k}", v, step)
Beispiel #10
0
def _log_to_tensorboard(writer: tensorboard.SummaryWriter,
                        state: model_utils.TrainState,
                        scalar_params: training.ScalarParams,
                        stats: Dict[str, Union[Dict[str, jnp.ndarray],
                                               jnp.ndarray]],
                        time_dict: Dict[str, jnp.ndarray]):
    """Log statistics to Tensorboard."""
    step = int(state.optimizer.state.step)
    writer.scalar('params/learning_rate', scalar_params.learning_rate, step)
    writer.scalar('params/warp_alpha', state.warp_alpha, step)
    writer.scalar('params/elastic_loss/weight',
                  scalar_params.elastic_loss_weight, step)

    # pmean is applied in train_step so just take the item.
    for branch in {'coarse', 'fine'}:
        if branch not in stats:
            continue
        for stat_key, stat_value in stats[branch].items():
            writer.scalar(f'{stat_key}/{branch}', stat_value, step)

    if 'background_loss' in stats:
        writer.scalar('losses/background', stats['background_loss'], step)

    params = state.optimizer.target['model']
    if 'appearance_encoder' in params:
        embeddings = params['appearance_encoder']['embed']['embedding']
        writer.histogram('appearance_embedding', embeddings, step)
    if 'camera_encoder' in params:
        embeddings = params['camera_encoder']['embed']['embedding']
        writer.histogram('camera_embedding', embeddings, step)
    if 'warp_field' in params:
        embeddings = params['warp_field']['metadata_encoder']['embed'][
            'embedding']
        writer.histogram('warp_embedding', embeddings, step)

    for k, v in time_dict.items():
        writer.scalar(f'time/{k}', v, step)