def valid_iterator_factory(): it = unbatched_valid_iterator_factory() # Fix validation randomness to smooth out noise. # (note: for final evaluation we should compute true marginals and not # do any sampling) it = add_rng_to_examples(it, jax.random.PRNGKey(0)) return data_loading.batch(it, (num_devices, batch_size_per_device), remainder_behavior=data_loading. BatchRemainderBehavior.PAD_ZERO)
def test_batch_uneven_pad(self): values = range(10) batched = list( data_loading.batch(values, (3, ), remainder_behavior=data_loading. BatchRemainderBehavior.PAD_ZERO)) expected = [ np.array([0, 1, 2]), np.array([3, 4, 5]), np.array([6, 7, 8]), np.array([9, 0, 0]), ] jax.test_util.check_eq(batched, expected)
def test_batch(self): values = [{"v": np.array([i])} for i in range(18)] batched = list(data_loading.batch(values, (3, 2))) expected = [ { "v": np.array([[[0], [1]], [[2], [3]], [[4], [5]]]) }, { "v": np.array([[[6], [7]], [[8], [9]], [[10], [11]]]) }, { "v": np.array([[[12], [13]], [[14], [15]], [[16], [17]]]) }, ] jax.test_util.check_eq(batched, expected)
def reify_id_and_batch(it): return data_loading.batch(reify_id(it), (num_devices, batch_size_per_device), remainder_behavior=data_loading. BatchRemainderBehavior.PAD_ZERO)
def train( runner, dataset_paths=gin.REQUIRED, prefetch=4, target_edge=gin.REQUIRED, batch_size_per_device=gin.REQUIRED, truncate_training_dataset_at=None, validation_example_skip=0, validation_example_count=gin.REQUIRED, model_type="automaton", evaluate_only=False, evaluation_model_path=None, evaluation_save_path=None, use_sampling_model=False, ): """Launch a training job for edge supervision. The dataset directories should be configured with gin. Args: runner: Helper object that runs the experiment. dataset_paths: Dictionary of dataset paths, with keys: - "metadata": Path to JSON file with dataset metadata. - "train_dataset": Path to training dataset files. - "eval_dataset": Path to validation/test dataset files. prefetch: Maximum number of examples to prefetch in a background thread. Note that we prefetch a maximum of 1 example for the validation set. target_edge: What edge to use as the training target. batch_size_per_device: Batch size for each device. truncate_training_dataset_at: Number of examples to truncate the training dataset to. validation_example_skip: Number of examples to skip when computing validation metrics. validation_example_count: How many examples to use when computing validation metrics. model_type: Either "automaton" or "baseline". evaluate_only: If True, doesn't run any training; instead evaluates a trained model on the validation/evaluation set. Make sure to change "eval_dataset" to the test dataset if using this to compute final metrics. evaluation_model_path: Path to the model checkpoint to evaluate. evaluation_save_path: Where to save the result JSON file. use_sampling_model: Whether to use sample-based version of the loss. Returns: Optimizer at the end of training (for interactive debugging). """ logging.info("Hello from train_edge_supervision_lib!") num_devices = jax.local_device_count() logging.info("Found %d devices: %s", num_devices, jax.devices()) logging.info("Setting up datasets...") with contextlib.ExitStack() as exit_stack: padding_config, edge_types = load_dataset_metadata( dataset_paths["metadata"]) if evaluate_only: assert evaluation_model_path is not None else: unbatched_train_iterator = runner.build_sampling_iterator( dataset_paths["train_dataset"], example_type=graph_bundle.GraphBundle, truncate_at=truncate_training_dataset_at) unbatched_train_iterator = add_rng_to_examples( unbatched_train_iterator, jax.random.PRNGKey(int(time.time() * 1000))) train_iterator = data_loading.batch( unbatched_train_iterator, (num_devices, batch_size_per_device)) if prefetch: train_iterator = exit_stack.enter_context( data_loading.ThreadedPrefetcher(train_iterator, prefetch)) unbatched_valid_iterator_factory = runner.build_one_pass_iterator_factory( dataset_paths["eval_dataset"], example_type=graph_bundle.GraphBundle, truncate_at=validation_example_count, skip_first=validation_example_skip) def valid_iterator_factory(): it = unbatched_valid_iterator_factory() # Fix validation randomness to smooth out noise. # (note: for final evaluation we should compute true marginals and not # do any sampling) it = add_rng_to_examples(it, jax.random.PRNGKey(0)) return data_loading.batch(it, (num_devices, batch_size_per_device), remainder_behavior=data_loading. BatchRemainderBehavior.PAD_ZERO) num_edge_types = len(edge_types) edge_types_to_indices = {name: i for i, name in enumerate(edge_types)} logging.info("Setting up model...") if model_type == "automaton": model_def = edge_supervision_models.automaton_model elif model_type == "baseline": model_def = edge_supervision_models.BaselineModel else: raise ValueError(f"Unknown model type '{model_type}'") # Bind statically-known information from the dataset. model_def = model_def.partial( graph_metadata=padding_config.static_max_metadata, edge_types_to_indices=edge_types_to_indices) # Initialize parameters randomly. @jax.jit def _init(rng): # Set up a dummy stochastic scope for random perturbations. with flax.nn.stochastic(jax.random.PRNGKey(0)): ex = graph_bundle.zeros_like_padded_example(padding_config) ex = jax.tree_map(jnp.array, ex) _, initial_params = model_def.init(rng, ex) return initial_params initial_params = _init(jax.random.PRNGKey(int(time.time() * 1000))) model = flax.nn.Model(model_def, initial_params) optimizer = flax.optim.Adam().create(model) validation_fn = build_validation_fn( valid_iterator_factory=valid_iterator_factory, target_edge_index=edge_types_to_indices[target_edge], num_edge_types=num_edge_types, full_evaluation=evaluate_only, use_sampling_model=use_sampling_model) if evaluate_only: optimizer, checkpoint_info = runner.load_from_checkpoint( optimizer, checkpoint_path=evaluation_model_path) model = train_util.device_broadcast(optimizer.target, num_devices) _, metrics = validation_fn(model) metrics["checkpoint_info"] = checkpoint_info metrics["model_path"] = evaluation_model_path metrics["dataset_metadata_path"] = dataset_paths["metadata"] metrics["dataset_path"] = dataset_paths["eval_dataset"] metrics["example_skip"] = validation_example_skip metrics["example_count"] = validation_example_count array_types = (np.ndarray, jnp.ndarray) metrics = jax.tree_map( lambda x: x.tolist() if isinstance(x, array_types) else x, metrics) gfile.makedirs(os.path.dirname(evaluation_save_path)) with gfile.GFile(evaluation_save_path, "w") as fp: json.dump(metrics, fp, indent=2) logging.info("Computed evaluation metrics: %s", metrics) else: def compute_loss_for_model(model, padded_example_and_rng, static_metadata): assert static_metadata is None if use_sampling_model: (_, _, _, loss, batch_metrics) = sample_loss_fn( model, padded_example_and_rng, target_edge_index=edge_types_to_indices[target_edge], num_edge_types=num_edge_types) return loss, batch_metrics else: return loss_fn(*extract_outputs_and_targets( model, padded_example_and_rng, target_edge_index=edge_types_to_indices[target_edge], num_edge_types=num_edge_types)) extra_artifacts = { "builder.pickle": py_ast_graphs.BUILDER, } return runner.training_loop(optimizer=optimizer, train_iterator=train_iterator, loss_fn=compute_loss_for_model, validation_fn=validation_fn, extra_artifacts=extra_artifacts)
def test_batch_uneven_error(self): values = range(10) with self.assertRaisesRegex(ValueError, "not divisible by batch size"): for _ in data_loading.batch(values, (3, )): pass