Example #1
0
    def test_prefetch_with_error(self):
        def raise_after(n):
            yield from range(n)
            raise RuntimeError("error in generation")

        with data_loading.ThreadedPrefetcher(raise_after(100),
                                             10) as prefetched:
            with self.assertRaisesRegex(RuntimeError, "error in generation"):
                for _ in prefetched:
                    pass
 def valid_iterator_factory():
   unbatched = unbatched_valid_iterator_factory()
   if evaluate_only:
     unbatched = logging_progress(unbatched)
   # Always use the same PRNGKey for the validation set.
   valid_iterator = pad_and_batch_with_rng(
       unbatched,
       num_devices,
       padding_and_batch_sizes,
       base_rng=jax.random.PRNGKey(0))
   if prefetch:
     with data_loading.ThreadedPrefetcher(valid_iterator, prefetch) as it:
       yield from it
   else:
     yield from valid_iterator
Example #3
0
    def validation_function(model):
        with contextlib.ExitStack() as exit_stack:
            valid_iterator = valid_iterator_factory()
            if prefetch:
                valid_iterator = exit_stack.enter_context(
                    data_loading.ThreadedPrefetcher(valid_iterator, 4))
            accumulated = None
            example_count = 0
            for batch in valid_iterator:
                results = parallel_metrics_batch(model, batch.example,
                                                 batch.mask,
                                                 batch.static_metadata)
                metrics = jax.tree_map(float,
                                       flax.jax_utils.unreplicate(results))
                metrics["epoch"] = np.sum(batch.epoch)
                if accumulated is None:
                    accumulated = metrics
                else:
                    accumulated = jax.tree_multimap(operator.add, accumulated,
                                                    metrics)
                example_count += jnp.count_nonzero(batch.mask)

            assert example_count > 0, "Validation iterator must be nonempty"
            accumulated = typing.cast(Dict[str, Any], accumulated)

            final_metrics = {}
            for k, v in accumulated.items():
                if isinstance(v, RatioMetric):
                    final_metrics[k] = v.numerator / v.denominator
                    if include_total_counts:
                        final_metrics[k + "_numerator"] = v.numerator
                        final_metrics[k + "_denominator"] = v.denominator
                else:
                    final_metrics[k] = v / example_count

            objective = final_metrics[objective_metric_name]
            if include_total_counts:
                final_metrics["validation_total_example_count"] = example_count
            return (objective, final_metrics)
Example #4
0
    def test_prefetch_slow_interrupted(self):
        lock = threading.Lock()
        kept_going = False

        def slow_prefetch():
            nonlocal kept_going
            for i in range(5):
                time.sleep(0.1)
                yield i
            time.sleep(1)
            yield "prefetched, but not consumed"
            with lock:
                kept_going = True
            yield "should not be prefetched"

        with data_loading.ThreadedPrefetcher(slow_prefetch(),
                                             10) as prefetched:
            # Wait for the first five values to be ready.
            first_five = list(itertools.islice(prefetched, 5))
            # The thread starts prefetching the sixth element...
            time.sleep(0.1)
            # ... and we exit the context manager before it finishes.

        time.sleep(2)
        # Main thread should have waited for the first five elements.
        self.assertEqual(first_five, list(range(5)))
        with lock:
            # Worker should NOT have prefetched the seventh element after we exited
            # the context manager.
            self.assertFalse(kept_going)

        # Main thread is not allowed to keep using the iterator at this point.
        with self.assertRaisesRegex(
                RuntimeError,
                "Iteration is only allowed inside the prefetching context manager!"
        ):
            next(prefetched)
