Exemplo n.º 1
0
def main(argv):
  tf.config.experimental.set_visible_devices([], 'GPU')
  del argv
  logging.info('*** Starting experiment')
  gin_configs = FLAGS.gin_configs

  logging.info('*** Loading Gin configs from: %s', str(gin_configs))
  gin.parse_config_files_and_bindings(
      config_files=gin_configs,
      bindings=FLAGS.gin_bindings,
      skip_unknown=True)

  # Load configurations.
  exp_config = configs.ExperimentConfig()
  model_config = configs.ModelConfig(use_stratified_sampling=False)
  train_config = configs.TrainConfig()
  eval_config = configs.EvalConfig()

  # Get directory information.
  exp_dir = gpath.GPath(FLAGS.base_folder)
  if exp_config.subname:
    exp_dir = exp_dir / exp_config.subname
  logging.info('\texp_dir = %s', exp_dir)
  if not exp_dir.exists():
    exp_dir.mkdir(parents=True, exist_ok=True)

  summary_dir = exp_dir / 'summaries' / 'eval'
  logging.info('\tsummary_dir = %s', summary_dir)
  if not summary_dir.exists():
    summary_dir.mkdir(parents=True, exist_ok=True)

  renders_dir = exp_dir / 'renders'
  logging.info('\trenders_dir = %s', renders_dir)
  if not renders_dir.exists():
    renders_dir.mkdir(parents=True, exist_ok=True)

  checkpoint_dir = exp_dir / 'checkpoints'
  logging.info('\tcheckpoint_dir = %s', checkpoint_dir)

  logging.info('Starting host %d. There are %d hosts : %s', jax.process_index(),
               jax.process_count(), str(jax.process_indexs()))
  logging.info('Found %d accelerator devices: %s.', jax.local_device_count(),
               str(jax.local_devices()))
  logging.info('Found %d total devices: %s.', jax.device_count(),
               str(jax.devices()))

  rng = random.PRNGKey(20200823)

  devices_to_use = jax.local_devices()
  n_devices = len(
      devices_to_use) if devices_to_use else jax.local_device_count()

  datasource_spec = exp_config.datasource_spec
  if datasource_spec is None:
    datasource_spec = {
        'type': exp_config.datasource_type,
        'data_dir': FLAGS.data_dir,
    }
  logging.info('Creating datasource: %s', datasource_spec)
  datasource = datasets.from_config(
      datasource_spec,
      image_scale=exp_config.image_scale,
      use_appearance_id=model_config.use_appearance_metadata,
      use_camera_id=model_config.use_camera_metadata,
      use_warp_id=model_config.use_warp,
      use_time=model_config.warp_metadata_encoder_type == 'time',
      random_seed=exp_config.random_seed,
      **exp_config.datasource_kwargs)

  # Get training IDs to evaluate.
  train_eval_ids = utils.strided_subset(
      datasource.train_ids, eval_config.num_train_eval)
  train_eval_iter = datasource.create_iterator(train_eval_ids, batch_size=0)
  val_eval_ids = utils.strided_subset(
      datasource.val_ids, eval_config.num_val_eval)
  val_eval_iter = datasource.create_iterator(val_eval_ids, batch_size=0)

  test_cameras = datasource.load_test_cameras(count=eval_config.num_test_eval)
  if test_cameras:
    test_dataset = datasource.create_cameras_dataset(test_cameras)
    test_eval_ids = [f'{x:03d}' for x in range(len(test_cameras))]
    test_eval_iter = datasets.iterator_from_dataset(test_dataset, batch_size=0)
  else:
    test_eval_ids = None
    test_eval_iter = None

  rng, key = random.split(rng)
  params = {}
  model, params['model'] = models.construct_nerf(
      key,
      model_config,
      batch_size=eval_config.chunk,
      appearance_ids=datasource.appearance_ids,
      camera_ids=datasource.camera_ids,
      warp_ids=datasource.warp_ids,
      near=datasource.near,
      far=datasource.far,
      use_warp_jacobian=False,
      use_weights=False)

  optimizer_def = optim.Adam(0.0)
  optimizer = optimizer_def.create(params)
  init_state = model_utils.TrainState(optimizer=optimizer)
  del params

  def _model_fn(key_0, key_1, params, rays_dict, warp_extra):
    out = model.apply({'params': params},
                      rays_dict,
                      warp_extra=warp_extra,
                      rngs={
                          'coarse': key_0,
                          'fine': key_1
                      },
                      mutable=False)
    return jax.lax.all_gather(out, axis_name='batch')

  pmodel_fn = jax.pmap(
      # Note rng_keys are useless in eval mode since there's no randomness.
      _model_fn,
      in_axes=(0, 0, 0, 0, 0),  # Only distribute the data input.
      devices=devices_to_use,
      donate_argnums=(3,),  # Donate the 'rays' argument.
      axis_name='batch',
  )

  render_fn = functools.partial(evaluation.render_image,
                                model_fn=pmodel_fn,
                                device_count=n_devices,
                                chunk=eval_config.chunk)

  last_step = 0
  summary_writer = tensorboard.SummaryWriter(str(summary_dir))

  while True:
    if not checkpoint_dir.exists():
      logging.info('No checkpoints yet.')
      time.sleep(10)
      continue

    state = checkpoints.restore_checkpoint(checkpoint_dir, init_state)
    state = jax_utils.replicate(state, devices=devices_to_use)
    step = int(state.optimizer.state.step[0])
    if step <= last_step:
      logging.info('No new checkpoints (%d <= %d).', step, last_step)
      time.sleep(10)
      continue

    save_dir = renders_dir if eval_config.save_output else None
    process_iterator(
        tag='val',
        item_ids=val_eval_ids,
        iterator=val_eval_iter,
        state=state,
        rng=rng,
        step=step,
        render_fn=render_fn,
        summary_writer=summary_writer,
        save_dir=save_dir,
        datasource=datasource)

    process_iterator(tag='train',
                     item_ids=train_eval_ids,
                     iterator=train_eval_iter,
                     state=state,
                     rng=rng,
                     step=step,
                     render_fn=render_fn,
                     summary_writer=summary_writer,
                     save_dir=save_dir,
                     datasource=datasource)

    if test_eval_iter:
      process_iterator(tag='test',
                       item_ids=test_eval_ids,
                       iterator=test_eval_iter,
                       state=state,
                       rng=rng,
                       step=step,
                       render_fn=render_fn,
                       summary_writer=summary_writer,
                       save_dir=save_dir,
                       datasource=datasource)

    if save_dir:
      delete_old_renders(renders_dir, eval_config.max_render_checkpoints)

    if eval_config.eval_once:
      break
    if step >= train_config.max_steps:
      break
    last_step = step
