Example #1
0
def decode_from_file_ll(estimator,
                        vocabulary,
                        model_type,
                        batch_size,
                        sequence_length,
                        checkpoint_path=None,
                        input_filename=gin.REQUIRED,
                        output_filename=gin.REQUIRED,
                        eos_id=1,
                        repeats=1,
                        control_codes_decode=None,
                        attribute_embedding=False):
    """Decode from a text file and write to output_filename.
    Args:
      estimator: a TPUEstimator
      vocabulary: a mtf.transformer.vocabulary.Vocabulary
      model_type: a string
      batch_size: an integer
      sequence_length: an integer or a dict from feature-key to integer
        the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
      checkpoint_path: an optional string
      input_filename: a string
      output_filename: a string
      eos_id: EOS id
      repeats: an integer, the number of times to repeat each input.
    """
    inputs_and_dst_attributes = get_inputs_from_file(input_filename)

    inputs_split = [
        line.split("|dst_attribute:") for line in inputs_and_dst_attributes
    ]

    inputs = []
    dst_attributes = []
    control_code_strings = []
    #for l in inputs_split:
    #    inputs.append(l[0])
    #    dst_attributes.append(l[1])
    #    if l[1] == "1":
    #        control_code_strings.append(target_prefix_style_1)
    #    elif l[1] == "2":
    #        control_code_strings.append(target_prefix_style_2)
    #    else:
    #        control_code_strings.append("")

    for l in inputs_split:
        inputs.append(l[0])
        dst_attributes.append(l[1])
        control_code_strings.append(control_codes_decode[int(
            l[1])])  # TODO: in the old example we shall remove 1...

    all_input_ids = encode_inputs(inputs,
                                  vocabulary,
                                  model_type,
                                  batch_size,
                                  sequence_length["inputs"],
                                  eos_id=eos_id)
    if control_codes_decode:
        all_controlcode_ids = encode_inputs(control_code_strings,
                                            vocabulary,
                                            "lm",
                                            batch_size,
                                            sequence_length["controlcode"],
                                            eos_id=eos_id)

    def input_fn(params):
        del params

        tensors = {"inputs": all_input_ids}
        if attribute_embedding:
            tensors["attribute"] = dst_attributes
        if control_codes_decode:
            tensors["controlcode"] = all_controlcode_ids

        dataset = tf.data.Dataset.from_tensor_slices(tensors)
        if attribute_embedding:
            dataset = process_attribute(dataset, mode="infer")
        dataset = dataset.flat_map(
            lambda x: tf.data.Dataset.from_tensors(x).repeat(repeats))
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
        return dataset

    checkpoint_step = get_step_from_checkpoint_path(checkpoint_path)
    decodes = decode(estimator,
                     input_fn,
                     vocabulary,
                     checkpoint_path=checkpoint_path)
    # Remove any padded examples
    dataset_size = len(inputs) * repeats
    decodes = decodes[:dataset_size]
    output_filename = "{}-{}".format(output_filename, checkpoint_step)
    write_lines_to_file(decodes, output_filename)
Example #2
0
    def _predict_or_score_fn(self,
                             tasks,
                             vocabulary,
                             checkpoint_step,
                             sequence_length,
                             examples,
                             split,
                             eval_with_score=False,
                             **unused_kwargs):
        """Helper function used by eval method to generate predictions or scores.

    Args:
      tasks: list, list of valid tasks to generate predictions or scores.
      vocabulary: a t5.data.vocabulary object or a tuple with separate
        vocabularies for inputs and targets,
      checkpoint_step: integer, step to evaluate the tasks.
      sequence_length: a dict, dictionary with sequence length for inputs and
        targets.
      examples: dict, cached examples for each task.
      split: string, split to run the evaluation on.
      eval_with_score: bool, whether to compute log likelihood of targets
        instead of predictions.
    Returns:
      list of decoded predictions or scores depending on eval_with_score flag.
    """
        estimator = self.estimator(vocabulary,
                                   score_in_predict_mode=eval_with_score,
                                   sequence_length=sequence_length)

        def estimator_input_fn(params):
            """Eval input function for estimator."""
            del params
            # Concatenate all dataset inputs to only have to do one decode loop
            combined_ds = None
            for task in tasks:
                ds = t5.models.mesh_transformer.mesh_eval_dataset_fn(
                    mixture_or_task_name=task.name,
                    sequence_length=sequence_length,
                    dataset_split=split)[0].dataset_fn()
                ds = ds.map(utils.filter_features,
                            num_parallel_calls=tf.data.experimental.AUTOTUNE)
                combined_ds = ds if not combined_ds else combined_ds.concatenate(
                    ds)
            combined_ds = combined_ds.batch(self.batch_size,
                                            drop_remainder=False)  # pytype:disable=attribute-error
            # Pad the final batch.
            combined_ds = transformer_dataset.trim_and_pad_dataset(
                combined_ds, length=self.batch_size)
            combined_ds = combined_ds.prefetch(tf.data.experimental.AUTOTUNE)
            return combined_ds

        checkpoint_path = os.path.join(self._model_dir,
                                       "model.ckpt-{}".format(checkpoint_step))
        if eval_with_score:
            outputs, _ = mtf_utils.score_with_estimator(
                estimator,
                estimator_input_fn,
                checkpoint_step,
                self._model_dir,
                vocabulary,
                num_examples=sum(len(cex) for cex in examples.values()))
        else:
            outputs = [
                tf.compat.as_text(d) for d in mtf_utils.decode(
                    estimator, estimator_input_fn, vocabulary, checkpoint_path)
            ]

        return outputs
