示例#1
0
    def _make_rnn_node(cell_no, cell_info, scope, **kwargs):
        """Make RNN node.

    Args:
      cell_no: Cell No.
      cell_info: Cell info obj.
      scope: Name scope.
      **kwargs: Other args.

    Returns:
      RNN node.

    """
        node = TensorflowNode()
        node.op_type = cell_info["type"].upper()
        node.name = "/".join([
            scope, node.op_type if cell_no == 0 else node.op_type +
            "_{}".format(cell_no)
        ])
        node.inputs = [
            cell_info["inputs"]["prev_c"], cell_info["inputs"]["w_kernel"],
            cell_info["inputs"]["r_kernel"], cell_info["inputs"]["bias"]
        ]
        node.outputs = node.get_outputs_names(num=2)
        for k, v in kwargs.items():
            node.attr[k] = v
        return node
示例#2
0
    def process_kernel_and_bias(cls, nodes, cell_dict, node_dict):
        new_kernel = None
        new_bias = None
        scopes = cell_dict["kernel"].split("/")
        scope = "/".join(scopes[:scopes.index("kernel")])
        for key, value in [("kernel", node_dict[cell_dict["kernel"][0]]),
                           ("bias", node_dict[cell_dict["bias"][0]])]:
            output_shape = node_dict[value.name].attr["_output_shapes"][0]
            if key == "kernel":
                hidden_size = output_shape[1]
                input_size = output_shape[0] - hidden_size
                transposed_shape = output_shape[::-1]
                transpose_node = TensorflowNode(
                    op_type="Transpose",
                    name="/".join(
                        [scope, key, "transpose_" + get_unique_suffix()]),
                    inputs=[value.name, None],
                    attr={"_output_shapes": [transposed_shape]})

                split_const_node = TensorflowNode(
                    op_type="Const",
                    name="/".join(
                        [scope, key, "split_const_" + get_unique_suffix()]),
                    attr={
                        "value": np.asarray([input_size, hidden_size],
                                            np.int32),
                        "dtype": data_type.tf2onnx(tf.int32),
                        "_output_shapes": [[1]]
                    })

                split_node = TensorflowNode(
                    op_type="SplitV",
                    name="/".join([scope, key,
                                   "split_" + get_unique_suffix()]),
                    inputs=transpose_node.outputs + split_const_node.outputs +
                    [CONST_ONE_INT32],
                    attr={
                        "num_split":
                        2,
                        "_output_shapes": [[hidden_size, input_size],
                                           [hidden_size, hidden_size]]
                    })

                nodes.extend([transpose_node, split_const_node, split_node])
                new_kernel = split_node.outputs
            else:
                new_bias = [value.name]
        return new_kernel + new_bias
示例#3
0
    def _make_major_transpose_nodes(inputs, scope, node_dict, prev_node, post):
        """Make major transpose nodes if is batch major.

    Args:
      inputs: Inputs names.
      scope: Name scope.
      node_dict: Node dict.
      prev_node: Previous node.
      post: If post transpose flag.

    Returns:
      Perm node.
      Transpose node.

    """
        input_shape = node_dict[inputs[0]].attr["_output_shapes"][0]
        input_rank = len(input_shape)

        perm_node = TensorflowNode(
            op_type="Const",
            name="/".join([scope, "transpose", "perm",
                           get_unique_suffix()]),
            attr={
                "value": np.asarray([1, 0] + list(range(input_rank))[2:],
                                    np.int32),
                "dtype": data_type.tf2onnx(tf.int32),
                "_output_shapes": [input_rank]
            })

        if post:
            input_shape = [input_shape[i] for i in perm_node.attr["value"]]
            prev_node.attr["_output_shapes"] = [input_shape]

        trans_node = TensorflowNode(
            op_type="Transpose",
            name="/".join([scope, "transpose",
                           get_unique_suffix()]),
            inputs=[inputs[0] if not post else prev_node.name, perm_node.name],
            attr={
                "dtype":
                data_type.tf2onnx(node_dict[inputs[0]].attr["T"]),
                "_output_shapes":
                [[input_shape[i] for i in perm_node.attr["value"]]]
            })
        return [perm_node, trans_node]
示例#4
0
 def version_1(cls, node, **kwargs):
     mul_node = Multiply.handle(
         TensorflowNode(name='Mul',
                        inputs=[node.inputs[0], node.inputs[0]],
                        outputs=node.outputs,
                        attr=node.attr,
                        domain=node.domain,
                        op_type='Mul'), **kwargs)
     return mul_node