def train(
    runner,
    dataset_paths = gin.REQUIRED,
    padding_and_batch_sizes = gin.REQUIRED,
    prefetch = 8,
    validation_example_count = gin.REQUIRED,
    evaluate_only = False,
    evaluation_model_path = None,
    evaluation_save_path = None,
    parameter_seed = int(time.time() * 1000),
    profile_during_warmup = True,
    restore_from_path = None,
):
  """Training script entry point.

  Args:
    runner: Helper object that runs the experiment.
    dataset_paths: Dictionary of dataset paths, with keys:
      - "train_dataset": Path to training dataset files.
      - "train_metadata": Path to JSON file with training dataset metadata.
      - "eval_dataset": Path to validation/test dataset files.
      - "eval_metadata": Path to JSON file with eval dataset metadata.
    padding_and_batch_sizes: Padding configurations and batch sizes to use.
      Padding should be specified using the keyword arguments for the function
      `build_padding_config`. Batch sizes may be None, in which case we will try
      to find the maximum batch size for each padding config.
    prefetch: How many examples to prefetch.
    validation_example_count: How many examples to use when validating during
      training. If None, uses all examples.
    evaluate_only: If True, doesn't run any training; instead evaluates a
      trained model on the validation/evaluation set.
    evaluation_model_path: Where to load the model from.
    evaluation_save_path: Where to save the result JSON file.
    parameter_seed: Random seed to use when initializing parameters.
    profile_during_warmup: Whether to use XProf during warmup.
    restore_from_path: Optional path to restore parameters from; useful for
      warm-starting.

  Returns:
    Final optimizer.
  """

  num_devices = jax.local_device_count()
  logging.info('Found %d devices: %s', num_devices, jax.devices())

  padding_and_batch_sizes = [
      (build_padding_config(**config_kwargs), batch_size)
      for config_kwargs, batch_size in padding_and_batch_sizes
  ]

  with contextlib.ExitStack() as exit_stack:
    # Loading metadata and task info.
    with gfile.GFile(dataset_paths['train_metadata'], 'r') as fp:
      train_metadata = json.load(fp)
    with gfile.GFile(dataset_paths['eval_metadata'], 'r') as fp:
      valid_metadata = json.load(fp)

    assert train_metadata['spec_file'] == valid_metadata['spec_file']
    assert train_metadata['vocab_file'] == valid_metadata['vocab_file']
    assert train_metadata['edge_types'] == valid_metadata['edge_types']

    encoding_info = example_definition.ExampleEncodingInfo.from_files(
        train_metadata['spec_file'], train_metadata['vocab_file'])
    assert encoding_info.edge_types == train_metadata['edge_types']

    # Model setup.
    logging.info('Setting up model...')
    model_def = var_misuse_models.var_misuse_model.partial(
        encoding_info=encoding_info)

    # Set up a dummy stochastic scope for random perturbations.
    with flax.nn.stochastic(jax.random.PRNGKey(0)):
      # Initialize parameters based on our seed.
      _, initial_params = model_def.init(
          jax.random.PRNGKey(parameter_seed),
          jax.tree_map(
              jnp.array,
              example_definition.zeros_like_padded_example(
                  TINY_PADDING_CONFIG)), TINY_PADDING_CONFIG)

    model = flax.nn.Model(model_def, initial_params)
    del initial_params
    optimizer = flax.optim.Adam().create(model)

    if restore_from_path:
      optimizer, checkpoint_info = runner.load_from_checkpoint(
          optimizer, restore_from_path)
      logging.info('Warm starting from checkpoint with info: %s',
                   checkpoint_info)

    # Compute missing batch sizes.
    tmp_replicated_optimizer = train_util.device_broadcast(
        optimizer, num_devices)
    for i, (padding_config, batch_size) in enumerate(padding_and_batch_sizes):
      fake_example_and_rng = (
          example_definition.zeros_like_padded_example(padding_config),
          jax.random.PRNGKey(0))
      assert batch_size is not None
      logging.info(
          'Running a fake train step for batch size %d and padding config %d: %s',
          batch_size, i, padding_config)
      # pylint: disable=cell-var-from-loop
      fake_batch = jax.vmap(lambda _: fake_example_and_rng)(
          jnp.arange(batch_size))
      fake_device_batch = jax.vmap(lambda _: fake_batch)(
          jnp.arange(num_devices))
      # pylint: enable=cell-var-from-loop
      train_util.warmup_train_step(
          tmp_replicated_optimizer,
          fake_device_batch,
          padding_config,
          loss_fn,
          optimizer_is_replicated=True,
          profile=profile_during_warmup,
          runner=runner)

    del tmp_replicated_optimizer

    extra_artifacts = {
        'encoding_info.pickle': encoding_info,
    }

    # Dataset iterator setup.
    logging.info('Setting up datasets...')
    unbatched_train_iterator = runner.build_sampling_iterator(
        dataset_paths['train_dataset'],
        example_type=example_definition.VarMisuseExample)

    # Randomly generate the base RNG. (Since the iterator is already randomly
    # shuffling, and we might have restarted this job anyway, there's no point
    # in setting a seed here.)
    train_iterator = pad_and_batch_with_rng(
        unbatched_train_iterator,
        num_devices,
        padding_and_batch_sizes,
        base_rng=jax.random.PRNGKey(int(time.time() * 1000)))
    if prefetch:
      train_iterator = exit_stack.enter_context(
          data_loading.ThreadedPrefetcher(train_iterator, prefetch))

    unbatched_valid_iterator_factory = (
        runner.build_one_pass_iterator_factory(
            dataset_paths['eval_dataset'],
            example_type=example_definition.VarMisuseExample,
            truncate_at=validation_example_count))

    def logging_progress(it):
      maxct = validation_example_count or valid_metadata['num_examples']
      for i, val in enumerate(it):
        if i % 10000 == 0:
          logging.info('Validation progress: %d of %d', i, maxct)
        yield val

    def valid_iterator_factory():
      unbatched = unbatched_valid_iterator_factory()
      if evaluate_only:
        unbatched = logging_progress(unbatched)
      # Always use the same PRNGKey for the validation set.
      valid_iterator = pad_and_batch_with_rng(
          unbatched,
          num_devices,
          padding_and_batch_sizes,
          base_rng=jax.random.PRNGKey(0))
      if prefetch:
        with data_loading.ThreadedPrefetcher(valid_iterator, prefetch) as it:
          yield from it
      else:
        yield from valid_iterator

    validation_fn = train_util.build_averaging_validator(
        loss_fn,
        valid_iterator_factory,
        objective_metric_name='inaccuracy/overall',
        include_total_counts=evaluate_only)

    if evaluate_only:
      logging.warning('This job is running in evaluation mode!')

      optimizer, checkpoint_info = runner.load_from_checkpoint(
          optimizer, checkpoint_path=evaluation_model_path)
      model = train_util.device_broadcast(optimizer.target, num_devices)

      _, metrics = validation_fn(model)
      metrics['checkpoint_info'] = checkpoint_info
      metrics['model_path'] = evaluation_model_path
      metrics['dataset_path'] = dataset_paths['eval_dataset']
      metrics['example_count'] = validation_example_count

      array_types = (np.ndarray, jnp.ndarray)
      metrics = jax.tree_map(
          lambda x: x.tolist() if isinstance(x, array_types) else x, metrics)

      gfile.makedirs(os.path.dirname(evaluation_save_path))
      with gfile.GFile(evaluation_save_path, 'w') as fp:
        json.dump(metrics, fp, indent=2)

      logging.info('Computed evaluation metrics: %s', metrics)

    else:
      return runner.training_loop(
          optimizer=optimizer,
          train_iterator=train_iterator,
          loss_fn=loss_fn,
          validation_fn=validation_fn,
          extra_artifacts=extra_artifacts)
