def mesh_eval_dataset_fn(mixture_name, sequence_length, vocabulary, dataset_split, num_eval_examples=None, use_cached=False): """Returns all tf.data.Datasets for evaluation on a given mixture. This uses the format required for utils.run's `eval_dataset_fn` argument in the Mesh TF transformer standalone. Args: mixture_name: string, an identifier for a mixture in the MixtureRegistry. Must be specified via gin. sequence_length: dict mapping feature key to the int length for that feature the max sequence length. vocabulary: a SentencePieceVocabulary. dataset_split: string, which split of the dataset to load. num_eval_examples: maximum number of examples per task to use for continuous eval. If None, use all examples. use_cached: bool, whether to load the cached version of this dataset. Returns: A list of mesh_tensorflow.transformer.dataset.EvalDataset tuples. """ if not isinstance(vocabulary, SentencePieceVocabulary): raise ValueError("vocabulary must be a SentencePieceVocabulary") mixture = MixtureRegistry.get(mixture_name) def _get_dataset_for_single_task(task): """Get a tensorflow.data.Dataset for the provided task.""" ds = task.get_dataset(sequence_length, split=dataset_split, use_cached=use_cached, shuffle=False) ds = transformer_dataset.pack_or_pad(ds, sequence_length, pack=False, feature_keys=task.output_features, ensure_eos=True) if num_eval_examples is not None: ds = ds.take(num_eval_examples) return ds outputs = [] for task in mixture.tasks: if dataset_split not in task.splits: logging.info("Task %s has no '%s' split, skipping eval.", task.name, dataset_split) continue outputs.append( transformer_dataset.EvalDataset( task.name, functools.partial(_get_dataset_for_single_task, task), task.postprocess_fn, task.metric_fns, )) return outputs
def mesh_eval_dataset_fn(mixture_or_task_name, sequence_length, dataset_split, vocabulary=None, num_eval_examples=-1, use_cached=False, pack=False, shuffle_eval_examples=False, seed=None): """Returns all tf.data.Datasets for evaluation on a given mixture. This uses the format required for utils.run's `eval_dataset_fn` argument in the Mesh TF transformer standalone. Args: mixture_or_task_name: string, an identifier for a Mixture or Task in the appropriate registry. Must be specified via gin. sequence_length: dict mapping feature key to the int length for that feature the max sequence length. If set to None, packing and padding will be disabled. dataset_split: string, which split of the dataset to load. vocabulary: unused argument, maintains compatibility with other dataaset_fns num_eval_examples: maximum number of examples per task to use for continuous eval. If None or less than 0, use all examples. use_cached: bool, whether to load the cached version of this dataset. pack: a boolean, whether to pack examples. This is useful for perplexity evals but should not be used for iterative decoding. shuffle_eval_examples: boolean, whether to shuffle eval examples, applied only when num_eval_examples is not None. Intended to be able to eval on a different eval slice at every iteration. seed: tf.int64 scalar tf.Tensor (or None). Used for both the global seed and shuffle seed for tf.data Returns: A list of mesh_tensorflow.transformer.dataset.EvalDataset tuples. """ del vocabulary mixture_or_task = t5.data.get_mixture_or_task(mixture_or_task_name) def _get_dataset_for_single_task(task, sequence_length): """Get a tensorflow.data.Dataset for the provided task.""" if shuffle_eval_examples and seed is None: logging.warning(("shuffle_seed_examples is true but no seed was ", "provided. Using a random seed.")) ds = task.get_dataset( sequence_length, split=dataset_split, use_cached=use_cached, shuffle=shuffle_eval_examples, seed=seed, ) eos_keys = set(k for k, f in mixture_or_task.output_features.items() if f.add_eos) if sequence_length is None: logging.info( "Skipping packing/padding for '%s' since sequence length is None.", task.name) else: logging.info("%sing '%s' with sequence lengths: %s", "Pack" if pack else "Padd", task.name, sequence_length) ds = transformer_dataset.pack_or_pad(ds, sequence_length, pack=pack, feature_keys=tuple( task.output_features), ensure_eos=eos_keys) if num_eval_examples is not None and num_eval_examples >= 0: ds = ds.take(num_eval_examples) return ds outputs = [] for task in t5.data.get_subtasks(mixture_or_task): if dataset_split not in task.splits: logging.info("Task %s has no '%s' split, skipping eval.", task.name, dataset_split) continue outputs.append( transformer_dataset.EvalDataset( task.name, functools.partial(_get_dataset_for_single_task, task=task, sequence_length=sequence_length), task.postprocess_fn, task.metric_fns, )) if not outputs: logging.warning("No %s data found for %s.", dataset_split, mixture_or_task_name) return outputs
def mesh_inference_dataset_fn(mixture_or_task_name, sequence_length, dataset_split, shuffle=False, seed=None, vocabulary=None, num_inference_examples=-1, use_cached=False, priming_sequence_length=None): """Returns all tf.data.Datasets for LM inference on a given mixture. For Tasks without inputs (such as language modeling), the first `priming_sequence_length` tokens in the target are used as the "inputs" for inference. Args: mixture_or_task_name: string, an identifier for a Mixture or Task in the appropriate registry. Must be specified via gin. sequence_length: dict mapping feature key to the int length for that feature the max sequence length. If set to None, packing and padding will be disabled. dataset_split: string, which split of the dataset to load. NOTE, this function does NOT receive the split specified in utils.run. It needs to be specified separately. shuffle: Whether or not to shuffle dataset. seed: tf.int64 scalar tf.Tensor (or None). Used as shuffle seed for tf.data. vocabulary: unused argument, maintains compatibility with other dataaset_fns num_inference_examples: maximum number of examples per task to do inference on. If None or less than 0, use all examples. use_cached: bool, whether to load the cached version of this dataset. evals but should not be used for iterative decoding. priming_sequence_length: If the Task only has "targets", select the first this many tokens from each target sequence to use as "inputs". This is useful for decoder-only language models where you would like to use a portion of the targets as a priming sequence for generation. Returns: A list of mesh_tensorflow.transformer.dataset.EvalDataset tuples. """ del vocabulary mixture_or_task = t5.data.get_mixture_or_task(mixture_or_task_name) def _split_targets_for_primed_inference(ex): ex["inputs"] = ex["targets"][:priming_sequence_length] ex["targets"] = ex["targets"][priming_sequence_length:] ex["inputs"] = tf.pad( ex["inputs"], [[0, priming_sequence_length - tf.shape(ex["inputs"])[0]]], "CONSTANT") ex["inputs"] = tf.reshape(ex["inputs"], shape=(priming_sequence_length, )) return ex def _prepare_for_unprimed_inference(ex): ex["inputs"] = tf.constant([], dtype=tf.int64) return ex def _get_dataset_for_single_task(task, sequence_length): """Get a tensorflow.data.Dataset for the provided task.""" ds = task.get_dataset(sequence_length, split=dataset_split, use_cached=use_cached, shuffle=shuffle, seed=seed) if "inputs" not in ds.element_spec: if not priming_sequence_length or priming_sequence_length <= 0: logging.warning( "Priming sequence length not specified so priming " "with the empty string.") ds = ds.map(_prepare_for_unprimed_inference) else: logging.info( "Using the first %d tokens of each target as input.", priming_sequence_length) ds = ds.map(_split_targets_for_primed_inference) elif priming_sequence_length is not None: raise ValueError( "Setting a priming sequence length only makes sense for decoder-only " "Tasks, which have `targets` but no `inputs`.") eos_keys = set(k for k, f in mixture_or_task.output_features.items() if f.add_eos) logging.info("Padding '%s' with sequence lengths: %s", task.name, sequence_length) ds = transformer_dataset.pack_or_pad(ds, sequence_length, pack=False, feature_keys=tuple( task.output_features), ensure_eos=eos_keys) if num_inference_examples is not None and num_inference_examples >= 0: ds = ds.take(num_inference_examples) return ds outputs = [] for task in t5.data.get_subtasks(mixture_or_task): if dataset_split not in task.splits: logging.info("Task %s has no '%s' split, skipping inference.", task.name, dataset_split) continue outputs.append( transformer_dataset.EvalDataset( task.name, functools.partial(_get_dataset_for_single_task, task=task, sequence_length=sequence_length), task.postprocess_fn, task.metric_fns, )) if not outputs: logging.warning("No %s data found for %s.", dataset_split, mixture_or_task_name) return outputs
def mesh_eval_dataset_fn( mixture_or_task_name, sequence_length, vocabulary, dataset_split, num_eval_examples=None, use_cached=False, pack=False, shuffle_eval_examples=False, shuffle_buffer_size=t5.data.SHUFFLE_BUFFER_SIZE): """Returns all tf.data.Datasets for evaluation on a given mixture. This uses the format required for utils.run's `eval_dataset_fn` argument in the Mesh TF transformer standalone. Args: mixture_or_task_name: string, an identifier for a Mixture or Task in the appropriate registry. Must be specified via gin. sequence_length: dict mapping feature key to the int length for that feature the max sequence length. vocabulary: a t5.data.vocabularies.Vocabulary. dataset_split: string, which split of the dataset to load. num_eval_examples: maximum number of examples per task to use for continuous eval. If None, use all examples. use_cached: bool, whether to load the cached version of this dataset. pack: a boolean, whether to pack examples. This is useful for perplexity evals but should not be used for iterative decoding. shuffle_eval_examples: boolean, whether to shuffle eval examples, applied only when num_eval_examples is not None. Intended to be able to eval on a different eval slice at every iteration. shuffle_buffer_size: integer - the shuffle buffer size if we shuffle eval examples, ideally this should be some large multiple of `num_eval_examples` to ensure good mixing and random batches. Returns: A list of mesh_tensorflow.transformer.dataset.EvalDataset tuples. """ valid_vocabulary(vocabulary) mixture_or_task = t5.data.get_mixture_or_task(mixture_or_task_name) def _get_dataset_for_single_task(task): """Get a tensorflow.data.Dataset for the provided task.""" ds = task.get_dataset( sequence_length, split=dataset_split, use_cached=use_cached, shuffle=False ) eos_keys = set( k for k, f in mixture_or_task.output_features.items() if f.add_eos) ds = transformer_dataset.pack_or_pad( ds, sequence_length, pack=pack, feature_keys=tuple(task.output_features), ensure_eos=eos_keys) ds = maybe_shuffle_and_subsample_dataset( ds, num_eval_examples, shuffle_eval_examples, shuffle_buffer_size) return ds outputs = [] for task in t5.data.get_subtasks(mixture_or_task): if dataset_split not in task.splits: logging.info( "Task %s has no '%s' split, skipping eval.", task.name, dataset_split ) continue outputs.append( transformer_dataset.EvalDataset( task.name, functools.partial(_get_dataset_for_single_task, task), task.postprocess_fn, task.metric_fns, ) ) return outputs
def eval_model_ll(estimator, vocabulary, sequence_length, batch_size, dataset_split, model_dir, eval_dataset_fn, eval_summary_dir, eval_checkpoint_step, attribute_bit=True, unsupervised_attribute_transfer_metrics=True, control_code_bool=False): """Eval a Mesh-TF model. Args: estimator: Estimator object, created with the appropriate model_fn. vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary, targets_vocabulary) tuple sequence_length: a dict from feature-key to integer the (packed) sequence length, e.g. {"inputs": 512, "targets": 128} batch_size: an integer, global batch size dataset_split: a string model_dir: a string, directory with the model. eval_dataset_fn: A function returning a list of dataset.EvalDataset tuples. Must be provided for mode="eval". Should accept the following arguments: - sequence_length: an integer or a dict from feature-key to integer the (packed) sequence length, e.g. {"inputs": 512, "targets": 128} - vocabulary: Vocabulary instance to use for encoding. - dataset_split: str, which dataset split to load. dataset.EvalDataset tuples are namedtuples with the following fields: - name: string, the task name - dataset_fn: function which returns a tf.data.Dataset of tokenized and padded examples. Must not require any arguments and must include the feature keys 'inputs' and 'targets_plaintext'. - postprocess_fn: function which converts plaintext targets to values that can be processed by a `metric_fn`. - list_of_metric_fns: list of metric_name functions with the call signature `metric_fn(targets, predictions)` which returns a dict mapping submetric names to scalar values. TensorBoard summaries and other tags will be written out using the submetric names. eval_summary_dir: str, path to write TensorBoard events file summaries for eval. If None, use model_dir/eval_{split}. eval_checkpoint_step: int, list of ints, or None. If an int or list of ints, evaluation or inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None and mode="eval", run eval continuously waiting for new checkpoints via `tf.train.checkpoints_iterator`. """ if eval_dataset_fn is None: raise ValueError("Must provide eval_dataset_fn through gin for eval.") eval_datasets = eval_dataset_fn( sequence_length=sequence_length, vocabulary=vocabulary, dataset_split=dataset_split, ) valid_eval_datasets = [] for eval_dataset in eval_datasets: if not eval_dataset.metric_fns: tf.logging.info("Skipping %s because metric_fns is empty", eval_dataset.name) continue # Convert to EvalDataset tuple in case eval_dataset_fn returns raw tuples valid_eval_datasets.append( transformer_dataset.EvalDataset(*eval_dataset)) eval_datasets = valid_eval_datasets if not eval_datasets: tf.logging.info( "All provided EvalDatasets have metric_fns=[]; eval is not possible." ) return eval_summary_dir = eval_summary_dir or os.path.join( model_dir, "{}_eval".format(dataset_split)) summary_writer = tf.summary.FileWriter(eval_summary_dir) # Pre-load in all of the targets once before entering continuous eval loop cached_targets = {} cached_examples = {} if attribute_bit: cached_attributes_origin = {} # Need to create a separate graph for loading in plaintext targets # or else TF will complain that we modified the graph with tf.Graph().as_default(): for eval_dataset in eval_datasets: if eval_dataset.metric_fns: ds = eval_dataset.dataset_fn() # Create list of postprocessed text targets examples = [ex for ex in tfds.as_numpy(ds)] targets = [ eval_dataset.postprocess_fn( # pylint:disable=g-complex-comprehension tf.compat.as_text(ex["targets_plaintext"]), example=ex, is_target=True) for ex in examples ] if attribute_bit: attributes_origin = [ str(ex["attribute"][0] - 1) for ex in examples ] targets_filename = os.path.join( eval_summary_dir, "{}_targets".format(eval_dataset.name), ) write_lines_to_file(targets, targets_filename) cached_targets[eval_dataset.name] = targets cached_examples[eval_dataset.name] = examples if attribute_bit: cached_attributes_origin[ eval_dataset.name] = attributes_origin if attribute_bit: _INPUT_FEATURES_ll.append("attribute") if control_code_bool: # TODO check if everything is usefull... _INPUT_FEATURES_ll.extend([ "controlcode", "controlcode_position", "controlcode_segmentation", "controlcode_subsegmentation", "codeprefixedtargets", "codeprefixedtargets_position", "codeprefixedtargets_segmentation", "codeprefixedtargets_subsegmentation" ]) def input_fn(params): """Eval input function for estimator.""" del params # Concatenate all dataset inputs to only have to do one decode loop combined_ds = None for eval_dataset in eval_datasets: # Only cache targets for those tasks with eval functions provides if eval_dataset.metric_fns: ds = eval_dataset.dataset_fn() # Only pass those variables which will be used for decoding ds = ds.map( lambda x: {k: v for k, v in x.items() if k in _INPUT_FEATURES_ll}) combined_ds = ds if not combined_ds else combined_ds.concatenate( ds) combined_ds = combined_ds.batch(batch_size, drop_remainder=False) # Pad the final batch. combined_ds = transformer_dataset.trim_and_pad_dataset( combined_ds, length=batch_size) combined_ds = combined_ds.prefetch(tf.data.experimental.AUTOTUNE) return combined_ds checkpoint_paths = get_checkpoint_iterator(eval_checkpoint_step, model_dir) for checkpoint_path in checkpoint_paths: tf.logging.info("Checkpoint path %s" % checkpoint_path) global_step = int(get_step_from_checkpoint_path(checkpoint_path)) if global_step == 0: continue decodes = decode(estimator, input_fn, vocabulary, checkpoint_path) for eval_dataset in eval_datasets: # Extract the portion of decodes corresponding to this dataset examples = cached_examples[eval_dataset.name] dataset_size = len(examples) predictions = [ eval_dataset.postprocess_fn(tf.compat.as_text(d), example=ex) for d, ex in zip(decodes[:dataset_size], examples) ] # Remove the used decodes. del decodes[:dataset_size] global_step = int(get_step_from_checkpoint_path(checkpoint_path)) predictions_filename = os.path.join( eval_summary_dir, "{}_{}_predictions".format(eval_dataset.name, global_step), ) write_lines_to_file_ll(predictions, predictions_filename) for metric_fn in eval_dataset.metric_fns: summary = tf.Summary() targets = cached_targets[eval_dataset.name] if unsupervised_attribute_transfer_metrics and attribute_bit: attributes_origin = cached_attributes_origin[ eval_dataset.name] metric_result = metric_fn( targets, predictions, attributes_origin=attributes_origin) else: metric_result = metric_fn(targets, predictions) for metric_name, metric_value in metric_result.items(): tag = "eval/{}/{}".format(eval_dataset.name, metric_name) tf.logging.info("%s at step %d: %.3f", tag, global_step, metric_value) summary.value.add(tag=tag, simple_value=metric_value) summary_writer.add_summary(summary, global_step) summary_writer.flush() # Only padding should remain. expected_pad = -sum(len(t) for t in cached_targets.values()) % batch_size if len(decodes) != expected_pad: raise ValueError("{} padded decodes, {} expected.".format( len(decodes), expected_pad))
def mesh_eval_dataset_fn(mixture_or_task_name, sequence_length, vocabulary, dataset_split, num_eval_examples=None, use_cached=False, pack=False): """Returns all tf.data.Datasets for evaluation on a given mixture. This uses the format required for utils.run's `eval_dataset_fn` argument in the Mesh TF transformer standalone. Args: mixture_or_task_name: string, an identifier for a Mixture or Task in the appropriate registry. Must be specified via gin. sequence_length: dict mapping feature key to the int length for that feature the max sequence length. vocabulary: a t5.data.vocabularies.Vocabulary. dataset_split: string, which split of the dataset to load. num_eval_examples: maximum number of examples per task to use for continuous eval. If None, use all examples. use_cached: bool, whether to load the cached version of this dataset. pack: a boolean, whether to pack examples. This is useful for perplexity evals but should not be used for iterative decoding. Returns: A list of mesh_tensorflow.transformer.dataset.EvalDataset tuples. """ valid_vocabulary(vocabulary) mixture_or_task = t5.data.get_mixture_or_task(mixture_or_task_name) def _get_dataset_for_single_task(task): """Get a tensorflow.data.Dataset for the provided task.""" ds = task.get_dataset(sequence_length, split=dataset_split, use_cached=use_cached, shuffle=False) if any(not f.add_eos for f in task.output_features.values()): warnings.warn( "pack_or_pad is being called with ensure_eos=True, but EOS is not " "being added to all features.") ds = transformer_dataset.pack_or_pad(ds, sequence_length, pack=pack, feature_keys=tuple( task.output_features), ensure_eos=True) if num_eval_examples is not None: ds = ds.take(num_eval_examples) return ds outputs = [] for task in t5.data.get_subtasks(mixture_or_task): if dataset_split not in task.splits: logging.info("Task %s has no '%s' split, skipping eval.", task.name, dataset_split) continue outputs.append( transformer_dataset.EvalDataset( task.name, functools.partial(_get_dataset_for_single_task, task), task.postprocess_fn, task.metric_fns, )) return outputs
def run(tpu_job_name, tpu, gcp_project, tpu_zone, model_dir, model_type="bitransformer", vocabulary=gin.REQUIRED, train_dataset_fn=None, eval_dataset_fn=None, dataset_split="train", autostack=True, checkpoint_step=None, mode="train", iterations_per_loop=100, save_checkpoints_steps=1000, keep_checkpoint_max=10, eval_summary_dir=None, batch_size=("tokens_per_replica", 2048), train_steps=auto_train_steps, sequence_length=gin.REQUIRED, mesh_shape=gin.REQUIRED, layout_rules=gin.REQUIRED, learning_rate_schedule=None, optimizer=None, predict_fn=None): """Run training/eval/inference. Args: tpu_job_name: string, name of TPU worker binary tpu: string, the Cloud TPU to use for training gcp_project: string, project name for the Cloud TPU-enabled project tpu_zone: string, GCE zone where the Cloud TPU is located in model_dir: string, estimator model_dir model_type: a string - either "bitransformer", "bi_student_teacher", lm" or "aligned" vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary, targets_vocabulary) tuple. train_dataset_fn: A function returning a tf.data.Dataset. Must be provided for mode="train". Should accept the following arguments: - batch_size: int, number of entries in each batch. - sequence_length: int, length of each packed or padded sequence. - vocabulary: Vocabulary instance to use for encoding. - dataset_split: str, which dataset split to load. eval_dataset_fn: A function returning a list of dataset.EvalDataset tuples. Must be provided for mode="eval". Should accept the following arguments: - batch_size: int, number of entries in each batch. - sequence_length: int, length of each packed or padded sequence. - vocabulary: Vocabulary instance to use for encoding. - dataset_split: str, which dataset split to load. dataset.EvalDataset tuples are namedtuples with the following fields: - name: string, the task name - dataset_fn: function which returns a tf.data.Dataset of tokenized and padded examples. Must not require any arguments and must include the feature keys 'inputs' and 'targets_plaintext'. - postprocess_fn: function which converts model outputs to evalable str - list_of_metric_fns: list of metric functions with the call signature `metric_fn(targets, predictions)` which return either a scalar value or a dict mapping submetric names to scalar values. TensorBoard summaries and other tags will be written out using `metric_fn.__name__`. - dataset_size: number of entries in the dataset. - padded_dataset_size: number of entries in the dataset after padding. dataset_split: a string autostack: boolean, internally combine variables checkpoint_step: int, list of ints, or None. Only used when mode="eval" or mode="infer". If an int or list of ints, evaluation or inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None and mode="eval", run eval continuously waiting for new checkpoints via `tf.contrib.training.checkpoints_iterator`. mode: string, train/eval/infer iterations_per_loop: integer, steps per train loop save_checkpoints_steps: integer, steps per checkpoint keep_checkpoint_max: an integer, keep up to this many checkpoints eval_summary_dir: str, path to write TensorBoard events file summaries for eval. If None, use model_dir/eval_{split}. batch_size: An integer or a (method, value) pair to pass to compute_batch_size(). Note that this is the global batch size and not the per-shard batch size. train_steps: An integer or a function with the same signature as auto_train_steps(). Total number of training steps. sequence_length: an integer mesh_shape: an input to mtf.convert_to_shape() layout_rules: an input to mtf.convert_to_layout_rules() learning_rate_schedule: an optional function taking the scalar name argument `step` and the numeric argument `total_train_steps` and return the scalar learning rate optimizer: a class extending optimize.Optimizer, required for training predict_fn: an optional function that can be used to override the default transformer prediction behavior. Must return a tensor of shape [batch_dim, length_dim] that will be the prediction for each example. Must accept the following arguments: - model: a Unitransformer or Bitransformer - features: a dict representing an example. Every value will be an mtf.Tensor with shape [batch_dim, length_dim]. - variable_dtype: an mtf.VariableDType """ if not isinstance(batch_size, int): batch_size = compute_batch_size( sequence_length, mesh_shape, layout_rules, batch_size) if not isinstance(train_steps, int): train_steps = train_steps(batch_size, sequence_length) if callable(learning_rate_schedule): learning_rate_schedule = functools.partial( learning_rate_schedule, total_train_steps=train_steps) tf.logging.info("model_type=%s" % model_type,) tf.logging.info("mode=%s" % mode,) tf.logging.info("sequence_length=%s" % sequence_length,) tf.logging.info("batch_size=%s" % batch_size,) tf.logging.info("train_steps=%s" % train_steps,) tf.logging.info("mesh_shape=%s" % mesh_shape,) tf.logging.info("layout_rules=%s" % layout_rules,) if mode == "train" and dataset_split != "train": raise ValueError("mode==\"train\" requires dataset_split==\"train\"") mesh_shape = mtf.convert_to_shape(mesh_shape) layout_rules = mtf.convert_to_layout_rules(layout_rules) cluster = tf.contrib.cluster_resolver.TPUClusterResolver( tpu if (tpu) else "", zone=tpu_zone, project=gcp_project) tf.logging.info( "Building TPUConfig with tpu_job_name={}".format(tpu_job_name) ) my_tpu_config = tpu_config.TPUConfig( tpu_job_name=tpu_job_name, iterations_per_loop=iterations_per_loop, num_cores_per_replica=1, per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST, ) run_config = tpu_config.RunConfig( cluster=cluster, model_dir=model_dir, tpu_config=my_tpu_config, # We use a saver hook, so disable checkpoints here to prevent double # saving. save_checkpoints_steps=None, save_checkpoints_secs=None) transformer_model = build_model( model_type=model_type, input_vocab_size=inputs_vocabulary(vocabulary).vocab_size, output_vocab_size=targets_vocabulary(vocabulary).vocab_size, layout_rules=layout_rules, mesh_shape=mesh_shape) model_fn = tpu_estimator_model_fn( model_type=model_type, transformer_model=transformer_model, model_dir=model_dir, use_tpu=tpu, mesh_shape=mesh_shape, layout_rules=layout_rules, batch_size=batch_size, sequence_length=sequence_length, autostack=autostack, learning_rate_schedule=learning_rate_schedule, keep_checkpoint_max=keep_checkpoint_max, save_checkpoints_steps=save_checkpoints_steps, optimizer=optimizer, predict_fn=predict_fn) estimator = tpu_estimator.TPUEstimator( model_fn=model_fn, config=run_config, train_batch_size=batch_size, eval_batch_size=batch_size, predict_batch_size=batch_size, use_tpu=tpu, export_to_tpu=False, params={}) if mode == "train": if train_dataset_fn is None: raise ValueError("Must provide train_dataset_fn through gin for train.") def input_fn(params): del params dataset = train_dataset_fn(batch_size=batch_size, sequence_length=sequence_length, vocabulary=vocabulary, dataset_split=dataset_split) return dataset estimator.train(input_fn=input_fn, max_steps=train_steps) elif mode == "eval": if eval_dataset_fn is None: raise ValueError("Must provide eval_dataset_fn through gin for eval.") eval_datasets = eval_dataset_fn( batch_size=batch_size, sequence_length=sequence_length, vocabulary=vocabulary, dataset_split=dataset_split, ) # Pre-load in all of the targets once before entering continuous eval loop cached_targets = {} # Need to create a separate graph for loading in plaintext targets # or else TF will complain that we modified the graph with tf.Graph().as_default(): for eval_dataset in eval_datasets: eval_dataset = transformer_dataset.EvalDataset(*eval_dataset) # Only cache targets for those tasks with eval functions provides if eval_dataset.metric_fns: ds = eval_dataset.dataset_fn() # De-batch the dataset ds = ds.flat_map(tf.data.Dataset.from_tensor_slices) ds = tfds.as_numpy(ds) targets = [ eval_dataset.postprocess_fn(d["targets_plaintext"]) for d in ds ] targets = targets[:eval_dataset.dataset_size] cached_targets[eval_dataset.name] = targets for checkpoint_path in get_checkpoint_iterator(checkpoint_step, model_dir): for eval_dataset in eval_datasets: eval_dataset = transformer_dataset.EvalDataset(*eval_dataset) if not eval_dataset.metric_fns: tf.logging.info( "Skipping %s because metric_fns is empty", eval_dataset.name ) continue metric_names = [metric.__name__ for metric in eval_dataset.metric_fns] tf.logging.info( "Evaluating %s on metrics %s", eval_dataset.name, metric_names ) tf.logging.info("on split %s", dataset_split) def input_fn(params): del params ds = eval_dataset.dataset_fn() # Only pass those variables which will be used for decoding ds = ds.map( lambda x: {k: v for k, v in x.items() if k in _INPUT_FEATURES} ) return ds decodes = decode( estimator, input_fn, eval_dataset.dataset_size, eval_dataset.padded_dataset_size, batch_size, vocabulary, checkpoint_path=checkpoint_path, ) predictions = [eval_dataset.postprocess_fn(d) for d in decodes] # TODO(craffel): Log predictions and targets. eval_summary_dir = eval_summary_dir or os.path.join( model_dir, "{}_eval".format(dataset_split) ) summary_writer = tf.summary.FileWriter(eval_summary_dir) global_step = int(get_step_from_checkpoint_path(checkpoint_path)) for metric_fn in eval_dataset.metric_fns: summary = tf.Summary() tag = "eval/{}/{}/{}".format( eval_dataset.name, dataset_split, metric_fn.__name__ ) targets = cached_targets[eval_dataset.name] metric_result = metric_fn(targets, predictions) if isinstance(metric_result, dict): tags = ["{}.{}".format(tag, key) for key in metric_result] metric_values = metric_result.values() else: tags, metric_values = [tag], [metric_result] for tag, metric_value in zip(tags, metric_values): tf.logging.info( "%s at step %d: %.3f", tag, global_step, metric_value ) summary.value.add(tag=tag, simple_value=metric_value) summary_writer.add_summary(summary, global_step) summary_writer.flush() elif mode == "infer": for checkpoint_path in get_checkpoint_iterator(checkpoint_step, model_dir): decode_from_file( estimator, vocabulary=vocabulary, model_type=model_type, batch_size=batch_size, sequence_length=sequence_length, checkpoint_path=checkpoint_path) else: raise ValueError( "unknown mode %s - must be train/eval/infer" % mode)