Example #3
0
def eval_model_ll(estimator,
                  vocabulary,
                  sequence_length,
                  batch_size,
                  dataset_split,
                  model_dir,
                  eval_dataset_fn,
                  eval_summary_dir,
                  eval_checkpoint_step,
                  attribute_bit=True,
                  unsupervised_attribute_transfer_metrics=True,
                  control_code_bool=False):
    """Eval a Mesh-TF model.
    Args:
      estimator: Estimator object, created with the appropriate model_fn.
      vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
        targets_vocabulary) tuple
      sequence_length: a dict from feature-key to integer the (packed)
        sequence length, e.g. {"inputs": 512, "targets": 128}
      batch_size: an integer, global batch size
      dataset_split: a string
      model_dir: a string, directory with the model.
      eval_dataset_fn: A function returning a list of dataset.EvalDataset tuples.
        Must be provided for mode="eval". Should accept the following arguments:
          - sequence_length: an integer or a dict from feature-key to integer
            the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}
          - vocabulary: Vocabulary instance to use for encoding.
          - dataset_split: str, which dataset split to load.
        dataset.EvalDataset tuples are namedtuples with the following fields:
          - name: string, the task name
          - dataset_fn: function which returns a tf.data.Dataset of tokenized and
            padded examples. Must not require any arguments and must include the
            feature keys 'inputs' and 'targets_plaintext'.
          - postprocess_fn: function which converts plaintext targets to values
            that can be processed by a `metric_fn`.
          - list_of_metric_fns: list of metric_name functions with the call signature
            `metric_fn(targets, predictions)` which returns a dict mapping
            submetric names to scalar values. TensorBoard summaries and other tags
            will be written out using the submetric names.
      eval_summary_dir: str, path to write TensorBoard events file summaries for
        eval. If None, use model_dir/eval_{split}.
      eval_checkpoint_step: int, list of ints, or None. If an int or list of ints,
        evaluation or inference will be run on the checkpoint files in `model_dir`
        whose global steps are closest to the global steps provided. If None and
        mode="eval", run eval continuously waiting for new checkpoints via
        `tf.train.checkpoints_iterator`.
    """
    if eval_dataset_fn is None:
        raise ValueError("Must provide eval_dataset_fn through gin for eval.")

    eval_datasets = eval_dataset_fn(
        sequence_length=sequence_length,
        vocabulary=vocabulary,
        dataset_split=dataset_split,
    )

    valid_eval_datasets = []
    for eval_dataset in eval_datasets:
        if not eval_dataset.metric_fns:
            tf.logging.info("Skipping %s because metric_fns is empty",
                            eval_dataset.name)
            continue
        # Convert to EvalDataset tuple in case eval_dataset_fn returns raw tuples
        valid_eval_datasets.append(
            transformer_dataset.EvalDataset(*eval_dataset))
    eval_datasets = valid_eval_datasets

    if not eval_datasets:
        tf.logging.info(
            "All provided EvalDatasets have metric_fns=[]; eval is not possible."
        )
        return

    eval_summary_dir = eval_summary_dir or os.path.join(
        model_dir, "{}_eval".format(dataset_split))
    summary_writer = tf.summary.FileWriter(eval_summary_dir)

    # Pre-load in all of the targets once before entering continuous eval loop
    cached_targets = {}
    cached_examples = {}
    if attribute_bit:
        cached_attributes_origin = {}
    # Need to create a separate graph for loading in plaintext targets
    # or else TF will complain that we modified the graph
    with tf.Graph().as_default():
        for eval_dataset in eval_datasets:
            if eval_dataset.metric_fns:
                ds = eval_dataset.dataset_fn()
                # Create list of postprocessed text targets
                examples = [ex for ex in tfds.as_numpy(ds)]
                targets = [
                    eval_dataset.postprocess_fn(  # pylint:disable=g-complex-comprehension
                        tf.compat.as_text(ex["targets_plaintext"]),
                        example=ex,
                        is_target=True) for ex in examples
                ]

                if attribute_bit:
                    attributes_origin = [
                        str(ex["attribute"][0] - 1) for ex in examples
                    ]

                targets_filename = os.path.join(
                    eval_summary_dir,
                    "{}_targets".format(eval_dataset.name),
                )
                write_lines_to_file(targets, targets_filename)
                cached_targets[eval_dataset.name] = targets
                cached_examples[eval_dataset.name] = examples
                if attribute_bit:
                    cached_attributes_origin[
                        eval_dataset.name] = attributes_origin

    if attribute_bit:
        _INPUT_FEATURES_ll.append("attribute")

    if control_code_bool:  # TODO check if everything is usefull...
        _INPUT_FEATURES_ll.extend([
            "controlcode", "controlcode_position", "controlcode_segmentation",
            "controlcode_subsegmentation", "codeprefixedtargets",
            "codeprefixedtargets_position", "codeprefixedtargets_segmentation",
            "codeprefixedtargets_subsegmentation"
        ])

    def input_fn(params):
        """Eval input function for estimator."""
        del params
        # Concatenate all dataset inputs to only have to do one decode loop
        combined_ds = None
        for eval_dataset in eval_datasets:
            # Only cache targets for those tasks with eval functions provides
            if eval_dataset.metric_fns:
                ds = eval_dataset.dataset_fn()
                # Only pass those variables which will be used for decoding
                ds = ds.map(
                    lambda x:
                    {k: v
                     for k, v in x.items() if k in _INPUT_FEATURES_ll})
                combined_ds = ds if not combined_ds else combined_ds.concatenate(
                    ds)
        combined_ds = combined_ds.batch(batch_size, drop_remainder=False)
        # Pad the final batch.
        combined_ds = transformer_dataset.trim_and_pad_dataset(
            combined_ds, length=batch_size)
        combined_ds = combined_ds.prefetch(tf.data.experimental.AUTOTUNE)
        return combined_ds

    checkpoint_paths = get_checkpoint_iterator(eval_checkpoint_step, model_dir)
    for checkpoint_path in checkpoint_paths:
        tf.logging.info("Checkpoint path %s" % checkpoint_path)
        global_step = int(get_step_from_checkpoint_path(checkpoint_path))
        if global_step == 0:
            continue
        decodes = decode(estimator, input_fn, vocabulary, checkpoint_path)
        for eval_dataset in eval_datasets:
            # Extract the portion of decodes corresponding to this dataset
            examples = cached_examples[eval_dataset.name]
            dataset_size = len(examples)
            predictions = [
                eval_dataset.postprocess_fn(tf.compat.as_text(d), example=ex)
                for d, ex in zip(decodes[:dataset_size], examples)
            ]
            # Remove the used decodes.
            del decodes[:dataset_size]

            global_step = int(get_step_from_checkpoint_path(checkpoint_path))

            predictions_filename = os.path.join(
                eval_summary_dir,
                "{}_{}_predictions".format(eval_dataset.name, global_step),
            )
            write_lines_to_file_ll(predictions, predictions_filename)

            for metric_fn in eval_dataset.metric_fns:
                summary = tf.Summary()
                targets = cached_targets[eval_dataset.name]
                if unsupervised_attribute_transfer_metrics and attribute_bit:
                    attributes_origin = cached_attributes_origin[
                        eval_dataset.name]
                    metric_result = metric_fn(
                        targets,
                        predictions,
                        attributes_origin=attributes_origin)
                else:
                    metric_result = metric_fn(targets, predictions)
                for metric_name, metric_value in metric_result.items():
                    tag = "eval/{}/{}".format(eval_dataset.name, metric_name)
                    tf.logging.info("%s at step %d: %.3f", tag, global_step,
                                    metric_value)
                    summary.value.add(tag=tag, simple_value=metric_value)
                    summary_writer.add_summary(summary, global_step)
            summary_writer.flush()

        # Only padding should remain.
        expected_pad = -sum(len(t)
                            for t in cached_targets.values()) % batch_size
        if len(decodes) != expected_pad:
            raise ValueError("{} padded decodes, {} expected.".format(
                len(decodes), expected_pad))