Exemplo n.º 1
0
def create_loop_op(ctx, node, batch_val_input_id, data_type):
    nodes = []

    cond_var_name = "condition"
    true = helper.make_tensor(cond_var_name, TensorProto.BOOL, (), [True])
    init_cond = Node(
        helper.make_node("Constant", [], [cond_var_name], value=true), ctx)
    nodes.append(init_cond)

    # Loop requires at least a variable, add a useless fake variable.
    fake_val_name = "fake_var"
    fake_var_init_val = helper.make_tensor(fake_val_name, TensorProto.FLOAT,
                                           (), [0.0])
    fake_var_init_node = Node(
        helper.make_node("Constant", [], ["fake_var"],
                         value=fake_var_init_val), ctx)
    nodes.append(fake_var_init_node)

    op_name = utils.make_name("Loop")
    out_name = port_name(op_name)
    loop_inputs = [
        batch_val_input_id,  # trip count
        cond_var_name,  # termination condition
        fake_val_name  # initial value of loop-carried dependencies
    ]
    loop_scan_output_id = port_name(op_name, 1)
    loop_node = Node(
        helper.make_node("Loop",
                         loop_inputs, [out_name, loop_scan_output_id],
                         name=op_name), ctx)
    loop_body = create_loop_body_graph(ctx, node, node.input[0], data_type)
    loop_node.set_attr("body", loop_body)
    ctx.add_body_graph(out_name, loop_body)
    nodes.append(loop_node)
    return nodes
    def _create_transpose_pairs_after_node(self, node):
        assert len(node.output) == 1  # just support node who has 1 output
        non_nchw_trans_consumers = self._get_non_nchw_transpose_output_nodes(
            node)
        added_node = []
        # add Transpose(0, 3, 1, 2) and Transpose(0, 2, 3, 1) before each non_nchw_trans_consumers
        for consumer in non_nchw_trans_consumers:
            nchw_op_name = utils.make_name("Transpose")
            nchw_out_name = utils.port_name(nchw_op_name)

            kwargs = {"perm": [0, 3, 1, 2]}
            nchw = helper.make_node("Transpose", [node.output[0]],
                                    [nchw_out_name],
                                    name=nchw_op_name,
                                    **kwargs)

            nhwc_op_name = utils.make_name("Transpose")
            nhwc_out_name = utils.port_name(nhwc_op_name)

            kwargs = {"perm": [0, 2, 3, 1]}
            nhwc = helper.make_node("Transpose", [nchw_out_name],
                                    [nhwc_out_name],
                                    name=nhwc_op_name,
                                    **kwargs)
            nchw_node = Node(nchw, self._g)
            nhwc_node = Node(nhwc, self._g)
            self._g.replace_input(consumer, node.output[0], nhwc_out_name)
            added_node.extend([nchw_node, nhwc_node])

        if added_node:
            self._update_graph_nodes(added_node, None, True)
        return added_node
Exemplo n.º 3
0
def create_body_graph_for_if_branch(ctx, input_id, data_shape):
    data_type = ctx.get_dtype(input_id)
    nodes = []

    op_name = utils.make_name("Gather")
    true_out_name = port_name(op_name)
    true_gather_node = helper.make_node("Gather", [input_id, "i"],
                                        [true_out_name],
                                        name=op_name)
    nodes.append(true_gather_node)

    op_name = utils.make_name("Squeeze")
    true_squeeze_out_name = port_name(op_name)
    cur_true_val_scalar_node = helper.make_node("Squeeze", [true_out_name],
                                                [true_squeeze_out_name],
                                                name=op_name,
                                                axes=[0])
    nodes.append(cur_true_val_scalar_node)

    identity_node = helper.make_node('Identity', [true_squeeze_out_name],
                                     ['y'],
                                     name=utils.make_name("Identity"))
    nodes.append(identity_node)

    # create one output
    y = helper.make_tensor_value_info('y', data_type, tuple(data_shape[1:]))

    graph_def = helper.make_graph(
        nodes,
        'if-body-graph',
        [],
        [y],
    )
    return graph_def
