Esempio n. 1
0
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
    """
    Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model.
    Args:
        nlp: The pipeline object holding the model to be exported
        framework: The framework identifier to dispatch to the correct inference scheme (pt/tf)

    Returns:
        - List of the inferred input variable names
        - List of the inferred output variable names
        - Dictionary with input/output variables names as key and shape tensor as value
        - a BatchEncoding reference which was used to infer all the above information
    """

    def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):
        if isinstance(tensor, (tuple, list)):
            return [build_shape_dict(name, t, is_input, seq_len) for t in tensor]

        else:
            # Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)
            axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: "batch"}
            if is_input:
                if len(tensor.shape) == 2:
                    axes[1] = "sequence"
                else:
                    raise ValueError(f"Unable to infer tensor axes ({len(tensor.shape)})")
            else:
                seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
                axes.update({dim: "sequence" for dim in seq_axes})

        print(f"Found {'input' if is_input else 'output'} {name} with shape: {axes}")
        return axes

    tokens = nlp.tokenizer("This is a sample output", return_tensors=framework)
    seq_len = tokens.input_ids.shape[-1]
    outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens)
    if isinstance(outputs, ModelOutput):
        outputs = outputs.to_tuple()
    if not isinstance(outputs, (list, tuple)):
        outputs = (outputs,)

    # Generate input names & axes
    input_vars = list(tokens.keys())
    input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()}

    # flatten potentially grouped outputs (past for gpt2, attentions)
    outputs_flat = []
    for output in outputs:
        if isinstance(output, (tuple, list)):
            outputs_flat.extend(output)
        else:
            outputs_flat.append(output)

    # Generate output names & axes
    output_names = [f"output_{i}" for i in range(len(outputs_flat))]
    output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)}

    # Create the aggregated axes representation
    dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)
    return input_vars, output_names, dynamic_axes, tokens
Esempio n. 2
0
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
    def build_shape_dict(tensor, is_input: bool, seq_len: int):
        if isinstance(tensor, (tuple, list)):
            return [build_shape_dict(t, is_input, seq_len) for t in tensor]

        else:
            # Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)
            axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: "batch"}
            if is_input:
                if len(tensor.shape) == 2:
                    axes[1] = "sequence"
                else:
                    raise ValueError("Unable to infer tensor axes ({})".format(len(tensor.shape)))
            else:
                seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
                axes.update({dim: "sequence" for dim in seq_axes})

        return axes

    tokens = nlp.tokenizer.encode_plus("This is a sample output", return_tensors=framework)
    seq_len = tokens.input_ids.shape[-1]
    outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens)

    if not isinstance(outputs, (list, tuple)):
        outputs = (outputs,)

    # Generate input names & axes
    input_vars = list(tokens.keys())
    input_dynamic_axes = {k: build_shape_dict(v, True, seq_len) for k, v in tokens.items()}

    # flatten potentially grouped outputs (past for gpt2, attentions)
    outputs_flat = []
    for output in outputs:
        if isinstance(output, (tuple, list)):
            outputs_flat.extend(output)
        else:
            outputs_flat.append(output)

    # Generate output names & axes
    output_names = ["output_{}".format(i) for i in range(len(outputs_flat))]
    output_dynamic_axes = {k: build_shape_dict(v, False, seq_len) for k, v in zip(output_names, outputs_flat)}

    # Create the aggregated axes representation
    dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)
    return input_vars, output_names, dynamic_axes, tokens
    def _test_pipeline(self, nlp: Pipeline):
        output_keys = {"entity", "word", "score"}
        if nlp.grouped_entities:
            output_keys = {"entity_group", "word", "score"}

        ungrouped_ner_inputs = [
            [
                {"entity": "B-PER", "index": 1, "score": 0.9994944930076599, "is_subword": False, "word": "Cons"},
                {"entity": "B-PER", "index": 2, "score": 0.8025449514389038, "is_subword": True, "word": "##uelo"},
                {"entity": "I-PER", "index": 3, "score": 0.9993102550506592, "is_subword": False, "word": "Ara"},
                {"entity": "I-PER", "index": 4, "score": 0.9993743896484375, "is_subword": True, "word": "##új"},
                {"entity": "I-PER", "index": 5, "score": 0.9992871880531311, "is_subword": True, "word": "##o"},
                {"entity": "I-PER", "index": 6, "score": 0.9993029236793518, "is_subword": False, "word": "No"},
                {"entity": "I-PER", "index": 7, "score": 0.9981776475906372, "is_subword": True, "word": "##guera"},
                {"entity": "B-PER", "index": 15, "score": 0.9998136162757874, "is_subword": False, "word": "Andrés"},
                {"entity": "I-PER", "index": 16, "score": 0.999740719795227, "is_subword": False, "word": "Pas"},
                {"entity": "I-PER", "index": 17, "score": 0.9997414350509644, "is_subword": True, "word": "##tran"},
                {"entity": "I-PER", "index": 18, "score": 0.9996136426925659, "is_subword": True, "word": "##a"},
                {"entity": "B-ORG", "index": 28, "score": 0.9989739060401917, "is_subword": False, "word": "Far"},
                {"entity": "I-ORG", "index": 29, "score": 0.7188422083854675, "is_subword": True, "word": "##c"},
            ],
            [
                {"entity": "I-PER", "index": 1, "score": 0.9968166351318359, "is_subword": False, "word": "En"},
                {"entity": "I-PER", "index": 2, "score": 0.9957635998725891, "is_subword": True, "word": "##zo"},
                {"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "is_subword": False, "word": "UN"},
            ],
        ]

        expected_grouped_ner_results = [
            [
                {"entity_group": "PER", "score": 0.999369223912557, "word": "Consuelo Araújo Noguera"},
                {"entity_group": "PER", "score": 0.9997771680355072, "word": "Andrés Pastrana"},
                {"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc"},
            ],
            [
                {"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo"},
                {"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"},
            ],
        ]

        expected_grouped_ner_results_w_subword = [
            [
                {"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons"},
                {"entity_group": "PER", "score": 0.9663328925768534, "word": "##uelo Araújo Noguera"},
                {"entity_group": "PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"},
                {"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc"},
            ],
            [
                {"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo"},
                {"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"},
            ],
        ]

        self.assertIsNotNone(nlp)

        mono_result = nlp(VALID_INPUTS[0])
        self.assertIsInstance(mono_result, list)
        self.assertIsInstance(mono_result[0], (dict, list))

        if isinstance(mono_result[0], list):
            mono_result = mono_result[0]

        for key in output_keys:
            self.assertIn(key, mono_result[0])

        multi_result = [nlp(input) for input in VALID_INPUTS]
        self.assertIsInstance(multi_result, list)
        self.assertIsInstance(multi_result[0], (dict, list))

        if isinstance(multi_result[0], list):
            multi_result = multi_result[0]

        for result in multi_result:
            for key in output_keys:
                self.assertIn(key, result)

        if nlp.grouped_entities:
            if nlp.ignore_subwords:
                for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results):
                    self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result)
            else:
                for ungrouped_input, grouped_result in zip(
                    ungrouped_ner_inputs, expected_grouped_ner_results_w_subword
                ):
                    self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result)