def train( runner, dataset_paths = gin.REQUIRED, padding_and_batch_sizes = gin.REQUIRED, prefetch = 8, validation_example_count = gin.REQUIRED, evaluate_only = False, evaluation_model_path = None, evaluation_save_path = None, parameter_seed = int(time.time() * 1000), profile_during_warmup = True, restore_from_path = None, ): """Training script entry point. Args: runner: Helper object that runs the experiment. dataset_paths: Dictionary of dataset paths, with keys: - "train_dataset": Path to training dataset files. - "train_metadata": Path to JSON file with training dataset metadata. - "eval_dataset": Path to validation/test dataset files. - "eval_metadata": Path to JSON file with eval dataset metadata. padding_and_batch_sizes: Padding configurations and batch sizes to use. Padding should be specified using the keyword arguments for the function `build_padding_config`. Batch sizes may be None, in which case we will try to find the maximum batch size for each padding config. prefetch: How many examples to prefetch. validation_example_count: How many examples to use when validating during training. If None, uses all examples. evaluate_only: If True, doesn't run any training; instead evaluates a trained model on the validation/evaluation set. evaluation_model_path: Where to load the model from. evaluation_save_path: Where to save the result JSON file. parameter_seed: Random seed to use when initializing parameters. profile_during_warmup: Whether to use XProf during warmup. restore_from_path: Optional path to restore parameters from; useful for warm-starting. Returns: Final optimizer. """ num_devices = jax.local_device_count() logging.info('Found %d devices: %s', num_devices, jax.devices()) padding_and_batch_sizes = [ (build_padding_config(**config_kwargs), batch_size) for config_kwargs, batch_size in padding_and_batch_sizes ] with contextlib.ExitStack() as exit_stack: # Loading metadata and task info. with gfile.GFile(dataset_paths['train_metadata'], 'r') as fp: train_metadata = json.load(fp) with gfile.GFile(dataset_paths['eval_metadata'], 'r') as fp: valid_metadata = json.load(fp) assert train_metadata['spec_file'] == valid_metadata['spec_file'] assert train_metadata['vocab_file'] == valid_metadata['vocab_file'] assert train_metadata['edge_types'] == valid_metadata['edge_types'] encoding_info = example_definition.ExampleEncodingInfo.from_files( train_metadata['spec_file'], train_metadata['vocab_file']) assert encoding_info.edge_types == train_metadata['edge_types'] # Model setup. logging.info('Setting up model...') model_def = var_misuse_models.var_misuse_model.partial( encoding_info=encoding_info) # Set up a dummy stochastic scope for random perturbations. with flax.nn.stochastic(jax.random.PRNGKey(0)): # Initialize parameters based on our seed. _, initial_params = model_def.init( jax.random.PRNGKey(parameter_seed), jax.tree_map( jnp.array, example_definition.zeros_like_padded_example( TINY_PADDING_CONFIG)), TINY_PADDING_CONFIG) model = flax.nn.Model(model_def, initial_params) del initial_params optimizer = flax.optim.Adam().create(model) if restore_from_path: optimizer, checkpoint_info = runner.load_from_checkpoint( optimizer, restore_from_path) logging.info('Warm starting from checkpoint with info: %s', checkpoint_info) # Compute missing batch sizes. tmp_replicated_optimizer = train_util.device_broadcast( optimizer, num_devices) for i, (padding_config, batch_size) in enumerate(padding_and_batch_sizes): fake_example_and_rng = ( example_definition.zeros_like_padded_example(padding_config), jax.random.PRNGKey(0)) assert batch_size is not None logging.info( 'Running a fake train step for batch size %d and padding config %d: %s', batch_size, i, padding_config) # pylint: disable=cell-var-from-loop fake_batch = jax.vmap(lambda _: fake_example_and_rng)( jnp.arange(batch_size)) fake_device_batch = jax.vmap(lambda _: fake_batch)( jnp.arange(num_devices)) # pylint: enable=cell-var-from-loop train_util.warmup_train_step( tmp_replicated_optimizer, fake_device_batch, padding_config, loss_fn, optimizer_is_replicated=True, profile=profile_during_warmup, runner=runner) del tmp_replicated_optimizer extra_artifacts = { 'encoding_info.pickle': encoding_info, } # Dataset iterator setup. logging.info('Setting up datasets...') unbatched_train_iterator = runner.build_sampling_iterator( dataset_paths['train_dataset'], example_type=example_definition.VarMisuseExample) # Randomly generate the base RNG. (Since the iterator is already randomly # shuffling, and we might have restarted this job anyway, there's no point # in setting a seed here.) train_iterator = pad_and_batch_with_rng( unbatched_train_iterator, num_devices, padding_and_batch_sizes, base_rng=jax.random.PRNGKey(int(time.time() * 1000))) if prefetch: train_iterator = exit_stack.enter_context( data_loading.ThreadedPrefetcher(train_iterator, prefetch)) unbatched_valid_iterator_factory = ( runner.build_one_pass_iterator_factory( dataset_paths['eval_dataset'], example_type=example_definition.VarMisuseExample, truncate_at=validation_example_count)) def logging_progress(it): maxct = validation_example_count or valid_metadata['num_examples'] for i, val in enumerate(it): if i % 10000 == 0: logging.info('Validation progress: %d of %d', i, maxct) yield val def valid_iterator_factory(): unbatched = unbatched_valid_iterator_factory() if evaluate_only: unbatched = logging_progress(unbatched) # Always use the same PRNGKey for the validation set. valid_iterator = pad_and_batch_with_rng( unbatched, num_devices, padding_and_batch_sizes, base_rng=jax.random.PRNGKey(0)) if prefetch: with data_loading.ThreadedPrefetcher(valid_iterator, prefetch) as it: yield from it else: yield from valid_iterator validation_fn = train_util.build_averaging_validator( loss_fn, valid_iterator_factory, objective_metric_name='inaccuracy/overall', include_total_counts=evaluate_only) if evaluate_only: logging.warning('This job is running in evaluation mode!') optimizer, checkpoint_info = runner.load_from_checkpoint( optimizer, checkpoint_path=evaluation_model_path) model = train_util.device_broadcast(optimizer.target, num_devices) _, metrics = validation_fn(model) metrics['checkpoint_info'] = checkpoint_info metrics['model_path'] = evaluation_model_path metrics['dataset_path'] = dataset_paths['eval_dataset'] metrics['example_count'] = validation_example_count array_types = (np.ndarray, jnp.ndarray) metrics = jax.tree_map( lambda x: x.tolist() if isinstance(x, array_types) else x, metrics) gfile.makedirs(os.path.dirname(evaluation_save_path)) with gfile.GFile(evaluation_save_path, 'w') as fp: json.dump(metrics, fp, indent=2) logging.info('Computed evaluation metrics: %s', metrics) else: return runner.training_loop( optimizer=optimizer, train_iterator=train_iterator, loss_fn=loss_fn, validation_fn=validation_fn, extra_artifacts=extra_artifacts)
tf.io.gfile.makedirs(log_dir) train_metrics_file = tf.io.gfile.GFile( os.path.join(log_dir, "train_metrics.json"), "w") valid_metrics_file = tf.io.gfile.GFile( os.path.join(log_dir, "valid_metrics.json"), "w") train_metrics_file = typing.cast(IO[str], train_metrics_file) valid_metrics_file = typing.cast(IO[str], valid_metrics_file) # Peek at the first example in our dataset. logging.info("Peeking at dataset format...") first_batch = next(train_iterator) train_iterator = itertools.chain((first_batch, ), train_iterator) num_devices = first_batch.epoch.shape[0] # Vectorize our optimizer accordingly. replicated_optimizer = train_util.device_broadcast(optimizer, num_devices) del optimizer # Prepare for graceful job shutdown shutdown_after_this_iteration = False def graceful_shutdown_handler(unused_signal_number): del unused_signal_number nonlocal shutdown_after_this_iteration shutdown_after_this_iteration = True start_time = last_summary_time = time.time() last_summary_step = None best_objective_value = np.inf best_optimizer = None best_at_step = None
def train( runner, dataset_paths=gin.REQUIRED, prefetch=4, target_edge=gin.REQUIRED, batch_size_per_device=gin.REQUIRED, truncate_training_dataset_at=None, validation_example_skip=0, validation_example_count=gin.REQUIRED, model_type="automaton", evaluate_only=False, evaluation_model_path=None, evaluation_save_path=None, use_sampling_model=False, ): """Launch a training job for edge supervision. The dataset directories should be configured with gin. Args: runner: Helper object that runs the experiment. dataset_paths: Dictionary of dataset paths, with keys: - "metadata": Path to JSON file with dataset metadata. - "train_dataset": Path to training dataset files. - "eval_dataset": Path to validation/test dataset files. prefetch: Maximum number of examples to prefetch in a background thread. Note that we prefetch a maximum of 1 example for the validation set. target_edge: What edge to use as the training target. batch_size_per_device: Batch size for each device. truncate_training_dataset_at: Number of examples to truncate the training dataset to. validation_example_skip: Number of examples to skip when computing validation metrics. validation_example_count: How many examples to use when computing validation metrics. model_type: Either "automaton" or "baseline". evaluate_only: If True, doesn't run any training; instead evaluates a trained model on the validation/evaluation set. Make sure to change "eval_dataset" to the test dataset if using this to compute final metrics. evaluation_model_path: Path to the model checkpoint to evaluate. evaluation_save_path: Where to save the result JSON file. use_sampling_model: Whether to use sample-based version of the loss. Returns: Optimizer at the end of training (for interactive debugging). """ logging.info("Hello from train_edge_supervision_lib!") num_devices = jax.local_device_count() logging.info("Found %d devices: %s", num_devices, jax.devices()) logging.info("Setting up datasets...") with contextlib.ExitStack() as exit_stack: padding_config, edge_types = load_dataset_metadata( dataset_paths["metadata"]) if evaluate_only: assert evaluation_model_path is not None else: unbatched_train_iterator = runner.build_sampling_iterator( dataset_paths["train_dataset"], example_type=graph_bundle.GraphBundle, truncate_at=truncate_training_dataset_at) unbatched_train_iterator = add_rng_to_examples( unbatched_train_iterator, jax.random.PRNGKey(int(time.time() * 1000))) train_iterator = data_loading.batch( unbatched_train_iterator, (num_devices, batch_size_per_device)) if prefetch: train_iterator = exit_stack.enter_context( data_loading.ThreadedPrefetcher(train_iterator, prefetch)) unbatched_valid_iterator_factory = runner.build_one_pass_iterator_factory( dataset_paths["eval_dataset"], example_type=graph_bundle.GraphBundle, truncate_at=validation_example_count, skip_first=validation_example_skip) def valid_iterator_factory(): it = unbatched_valid_iterator_factory() # Fix validation randomness to smooth out noise. # (note: for final evaluation we should compute true marginals and not # do any sampling) it = add_rng_to_examples(it, jax.random.PRNGKey(0)) return data_loading.batch(it, (num_devices, batch_size_per_device), remainder_behavior=data_loading. BatchRemainderBehavior.PAD_ZERO) num_edge_types = len(edge_types) edge_types_to_indices = {name: i for i, name in enumerate(edge_types)} logging.info("Setting up model...") if model_type == "automaton": model_def = edge_supervision_models.automaton_model elif model_type == "baseline": model_def = edge_supervision_models.BaselineModel else: raise ValueError(f"Unknown model type '{model_type}'") # Bind statically-known information from the dataset. model_def = model_def.partial( graph_metadata=padding_config.static_max_metadata, edge_types_to_indices=edge_types_to_indices) # Initialize parameters randomly. @jax.jit def _init(rng): # Set up a dummy stochastic scope for random perturbations. with flax.nn.stochastic(jax.random.PRNGKey(0)): ex = graph_bundle.zeros_like_padded_example(padding_config) ex = jax.tree_map(jnp.array, ex) _, initial_params = model_def.init(rng, ex) return initial_params initial_params = _init(jax.random.PRNGKey(int(time.time() * 1000))) model = flax.nn.Model(model_def, initial_params) optimizer = flax.optim.Adam().create(model) validation_fn = build_validation_fn( valid_iterator_factory=valid_iterator_factory, target_edge_index=edge_types_to_indices[target_edge], num_edge_types=num_edge_types, full_evaluation=evaluate_only, use_sampling_model=use_sampling_model) if evaluate_only: optimizer, checkpoint_info = runner.load_from_checkpoint( optimizer, checkpoint_path=evaluation_model_path) model = train_util.device_broadcast(optimizer.target, num_devices) _, metrics = validation_fn(model) metrics["checkpoint_info"] = checkpoint_info metrics["model_path"] = evaluation_model_path metrics["dataset_metadata_path"] = dataset_paths["metadata"] metrics["dataset_path"] = dataset_paths["eval_dataset"] metrics["example_skip"] = validation_example_skip metrics["example_count"] = validation_example_count array_types = (np.ndarray, jnp.ndarray) metrics = jax.tree_map( lambda x: x.tolist() if isinstance(x, array_types) else x, metrics) gfile.makedirs(os.path.dirname(evaluation_save_path)) with gfile.GFile(evaluation_save_path, "w") as fp: json.dump(metrics, fp, indent=2) logging.info("Computed evaluation metrics: %s", metrics) else: def compute_loss_for_model(model, padded_example_and_rng, static_metadata): assert static_metadata is None if use_sampling_model: (_, _, _, loss, batch_metrics) = sample_loss_fn( model, padded_example_and_rng, target_edge_index=edge_types_to_indices[target_edge], num_edge_types=num_edge_types) return loss, batch_metrics else: return loss_fn(*extract_outputs_and_targets( model, padded_example_and_rng, target_edge_index=edge_types_to_indices[target_edge], num_edge_types=num_edge_types)) extra_artifacts = { "builder.pickle": py_ast_graphs.BUILDER, } return runner.training_loop(optimizer=optimizer, train_iterator=train_iterator, loss_fn=compute_loss_for_model, validation_fn=validation_fn, extra_artifacts=extra_artifacts)