Exemplo n.º 4
0
def select_op8(ctx, node, name, args):
    # T output = Select(bool condition, T x, T y)
    # V v_final_and_scan_outputs = Loop(int64 M, B cond, V v_initial)
    if len(node.input) == 1:
        raise ValueError("Select with only condition is not supported.")
    nodes = []
    data_type = ctx.get_dtype(node.input[1])

    op_name = utils.make_name("Size")
    out_name = port_name(op_name)
    batch_size_node = Node(
        helper.make_node("Size", [node.input[0]], [out_name], name=op_name),
        ctx)
    nodes.append(batch_size_node)

    nodes_to_append = create_loop_op(ctx, node, batch_size_node.output[0],
                                     data_type)
    nodes.extend(nodes_to_append)

    loop_scan_output_id = nodes[-1].output[1]
    ctx.copy_shape(node.output[0], loop_scan_output_id)
    ctx.set_dtype(node.output[0], data_type)

    op_name = node.name
    out_name = port_name(op_name)
    output_node = Node(
        helper.make_node("Identity", [loop_scan_output_id], [out_name],
                         name=op_name), ctx)
    nodes.append(output_node)

    return nodes
def sparse_softmax_cross_entropy_with_logits_op_by_gathernd(
        ctx, node, name, args):
    nodes = []
    # make subgraph to implement one_hot, idea comes from onehot_op
    indices_name = node.input[1]
    indices_shape = ctx.get_shape(indices_name)
    if len(indices_shape) != 1:
        # TODO: this works for rank=1 but tensorflow supports more than this.
        # Same principle should work but we need to implement our own eye.
        raise ValueError("onehot op: only rank1 is supported")
    logit_name = node.input[0]
    logit_dtype = ctx.get_dtype(logit_name)
    utils.make_sure(logit_dtype, "Dtype of {} is None".format(logit_name))
    indices_dtype = ctx.get_dtype(indices_name)
    if indices_dtype != TensorProto.INT64:
        indices_cast = ctx.make_node("Cast", [indices_name],
                                     attr={"to": TensorProto.INT64})
        nodes.append(indices_cast)
        indices_name = indices_cast.output[0]
    indices_size = ctx.make_node("Size", [indices_name])
    indices_unsqueeze = ctx.make_node("Unsqueeze", [indices_name],
                                      attr={"axes": [1]})
    zero_const = ctx.make_const(utils.make_name("zero"),
                                np.array(0, dtype=np.int64))
    one_const = ctx.make_const(utils.make_name("one"),
                               np.array(1, dtype=np.int64))
    id_name = utils.make_name("sparse_softmax_id")
    id_output = utils.port_name(id_name)
    nodes.extend(
        make_range(ctx, zero_const.output[0], indices_size.output[0],
                   one_const.output[0], id_output, id_name, TensorProto.INT64))
    id_unsqueeze = ctx.make_node("Unsqueeze", [id_output], attr={"axes": [1]})
    indices_with_id = ctx.make_node(
        "Concat", [id_unsqueeze.output[0], indices_unsqueeze.output[0]],
        attr={"axis": 1})
    log_softmax = ctx.make_node(op_type="LogSoftmax",
                                inputs=[logit_name],
                                dtypes=[logit_dtype])
    gathernd_name = utils.make_name("sparse_softmax_gathernd")
    gathernd_output = utils.port_name(gathernd_name)
    nodes.extend(
        make_gathernd(ctx, log_softmax.output[0], indices_with_id.output[0],
                      gathernd_output, gathernd_name, logit_dtype))
    const_name = utils.make_name("const_negative_one")
    const_negative_one = ctx.make_const(
        const_name,
        np.array(-1).astype(utils.ONNX_TO_NUMPY_DTYPE[logit_dtype]))
    mul2 = ctx.make_node(
        op_type="Mul", inputs=[const_negative_one.output[0], gathernd_output])
    res = ctx.make_node(op_type="Squeeze",
                        inputs=[mul2.output[0]],
                        outputs=[node.output[0]],
                        attr={"axes": [1]})

    nodes.extend([
        indices_size, indices_unsqueeze, id_unsqueeze, indices_with_id,
        log_softmax, mul2, res
    ])
    return nodes
