def valid_iterator_factory():
     it = unbatched_valid_iterator_factory()
     # Fix validation randomness to smooth out noise.
     # (note: for final evaluation we should compute true marginals and not
     # do any sampling)
     it = add_rng_to_examples(it, jax.random.PRNGKey(0))
     return data_loading.batch(it, (num_devices, batch_size_per_device),
                               remainder_behavior=data_loading.
                               BatchRemainderBehavior.PAD_ZERO)
Example #2
0
 def test_batch_uneven_pad(self):
     values = range(10)
     batched = list(
         data_loading.batch(values, (3, ),
                            remainder_behavior=data_loading.
                            BatchRemainderBehavior.PAD_ZERO))
     expected = [
         np.array([0, 1, 2]),
         np.array([3, 4, 5]),
         np.array([6, 7, 8]),
         np.array([9, 0, 0]),
     ]
     jax.test_util.check_eq(batched, expected)
Example #3
0
 def test_batch(self):
     values = [{"v": np.array([i])} for i in range(18)]
     batched = list(data_loading.batch(values, (3, 2)))
     expected = [
         {
             "v": np.array([[[0], [1]], [[2], [3]], [[4], [5]]])
         },
         {
             "v": np.array([[[6], [7]], [[8], [9]], [[10], [11]]])
         },
         {
             "v": np.array([[[12], [13]], [[14], [15]], [[16], [17]]])
         },
     ]
     jax.test_util.check_eq(batched, expected)
 def reify_id_and_batch(it):
     return data_loading.batch(reify_id(it),
                               (num_devices, batch_size_per_device),
                               remainder_behavior=data_loading.
                               BatchRemainderBehavior.PAD_ZERO)