def train(
    runner,
    dataset_paths=gin.REQUIRED,
    prefetch=4,
    batch_size_per_device=gin.REQUIRED,
    validation_example_count=gin.REQUIRED,
):
    """Train the maze automaton.

  Args:
    runner: Helper object that runs the experiment.
    dataset_paths: Dictionary of dataset paths, with keys:
      - "train_dataset": Path to training dataset files.
      - "eval_dataset": Path to validation dataset files.
    prefetch: Maximum number of examples to prefetch in a background thread.
    batch_size_per_device: Batch size for each device.
    validation_example_count: How many examples to use when computing validation
      metrics.

  Returns:
    Optimizer at the end of training (for interactive debugging).
  """
    num_devices = jax.local_device_count()
    logging.info("Found %d devices: %s", num_devices, jax.devices())

    with contextlib.ExitStack() as exit_stack:
        logging.info("Setting up datasets...")
        raw_train_iterator = runner.build_sampling_iterator(
            dataset_paths["train_dataset"],
            example_type=graph_bundle.GraphBundle)

        raw_valid_iterator_factory = runner.build_one_pass_iterator_factory(
            dataset_paths["eval_dataset"],
            example_type=graph_bundle.GraphBundle,
            truncate_at=validation_example_count)

        # Add the example id into the example itself, so that we can use it to
        # randomly choose a goal.
        def reify_id(it):
            for item in it:
                yield dataclasses.replace(item,
                                          example=(item.example,
                                                   item.example_id))

        def reify_id_and_batch(it):
            return data_loading.batch(reify_id(it),
                                      (num_devices, batch_size_per_device),
                                      remainder_behavior=data_loading.
                                      BatchRemainderBehavior.PAD_ZERO)

        train_iterator = reify_id_and_batch(raw_train_iterator)
        valid_iterator_factory = (
            lambda: reify_id_and_batch(raw_valid_iterator_factory()))

        if prefetch:
            train_iterator = exit_stack.enter_context(
                data_loading.ThreadedPrefetcher(train_iterator, prefetch))

        logging.info("Setting up model...")
        padding_config = maze_task.PADDING_CONFIG
        model_def = automaton_layer.FiniteStateGraphAutomaton.partial(
            static_metadata=padding_config.static_max_metadata,
            builder=maze_task.BUILDER)

        # Initialize parameters randomly.
        _, initial_params = model_def.init(
            jax.random.PRNGKey(int(time.time() * 1000)),
            graph_bundle.zeros_like_padded_example(
                padding_config).automaton_graph,
            dynamic_metadata=padding_config.static_max_metadata)

        model = flax.nn.Model(model_def, initial_params)
        optimizer = flax.optim.Adam().create(model)

        extra_artifacts = {
            "builder.pickle": maze_task.BUILDER,
        }

        return runner.training_loop(
            optimizer=optimizer,
            train_iterator=train_iterator,
            loss_fn=loss_fn,
            validation_fn=train_util.build_averaging_validator(
                loss_fn, valid_iterator_factory),
            extra_artifacts=extra_artifacts)