Exemplo n.º 6
0
def create_loop_body_graph(ctx, node, select_condition_input_id,
                           select_output_data_type):
    nodes = []
    graph_inputs = [
        helper.make_tensor_value_info("i", TensorProto.INT64,
                                      (1, )),  # iteration_num
        helper.make_tensor_value_info("cond", TensorProto.BOOL,
                                      ()),  # condition
        helper.make_tensor_value_info("fake_var", TensorProto.FLOAT,
                                      ())  # loop-carried dependency
    ]

    # get the i'th value of "Select"'s condition
    op_name = utils.make_name("Gather")
    cond_gather_out_name = port_name(op_name)
    cond_gather_node = helper.make_node("Gather",
                                        [select_condition_input_id, "i"],
                                        [cond_gather_out_name],
                                        name=op_name)
    nodes.append(cond_gather_node)

    op_name = utils.make_name("Squeeze")
    cur_cond_val_out_name = port_name(op_name)
    cur_cond_val_scalar_node = helper.make_node("Squeeze",
                                                [cond_gather_out_name],
                                                [cur_cond_val_out_name],
                                                name=op_name,
                                                axes=[0])
    nodes.append(cur_cond_val_scalar_node)

    if_node, if_node_output_id = create_if_op(ctx, node, cur_cond_val_out_name)
    nodes.append(if_node)

    identity_node = helper.make_node('Identity', [if_node_output_id],
                                     ['output'],
                                     name=utils.make_name("Identity"))
    nodes.append(identity_node)

    identity_node = helper.make_node('Identity', ["cond"], ['cond_output'],
                                     name=utils.make_name("Identity"))
    nodes.append(identity_node)

    identity_node = helper.make_node('Identity', ["fake_var"],
                                     ['fake_var_output'],
                                     name=utils.make_name("Identity"))
    nodes.append(identity_node)

    output_shape = get_hidden_size_best_effort(ctx, node)
    graph_outputs = [
        helper.make_tensor_value_info("cond_output", TensorProto.BOOL, ()),
        helper.make_tensor_value_info("fake_var_output", TensorProto.FLOAT,
                                      ()),
        helper.make_tensor_value_info("output", select_output_data_type,
                                      output_shape[1:])
    ]

    body_graph = helper.make_graph(nodes, "loop-body-graph", graph_inputs,
                                   graph_outputs)
    return body_graph
Exemplo n.º 7
0
    def pre_optimize_action(self):
        # make Reshape into a const, which then can be fused into Conv's weight for mobilenet_v1_75_192
        ops = self.nodes
        constable_reshape_ops = [
            n for n in ops if (n.type == "Reshape" and n.inputs[0].is_const()
                               and n.inputs[1].is_const())
        ]
        for reshape_op in constable_reshape_ops:
            target_t = reshape_op.inputs[0].get_tensor_value(as_list=False)
            target_shape = reshape_op.inputs[1].get_tensor_value(as_list=False)
            new_data = np.reshape(target_t, tuple(target_shape))
            const_name = utils.port_name(utils.make_name("Const"))

            # point all children nodes inputs to the new node
            for output_name in reshape_op.output:
                for child in ops:
                    for i, name in enumerate(child.input):
                        if name == output_name:
                            child.input[i] = const_name
            ops.append(self._g.make_const(const_name, new_data))

            # need call this to make input update synced to protobuf val
            self._g.update_proto()
            ops.remove(reshape_op)
            self._g.set_nodes(ops)
            self._g.topological_sort(ops)
def rewrite_random_normal(g, ops):
    pattern1 = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', name='input2', inputs=[
                OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]), "Const|ConstV2"
            ]), "Const|ConstV2"
        ])

    pattern2 = \
        OpTypePattern('Identity', name='output', inputs=[
            OpTypePattern('Identity', name='input2', inputs=[
                OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"])
            ])
        ])

    pattern_list = [pattern1, pattern2]
    for pattern in pattern_list:
        matcher = GraphMatcher(pattern)
        match_results = list(matcher.match_ops(ops))
        for match in match_results:
            output = match.get_op('output')
            if output.type == 'Add':
                # pattern 1
                mean = output.inputs[1].get_tensor_value()
            else:
                # pattern 2
                mean = 0.0
            input2 = match.get_op('input2')
            if input2.type == 'Mul':
                scale = input2.inputs[1].get_tensor_value()
            else:
                scale = 1.0
            dtype = g.get_dtype(output.output[0])
            op_name = utils.make_name("RandomNormal")
            out_name = utils.port_name(op_name)

            rn_op = match.get_op('input1')
            seed = float(rn_op.get_attr('seed2').i)

            attr = {"mean": mean, "scale": scale, "dtype": dtype, "seed": seed}
            if rn_op.inputs[0].type == "Shape":
                shape_node = rn_op.inputs[0]
                new_node = g.make_node("RandomNormalLike",
                                       [shape_node.input[0]],
                                       outputs=[out_name],
                                       name=op_name,
                                       attr=attr)
            else:
                shape = g.get_shape(output.output[0])
                if shape is None or -1 in shape:
                    continue
                attr['shape'] = shape
                new_node = g.make_node("RandomNormal", [],
                                       outputs=[out_name],
                                       name=op_name,
                                       attr=attr)

            g.replace_all_inputs(output.output[0], new_node.output[0], ops=ops)
            g.safe_remove_nodes(match.get_nodes())
    return ops
    def pre_optimize_action(self):
        # make Reshape into a const, which then can be fused into Conv's weight for mobilenet_v1_75_192
        ops = self.nodes
        constable_reshape_ops = [
            n for n in ops
            if (n.type == "Reshape" and self._g.is_initializer(n.input[0])
                and self._g.is_initializer(n.input[1]))
        ]
        for reshape_op in constable_reshape_ops:
            target_t = numpy_helper.to_array(
                self._g.get_initializer(reshape_op.input[0]))
            target_shape = numpy_helper.to_array(
                self._g.get_initializer(reshape_op.input[1]))
            new_data = np.reshape(target_t, tuple(target_shape))
            const_name = utils.port_name(utils.make_name("Const"))
            new_tensor = numpy_helper.from_array(new_data, const_name)

            # point all children nodes inputs to the new node
            for output_name in reshape_op.output:
                for child in ops:
                    for i, name in enumerate(child.input):
                        if name == output_name:
                            child.input[i] = const_name
            self._g.add_initializer(new_tensor)
            # need call this to make input update synced to protobuf val
            self._g.update_proto()
            ops.remove(reshape_op)
            self._g.set_nodes(ops)
            self._g.topological_sort(ops)
