Ejemplo n.º 1
0
    def _cast(g):
        cast_result_name = g.generate_name('cast_result')
        nodes = [
            onnx.helper.make_node("Cast", [g.transients[0].name],
                                  [cast_result_name],
                                  to=to),
        ]

        return g._replace(
            nodes=graph.extend(g.nodes, nodes),
            transients=[
                onnx.helper.make_tensor_value_info(cast_result_name, to, []),
            ],
        )
Ejemplo n.º 2
0
    def _less_or_equal(g):
        less_or_equal_result_name = g.generate_name('less_or_equal_result')
        nodes = [
            onnx.helper.make_node("LessOrEqual",
                                  [g.transients[0].name, g.transients[1].name],
                                  [less_or_equal_result_name]),
        ]

        return g._replace(
            nodes=graph.extend(g.nodes, nodes),
            transients=[
                onnx.helper.make_tensor_value_info(less_or_equal_result_name,
                                                   onnx.TensorProto.BOOL, []),
            ],
        )
Ejemplo n.º 3
0
    def _identity(g):
        identity_name = g.generate_name(name)
        nodes = [
            onnx.helper.make_node("Identity", [g.transients[0].name],
                                  [identity_name]),
        ]

        return g._replace(
            nodes=graph.extend(g.nodes, nodes),
            transients=[
                onnx.helper.make_tensor_value_info(identity_name,
                                                   onnx.TensorProto.UNDEFINED,
                                                   []),
            ],
        )
Ejemplo n.º 4
0
    def _expand(g):
        expand_result_name = g.generate_name('expand_result')
        nodes = [
            onnx.helper.make_node("Expand",
                                  [g.transients[0].name, g.transients[1].name],
                                  [expand_result_name]),
        ]

        return g._replace(
            nodes=graph.extend(g.nodes, nodes),
            transients=[
                onnx.helper.make_tensor_value_info(expand_result_name,
                                                   onnx.TensorProto.UNDEFINED,
                                                   []),
            ],
        )
Ejemplo n.º 5
0
    def _add(g):
        add_result_name = g.generate_name('add_result')
        nodes = [
            onnx.helper.make_node("Add",
                                  [g.transients[0].name, g.transients[1].name],
                                  [add_result_name]),
        ]

        return g._replace(
            nodes=graph.extend(g.nodes, nodes),
            transients=[
                onnx.helper.make_tensor_value_info(
                    add_result_name,
                    g.transients[0].type.tensor_type.elem_type, []),
            ],
        )
Ejemplo n.º 6
0
    def _flatten(g):
        flatten_result_name = g.generate_name('flatten_result')
        nodes = [
            onnx.helper.make_node("Flatten", [g.transients[0].name],
                                  [flatten_result_name],
                                  axis=axis),
        ]

        return g._replace(
            nodes=graph.extend(g.nodes, nodes),
            transients=[
                onnx.helper.make_tensor_value_info(flatten_result_name,
                                                   onnx.TensorProto.UNDEFINED,
                                                   []),
            ],
        )
Ejemplo n.º 7
0
    def _argmax(g):
        argmax_result_name = g.generate_name('argmax_result')
        nodes = [
            onnx.helper.make_node("ArgMax", [g.transients[0].name],
                                  [argmax_result_name],
                                  axis=axis,
                                  keepdims=keepdims,
                                  select_last_index=select_last_index),
        ]

        return g._replace(
            nodes=graph.extend(g.nodes, nodes),
            transients=[
                onnx.helper.make_tensor_value_info(argmax_result_name,
                                                   onnx.TensorProto.INT64, []),
            ],
        )
Ejemplo n.º 8
0
    def _concat(g):
        concat_result_name = g.generate_name('concat_result')

        sources = [t.name for t in g.transients]
        nodes = [
            onnx.helper.make_node("Concat",
                                  sources, [concat_result_name],
                                  axis=axis),
        ]

        return g._replace(
            nodes=graph.extend(g.nodes, nodes),
            transients=[
                onnx.helper.make_tensor_value_info(concat_result_name,
                                                   onnx.TensorProto.UNDEFINED,
                                                   []),
            ],
        )
Ejemplo n.º 9
0
    def _reshape(g):
        reshape_result_name = g.generate_name('reshape_result')
        nodes = [
            onnx.helper.make_node(
                "Reshape",
                [g.transients[0].name, g.transients[1].name],
                [reshape_result_name],
            ),
        ]

        return g._replace(
            nodes=graph.extend(g.nodes, nodes),
            transients=[
                onnx.helper.make_tensor_value_info(reshape_result_name,
                                                   onnx.TensorProto.UNDEFINED,
                                                   []),
            ],
        )
Ejemplo n.º 10
0
    def _softmax(g):
        softmax_result_name = g.generate_name('softmax_result')
        nodes = [
            onnx.helper.make_node(
                "Softmax",
                [g.transients[0].name],
                [softmax_result_name],
                axis=axis,
            ),
        ]

        return g._replace(
            nodes=graph.extend(g.nodes, nodes),
            transients=[
                onnx.helper.make_tensor_value_info(softmax_result_name,
                                                   onnx.TensorProto.UNDEFINED,
                                                   []),
            ],
        )
Ejemplo n.º 11
0
    def _gather_elements(g):
        gather_elements_result_name = g.generate_name('gather_elements_result')
        nodes = [
            onnx.helper.make_node(
                "GatherElements",
                [g.transients[0].name, g.transients[1].name],
                [gather_elements_result_name],
                axis=axis,
            ),
        ]

        return g._replace(
            nodes=graph.extend(g.nodes, nodes),
            transients=[
                onnx.helper.make_tensor_value_info(gather_elements_result_name,
                                                   onnx.TensorProto.UNDEFINED,
                                                   []),
            ],
        )
Ejemplo n.º 12
0
    def _category_mapper(g):
        category_mapper_result_name = g.generate_name('category_mapper_result')
        nodes = [
            onnx.helper.make_node(
                "CategoryMapper",
                [g.transients[0].name], [category_mapper_result_name],
                cats_int64s=cats_int64s,
                cats_strings=cats_strings,
                default_int64=default_int64,
                default_string=default_string,
                domain='ai.onnx.ml',
            ),
        ]

        return g._replace(
            nodes=graph.extend(g.nodes, nodes),
            transients=[
                 onnx.helper.make_tensor_value_info(category_mapper_result_name, onnx.TensorProto.UNDEFINED, []),
            ],
        )
Ejemplo n.º 13
0
    def _reduce_sum(g):
        reduce_sum_result_name = g.generate_name('reduce_sum_result')
        nodes = [
            onnx.helper.make_node(
                "ReduceSum",
                [g.transients[0].name, g.transients[1].name],
                [reduce_sum_result_name],
                keepdims=keepdims,
                noop_with_empty_axes=noop_with_empty_axes,
            ),
        ]

        return g._replace(
            nodes=graph.extend(g.nodes, nodes),
            transients=[
                onnx.helper.make_tensor_value_info(reduce_sum_result_name,
                                                   onnx.TensorProto.UNDEFINED,
                                                   []),
            ],
        )