示例#5
0
 def version_1(cls, node, **kwargs):
   div_suffix = '_' + get_unique_suffix()
   div_output_name = node.outputs[0] + div_suffix
   div_node = Div.handle(
       TensorflowNode(
           name='Div',
           inputs=node.inputs[0:2],
           outputs=[div_output_name],
           attr=node.attr,
           domain=node.domain,
           op_type='Div'), **kwargs)
   floor_node = Floor.handle(
       TensorflowNode(
           name='Floor',
           inputs=[div_output_name],
           outputs=node.outputs,
           attr=node.attr,
           domain=node.domain,
           op_type='Floor'))
   return [div_node, floor_node]
示例#6
0
  def version_1(cls, node, **kwargs):
    rsqrt_suffix = "_" + get_unique_suffix()
    rsqrt_output_name = cls.get_outputs_names(node)[0] + rsqrt_suffix

    sqrt_node = Sqrt.handle(
        TensorflowNode(
            op_type='Sqrt',
            name=node.name + rsqrt_suffix,
            inputs=[node.inputs[0]],
            outputs=[rsqrt_output_name],
            attr=node.attr), **kwargs)

    reciprocal_node = Reciprocal.handle(
        TensorflowNode(
            op_type='Reciprocal',
            inputs=[rsqrt_output_name],
            outputs=cls.get_outputs_names(node),
            name=node.name,
            attr=node.attr), **kwargs)
    return [sqrt_node, reciprocal_node]
示例#7
0
    def _get_input_output_node_names(nodes):
        """Get input and output node names by given nodes.

    Args:
      nodes:

    Returns:
      Input node names.
      Output node names.
    """
        input_names, output_names = set(), set()
        extension_output_names = set()
        for node in nodes:
            tf_node = node if isinstance(
                node, TensorflowNode) else TensorflowNode(node)
            output_names.add(node.name)
            # Add outputs for Split, Switch TensorArrayV3
            if tf_node.op_type == "Split":
                for i in range(1, tf_node.attr["num_split"]):
                    output_names.add(tf_node.name + ":{}".format(i))
            if tf_node.op_type == "Switch":
                output_names.add(tf_node.name + ":1")
                extension_output_names.add((tf_node.name, tf_node.name + ":1"))
            if tf_node.op_type == "TensorArrayV3":
                output_names.add(tf_node.name + ":1")
                extension_output_names.add((tf_node.name, tf_node.name + ":1"))
            input_names.update(
                set([
                    inp if inp[0] != "^" else inp[1:] for inp in tf_node.inputs
                ]))
        inputs = input_names - output_names
        outputs = output_names - input_names
        while extension_output_names:
            ext_names = extension_output_names.pop()
            for name in ext_names:
                if name in outputs:
                    outputs -= set(ext_names)
                    break
        inputs.discard(None)
        return list(inputs), list(outputs)
示例#8
0
def check_node_args(graph_def, supported):
    """ Check for required node arguments in graph

  :param graph_def: the graph of operations
  :param supported: the supported operators in graph
  :return: whether all required parameters are provided
  """

    logger.info('Checking for required node arguments...')

    opset_dict = {}
    opset_dict[defs.ONNX_DOMAIN] = defs.onnx_opset_version()
    handlers = get_all_frontend_handlers(opset_dict)

    total_nodes = 0
    failed_nodes = 0
    for node in graph_def.node:
        if node.op in supported:
            total_nodes += 1
            tf_node = TensorflowNode(node)
            kwargs = {}
            for inp in node.input:
                for attr_node in graph_def.node:
                    if inp == attr_node.name:
                        kwargs[inp] = attr_node.attr['value']
                        break
            handler = handlers.get(defs.ONNX_DOMAIN, {}).get(node.op, None)
            try:
                handler.args_check(tf_node, consts=kwargs)
            except Exception as e:
                logger.info(e)
                failed_nodes += 1

    logger.info('We checked %d supported nodes for required arguments.',
                total_nodes)
    logger.info('  # of nodes passed the args check: %d',
                total_nodes - failed_nodes)
    logger.info('  # of nodes failed the args check: %d', failed_nodes)
    return failed_nodes == 0