Exemplo n.º 10
0
 def test_rewrite_subgraph(self):
     graph_proto = self.sample_net()
     g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
     pattern = \
         OpTypePattern('Abs', name='output', inputs=[
             OpTypePattern('Add', name='input')
         ])
     ops = g.get_nodes()
     matcher = GraphMatcher(pattern)
     match_results = list(matcher.match_ops(ops))
     for match in match_results:
         input_node = match.get_op('input')
         output_node = match.get_op('output')
         op_name = utils.make_name("ReplacedOp")
         out_name = utils.port_name(op_name)
         new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name)
         g.replace_all_inputs(output_node.output[0], new_node.output[0])  # ops=ops
         for n in set(match.get_nodes()):
             g.remove_node(n.name)
     g.topological_sort(ops)
     result = onnx_to_graphviz(g)
     expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] ' \
                'n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__5 [op_type=Sub] ' \
                'n6 [op_type=Identity] n5_graph_outputs_Identity__3 [op_type=Identity] ' \
                'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__5 n3:0 -> ReplacedOp__5 ' \
                'ReplacedOp__5:0 -> n6 ReplacedOp__5:0 -> n5_graph_outputs_Identity__3 }'
     self.assertEqual(expected, result)
Exemplo n.º 11
0
    def _switch_transpose_and_node(self, node, trans):
        if not self._transpose_has_single_consumer_node([trans]):
            return False

        input_index = 0
        for i in node.input:
            if i == trans.output[0]:
                break
            else:
                input_index += 1

        ops = self._g.get_nodes()
        self._g.replace_all_inputs(ops, node.output[0], trans.output[0])
        node.input[input_index] = trans.input[0]
        trans.input[0] = utils.port_name(node.name)

        # need to transpose node shape in backward direction as well after switch
        # otherwise, reshape added in post_optimize_action may not work correctly
        shape = self._g.get_shape(node.output[0])
        if shape:
            # only nhwc transpose can reach here
            new_shape = [shape[i] for i in [0, 3, 1, 2]]
            self._g.set_shape(node.output[0], new_shape)

        self._g.set_nodes(ops)
        return True
Exemplo n.º 12
0
    def test_tensor_data(self):
        tensors = {
            "empty_tensor": np.array([], dtype=np.float32),
            "multi_dim_empty_tensor": np.array([[], []], dtype=np.float32),
            "scalar": np.array(1., dtype=np.float32),
            "one_item_array": np.array([1.], dtype=np.float32),
            "normal_array": np.array([[1., 2.], [2., 3.]], dtype=np.float32)
        }
        tf_reset_default_graph()
        with tf_session() as sess:
            for n, data in tensors.items():
                tf.constant(data, dtype=tf.float32, name=n)

        for tf_node in sess.graph.get_operations():
            name = tf_node.name
            self.assertTrue(name in tensors.keys())

            self.assertTrue("value" in tf_node.node_def.attr)
            # convert to onnx tensor value
            tensor_value = tf_utils.tf_to_onnx_tensor(
                tf_utils.get_tf_node_attr(tf_node, "value"),
                name=utils.port_name(tf_node.name)
            )
            attr = helper.make_attribute("value", tensor_value)
            # same as node.get_tensor_value(is_list=False)
            actual = numpy_helper.to_array(helper.get_attribute_value(attr))

            expected = tensors[name]

            self.assertTrue(np.array_equal(expected, actual))