Exemplo n.º 2
0
def main(argv):
    tf.config.experimental.set_visible_devices([], 'GPU')
    del argv
    logging.info('*** Starting experiment')
    gin_configs = FLAGS.gin_configs

    logging.info('*** Loading Gin configs from: %s', str(gin_configs))
    gin.parse_config_files_and_bindings(config_files=gin_configs,
                                        bindings=FLAGS.gin_bindings,
                                        skip_unknown=True)

    # Load configurations.
    exp_config = configs.ExperimentConfig()
    model_config = configs.ModelConfig()
    train_config = configs.TrainConfig()

    # Get directory information.
    exp_dir = gpath.GPath(FLAGS.base_folder)
    if exp_config.subname:
        exp_dir = exp_dir / exp_config.subname
    summary_dir = exp_dir / 'summaries' / 'train'
    checkpoint_dir = exp_dir / 'checkpoints'

    # Log and create directories if this is the main host.
    if jax.process_index() == 0:
        logging.info('exp_dir = %s', exp_dir)
        if not exp_dir.exists():
            exp_dir.mkdir(parents=True, exist_ok=True)

        logging.info('summary_dir = %s', summary_dir)
        if not summary_dir.exists():
            summary_dir.mkdir(parents=True, exist_ok=True)

        logging.info('checkpoint_dir = %s', checkpoint_dir)
        if not checkpoint_dir.exists():
            checkpoint_dir.mkdir(parents=True, exist_ok=True)

        config_str = gin.operative_config_str()
        logging.info('Configuration: \n%s', config_str)
        with (exp_dir / 'config.gin').open('w') as f:
            f.write(config_str)

    logging.info('Starting host %d. There are %d hosts : %s',
                 jax.process_index(), jax.process_count(),
                 str(jax.process_indexs()))
    logging.info('Found %d accelerator devices: %s.', jax.local_device_count(),
                 str(jax.local_devices()))
    logging.info('Found %d total devices: %s.', jax.device_count(),
                 str(jax.devices()))

    rng = random.PRNGKey(exp_config.random_seed)
    # Shift the numpy random seed by host_id() to shuffle data loaded by different
    # hosts.
    np.random.seed(exp_config.random_seed + jax.process_index())

    if train_config.batch_size % jax.device_count() != 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices.')

    devices = jax.local_devices()
    datasource_spec = exp_config.datasource_spec
    if datasource_spec is None:
        datasource_spec = {
            'type': exp_config.datasource_type,
            'data_dir': FLAGS.data_dir,
        }
    logging.info('Creating datasource: %s', datasource_spec)
    datasource = datasets.from_config(
        datasource_spec,
        image_scale=exp_config.image_scale,
        use_appearance_id=model_config.use_appearance_metadata,
        use_camera_id=model_config.use_camera_metadata,
        use_warp_id=model_config.use_warp,
        use_time=model_config.warp_metadata_encoder_type == 'time',
        random_seed=exp_config.random_seed,
        **exp_config.datasource_kwargs)
    train_iter = datasource.create_iterator(
        datasource.train_ids,
        flatten=True,
        shuffle=True,
        batch_size=train_config.batch_size,
        prefetch_size=3,
        shuffle_buffer_size=train_config.shuffle_buffer_size,
        devices=devices,
    )

    points_iter = None
    if train_config.use_background_loss:
        points = datasource.load_points(shuffle=True)
        points_batch_size = min(
            len(points),
            len(devices) * train_config.background_points_batch_size)
        points_batch_size -= points_batch_size % len(devices)
        points_dataset = tf.data.Dataset.from_tensor_slices(points)
        points_iter = datasets.iterator_from_dataset(
            points_dataset,
            batch_size=points_batch_size,
            prefetch_size=3,
            devices=devices)

    learning_rate_sched = schedules.from_config(train_config.lr_schedule)
    warp_alpha_sched = schedules.from_config(train_config.warp_alpha_schedule)
    time_alpha_sched = schedules.from_config(train_config.time_alpha_schedule)
    elastic_loss_weight_sched = schedules.from_config(
        train_config.elastic_loss_weight_schedule)

    rng, key = random.split(rng)
    params = {}
    model, params['model'] = models.construct_nerf(
        key,
        model_config,
        batch_size=train_config.batch_size,
        appearance_ids=datasource.appearance_ids,
        camera_ids=datasource.camera_ids,
        warp_ids=datasource.warp_ids,
        near=datasource.near,
        far=datasource.far,
        use_warp_jacobian=train_config.use_elastic_loss,
        use_weights=train_config.use_elastic_loss)

    optimizer_def = optim.Adam(learning_rate_sched(0))
    optimizer = optimizer_def.create(params)
    state = model_utils.TrainState(optimizer=optimizer,
                                   warp_alpha=warp_alpha_sched(0),
                                   time_alpha=time_alpha_sched(0))
    scalar_params = training.ScalarParams(
        learning_rate=learning_rate_sched(0),
        elastic_loss_weight=elastic_loss_weight_sched(0),
        warp_reg_loss_weight=train_config.warp_reg_loss_weight,
        warp_reg_loss_alpha=train_config.warp_reg_loss_alpha,
        warp_reg_loss_scale=train_config.warp_reg_loss_scale,
        background_loss_weight=train_config.background_loss_weight)
    state = checkpoints.restore_checkpoint(checkpoint_dir, state)
    init_step = state.optimizer.state.step + 1
    state = jax_utils.replicate(state, devices=devices)
    del params

    logging.info('Initializing models')

    summary_writer = None
    if jax.process_index() == 0:
        summary_writer = tensorboard.SummaryWriter(str(summary_dir))
        summary_writer.text('gin/train',
                            textdata=gin.config.markdown(config_str),
                            step=0)

    train_step = functools.partial(
        training.train_step,
        model,
        elastic_reduce_method=train_config.elastic_reduce_method,
        elastic_loss_type=train_config.elastic_loss_type,
        use_elastic_loss=train_config.use_elastic_loss,
        use_background_loss=train_config.use_background_loss,
        use_warp_reg_loss=train_config.use_warp_reg_loss,
    )
    ptrain_step = jax.pmap(
        train_step,
        axis_name='batch',
        devices=devices,
        # rng_key, state, batch, scalar_params.
        in_axes=(0, 0, 0, None),
        # Treat use_elastic_loss as compile-time static.
        donate_argnums=(2, ),  # Donate the 'batch' argument.
    )

    if devices:
        n_local_devices = len(devices)
    else:
        n_local_devices = jax.local_device_count()

    logging.info('Starting training')
    rng = rng + jax.process_index()  # Make random seed separate across hosts.
    keys = random.split(rng, n_local_devices)
    time_tracker = utils.TimeTracker()
    time_tracker.tic('data', 'total')
    for step, batch in zip(range(init_step, train_config.max_steps + 1),
                           train_iter):
        if points_iter is not None:
            batch['background_points'] = next(points_iter)
        time_tracker.toc('data')
        # pytype: disable=attribute-error
        scalar_params = scalar_params.replace(
            learning_rate=learning_rate_sched(step),
            elastic_loss_weight=elastic_loss_weight_sched(step))
        warp_alpha = jax_utils.replicate(warp_alpha_sched(step), devices)
        time_alpha = jax_utils.replicate(time_alpha_sched(step), devices)
        state = state.replace(warp_alpha=warp_alpha, time_alpha=time_alpha)

        with time_tracker.record_time('train_step'):
            state, stats, keys = ptrain_step(keys, state, batch, scalar_params)
            time_tracker.toc('total')

        if step % train_config.print_every == 0 and jax.process_index() == 0:
            logging.info('step=%d, warp_alpha=%.04f, time_alpha=%.04f, %s',
                         step, warp_alpha_sched(step), time_alpha_sched(step),
                         time_tracker.summary_str('last'))
            coarse_metrics_str = ', '.join(
                [f'{k}={v.mean():.04f}' for k, v in stats['coarse'].items()])
            fine_metrics_str = ', '.join(
                [f'{k}={v.mean():.04f}' for k, v in stats['fine'].items()])
            logging.info('\tcoarse metrics: %s', coarse_metrics_str)
            if 'fine' in stats:
                logging.info('\tfine metrics: %s', fine_metrics_str)

        if step % train_config.save_every == 0 and jax.process_index() == 0:
            training.save_checkpoint(checkpoint_dir, state)

        if step % train_config.log_every == 0 and jax.process_index() == 0:
            # Only log via host 0.
            _log_to_tensorboard(summary_writer,
                                jax_utils.unreplicate(state),
                                scalar_params,
                                jax_utils.unreplicate(stats),
                                time_dict=time_tracker.summary('mean'))
            time_tracker.reset()

        if step % train_config.histogram_every == 0 and jax.process_index(
        ) == 0:
            _log_histograms(summary_writer, model,
                            jax_utils.unreplicate(state))

        time_tracker.tic('data', 'total')

    if train_config.max_steps % train_config.save_every != 0:
        training.save_checkpoint(checkpoint_dir, state)