示例#9
0
    def version_1(cls, node, **kwargs):
        # tf.size out_type could be int32 or int64
        need_cast = node.attr['out_type'] == tf.int32

        size_suffix = "_" + get_unique_suffix() if need_cast else ""
        size_output_name = node.outputs[0] + size_suffix
        size_node = cls.make_node_from_tf_node(node, [node.inputs[0]],
                                               outputs=[size_output_name],
                                               name=node.name + size_suffix)

        if not need_cast:
            return [size_node]

        attrs = {}
        attrs['DstT'] = node.attr['out_type']

        cast_node = Cast.handle(
            TensorflowNode(name=node.name,
                           inputs=[size_output_name],
                           outputs=node.outputs,
                           op_type='Cast',
                           attr=attrs))
        return [size_node, cast_node]
示例#10
0
    def tensorflow_graph_to_onnx_graph(cls,
                                       graph_def,
                                       output,
                                       opset=((defs.ONNX_DOMAIN,
                                               defs.onnx_opset_version()), ),
                                       name="graph",
                                       ignore_unimplemented=False):
        """Converts a Tensorflow Graph Proto to an ONNX graph

    This function converts a Tensorflow Graph proto to an equivalent
    representation of ONNX graph.

    :param graph_def: Tensorflow Graph Proto object.
    :param output: List of Tensorflow NodeDef object specifying which nodes
      to be taken as outputs of the ONNX graph.
    :param opset: Opset, which should be ((str domain: int version number),).
    :param name: The name of the output ONNX Graph.
    :param ignore_unimplemented: Convert to ONNX model and ignore all the operators
      that are not currently supported by onnx-tensorflow.
      This is an experimental feature. By enabling this feature,
      the graph would not be guaranteed to match the ONNX specifications.

    :returns: The equivalent ONNX Graph Proto object.
    """
        onnx_graph = OnnxGraph(name)
        exception.IGNORE_UNIMPLEMENTED = ignore_unimplemented

        opset_dict = {}
        for domain, version in opset:
            if domain == "ai.onnx":
                domain = defs.ONNX_DOMAIN
            opset_dict[domain] = version

        handlers = get_all_frontend_handlers(opset_dict)

        node_tup = [(node.name, TensorflowNode(node))
                    for node in graph_def.node]
        for name, node in node_tup:

            if node.op_type == "Placeholder":
                onnx_graph.add_input_proto(node)
            elif node.op_type == "Const":
                onnx_graph.add_const(node)
                onnx_graph.add_const_proto(node)
                onnx_graph.add_input_proto(node)
            else:
                onnx_graph.add_value_info_proto(node)
                handler = handlers.get(node.domain, {}).get(node.op_type, None)
                node_proto = None
                if handler:
                    node_proto = handler.handle(
                        node,
                        consts=onnx_graph.consts,
                        node_dict=dict(node_tup),
                        data_type_cast_map=onnx_graph.data_type_cast_map)
                else:
                    exception.OP_UNIMPLEMENTED_EXCEPT(
                        node.op_type,
                        domain=None
                        if node.domain in handlers else node.domain)

                if node_proto is None:
                    node_proto = FrontendHandler.make_node_from_tf_node(
                        node, op_type=node.op_type, should_check=False)
                onnx_graph.add_node_proto(node_proto)

        for o in output:
            output_node = TensorflowNode(o)
            onnx_graph.add_output_proto(output_node)

        return onnx_graph.make_graph_proto()
示例#11
0
    def parse(cls, nodes):
        """Parse nodes.

    Args:
      nodes: List of NodeDef.

    Returns:
      Parsed nodes of TensorflowNode.

    """
        node_info_holder = cls._make_node_info(nodes)
        node_dict = {
            n.name:
            TensorflowNode(n) if not isinstance(n, TensorflowNode) else n
            for n in nodes
        }
        group_nodes, new_cell_nodes = cls._group_nodes(nodes, node_info_holder)

        for scope in node_info_holder.nodes:
            inputs, outputs = cls._get_input_output_node_names(
                node_info_holder.nodes[scope])
            inputs = [i for i in inputs if scope not in i]
            input_nodes = [node_dict[i] for i in inputs]

            batch_major = [
                n for n in node_info_holder.nodes[scope]
                if inputs[0] in n.input
            ][0].op == "Transpose"

            if batch_major:
                perm_node, trans_node = cls._make_major_transpose_nodes(
                    inputs, scope, node_dict, new_cell_nodes[scope][-1], False)
                input_nodes = [trans_node]
                new_cell_nodes[scope].extend([perm_node, trans_node])

            dtype = input_nodes[0].attr["T"]
            for cell_no, cell_info in node_info_holder.cell_dict.items():
                cell_node = cls._make_rnn_node(cell_no,
                                               cell_info,
                                               scope,
                                               dtype=dtype)
                if cell_no == 0:
                    cell_node.inputs[0] = input_nodes[0].name
                else:
                    cell_node.inputs[0] = new_cell_nodes[scope][-1].name + ":2"
                    prev_c_output_shapes = node_dict[
                        cell_info["prev_c"]].attr["_output_shapes"]
                    new_cell_nodes[scope][-1].attr["_output_shapes"] = [
                        "", ""
                    ] + prev_c_output_shapes

                new_cell_nodes[scope].append(cell_node)
            scope_output_shapes = node_dict[outputs[0]].attr["_output_shapes"]
            new_cell_nodes[scope][-1].attr[
                "_output_shapes"] = scope_output_shapes

            if batch_major:
                perm_node, trans_node = cls._make_major_transpose_nodes(
                    outputs, scope, node_dict, new_cell_nodes[scope][-1], True)
                new_cell_nodes[scope].extend([perm_node, trans_node])

            new_cell_nodes[scope][-1].outputs = [outputs[0]]

        res_nodes = []
        for g in group_nodes:
            if isinstance(g, list):
                res_nodes.extend(g)
            else:
                res_nodes.extend(new_cell_nodes[g])
        return [
            n if isinstance(n, TensorflowNode) else TensorflowNode(n)
            for n in res_nodes
        ]
