def decode_from_file_ll(estimator, vocabulary, model_type, batch_size, sequence_length, checkpoint_path=None, input_filename=gin.REQUIRED, output_filename=gin.REQUIRED, eos_id=1, repeats=1, control_codes_decode=None, attribute_embedding=False): """Decode from a text file and write to output_filename. Args: estimator: a TPUEstimator vocabulary: a mtf.transformer.vocabulary.Vocabulary model_type: a string batch_size: an integer sequence_length: an integer or a dict from feature-key to integer the (packed) sequence length, e.g. {"inputs": 512, "targets": 128} checkpoint_path: an optional string input_filename: a string output_filename: a string eos_id: EOS id repeats: an integer, the number of times to repeat each input. """ inputs_and_dst_attributes = get_inputs_from_file(input_filename) inputs_split = [ line.split("|dst_attribute:") for line in inputs_and_dst_attributes ] inputs = [] dst_attributes = [] control_code_strings = [] #for l in inputs_split: # inputs.append(l[0]) # dst_attributes.append(l[1]) # if l[1] == "1": # control_code_strings.append(target_prefix_style_1) # elif l[1] == "2": # control_code_strings.append(target_prefix_style_2) # else: # control_code_strings.append("") for l in inputs_split: inputs.append(l[0]) dst_attributes.append(l[1]) control_code_strings.append(control_codes_decode[int( l[1])]) # TODO: in the old example we shall remove 1... all_input_ids = encode_inputs(inputs, vocabulary, model_type, batch_size, sequence_length["inputs"], eos_id=eos_id) if control_codes_decode: all_controlcode_ids = encode_inputs(control_code_strings, vocabulary, "lm", batch_size, sequence_length["controlcode"], eos_id=eos_id) def input_fn(params): del params tensors = {"inputs": all_input_ids} if attribute_embedding: tensors["attribute"] = dst_attributes if control_codes_decode: tensors["controlcode"] = all_controlcode_ids dataset = tf.data.Dataset.from_tensor_slices(tensors) if attribute_embedding: dataset = process_attribute(dataset, mode="infer") dataset = dataset.flat_map( lambda x: tf.data.Dataset.from_tensors(x).repeat(repeats)) dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset checkpoint_step = get_step_from_checkpoint_path(checkpoint_path) decodes = decode(estimator, input_fn, vocabulary, checkpoint_path=checkpoint_path) # Remove any padded examples dataset_size = len(inputs) * repeats decodes = decodes[:dataset_size] output_filename = "{}-{}".format(output_filename, checkpoint_step) write_lines_to_file(decodes, output_filename)
def _predict_or_score_fn(self, tasks, vocabulary, checkpoint_step, sequence_length, examples, split, eval_with_score=False, **unused_kwargs): """Helper function used by eval method to generate predictions or scores. Args: tasks: list, list of valid tasks to generate predictions or scores. vocabulary: a t5.data.vocabulary object or a tuple with separate vocabularies for inputs and targets, checkpoint_step: integer, step to evaluate the tasks. sequence_length: a dict, dictionary with sequence length for inputs and targets. examples: dict, cached examples for each task. split: string, split to run the evaluation on. eval_with_score: bool, whether to compute log likelihood of targets instead of predictions. Returns: list of decoded predictions or scores depending on eval_with_score flag. """ estimator = self.estimator(vocabulary, score_in_predict_mode=eval_with_score, sequence_length=sequence_length) def estimator_input_fn(params): """Eval input function for estimator.""" del params # Concatenate all dataset inputs to only have to do one decode loop combined_ds = None for task in tasks: ds = t5.models.mesh_transformer.mesh_eval_dataset_fn( mixture_or_task_name=task.name, sequence_length=sequence_length, dataset_split=split)[0].dataset_fn() ds = ds.map(utils.filter_features, num_parallel_calls=tf.data.experimental.AUTOTUNE) combined_ds = ds if not combined_ds else combined_ds.concatenate( ds) combined_ds = combined_ds.batch(self.batch_size, drop_remainder=False) # pytype:disable=attribute-error # Pad the final batch. combined_ds = transformer_dataset.trim_and_pad_dataset( combined_ds, length=self.batch_size) combined_ds = combined_ds.prefetch(tf.data.experimental.AUTOTUNE) return combined_ds checkpoint_path = os.path.join(self._model_dir, "model.ckpt-{}".format(checkpoint_step)) if eval_with_score: outputs, _ = mtf_utils.score_with_estimator( estimator, estimator_input_fn, checkpoint_step, self._model_dir, vocabulary, num_examples=sum(len(cex) for cex in examples.values())) else: outputs = [ tf.compat.as_text(d) for d in mtf_utils.decode( estimator, estimator_input_fn, vocabulary, checkpoint_path) ] return outputs
def eval_model_ll(estimator, vocabulary, sequence_length, batch_size, dataset_split, model_dir, eval_dataset_fn, eval_summary_dir, eval_checkpoint_step, attribute_bit=True, unsupervised_attribute_transfer_metrics=True, control_code_bool=False): """Eval a Mesh-TF model. Args: estimator: Estimator object, created with the appropriate model_fn. vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary, targets_vocabulary) tuple sequence_length: a dict from feature-key to integer the (packed) sequence length, e.g. {"inputs": 512, "targets": 128} batch_size: an integer, global batch size dataset_split: a string model_dir: a string, directory with the model. eval_dataset_fn: A function returning a list of dataset.EvalDataset tuples. Must be provided for mode="eval". Should accept the following arguments: - sequence_length: an integer or a dict from feature-key to integer the (packed) sequence length, e.g. {"inputs": 512, "targets": 128} - vocabulary: Vocabulary instance to use for encoding. - dataset_split: str, which dataset split to load. dataset.EvalDataset tuples are namedtuples with the following fields: - name: string, the task name - dataset_fn: function which returns a tf.data.Dataset of tokenized and padded examples. Must not require any arguments and must include the feature keys 'inputs' and 'targets_plaintext'. - postprocess_fn: function which converts plaintext targets to values that can be processed by a `metric_fn`. - list_of_metric_fns: list of metric_name functions with the call signature `metric_fn(targets, predictions)` which returns a dict mapping submetric names to scalar values. TensorBoard summaries and other tags will be written out using the submetric names. eval_summary_dir: str, path to write TensorBoard events file summaries for eval. If None, use model_dir/eval_{split}. eval_checkpoint_step: int, list of ints, or None. If an int or list of ints, evaluation or inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None and mode="eval", run eval continuously waiting for new checkpoints via `tf.train.checkpoints_iterator`. """ if eval_dataset_fn is None: raise ValueError("Must provide eval_dataset_fn through gin for eval.") eval_datasets = eval_dataset_fn( sequence_length=sequence_length, vocabulary=vocabulary, dataset_split=dataset_split, ) valid_eval_datasets = [] for eval_dataset in eval_datasets: if not eval_dataset.metric_fns: tf.logging.info("Skipping %s because metric_fns is empty", eval_dataset.name) continue # Convert to EvalDataset tuple in case eval_dataset_fn returns raw tuples valid_eval_datasets.append( transformer_dataset.EvalDataset(*eval_dataset)) eval_datasets = valid_eval_datasets if not eval_datasets: tf.logging.info( "All provided EvalDatasets have metric_fns=[]; eval is not possible." ) return eval_summary_dir = eval_summary_dir or os.path.join( model_dir, "{}_eval".format(dataset_split)) summary_writer = tf.summary.FileWriter(eval_summary_dir) # Pre-load in all of the targets once before entering continuous eval loop cached_targets = {} cached_examples = {} if attribute_bit: cached_attributes_origin = {} # Need to create a separate graph for loading in plaintext targets # or else TF will complain that we modified the graph with tf.Graph().as_default(): for eval_dataset in eval_datasets: if eval_dataset.metric_fns: ds = eval_dataset.dataset_fn() # Create list of postprocessed text targets examples = [ex for ex in tfds.as_numpy(ds)] targets = [ eval_dataset.postprocess_fn( # pylint:disable=g-complex-comprehension tf.compat.as_text(ex["targets_plaintext"]), example=ex, is_target=True) for ex in examples ] if attribute_bit: attributes_origin = [ str(ex["attribute"][0] - 1) for ex in examples ] targets_filename = os.path.join( eval_summary_dir, "{}_targets".format(eval_dataset.name), ) write_lines_to_file(targets, targets_filename) cached_targets[eval_dataset.name] = targets cached_examples[eval_dataset.name] = examples if attribute_bit: cached_attributes_origin[ eval_dataset.name] = attributes_origin if attribute_bit: _INPUT_FEATURES_ll.append("attribute") if control_code_bool: # TODO check if everything is usefull... _INPUT_FEATURES_ll.extend([ "controlcode", "controlcode_position", "controlcode_segmentation", "controlcode_subsegmentation", "codeprefixedtargets", "codeprefixedtargets_position", "codeprefixedtargets_segmentation", "codeprefixedtargets_subsegmentation" ]) def input_fn(params): """Eval input function for estimator.""" del params # Concatenate all dataset inputs to only have to do one decode loop combined_ds = None for eval_dataset in eval_datasets: # Only cache targets for those tasks with eval functions provides if eval_dataset.metric_fns: ds = eval_dataset.dataset_fn() # Only pass those variables which will be used for decoding ds = ds.map( lambda x: {k: v for k, v in x.items() if k in _INPUT_FEATURES_ll}) combined_ds = ds if not combined_ds else combined_ds.concatenate( ds) combined_ds = combined_ds.batch(batch_size, drop_remainder=False) # Pad the final batch. combined_ds = transformer_dataset.trim_and_pad_dataset( combined_ds, length=batch_size) combined_ds = combined_ds.prefetch(tf.data.experimental.AUTOTUNE) return combined_ds checkpoint_paths = get_checkpoint_iterator(eval_checkpoint_step, model_dir) for checkpoint_path in checkpoint_paths: tf.logging.info("Checkpoint path %s" % checkpoint_path) global_step = int(get_step_from_checkpoint_path(checkpoint_path)) if global_step == 0: continue decodes = decode(estimator, input_fn, vocabulary, checkpoint_path) for eval_dataset in eval_datasets: # Extract the portion of decodes corresponding to this dataset examples = cached_examples[eval_dataset.name] dataset_size = len(examples) predictions = [ eval_dataset.postprocess_fn(tf.compat.as_text(d), example=ex) for d, ex in zip(decodes[:dataset_size], examples) ] # Remove the used decodes. del decodes[:dataset_size] global_step = int(get_step_from_checkpoint_path(checkpoint_path)) predictions_filename = os.path.join( eval_summary_dir, "{}_{}_predictions".format(eval_dataset.name, global_step), ) write_lines_to_file_ll(predictions, predictions_filename) for metric_fn in eval_dataset.metric_fns: summary = tf.Summary() targets = cached_targets[eval_dataset.name] if unsupervised_attribute_transfer_metrics and attribute_bit: attributes_origin = cached_attributes_origin[ eval_dataset.name] metric_result = metric_fn( targets, predictions, attributes_origin=attributes_origin) else: metric_result = metric_fn(targets, predictions) for metric_name, metric_value in metric_result.items(): tag = "eval/{}/{}".format(eval_dataset.name, metric_name) tf.logging.info("%s at step %d: %.3f", tag, global_step, metric_value) summary.value.add(tag=tag, simple_value=metric_value) summary_writer.add_summary(summary, global_step) summary_writer.flush() # Only padding should remain. expected_pad = -sum(len(t) for t in cached_targets.values()) % batch_size if len(decodes) != expected_pad: raise ValueError("{} padded decodes, {} expected.".format( len(decodes), expected_pad))