Exemplo n.º 3
0
def main(argv):
    del argv
    logging.info("*** Starting experiment")
    gin_configs = FLAGS.gin_configs

    logging.info("*** Loading Gin configs from: %s", str(gin_configs))
    gin.parse_config_files_and_bindings(config_files=gin_configs,
                                        bindings=FLAGS.gin_bindings,
                                        skip_unknown=True)

    # Load configurations.
    exp_config = configs.ExperimentConfig()
    model_config = configs.ModelConfig(use_stratified_sampling=False)
    train_config = configs.TrainConfig()
    eval_config = configs.EvalConfig()

    # Get directory information.
    exp_dir = gpath.GPath(FLAGS.exp_dir)
    if exp_config.subname:
        exp_dir = exp_dir / exp_config.subname
    logging.info("\texp_dir = %s", exp_dir)
    if not exp_dir.exists():
        exp_dir.mkdir(parents=True, exist_ok=True)

    summary_dir = exp_dir / "summaries" / "eval"
    logging.info("\tsummary_dir = %s", summary_dir)
    if not summary_dir.exists():
        summary_dir.mkdir(parents=True, exist_ok=True)

    renders_dir = exp_dir / "renders"
    logging.info("\trenders_dir = %s", renders_dir)
    if not renders_dir.exists():
        renders_dir.mkdir(parents=True, exist_ok=True)

    checkpoint_dir = exp_dir / "checkpoints"
    logging.info("\tcheckpoint_dir = %s", checkpoint_dir)

    rng = random.PRNGKey(20200823)

    devices_to_use = jax.devices()
    n_devices = len(
        devices_to_use) if devices_to_use else jax.local_device_count()

    datasource_spec = exp_config.datasource_spec
    if datasource_spec is None:
        datasource_spec = {
            "type": exp_config.datasource_type,
            "data_dir": FLAGS.data_dir,
        }
    logging.info("Creating datasource: %s", datasource_spec)
    datasource = datasets.from_config(
        datasource_spec,
        image_scale=exp_config.image_scale,
        use_appearance_id=model_config.use_appearance_metadata,
        use_camera_id=model_config.use_camera_metadata,
        use_warp_id=model_config.use_warp,
        random_seed=exp_config.random_seed,
    )

    # Get training IDs to evaluate.
    train_eval_ids = utils.strided_subset(datasource.train_ids,
                                          eval_config.num_train_eval)
    train_eval_iter = datasource.create_iterator(train_eval_ids, batch_size=0)
    val_eval_ids = utils.strided_subset(datasource.val_ids,
                                        eval_config.num_val_eval)
    val_eval_iter = datasource.create_iterator(val_eval_ids, batch_size=0)

    test_cameras = datasource.load_test_cameras(
        count=eval_config.num_test_eval)
    if test_cameras:
        test_dataset = datasource.create_cameras_dataset(test_cameras)
        test_eval_ids = [f"{x:03d}" for x in range(len(test_cameras))]
        test_eval_iter = datasets.iterator_from_dataset(test_dataset,
                                                        batch_size=0)
    else:
        test_eval_ids = None
        test_eval_iter = None

    rng, key = random.split(rng)
    params = {}
    model, params["model"] = models.nerf(
        key,
        model_config,
        batch_size=eval_config.chunk,
        num_appearance_embeddings=len(datasource.appearance_ids),
        num_camera_embeddings=len(datasource.camera_ids),
        num_warp_embeddings=len(datasource.warp_ids),
        near=datasource.near,
        far=datasource.far,
        use_warp_jacobian=False,
        use_weights=False,
    )

    optimizer_def = optim.Adam(0.0)
    optimizer = optimizer_def.create(params)
    init_state = model_utils.TrainState(optimizer=optimizer, warp_alpha=0.0)
    del params

    def _model_fn(key_0, key_1, params, rays_dict, alpha):
        out = model.apply(
            {"params": params},
            rays_dict,
            warp_alpha=alpha,
            rngs={
                "coarse": key_0,
                "fine": key_1
            },
            mutable=False,
        )
        return jax.lax.all_gather(out, axis_name="batch")

    pmodel_fn = jax.pmap(
        # Note rng_keys are useless in eval mode since there's no randomness.
        _model_fn,
        in_axes=(0, 0, 0, 0, None),  # Only distribute the data input.
        devices=devices_to_use,
        donate_argnums=(3, ),  # Donate the 'rays' argument.
        axis_name="batch",
    )

    render_fn = functools.partial(
        evaluation.render_image,
        model_fn=pmodel_fn,
        device_count=n_devices,
        chunk=eval_config.chunk,
    )

    last_step = 0
    summary_writer = tensorboard.SummaryWriter(str(summary_dir))

    while True:
        if not checkpoint_dir.exists():
            logging.info("No checkpoints yet.")
            time.sleep(10)
            continue

        state = checkpoints.restore_checkpoint(checkpoint_dir, init_state)
        state = jax_utils.replicate(state, devices=devices_to_use)
        step = int(state.optimizer.state.step[0])
        if step <= last_step:
            logging.info("No new checkpoints (%d <= %d).", step, last_step)
            time.sleep(10)
            continue

        save_dir = renders_dir if eval_config.save_output else None
        process_iterator(
            tag="train",
            item_ids=train_eval_ids,
            iterator=train_eval_iter,
            state=state,
            rng=rng,
            step=step,
            render_fn=render_fn,
            summary_writer=summary_writer,
            save_dir=save_dir,
            datasource=datasource,
        )
        process_iterator(
            tag="val",
            item_ids=val_eval_ids,
            iterator=val_eval_iter,
            state=state,
            rng=rng,
            step=step,
            render_fn=render_fn,
            summary_writer=summary_writer,
            save_dir=save_dir,
            datasource=datasource,
        )
        if test_eval_iter:
            process_iterator(
                tag="test",
                item_ids=test_eval_ids,
                iterator=test_eval_iter,
                state=state,
                rng=rng,
                step=step,
                render_fn=render_fn,
                summary_writer=summary_writer,
                save_dir=save_dir,
                datasource=datasource,
            )

        if eval_config.eval_once:
            break
        if step >= train_config.max_steps:
            break
        last_step = step