示例#12
0
    def process_kernel_and_bias(cls, nodes, cell_dict, node_dict):
        new_kernel = None
        new_bias = None
        scopes = cell_dict["kernel"][0].split("/")
        scope = "/".join(scopes[:scopes.index("kernel")])
        for key, value in [[
                "kernel",
            [node_dict[kernel] for kernel in cell_dict["kernel"]]
        ], ["bias", [node_dict[bias] for bias in cell_dict["bias"]]]]:
            gate_output_shape = node_dict[
                value[0].name].attr["_output_shapes"][0]
            candidate_output_shape = node_dict[
                value[1].name].attr["_output_shapes"][0]
            last_idx = range(len(gate_output_shape))[-1]
            concat_output_shapes = [
                g if i != last_idx else g + c for i, (g, c) in enumerate(
                    zip(gate_output_shape, candidate_output_shape))
            ]
            concat_node = TensorflowNode(
                op_type="ConcatV2",
                name="/".join([scope, key, "concat_" + get_unique_suffix()]),
                inputs=[value[0].name, value[1].name, CONST_MINUS_ONE_INT32],
                attr={"_output_shapes": [concat_output_shapes]})
            nodes.append(concat_node)

            if key == "kernel":
                hidden_size = gate_output_shape[1] // 2
                input_size = gate_output_shape[0] - hidden_size
                transposed_shape = concat_output_shapes[::-1]
                transpose_node = TensorflowNode(
                    op_type="Transpose",
                    name="/".join(
                        [scope, key, "transpose_" + get_unique_suffix()]),
                    inputs=concat_node.outputs + [None],
                    attr={"_output_shapes": [transposed_shape]})

                split_const_node = TensorflowNode(
                    op_type="Const",
                    name="/".join(
                        [scope, key, "split_const_" + get_unique_suffix()]),
                    attr={
                        "value": np.asarray([input_size, hidden_size],
                                            np.int32),
                        "dtype": data_type.tf2onnx(tf.int32),
                        "_output_shapes": [[1]]
                    })

                split_node = TensorflowNode(
                    op_type="Split",
                    name="/".join([scope, key,
                                   "split_" + get_unique_suffix()]),
                    inputs=[CONST_ZERO_INT32] + transpose_node.outputs,
                    attr={
                        "num_split":
                        3,
                        "_output_shapes":
                        [[int(transposed_shape[0] / 3), transposed_shape[1]]
                         for _ in range(3)]
                    })

                re_concat_node = TensorflowNode(
                    op_type="ConcatV2",
                    name="/".join(
                        [scope, key, "re_concat_" + get_unique_suffix()]),
                    inputs=[
                        split_node.outputs[1], split_node.outputs[0],
                        CONST_ZERO_INT32
                    ],
                    attr={
                        "_output_shapes": [[
                            int(transposed_shape[0] / 3 * 2),
                            transposed_shape[1]
                        ]]
                    })

                nodes.extend([
                    transpose_node, split_const_node, split_node,
                    re_concat_node
                ])
                new_kernel = re_concat_node.outputs + [split_node.outputs[2]]
            else:
                new_bias = concat_node.outputs

        return new_kernel + new_bias