def train(
    runner,
    dataset_paths=gin.REQUIRED,
    prefetch=4,
    target_edge=gin.REQUIRED,
    batch_size_per_device=gin.REQUIRED,
    truncate_training_dataset_at=None,
    validation_example_skip=0,
    validation_example_count=gin.REQUIRED,
    model_type="automaton",
    evaluate_only=False,
    evaluation_model_path=None,
    evaluation_save_path=None,
    use_sampling_model=False,
):
    """Launch a training job for edge supervision.

  The dataset directories should be configured with gin.

  Args:
    runner: Helper object that runs the experiment.
    dataset_paths: Dictionary of dataset paths, with keys:
      - "metadata": Path to JSON file with dataset metadata.
      - "train_dataset": Path to training dataset files.
      - "eval_dataset": Path to validation/test dataset files.
    prefetch: Maximum number of examples to prefetch in a background thread.
      Note that we prefetch a maximum of 1 example for the validation set.
    target_edge: What edge to use as the training target.
    batch_size_per_device: Batch size for each device.
    truncate_training_dataset_at: Number of examples to truncate the training
      dataset to.
    validation_example_skip: Number of examples to skip when computing
      validation metrics.
    validation_example_count: How many examples to use when computing validation
      metrics.
    model_type: Either "automaton" or "baseline".
    evaluate_only: If True, doesn't run any training; instead evaluates a
      trained model on the validation/evaluation set. Make sure to change
      "eval_dataset" to the test dataset if using this to compute final metrics.
    evaluation_model_path: Path to the model checkpoint to evaluate.
    evaluation_save_path: Where to save the result JSON file.
    use_sampling_model: Whether to use sample-based version of the loss.

  Returns:
    Optimizer at the end of training (for interactive debugging).
  """

    logging.info("Hello from train_edge_supervision_lib!")
    num_devices = jax.local_device_count()
    logging.info("Found %d devices: %s", num_devices, jax.devices())
    logging.info("Setting up datasets...")
    with contextlib.ExitStack() as exit_stack:

        padding_config, edge_types = load_dataset_metadata(
            dataset_paths["metadata"])

        if evaluate_only:
            assert evaluation_model_path is not None
        else:
            unbatched_train_iterator = runner.build_sampling_iterator(
                dataset_paths["train_dataset"],
                example_type=graph_bundle.GraphBundle,
                truncate_at=truncate_training_dataset_at)
            unbatched_train_iterator = add_rng_to_examples(
                unbatched_train_iterator,
                jax.random.PRNGKey(int(time.time() * 1000)))
            train_iterator = data_loading.batch(
                unbatched_train_iterator, (num_devices, batch_size_per_device))

            if prefetch:
                train_iterator = exit_stack.enter_context(
                    data_loading.ThreadedPrefetcher(train_iterator, prefetch))

        unbatched_valid_iterator_factory = runner.build_one_pass_iterator_factory(
            dataset_paths["eval_dataset"],
            example_type=graph_bundle.GraphBundle,
            truncate_at=validation_example_count,
            skip_first=validation_example_skip)

        def valid_iterator_factory():
            it = unbatched_valid_iterator_factory()
            # Fix validation randomness to smooth out noise.
            # (note: for final evaluation we should compute true marginals and not
            # do any sampling)
            it = add_rng_to_examples(it, jax.random.PRNGKey(0))
            return data_loading.batch(it, (num_devices, batch_size_per_device),
                                      remainder_behavior=data_loading.
                                      BatchRemainderBehavior.PAD_ZERO)

        num_edge_types = len(edge_types)
        edge_types_to_indices = {name: i for i, name in enumerate(edge_types)}

        logging.info("Setting up model...")
        if model_type == "automaton":
            model_def = edge_supervision_models.automaton_model
        elif model_type == "baseline":
            model_def = edge_supervision_models.BaselineModel
        else:
            raise ValueError(f"Unknown model type '{model_type}'")

        # Bind statically-known information from the dataset.
        model_def = model_def.partial(
            graph_metadata=padding_config.static_max_metadata,
            edge_types_to_indices=edge_types_to_indices)

        # Initialize parameters randomly.
        @jax.jit
        def _init(rng):
            # Set up a dummy stochastic scope for random perturbations.
            with flax.nn.stochastic(jax.random.PRNGKey(0)):
                ex = graph_bundle.zeros_like_padded_example(padding_config)
                ex = jax.tree_map(jnp.array, ex)
                _, initial_params = model_def.init(rng, ex)
            return initial_params

        initial_params = _init(jax.random.PRNGKey(int(time.time() * 1000)))

        model = flax.nn.Model(model_def, initial_params)
        optimizer = flax.optim.Adam().create(model)

        validation_fn = build_validation_fn(
            valid_iterator_factory=valid_iterator_factory,
            target_edge_index=edge_types_to_indices[target_edge],
            num_edge_types=num_edge_types,
            full_evaluation=evaluate_only,
            use_sampling_model=use_sampling_model)

        if evaluate_only:
            optimizer, checkpoint_info = runner.load_from_checkpoint(
                optimizer, checkpoint_path=evaluation_model_path)
            model = train_util.device_broadcast(optimizer.target, num_devices)

            _, metrics = validation_fn(model)
            metrics["checkpoint_info"] = checkpoint_info
            metrics["model_path"] = evaluation_model_path
            metrics["dataset_metadata_path"] = dataset_paths["metadata"]
            metrics["dataset_path"] = dataset_paths["eval_dataset"]
            metrics["example_skip"] = validation_example_skip
            metrics["example_count"] = validation_example_count

            array_types = (np.ndarray, jnp.ndarray)
            metrics = jax.tree_map(
                lambda x: x.tolist()
                if isinstance(x, array_types) else x, metrics)

            gfile.makedirs(os.path.dirname(evaluation_save_path))
            with gfile.GFile(evaluation_save_path, "w") as fp:
                json.dump(metrics, fp, indent=2)

            logging.info("Computed evaluation metrics: %s", metrics)

        else:

            def compute_loss_for_model(model, padded_example_and_rng,
                                       static_metadata):
                assert static_metadata is None
                if use_sampling_model:
                    (_, _, _, loss, batch_metrics) = sample_loss_fn(
                        model,
                        padded_example_and_rng,
                        target_edge_index=edge_types_to_indices[target_edge],
                        num_edge_types=num_edge_types)
                    return loss, batch_metrics
                else:
                    return loss_fn(*extract_outputs_and_targets(
                        model,
                        padded_example_and_rng,
                        target_edge_index=edge_types_to_indices[target_edge],
                        num_edge_types=num_edge_types))

            extra_artifacts = {
                "builder.pickle": py_ast_graphs.BUILDER,
            }

            return runner.training_loop(optimizer=optimizer,
                                        train_iterator=train_iterator,
                                        loss_fn=compute_loss_for_model,
                                        validation_fn=validation_fn,
                                        extra_artifacts=extra_artifacts)
Example #6
0
 def test_batch_uneven_error(self):
     values = range(10)
     with self.assertRaisesRegex(ValueError, "not divisible by batch size"):
         for _ in data_loading.batch(values, (3, )):
             pass