Exemplo n.º 13
0
    def insert_new_node_on_output(self,
                                  op_type,
                                  output_name,
                                  name,
                                  domain=None,
                                  **kwargs):
        """Create and insert a new node into the graph.
        Args:
            op_type: type for new operation
            output_name: the names of the outputs above us
            name: the name of the new op
            kwargs: attributes of the new node

        Returns:
            node that was inserted
        """
        utils.make_sure(isinstance(output_name, six.text_type),
                        "output_name's type is not expected: %s",
                        type(output_name))
        utils.make_sure(isinstance(op_type, six.text_type),
                        "op_type's type is not expected: %s", type(op_type))

        new_output = port_name(name)
        new_node = self.make_node(op_type, [output_name],
                                  attr=kwargs,
                                  outputs=[new_output],
                                  name=name,
                                  domain=domain)

        to_replace = [n for n in self.get_nodes() if n != new_node]
        self.replace_all_inputs(to_replace, output_name, new_output)
        return new_node
Exemplo n.º 14
0
def rewrite_dropout(g, ops):
    pattern = \
        OpTypePattern('Mul', name='outputs', inputs=[
            OpTypePattern('RealDiv', name="input2"),
            OpTypePattern('Floor', inputs=[
                OpTypePattern('Add', inputs=[
                    OpTypePattern(None, name="input3"),
                    OpTypePattern('RandomUniform|RandomUniformLike'),
                ])
            ]),
        ])
    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        inputs2 = match.get_op('input2')
        outputs = match.get_op('outputs')
        op_name = utils.make_name("Dropout")
        out_name = port_name(op_name)
        new_node = g.make_node("Dropout", [inputs2.input[0]],
                               outputs=[out_name],
                               name=op_name,
                               attr={"ratio": 1.0},
                               shapes=[g.get_shape(inputs2.input[0])],
                               dtypes=[g.get_dtype(inputs2.input[0])])
        g.replace_all_inputs(ops, outputs.output[0], new_node.output[0])
        g.safe_remove_nodes(match.get_nodes())

    # remove dropout if its ratio is 1.0
    for node in g.get_nodes():
        if node.type == "Dropout" and node.get_attr("ratio").f == 1.0:
            g.replace_all_inputs(g.get_nodes(), node.output[0], node.input[0])
            g.remove_node(node.name)

    return ops
Exemplo n.º 15
0
def rewrite_random_normal(g, ops):
    pattern = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', name='input2', inputs=[
                OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]), "*"
            ]), "*"
        ])

    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        output = match.get_op('output')
        mean = output.inputs[1].get_tensor_value()
        dtype = g.get_dtype(output.output[0])
        op_name = utils.make_name("RandomNormal")
        out_name = port_name(op_name)

        rn_op = match.get_op('input1')
        if rn_op.inputs[0].type == "Shape":
            shape_node = rn_op.inputs[0]
            new_node = g.make_node("RandomNormalLike", [shape_node.input[0]], outputs=[out_name], name=op_name,
                                   attr={"mean": mean, "scale": 1.0, "dtype": dtype})
        else:
            shape = g.get_shape(output.output[0])
            new_node = g.make_node("RandomNormal", [], outputs=[out_name], name=op_name,
                                   attr={"shape": shape, "mean": mean, "scale": 1.0, "dtype": dtype})

        g.replace_all_inputs(ops, output.output[0], new_node.output[0])
        for n in set(match.get_nodes()):
            g.remove_node(n.name)
    return ops
Exemplo n.º 16
0
    def insert_new_node_on_input(self,
                                 node,
                                 op_type,
                                 input_name,
                                 name=None,
                                 domain=None,
                                 **kwargs):
        """Create and insert a new node into the graph.
        Args:
            node: we want to replace the input for this node
            op_type: type for new operation
            input_name: the names of the outputs above us
            name: the name of the new op
            kwargs: attributes of the new node

        Returns:
            node that was inserted
        """
        if name is None:
            name = utils.make_name(node.name)
        new_output = port_name(name)
        new_node = self.make_node(op_type, [input_name],
                                  attr=kwargs,
                                  outputs=[new_output],
                                  name=name,
                                  domain=domain)
        for i, n in enumerate(node.input):
            if n == input_name:
                node.input[i] = new_output
                break
        return new_node
