Exemple #1
0
    def test_make_seuence_value_info(self):  # type: () -> None
        tensor_type_proto = helper.make_tensor_type_proto(elem_type=2,
                                                          shape=None)
        sequence_type_proto = helper.make_sequence_type_proto(
            tensor_type_proto)
        sequence_val_info = helper.make_value_info(
            name='test', type_proto=sequence_type_proto)
        sequence_val_info_prim = helper.make_tensor_sequence_value_info(
            name='test', elem_type=2, shape=None)

        self.assertEqual(sequence_val_info, sequence_val_info_prim)
    def _test_op_upgrade(
        self,
        op: Text,
        from_opset: int,
        input_shapes: List[Union[List[Optional[int]], Text]] = [[3, 4, 5]],
        output_shapes: List[List[Optional[int]]] = [[3, 4, 5]],
        input_types: Union[List[Any], None] = None,
        output_types: Union[List[Any], None] = None,
        initializer: List[Any] = [],
        attrs: Dict[Text, Any] = {},
        seq_inputs: List[int] = [],
        seq_outputs: List[int] = [],
        optional_inputs: List[int] = [],
        optional_outputs: List[int] = []
    ) -> None:
        global tested_ops
        tested_ops.append(op)

        n_inputs = len(input_shapes)
        letters = list(string.ascii_lowercase)[:n_inputs]
        input_names = [
            letter if shape != '' else '' for (letter, shape) in zip(letters, input_shapes)
        ]
        if input_types is None:
            input_types = [TensorProto.FLOAT] * n_inputs
        is_sequence = [0 if id not in seq_inputs else 1 for id in range(n_inputs)]
        is_optional = [0 if id not in optional_inputs else 1 for id in range(n_inputs)]
        # turn empty strings into [0] to ease type analysis, even though those entries
        # will be ignored
        input_shapes_cast = cast(List[List[int]],
                [[0] if isinstance(shape, str) else shape for shape in input_shapes]
        )
        inputs: List[ValueInfoProto] = []
        for (name, ttype, shape, is_seq, is_opt) in \
                zip(input_names, input_types, input_shapes_cast, is_sequence, is_optional):
            if name != '':
                if is_seq:
                    inputs += [helper.make_tensor_sequence_value_info(name, ttype, shape)]
                elif is_opt:
                    type_proto = helper.make_tensor_type_proto(ttype, shape)
                    optional_type_proto = helper.make_optional_type_proto(type_proto)
                    inputs += [helper.make_value_info(name, optional_type_proto)]
                else:
                    inputs += [helper.make_tensor_value_info(name, ttype, shape)]

        n_outputs = len(output_shapes)
        output_names = list(string.ascii_lowercase)[n_inputs:n_inputs + n_outputs]
        if output_types is None:
            output_types = [TensorProto.FLOAT] * n_outputs
        is_sequence = [0 if id not in seq_outputs else 1 for id in range(n_outputs)]
        is_optional = [0 if id not in optional_outputs else 1 for id in range(n_outputs)]
        output_shapes_cast = cast(List[List[int]],
                [[0] if isinstance(shape, str) else shape for shape in output_shapes]
        )
        outputs: List[ValueInfoProto] = []
        for (name, ttype, shape, is_seq, is_opt) in \
                zip(output_names, output_types, output_shapes_cast, is_sequence, is_optional):
            if is_seq:
                outputs += [helper.make_tensor_sequence_value_info(name, ttype, shape)]
            elif is_opt:
                type_proto = helper.make_tensor_type_proto(ttype, shape)
                optional_type_proto = helper.make_optional_type_proto(type_proto)
                outputs += [helper.make_value_info(name, optional_type_proto)]
            else:
                outputs += [helper.make_tensor_value_info(name, ttype, shape)]

        node = helper.make_node(op, input_names, output_names, **attrs)
        graph = helper.make_graph([node], op, inputs, outputs, initializer)
        original = helper.make_model(
            graph,
            producer_name='test',
            opset_imports=[helper.make_opsetid('', from_opset)]
        )
        onnx.checker.check_model(original)
        shape_inference.infer_shapes(original, strict_mode=True)

        converted = version_converter.convert_version(original, latest_opset)
        onnx.checker.check_model(converted)
        shape_inference.infer_shapes(converted, strict_mode=True)