def train(
    runner,
    dataset_paths = gin.REQUIRED,
    padding_and_batch_sizes = gin.REQUIRED,
    prefetch = 8,
    validation_example_count = gin.REQUIRED,
    evaluate_only = False,
    evaluation_model_path = None,
    evaluation_save_path = None,
    parameter_seed = int(time.time() * 1000),
    profile_during_warmup = True,
    restore_from_path = None,
):
  """Training script entry point.

  Args:
    runner: Helper object that runs the experiment.
    dataset_paths: Dictionary of dataset paths, with keys:
      - "train_dataset": Path to training dataset files.
      - "train_metadata": Path to JSON file with training dataset metadata.
      - "eval_dataset": Path to validation/test dataset files.
      - "eval_metadata": Path to JSON file with eval dataset metadata.
    padding_and_batch_sizes: Padding configurations and batch sizes to use.
      Padding should be specified using the keyword arguments for the function
      `build_padding_config`. Batch sizes may be None, in which case we will try
      to find the maximum batch size for each padding config.
    prefetch: How many examples to prefetch.
    validation_example_count: How many examples to use when validating during
      training. If None, uses all examples.
    evaluate_only: If True, doesn't run any training; instead evaluates a
      trained model on the validation/evaluation set.
    evaluation_model_path: Where to load the model from.
    evaluation_save_path: Where to save the result JSON file.
    parameter_seed: Random seed to use when initializing parameters.
    profile_during_warmup: Whether to use XProf during warmup.
    restore_from_path: Optional path to restore parameters from; useful for
      warm-starting.

  Returns:
    Final optimizer.
  """

  num_devices = jax.local_device_count()
  logging.info('Found %d devices: %s', num_devices, jax.devices())

  padding_and_batch_sizes = [
      (build_padding_config(**config_kwargs), batch_size)
      for config_kwargs, batch_size in padding_and_batch_sizes
  ]

  with contextlib.ExitStack() as exit_stack:
    # Loading metadata and task info.
    with gfile.GFile(dataset_paths['train_metadata'], 'r') as fp:
      train_metadata = json.load(fp)
    with gfile.GFile(dataset_paths['eval_metadata'], 'r') as fp:
      valid_metadata = json.load(fp)

    assert train_metadata['spec_file'] == valid_metadata['spec_file']
    assert train_metadata['vocab_file'] == valid_metadata['vocab_file']
    assert train_metadata['edge_types'] == valid_metadata['edge_types']

    encoding_info = example_definition.ExampleEncodingInfo.from_files(
        train_metadata['spec_file'], train_metadata['vocab_file'])
    assert encoding_info.edge_types == train_metadata['edge_types']

    # Model setup.
    logging.info('Setting up model...')
    model_def = var_misuse_models.var_misuse_model.partial(
        encoding_info=encoding_info)

    # Set up a dummy stochastic scope for random perturbations.
    with flax.nn.stochastic(jax.random.PRNGKey(0)):
      # Initialize parameters based on our seed.
      _, initial_params = model_def.init(
          jax.random.PRNGKey(parameter_seed),
          jax.tree_map(
              jnp.array,
              example_definition.zeros_like_padded_example(
                  TINY_PADDING_CONFIG)), TINY_PADDING_CONFIG)

    model = flax.nn.Model(model_def, initial_params)
    del initial_params
    optimizer = flax.optim.Adam().create(model)

    if restore_from_path:
      optimizer, checkpoint_info = runner.load_from_checkpoint(
          optimizer, restore_from_path)
      logging.info('Warm starting from checkpoint with info: %s',
                   checkpoint_info)

    # Compute missing batch sizes.
    tmp_replicated_optimizer = train_util.device_broadcast(
        optimizer, num_devices)
    for i, (padding_config, batch_size) in enumerate(padding_and_batch_sizes):
      fake_example_and_rng = (
          example_definition.zeros_like_padded_example(padding_config),
          jax.random.PRNGKey(0))
      assert batch_size is not None
      logging.info(
          'Running a fake train step for batch size %d and padding config %d: %s',
          batch_size, i, padding_config)
      # pylint: disable=cell-var-from-loop
      fake_batch = jax.vmap(lambda _: fake_example_and_rng)(
          jnp.arange(batch_size))
      fake_device_batch = jax.vmap(lambda _: fake_batch)(
          jnp.arange(num_devices))
      # pylint: enable=cell-var-from-loop
      train_util.warmup_train_step(
          tmp_replicated_optimizer,
          fake_device_batch,
          padding_config,
          loss_fn,
          optimizer_is_replicated=True,
          profile=profile_during_warmup,
          runner=runner)

    del tmp_replicated_optimizer

    extra_artifacts = {
        'encoding_info.pickle': encoding_info,
    }

    # Dataset iterator setup.
    logging.info('Setting up datasets...')
    unbatched_train_iterator = runner.build_sampling_iterator(
        dataset_paths['train_dataset'],
        example_type=example_definition.VarMisuseExample)

    # Randomly generate the base RNG. (Since the iterator is already randomly
    # shuffling, and we might have restarted this job anyway, there's no point
    # in setting a seed here.)
    train_iterator = pad_and_batch_with_rng(
        unbatched_train_iterator,
        num_devices,
        padding_and_batch_sizes,
        base_rng=jax.random.PRNGKey(int(time.time() * 1000)))
    if prefetch:
      train_iterator = exit_stack.enter_context(
          data_loading.ThreadedPrefetcher(train_iterator, prefetch))

    unbatched_valid_iterator_factory = (
        runner.build_one_pass_iterator_factory(
            dataset_paths['eval_dataset'],
            example_type=example_definition.VarMisuseExample,
            truncate_at=validation_example_count))

    def logging_progress(it):
      maxct = validation_example_count or valid_metadata['num_examples']
      for i, val in enumerate(it):
        if i % 10000 == 0:
          logging.info('Validation progress: %d of %d', i, maxct)
        yield val

    def valid_iterator_factory():
      unbatched = unbatched_valid_iterator_factory()
      if evaluate_only:
        unbatched = logging_progress(unbatched)
      # Always use the same PRNGKey for the validation set.
      valid_iterator = pad_and_batch_with_rng(
          unbatched,
          num_devices,
          padding_and_batch_sizes,
          base_rng=jax.random.PRNGKey(0))
      if prefetch:
        with data_loading.ThreadedPrefetcher(valid_iterator, prefetch) as it:
          yield from it
      else:
        yield from valid_iterator

    validation_fn = train_util.build_averaging_validator(
        loss_fn,
        valid_iterator_factory,
        objective_metric_name='inaccuracy/overall',
        include_total_counts=evaluate_only)

    if evaluate_only:
      logging.warning('This job is running in evaluation mode!')

      optimizer, checkpoint_info = runner.load_from_checkpoint(
          optimizer, checkpoint_path=evaluation_model_path)
      model = train_util.device_broadcast(optimizer.target, num_devices)

      _, metrics = validation_fn(model)
      metrics['checkpoint_info'] = checkpoint_info
      metrics['model_path'] = evaluation_model_path
      metrics['dataset_path'] = dataset_paths['eval_dataset']
      metrics['example_count'] = validation_example_count

      array_types = (np.ndarray, jnp.ndarray)
      metrics = jax.tree_map(
          lambda x: x.tolist() if isinstance(x, array_types) else x, metrics)

      gfile.makedirs(os.path.dirname(evaluation_save_path))
      with gfile.GFile(evaluation_save_path, 'w') as fp:
        json.dump(metrics, fp, indent=2)

      logging.info('Computed evaluation metrics: %s', metrics)

    else:
      return runner.training_loop(
          optimizer=optimizer,
          train_iterator=train_iterator,
          loss_fn=loss_fn,
          validation_fn=validation_fn,
          extra_artifacts=extra_artifacts)
        tf.io.gfile.makedirs(log_dir)
        train_metrics_file = tf.io.gfile.GFile(
            os.path.join(log_dir, "train_metrics.json"), "w")
        valid_metrics_file = tf.io.gfile.GFile(
            os.path.join(log_dir, "valid_metrics.json"), "w")
        train_metrics_file = typing.cast(IO[str], train_metrics_file)
        valid_metrics_file = typing.cast(IO[str], valid_metrics_file)

    # Peek at the first example in our dataset.
    logging.info("Peeking at dataset format...")
    first_batch = next(train_iterator)
    train_iterator = itertools.chain((first_batch, ), train_iterator)
    num_devices = first_batch.epoch.shape[0]

    # Vectorize our optimizer accordingly.
    replicated_optimizer = train_util.device_broadcast(optimizer, num_devices)
    del optimizer

    # Prepare for graceful job shutdown
    shutdown_after_this_iteration = False

    def graceful_shutdown_handler(unused_signal_number):
        del unused_signal_number
        nonlocal shutdown_after_this_iteration
        shutdown_after_this_iteration = True

    start_time = last_summary_time = time.time()
    last_summary_step = None
    best_objective_value = np.inf
    best_optimizer = None
    best_at_step = None