Exemplo n.º 17
0
    def replace_subgraph(ops, subgraph_nodes, old_inputs, old_outputs,
                         new_inputs, new_outputs):
        """Replace subgraph."""
        if len(old_inputs) != len(new_inputs) or len(old_outputs) != len(
                new_outputs):
            raise ValueError(
                "replace_subgraph - inputs and outputs need to be same length")

        # point all children nodes inputs to the new node
        for oo, no in zip(old_outputs, new_outputs):
            for output_name in oo.output:
                for child in ops:
                    for i, name in enumerate(child.input):
                        if name == output_name:
                            child.input[i] = port_name(no.name)

        # delete nodes no longer used
        removed = set()
        for node in subgraph_nodes.get_nodes():
            if not node or node in removed:
                continue
            ops.remove(node)
            removed.add(node)
        ops.extend(new_outputs)
        return ops
Exemplo n.º 18
0
    def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=None, extra_opset=None,
                 output_names=None):
        """Create Graph.
        Args:
            nodes: list of Node()
            output_shapes: dict of tensorflow output shapes
            dtypes: dict of tensorflow dtype
        """
        if target is None:
            target = []
        self._nodes = []
        self._nodes_by_name = {}
        self._output_to_node_name = {}
        self.shapes = {}

        self._target = set(target)
        self._dtypes = dtypes

        self._output_shapes = output_shapes
        self._opset = find_opset(opset)

        if extra_opset is not None:
            utils.make_sure(isinstance(extra_opset, list), "invalid extra_opset")
        self._extra_opset = extra_opset

        self._order_sensitive_inputs = []
        self.outputs = output_names if output_names is not None else []

        self.parent_graph = None
        self.contained_graphs = {}  # {node_name: {node_attribute_name: Graph}}

        ops = [Node(node, self) for node in nodes]
        self.reset_nodes(ops)

        # add identity node after each output, in case it is renamed during conversion.
        for o in self.outputs:
            n = self.get_node_by_output_in_current_graph(o)
            new_output_name = port_name(n.name + "_" + utils.make_name("raw_output_"))
            n_shapes = n.output_shapes
            n_dtypes = n.output_dtypes
            body_graphs = n.graph.contained_graphs.pop(n.name, None)
            self.remove_node(n.name)

            new_outputs = [output if output != o else new_output_name for output in n.output]
            # domain should be passed to new node
            new_node = self.make_node(n.type, n.input, outputs=new_outputs, attr=n.attr, name=n.name,
                                      skip_conversion=n._skip_conversion, dtypes=n_dtypes, shapes=n_shapes,
                                      domain=n.domain)

            if body_graphs:
                for attr_name, body_graph in body_graphs.items():
                    body_graph.parent_graph = self
                    new_node.set_body_graph_as_attr(attr_name, body_graph)

            self.replace_all_inputs(self.get_nodes(), o, new_output_name)
            self.make_node("Identity", [new_output_name], outputs=[o], op_name_scope=n.name + "_" + "graph_outputs")
            self.copy_shape(new_output_name, o)
            self.copy_dtype(new_output_name, o)
Exemplo n.º 19
0
    def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=None, extra_opset=None,
                 output_names=None):
        """Create Graph.
        Args:
            nodes: list of Node()
            output_shapes: dict of tensorflow output shapes
            dtypes: dict of tensorflow dtype
        """
        if target is None:
            target = []
        self._nodes = []
        self._initializers = {}
        self._nodes_by_name = {}
        self._output_to_node_name = {}
        self.shapes = {}

        self._target = set(target)
        self._dtypes = dtypes

        self._output_shapes = output_shapes
        self._opset = find_opset(opset)
        self._extra_opset = extra_opset

        self.inputs = []
        self.outputs = output_names

        self.parent_graph = None
        self.contained_graphs = {}  # {node_name: {node_attribute_name: Graph}}

        ops = [Node(node, self) for node in nodes]

        # add identity node after each output, in case it is renamed during conversion.
        if self.outputs:
            to_append = []
            for n in ops:
                raw_outputs = n.output
                new_output_base_name = None
                index_out = 0
                for i, o in enumerate(raw_outputs):
                    if o in output_names:
                        if not new_output_base_name:
                            new_output_base_name = utils.make_name("raw_output_")
                        new_out = port_name(new_output_base_name, index_out)
                        self.replace_all_inputs(ops, o, new_out)
                        n.output[i] = new_out
                        index_out += 1
                        new_output_node = self.make_node("Identity", [new_out], outputs=[o])
                        to_append.append(new_output_node)

                        self.copy_shape(o, new_out)
                        self.set_dtype(new_out, self.get_dtype(o))

                self.set_node_by_name(n)
            ops.extend(to_append)

        self.set_nodes(ops)