Exemplo n.º 4
0
def main(argv):
    del argv
    logging.info("*** Starting experiment")
    gin_configs = FLAGS.gin_configs

    logging.info("*** Loading Gin configs from: %s", str(gin_configs))
    gin.parse_config_files_and_bindings(config_files=gin_configs,
                                        bindings=FLAGS.gin_bindings,
                                        skip_unknown=True)

    # Load configurations.
    exp_config = configs.ExperimentConfig()
    model_config = configs.ModelConfig()
    train_config = configs.TrainConfig()

    # Get directory information.
    exp_dir = gpath.GPath(FLAGS.exp_dir)
    if exp_config.subname:
        exp_dir = exp_dir / exp_config.subname
    logging.info("exp_dir = %s", exp_dir)
    if not exp_dir.exists():
        exp_dir.mkdir(parents=True, exist_ok=True)

    summary_dir = exp_dir / "summaries" / "train"
    logging.info("summary_dir = %s", summary_dir)
    if not summary_dir.exists():
        summary_dir.mkdir(parents=True, exist_ok=True)

    checkpoint_dir = exp_dir / "checkpoints"
    logging.info("checkpoint_dir = %s", checkpoint_dir)
    if not checkpoint_dir.exists():
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

    config_str = gin.operative_config_str()
    logging.info("Configuration: \n%s", config_str)
    with (exp_dir / "config.gin").open("w") as f:
        f.write(config_str)

    rng = random.PRNGKey(exp_config.random_seed)
    # Shift the numpy random seed by host_id() to shuffle data loaded by different
    # hosts.
    np.random.seed(exp_config.random_seed + jax.host_id())

    if train_config.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")

    devices = jax.devices()
    datasource_spec = exp_config.datasource_spec
    if datasource_spec is None:
        datasource_spec = {
            "type": exp_config.datasource_type,
            "data_dir": FLAGS.data_dir,
        }
    logging.info("Creating datasource: %s", datasource_spec)
    datasource = datasets.from_config(
        datasource_spec,
        image_scale=exp_config.image_scale,
        use_appearance_id=model_config.use_appearance_metadata,
        use_camera_id=model_config.use_camera_metadata,
        use_warp_id=model_config.use_warp,
        random_seed=exp_config.random_seed,
    )
    train_dataset = datasource.create_dataset(datasource.train_ids,
                                              flatten=True,
                                              shuffle=True)
    train_iter = datasets.iterator_from_dataset(
        train_dataset,
        batch_size=train_config.batch_size,
        prefetch_size=3,
        devices=devices,
    )

    points_iter = None
    if train_config.use_background_loss:
        points_dataset = (tf.data.Dataset.from_tensor_slices(
            datasource.load_points()).repeat().shuffle(
                65536,
                reshuffle_each_iteration=True,
                seed=exp_config.random_seed))
        points_iter = datasets.iterator_from_dataset(
            points_dataset,
            batch_size=len(devices) *
            train_config.background_points_batch_size,
            prefetch_size=3,
            devices=devices,
        )

    learning_rate_sched = schedules.from_config(train_config.lr_schedule)
    warp_alpha_sched = schedules.from_config(train_config.warp_alpha_schedule)
    elastic_loss_weight_sched = schedules.from_config(
        train_config.elastic_loss_weight_schedule)

    rng, key = random.split(rng)
    params = {}
    model, params["model"] = models.nerf(
        key,
        model_config,
        batch_size=train_config.batch_size,
        num_appearance_embeddings=len(datasource.appearance_ids),
        num_camera_embeddings=len(datasource.camera_ids),
        num_warp_embeddings=len(datasource.warp_ids),
        near=datasource.near,
        far=datasource.far,
        use_warp_jacobian=train_config.use_elastic_loss,
        use_weights=train_config.use_elastic_loss,
    )

    optimizer_def = optim.Adam(learning_rate_sched(0))
    optimizer = optimizer_def.create(params)
    state = model_utils.TrainState(optimizer=optimizer,
                                   warp_alpha=warp_alpha_sched(0))
    scalar_params = training.ScalarParams(
        learning_rate=learning_rate_sched(0),
        elastic_loss_weight=elastic_loss_weight_sched(0),
        background_loss_weight=train_config.background_loss_weight,
    )
    state = checkpoints.restore_checkpoint(checkpoint_dir, state)
    init_step = state.optimizer.state.step + 1
    state = jax_utils.replicate(state, devices=devices)
    del params

    logging.info("Initializing models")

    summary_writer = None
    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(str(summary_dir))
        summary_writer.text(
            "gin/train",
            textdata=gin.config.markdownify_operative_config_str(config_str),
            step=0,
        )

    train_step = functools.partial(
        training.train_step,
        model,
        elastic_reduce_method=train_config.elastic_reduce_method,
        use_elastic_loss=train_config.use_elastic_loss,
        use_background_loss=train_config.use_background_loss,
    )
    ptrain_step = jax.pmap(
        train_step,
        axis_name="batch",
        devices=devices,
        # rng_key, state, batch, scalar_params.
        in_axes=(0, 0, 0, None),
        # Treat use_elastic_loss as compile-time static.
        donate_argnums=(2, ),  # Donate the 'batch' argument.
    )

    if devices:
        n_local_devices = len(devices)
    else:
        n_local_devices = jax.local_device_count()

    logging.info("Starting training")
    rng = rng + jax.host_id()  # Make random seed separate across hosts.
    keys = random.split(rng, n_local_devices)
    time_tracker = utils.TimeTracker()
    time_tracker.tic("data", "total")
    for step, batch in zip(range(init_step, train_config.max_steps + 1),
                           train_iter):
        if points_iter is not None:
            batch["background_points"] = next(points_iter)
        time_tracker.toc("data")
        # See: b/162398046.
        # pytype: disable=attribute-error
        scalar_params = scalar_params.replace(
            learning_rate=learning_rate_sched(step),
            elastic_loss_weight=elastic_loss_weight_sched(step),
        )
        warp_alpha = jax_utils.replicate(warp_alpha_sched(step), devices)
        state = state.replace(warp_alpha=warp_alpha)

        with time_tracker.record_time("train_step"):
            state, stats, keys = ptrain_step(keys, state, batch, scalar_params)
            time_tracker.toc("total")

        if step % train_config.print_every == 0:
            logging.info(
                "step=%d, warp_alpha=%.04f, %s",
                step,
                warp_alpha_sched(step),
                time_tracker.summary_str("last"),
            )
            coarse_metrics_str = ", ".join(
                [f"{k}={v.mean():.04f}" for k, v in stats["coarse"].items()])
            fine_metrics_str = ", ".join(
                [f"{k}={v.mean():.04f}" for k, v in stats["fine"].items()])
            logging.info("\tcoarse metrics: %s", coarse_metrics_str)
            if "fine" in stats:
                logging.info("\tfine metrics: %s", fine_metrics_str)

        if step % train_config.save_every == 0:
            training.save_checkpoint(checkpoint_dir, state)

        if step % train_config.log_every == 0 and jax.host_id() == 0:
            # Only log via host 0.
            _log_to_tensorboard(
                summary_writer,
                jax_utils.unreplicate(state),
                scalar_params,
                jax_utils.unreplicate(stats),
                time_dict=time_tracker.summary("mean"),
            )
            time_tracker.reset()

        time_tracker.tic("data", "total")

    if train_config.max_steps % train_config.save_every != 0:
        training.save_checkpoint(checkpoint_dir, state)