示例#13
0
    def version_9(cls, node, **kwargs):
        unique_suffix = get_unique_suffix()

        # Convert to NCHW:
        transpose_node = Transpose.handle(TensorflowNode(
            name='transopose_input_to_nchw_' + unique_suffix,
            inputs=node.inputs[:1] + ["perm"],
            outputs=["transposed_input_" + unique_suffix]),
                                          consts={"perm": [0, 3, 1, 2]})

        # Get shape of NCHW input tensor:
        input_shape_node = Shape.handle(
            TensorflowNode(name='get_input_shape_' + unique_suffix,
                           inputs=transpose_node.output,
                           outputs=["input_shape_" + unique_suffix],
                           attr=node.attr))

        util_one = OnnxGraph.CONST_ONE_FP32

        output_shape_tensor = node.inputs[1]

        # Cast output shape (HW dim only) to float32:
        out_shape_float = Cast.handle(
            TensorflowNode(
                name='cast_output_shape_to_fp32_' + unique_suffix,
                inputs=[output_shape_tensor],
                outputs=["output_shape_float_partial_" + unique_suffix],
                attr={"DstT": tf.float32}))

        # Cast input shape to float32:
        in_shape_cast = Cast.handle(
            TensorflowNode(name='cast_input_shape_to_fp32_' + unique_suffix,
                           inputs=input_shape_node.output,
                           outputs=["input_shape_float_" + unique_suffix],
                           attr={"DstT": tf.float32}))

        slice_const_items = [
            ("begin", np.array([2]).astype(np.int32)),
            ("end", np.array([4]).astype(np.int32)),
            ("strides", np.array([1]).astype(np.int32)),
        ]

        slice_const_proto = {}

        for k, v in slice_const_items:
            const_name = "{}_".format(k) + unique_suffix
            slice_const_proto[k] = make_node(
                "Constant", [], [const_name],
                value=make_tensor(const_name,
                                  any_dtype_to_onnx_dtype(np_dtype=v.dtype),
                                  v.shape, v))

        in_shape_slices = StridedSlice.handle(
            TensorflowNode(
                name="stridedslice_input_shape_" + unique_suffix,
                inputs=list(in_shape_cast.output) +
                [slice_const_proto[k].output[0] for k, v in slice_const_items],
                outputs=["sliced_input_shape_" + unique_suffix]),
            consts={
                slice_const_proto[k].output[0]: v
                for k, v in slice_const_items
            },
            add_consts=True)

        # Divide input shape with output shape to get scaling factor:
        div_node = Div.handle(
            TensorflowNode(name='div_to_get_scale_factor_' + unique_suffix,
                           inputs=list(out_shape_float.output) +
                           list(in_shape_slices[-1].output),
                           outputs=["hw_scale_" + unique_suffix]))

        # Prepend 1's in the N, C dimension:
        full_scale = Concat.handle(TensorflowNode(
            name='prepend_ones_to_scale_factor_' + unique_suffix,
            inputs=[util_one, util_one] + list(div_node.output) +
            ["concat_axis"],
            outputs=["scale_" + unique_suffix]),
                                   consts={"concat_axis": 0})

        # Upsample with the computed scaling factor:
        upsample_node = cls.make_node_from_tf_node(
            node,
            op_type="Upsample",
            mode="bilinear",
            inputs=list(transpose_node.output) + list(full_scale.output),
            outputs=["upsample_to_tranpose_" + unique_suffix])

        # Transpose back to NHWC:
        transpose_output_node = Transpose.handle(TensorflowNode(
            name='transpose_output_to_nhwc_' + unique_suffix,
            inputs=list(upsample_node.output) + ["perm"],
            outputs=node.outputs),
                                                 consts={"perm": [0, 2, 3, 1]})

        transpose_and_get_shapes = [
            transpose_node, input_shape_node, out_shape_float, in_shape_cast
        ]
        slice_shape = list(slice_const_proto.values()) + in_shape_slices
        get_scale_and_upsample_and_transpose = [
            div_node, full_scale, upsample_node, transpose_output_node
        ]

        return transpose_and_get_shapes + slice_shape + get_scale_and_upsample_and_transpose
 def handle_node_proto(cls, node, **kwargs):
     return super(FrontendHandler, cls).handle(TensorflowNode(node),
                                               **kwargs)
示例#15
0
 def handle(cls, node, **kwargs):
     if isinstance(node, NodeProto):
         node = TensorflowNode(node)
     return super(FrontendHandler, cls).handle(node, **kwargs)