Exemplo n.º 20
0
def create_loop_op(gather_input_ids, output_type, output_shape,
                   trip_count_input_ids, rank):
    nodes = []

    cond_var_name = utils.make_name("condition")
    true = helper.make_tensor(cond_var_name, TensorProto.BOOL, (), [True])
    init_cond = helper.make_node("Constant", [], [cond_var_name],
                                 value=true,
                                 name=cond_var_name)
    nodes.append(init_cond)

    # Loop requires at least a variable, add a useless fake variable.
    fake_val_name = utils.make_name("fake_var")
    fake_var_init_val = helper.make_tensor(fake_val_name, TensorProto.FLOAT,
                                           (), [0.0])
    fake_var_init_node = helper.make_node("Constant", [], [fake_val_name],
                                          value=fake_var_init_val,
                                          name=fake_val_name)
    nodes.append(fake_var_init_node)

    if rank < 1:
        raise ValueError("rank is < 1")
    trip_count_input_id = trip_count_input_ids[-1 * rank]

    op_name = utils.make_name("loop")
    fake_var_output_id = port_name(op_name)
    loop_inputs = [
        trip_count_input_id,  # trip count
        cond_var_name,  # termination condition
        fake_val_name  # initial value of loop-carried dependencies
    ]
    loop_scan_output_id = port_name(op_name, 1)

    loop_body = create_loop_body_graph(gather_input_ids, output_type,
                                       output_shape, trip_count_input_ids,
                                       rank, op_name)
    loop_node = helper.make_node("Loop",
                                 loop_inputs,
                                 [fake_var_output_id, loop_scan_output_id],
                                 name=op_name,
                                 body=loop_body)
    nodes.append(loop_node)
    return nodes
Exemplo n.º 21
0
def get_inputs_for_current_iteration(input_id, iter_index):
    nodes = []
    op_name = utils.make_name("Gather")
    cond_gather_out_name = port_name(op_name)
    cond_gather_node = helper.make_node("Gather", [input_id, iter_index],
                                        [cond_gather_out_name],
                                        name=op_name)
    nodes.append(cond_gather_node)

    op_name = utils.make_name("Squeeze")
    cur_cond_val_out_name = port_name(op_name)
    cur_cond_val_scalar_node = helper.make_node("Squeeze",
                                                [cond_gather_out_name],
                                                [cur_cond_val_out_name],
                                                name=op_name,
                                                axes=[0])
    nodes.append(cur_cond_val_scalar_node)

    return nodes, cur_cond_val_out_name
Exemplo n.º 22
0
def create_if_op(g, input_ids, output_data_type, output_shape):
    op_name = utils.make_name("If")
    true_graph = create_body_graph_for_if_branch(g, output_data_type, output_shape, input_ids[1], op_name)
    false_graph = create_body_graph_for_if_branch(g, output_data_type, output_shape, input_ids[2], op_name)
    out_name = utils.port_name(op_name)

    # output a scalar
    branches = {"then_branch": true_graph, "else_branch": false_graph}
    if_node = g.make_node("If", [input_ids[0]], outputs=[out_name], name=op_name,
                          skip_conversion=True, branches=branches)
    return if_node, out_name
Exemplo n.º 23
0
def create_if_op(g, input_ids, output_data_type, output_shape):
    op_name = utils.make_name("If")
    true_graph = create_body_graph_for_if_branch(g, output_data_type, output_shape, input_ids[1], op_name)
    false_graph = create_body_graph_for_if_branch(g, output_data_type, output_shape, input_ids[2], op_name)
    out_name = utils.port_name(op_name)

    # output a scalar
    if_node = g.make_node("If", [input_ids[0]], outputs=[out_name], name=op_name, skip_conversion=False)
    if_node.set_body_graph_as_attr("then_branch", true_graph)
    if_node.set_body_graph_as_attr("else_branch", false_graph)
    return if_node, out_name
Exemplo n.º 24
0
 def _add_handler(self, trans, node):
     if self._g.is_initializer(node.input[1]):
         t_p = trans.inputs[0]
         if t_p.type in ("Conv", "ConvTranspose") and len(t_p.input) == 2:
             # if Conv or ConvTranspose's bias input is not set, then we set, otherwise, we don't set
             # todo: maybe we can add already set bias with the input??? try later
             conv_inputs = [t_p.input[0], t_p.input[1], node.input[1]]
             conv_node = self._g.make_node(t_p.type, conv_inputs, attr=t_p.attr_onnx)
             ops = self._g.get_nodes()
             trans.input[0] = utils.port_name(conv_node.name)
             self._g.replace_all_inputs(ops, node.output[0], trans.output[0])
             self._update_graph_nodes([conv_node], [t_p, node], True)
             return True
         return False
     return self._handle_node_having_branches(node)