def train(
    runner,
    dataset_paths=gin.REQUIRED,
    prefetch=4,
    target_edge=gin.REQUIRED,
    batch_size_per_device=gin.REQUIRED,
    truncate_training_dataset_at=None,
    validation_example_skip=0,
    validation_example_count=gin.REQUIRED,
    model_type="automaton",
    evaluate_only=False,
    evaluation_model_path=None,
    evaluation_save_path=None,
    use_sampling_model=False,
):
    """Launch a training job for edge supervision.

  The dataset directories should be configured with gin.

  Args:
    runner: Helper object that runs the experiment.
    dataset_paths: Dictionary of dataset paths, with keys:
      - "metadata": Path to JSON file with dataset metadata.
      - "train_dataset": Path to training dataset files.
      - "eval_dataset": Path to validation/test dataset files.
    prefetch: Maximum number of examples to prefetch in a background thread.
      Note that we prefetch a maximum of 1 example for the validation set.
    target_edge: What edge to use as the training target.
    batch_size_per_device: Batch size for each device.
    truncate_training_dataset_at: Number of examples to truncate the training
      dataset to.
    validation_example_skip: Number of examples to skip when computing
      validation metrics.
    validation_example_count: How many examples to use when computing validation
      metrics.
    model_type: Either "automaton" or "baseline".
    evaluate_only: If True, doesn't run any training; instead evaluates a
      trained model on the validation/evaluation set. Make sure to change
      "eval_dataset" to the test dataset if using this to compute final metrics.
    evaluation_model_path: Path to the model checkpoint to evaluate.
    evaluation_save_path: Where to save the result JSON file.
    use_sampling_model: Whether to use sample-based version of the loss.

  Returns:
    Optimizer at the end of training (for interactive debugging).
  """

    logging.info("Hello from train_edge_supervision_lib!")
    num_devices = jax.local_device_count()
    logging.info("Found %d devices: %s", num_devices, jax.devices())
    logging.info("Setting up datasets...")
    with contextlib.ExitStack() as exit_stack:

        padding_config, edge_types = load_dataset_metadata(
            dataset_paths["metadata"])

        if evaluate_only:
            assert evaluation_model_path is not None
        else:
            unbatched_train_iterator = runner.build_sampling_iterator(
                dataset_paths["train_dataset"],
                example_type=graph_bundle.GraphBundle,
                truncate_at=truncate_training_dataset_at)
            unbatched_train_iterator = add_rng_to_examples(
                unbatched_train_iterator,
                jax.random.PRNGKey(int(time.time() * 1000)))
            train_iterator = data_loading.batch(
                unbatched_train_iterator, (num_devices, batch_size_per_device))

            if prefetch:
                train_iterator = exit_stack.enter_context(
                    data_loading.ThreadedPrefetcher(train_iterator, prefetch))

        unbatched_valid_iterator_factory = runner.build_one_pass_iterator_factory(
            dataset_paths["eval_dataset"],
            example_type=graph_bundle.GraphBundle,
            truncate_at=validation_example_count,
            skip_first=validation_example_skip)

        def valid_iterator_factory():
            it = unbatched_valid_iterator_factory()
            # Fix validation randomness to smooth out noise.
            # (note: for final evaluation we should compute true marginals and not
            # do any sampling)
            it = add_rng_to_examples(it, jax.random.PRNGKey(0))
            return data_loading.batch(it, (num_devices, batch_size_per_device),
                                      remainder_behavior=data_loading.
                                      BatchRemainderBehavior.PAD_ZERO)

        num_edge_types = len(edge_types)
        edge_types_to_indices = {name: i for i, name in enumerate(edge_types)}

        logging.info("Setting up model...")
        if model_type == "automaton":
            model_def = edge_supervision_models.automaton_model
        elif model_type == "baseline":
            model_def = edge_supervision_models.BaselineModel
        else:
            raise ValueError(f"Unknown model type '{model_type}'")

        # Bind statically-known information from the dataset.
        model_def = model_def.partial(
            graph_metadata=padding_config.static_max_metadata,
            edge_types_to_indices=edge_types_to_indices)

        # Initialize parameters randomly.
        @jax.jit
        def _init(rng):
            # Set up a dummy stochastic scope for random perturbations.
            with flax.nn.stochastic(jax.random.PRNGKey(0)):
                ex = graph_bundle.zeros_like_padded_example(padding_config)
                ex = jax.tree_map(jnp.array, ex)
                _, initial_params = model_def.init(rng, ex)
            return initial_params

        initial_params = _init(jax.random.PRNGKey(int(time.time() * 1000)))

        model = flax.nn.Model(model_def, initial_params)
        optimizer = flax.optim.Adam().create(model)

        validation_fn = build_validation_fn(
            valid_iterator_factory=valid_iterator_factory,
            target_edge_index=edge_types_to_indices[target_edge],
            num_edge_types=num_edge_types,
            full_evaluation=evaluate_only,
            use_sampling_model=use_sampling_model)

        if evaluate_only:
            optimizer, checkpoint_info = runner.load_from_checkpoint(
                optimizer, checkpoint_path=evaluation_model_path)
            model = train_util.device_broadcast(optimizer.target, num_devices)

            _, metrics = validation_fn(model)
            metrics["checkpoint_info"] = checkpoint_info
            metrics["model_path"] = evaluation_model_path
            metrics["dataset_metadata_path"] = dataset_paths["metadata"]
            metrics["dataset_path"] = dataset_paths["eval_dataset"]
            metrics["example_skip"] = validation_example_skip
            metrics["example_count"] = validation_example_count

            array_types = (np.ndarray, jnp.ndarray)
            metrics = jax.tree_map(
                lambda x: x.tolist()
                if isinstance(x, array_types) else x, metrics)

            gfile.makedirs(os.path.dirname(evaluation_save_path))
            with gfile.GFile(evaluation_save_path, "w") as fp:
                json.dump(metrics, fp, indent=2)

            logging.info("Computed evaluation metrics: %s", metrics)

        else:

            def compute_loss_for_model(model, padded_example_and_rng,
                                       static_metadata):
                assert static_metadata is None
                if use_sampling_model:
                    (_, _, _, loss, batch_metrics) = sample_loss_fn(
                        model,
                        padded_example_and_rng,
                        target_edge_index=edge_types_to_indices[target_edge],
                        num_edge_types=num_edge_types)
                    return loss, batch_metrics
                else:
                    return loss_fn(*extract_outputs_and_targets(
                        model,
                        padded_example_and_rng,
                        target_edge_index=edge_types_to_indices[target_edge],
                        num_edge_types=num_edge_types))

            extra_artifacts = {
                "builder.pickle": py_ast_graphs.BUILDER,
            }

            return runner.training_loop(optimizer=optimizer,
                                        train_iterator=train_iterator,
                                        loss_fn=compute_loss_for_model,
                                        validation_fn=validation_fn,
                                        extra_artifacts=extra_artifacts)