def train(
    runner,
    dataset_paths=gin.REQUIRED,
    prefetch=4,
    target_edge=gin.REQUIRED,
    batch_size_per_device=gin.REQUIRED,
    truncate_training_dataset_at=None,
    validation_example_skip=0,
    validation_example_count=gin.REQUIRED,
    model_type="automaton",
    evaluate_only=False,
    evaluation_model_path=None,
    evaluation_save_path=None,
    use_sampling_model=False,
):
    """Launch a training job for edge supervision.

  The dataset directories should be configured with gin.

  Args:
    runner: Helper object that runs the experiment.
    dataset_paths: Dictionary of dataset paths, with keys:
      - "metadata": Path to JSON file with dataset metadata.
      - "train_dataset": Path to training dataset files.
      - "eval_dataset": Path to validation/test dataset files.
    prefetch: Maximum number of examples to prefetch in a background thread.
      Note that we prefetch a maximum of 1 example for the validation set.
    target_edge: What edge to use as the training target.
    batch_size_per_device: Batch size for each device.
    truncate_training_dataset_at: Number of examples to truncate the training
      dataset to.
    validation_example_skip: Number of examples to skip when computing
      validation metrics.
    validation_example_count: How many examples to use when computing validation
      metrics.
    model_type: Either "automaton" or "baseline".
    evaluate_only: If True, doesn't run any training; instead evaluates a
      trained model on the validation/evaluation set. Make sure to change
      "eval_dataset" to the test dataset if using this to compute final metrics.
    evaluation_model_path: Path to the model checkpoint to evaluate.
    evaluation_save_path: Where to save the result JSON file.
    use_sampling_model: Whether to use sample-based version of the loss.

  Returns:
    Optimizer at the end of training (for interactive debugging).
  """

    logging.info("Hello from train_edge_supervision_lib!")
    num_devices = jax.local_device_count()
    logging.info("Found %d devices: %s", num_devices, jax.devices())
    logging.info("Setting up datasets...")
    with contextlib.ExitStack() as exit_stack:

        padding_config, edge_types = load_dataset_metadata(
            dataset_paths["metadata"])

        if evaluate_only:
            assert evaluation_model_path is not None
        else:
            unbatched_train_iterator = runner.build_sampling_iterator(
                dataset_paths["train_dataset"],
                example_type=graph_bundle.GraphBundle,
                truncate_at=truncate_training_dataset_at)
            unbatched_train_iterator = add_rng_to_examples(
                unbatched_train_iterator,
                jax.random.PRNGKey(int(time.time() * 1000)))
            train_iterator = data_loading.batch(
                unbatched_train_iterator, (num_devices, batch_size_per_device))

            if prefetch:
                train_iterator = exit_stack.enter_context(
                    data_loading.ThreadedPrefetcher(train_iterator, prefetch))

        unbatched_valid_iterator_factory = runner.build_one_pass_iterator_factory(
            dataset_paths["eval_dataset"],
            example_type=graph_bundle.GraphBundle,
            truncate_at=validation_example_count,
            skip_first=validation_example_skip)

        def valid_iterator_factory():
            it = unbatched_valid_iterator_factory()
            # Fix validation randomness to smooth out noise.
            # (note: for final evaluation we should compute true marginals and not
            # do any sampling)
            it = add_rng_to_examples(it, jax.random.PRNGKey(0))
            return data_loading.batch(it, (num_devices, batch_size_per_device),
                                      remainder_behavior=data_loading.
                                      BatchRemainderBehavior.PAD_ZERO)

        num_edge_types = len(edge_types)
        edge_types_to_indices = {name: i for i, name in enumerate(edge_types)}

        logging.info("Setting up model...")
        if model_type == "automaton":
            model_def = edge_supervision_models.automaton_model
        elif model_type == "baseline":
            model_def = edge_supervision_models.BaselineModel
        else:
            raise ValueError(f"Unknown model type '{model_type}'")

        # Bind statically-known information from the dataset.
        model_def = model_def.partial(
            graph_metadata=padding_config.static_max_metadata,
            edge_types_to_indices=edge_types_to_indices)

        # Initialize parameters randomly.
        @jax.jit
        def _init(rng):
            # Set up a dummy stochastic scope for random perturbations.
            with flax.nn.stochastic(jax.random.PRNGKey(0)):
                ex = graph_bundle.zeros_like_padded_example(padding_config)
                ex = jax.tree_map(jnp.array, ex)
                _, initial_params = model_def.init(rng, ex)
            return initial_params

        initial_params = _init(jax.random.PRNGKey(int(time.time() * 1000)))

        model = flax.nn.Model(model_def, initial_params)
        optimizer = flax.optim.Adam().create(model)

        validation_fn = build_validation_fn(
            valid_iterator_factory=valid_iterator_factory,
            target_edge_index=edge_types_to_indices[target_edge],
            num_edge_types=num_edge_types,
            full_evaluation=evaluate_only,
            use_sampling_model=use_sampling_model)

        if evaluate_only:
            optimizer, checkpoint_info = runner.load_from_checkpoint(
                optimizer, checkpoint_path=evaluation_model_path)
            model = train_util.device_broadcast(optimizer.target, num_devices)

            _, metrics = validation_fn(model)
            metrics["checkpoint_info"] = checkpoint_info
            metrics["model_path"] = evaluation_model_path
            metrics["dataset_metadata_path"] = dataset_paths["metadata"]
            metrics["dataset_path"] = dataset_paths["eval_dataset"]
            metrics["example_skip"] = validation_example_skip
            metrics["example_count"] = validation_example_count

            array_types = (np.ndarray, jnp.ndarray)
            metrics = jax.tree_map(
                lambda x: x.tolist()
                if isinstance(x, array_types) else x, metrics)

            gfile.makedirs(os.path.dirname(evaluation_save_path))
            with gfile.GFile(evaluation_save_path, "w") as fp:
                json.dump(metrics, fp, indent=2)

            logging.info("Computed evaluation metrics: %s", metrics)

        else:

            def compute_loss_for_model(model, padded_example_and_rng,
                                       static_metadata):
                assert static_metadata is None
                if use_sampling_model:
                    (_, _, _, loss, batch_metrics) = sample_loss_fn(
                        model,
                        padded_example_and_rng,
                        target_edge_index=edge_types_to_indices[target_edge],
                        num_edge_types=num_edge_types)
                    return loss, batch_metrics
                else:
                    return loss_fn(*extract_outputs_and_targets(
                        model,
                        padded_example_and_rng,
                        target_edge_index=edge_types_to_indices[target_edge],
                        num_edge_types=num_edge_types))

            extra_artifacts = {
                "builder.pickle": py_ast_graphs.BUILDER,
            }

            return runner.training_loop(optimizer=optimizer,
                                        train_iterator=train_iterator,
                                        loss_fn=compute_loss_for_model,
                                        validation_fn=validation_fn,
                                        extra_artifacts=extra_artifacts)
Example #8
0
 def test_prefetch_to_end(self):
     with data_loading.ThreadedPrefetcher(range(100), 10) as prefetched:
         values = list(prefetched)
     self.assertEqual(values, list(range(100)))