Exemplo n.º 25
0
def create_if_op(ctx, node, cur_cond_val_out_name):
    true_graph = create_body_graph_for_if_branch(ctx, node.input[1])
    false_graph = create_body_graph_for_if_branch(ctx, node.input[2])

    op_name = utils.make_name("If")
    out_name = port_name(op_name)

    # output a scalar
    if_node = helper.make_node("If", [cur_cond_val_out_name], [out_name],
                               name=op_name,
                               then_branch=true_graph,
                               else_branch=false_graph)
    ctx.add_body_graph(out_name, true_graph)
    ctx.add_body_graph(out_name, false_graph)
    return if_node, out_name
Exemplo n.º 26
0
 def follow_inputs(self, node, num, space=""):
     """Follow inputs for (helpful for debugging)."""
     val = []
     top = space == ""
     if num == 0:
         return []
     val.append("{}{} {} {}".format(space, node.type, node.name, self.get_shape(port_name(node.name))))
     space += "    "
     for j in node.inputs:
         val.extend(self.follow_inputs(j, num - 1, space))
     if top:
         print("\n".join(reversed(val)))
         print()
         return []
     return val
    def _create_transpose_pairs_before_node(self, node):
        non_nhwc_trans_inputs = []
        for input_id, n in zip(node.input, node.inputs):
            if not is_nhwc_transpose(n):
                # check in case node has two inputs coming from a same node output.
                if [input_id, n] not in non_nhwc_trans_inputs:
                    non_nhwc_trans_inputs.append([input_id, n])

        added_node = []
        # add Transpose(0, 3, 1, 2) and Transpose(0, 2, 3, 1) before each non_nhwc_trans_consumers
        for input_id, n in non_nhwc_trans_inputs:
            nchw_op_name = utils.make_name("Transpose")
            nchw_out_name = utils.port_name(nchw_op_name)

            kwargs = {"perm": [0, 3, 1, 2]}
            nchw = helper.make_node("Transpose", [input_id], [nchw_out_name],
                                    name=nchw_op_name,
                                    **kwargs)

            nhwc_op_name = utils.make_name("Transpose")
            nhwc_out_name = utils.port_name(nhwc_op_name)

            kwargs = {"perm": [0, 2, 3, 1]}
            nhwc = helper.make_node("Transpose", [nchw_out_name],
                                    [nhwc_out_name],
                                    name=nhwc_op_name,
                                    **kwargs)

            nchw_node = Node(nchw, self._g)
            nhwc_node = Node(nhwc, self._g)
            self._g.replace_input(node, input_id, nhwc_out_name)
            added_node.extend([nchw_node, nhwc_node])

        if added_node:
            self._update_graph_nodes(added_node, None, True)
        return added_node
Exemplo n.º 28
0
    def insert_new_node_on_output(self, op_type, output_name, name=None, **kwargs):
        """Create and insert a new node into the graph.
        Args:
            op_type: type for new operation
            output_name: the names of the outputs above us
            name: the name of the new op
            kwargs: attributes of the new node

        Returns:
            node that was inserted
        """
        assert isinstance(output_name, six.text_type) and isinstance(op_type, six.text_type)
        new_output = port_name(name)
        new_node = Node(helper.make_node(op_type, [output_name], [new_output], name=name, **kwargs), self)
        self.replace_all_inputs(self.get_nodes(), output_name, new_output)
        return new_node
Exemplo n.º 29
0
def create_if_op(ctx, node, cur_cond_val_out_name):
    data_shape = get_hidden_size_best_effort(ctx, node)
    true_graph = create_body_graph_for_if_branch(ctx, node.input[1],
                                                 data_shape)
    false_graph = create_body_graph_for_if_branch(ctx, node.input[2],
                                                  data_shape)

    op_name = utils.make_name("If")
    out_name = port_name(op_name)

    # output a scalar
    if_node = helper.make_node("If", [cur_cond_val_out_name], [out_name],
                               name=op_name,
                               then_branch=true_graph,
                               else_branch=false_graph)
    return if_node, out_name
Exemplo n.º 30
0
def create_if_op(input_ids, output_data_type, output_shape):
    op_name = utils.make_name("If")
    true_graph = create_body_graph_for_if_branch(output_data_type,
                                                 output_shape, input_ids[1],
                                                 op_name)
    false_graph = create_body_graph_for_if_branch(output_data_type,
                                                  output_shape, input_ids[2],
                                                  op_name)
    out_name = port_name(op_name)

    # output a scalar
    if_node = helper.make_node("If", [input_ids[0]], [out_name],
                               name=op_name,
                               then_branch=true_graph,
                               else_branch=false_graph)
    return if_node, out_name