Example #1
0
    def test_tensor_data(self):
        tensors = {
            "empty_tensor": np.array([], dtype=np.float32),
            "multi_dim_empty_tensor": np.array([[], []], dtype=np.float32),
            "scalar": np.array(1., dtype=np.float32),
            "one_item_array": np.array([1.], dtype=np.float32),
            "normal_array": np.array([[1., 2.], [2., 3.]], dtype=np.float32)
        }
        tf_reset_default_graph()
        with tf_session() as sess:
            for n, data in tensors.items():
                tf.constant(data, dtype=tf.float32, name=n)

        for tf_node in sess.graph.get_operations():
            name = tf_node.name
            self.assertTrue(name in tensors.keys())

            self.assertTrue("value" in tf_node.node_def.attr)
            # convert to onnx tensor value
            tensor_value = tf_utils.tf_to_onnx_tensor(
                tf_utils.get_tf_node_attr(tf_node, "value"),
                name=utils.port_name(tf_node.name))
            attr = helper.make_attribute("value", tensor_value)
            # same as node.get_tensor_value(is_list=False)
            actual = numpy_helper.to_array(helper.get_attribute_value(attr))

            expected = tensors[name]

            self.assertTrue(np.array_equal(expected, actual))
Example #2
0
 def test_rewrite_subgraph(self):
     graph_proto = self.sample_net()
     g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
     pattern = \
         OpTypePattern('Abs', name='output', inputs=[
             OpTypePattern('Add', name='input')
         ])
     ops = g.get_nodes()
     matcher = GraphMatcher(pattern)
     match_results = list(matcher.match_ops(ops))
     for match in match_results:
         input_node = match.get_op('input')
         output_node = match.get_op('output')
         op_name = utils.make_name("ReplacedOp")
         out_name = utils.port_name(op_name)
         new_node = g.make_node("Sub",
                                inputs=input_node.input,
                                outputs=[out_name],
                                name=op_name)
         g.replace_all_inputs(output_node.output[0],
                              new_node.output[0])  # ops=ops
         for n in set(match.get_nodes()):
             g.remove_node(n.name)
     g.topological_sort(ops)
     result = onnx_to_graphviz(g)
     expected = 'digraph { Placeholder__5 [op_type=Placeholder] n1 [op_type=Abs] ' \
                'n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__6 [op_type=Sub] ' \
                'n6 [op_type=Identity] n5_graph_outputs_Identity__4 [op_type=Identity] ' \
                'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__6 n3:0 -> ReplacedOp__6 ' \
                'ReplacedOp__6:0 -> n6 ReplacedOp__6:0 -> n5_graph_outputs_Identity__4 }'
     self.assertEqual(expected, result)
Example #3
0
def tflist_to_onnx(g,
                   shape_override,
                   const_node_values=None,
                   ignore_default=None,
                   use_default=None):
    """
    Convert the tf-node list into an onnx graph with minimal rewrites so
    we can use the onnx graph as intermediate graph.
    """

    # ignore the following attributes
    ignored_attr = {
        "unknown_rank",
        "_class",
        "Tshape",
        "use_cudnn_on_gpu",
        "Index",
        "Tpaddings",
        "TI",
        "Tparams",
        "Tindices",
        "Tlen",
        "Tdim",
        "Tin",
        "dynamic_size",
        "Tmultiples",
        "Tblock_shape",
        "Tcrops",
        "index_type",
        "Taxis",
        "U",
        "maxval",
        "Tout",
        "Tlabels",
        "Tindex",
        "element_shape",
        "Targmax",
        "Tperm",
        "Tcond",
        "T_threshold",
        "element_dtype",
        "shape_type",
        "_lower_using_switch_merge",
        "parallel_iterations",
        "_num_original_outputs",
        "output_types",
        "output_shapes",
        "key_dtype",
        "value_dtype",
        "Tin",
        "Tout",
        "capacity",
        "component_types",
        "shapes",
        "Toutput_types",
        "dense_shapes",
        "Tdense",
        "Tsegmentids",
        "Tshift",
        "Tnumsegments",
        "SrcT",
        "Tcomplex",
        "Treal",  # For RFFT, Tcomplex is ignored because
        # onnx.helper.make_node fails,
        # TODO: it should be added back.
    }

    node_list = g.get_operations()
    functions = {}

    # some stats
    op_cnt = collections.Counter()
    attr_cnt = collections.Counter()
    onnx_nodes = []
    output_shapes = {}
    dtypes = {}

    # find outputs
    ops = node_list

    # create dict with output to shape mappings
    for node in ops:
        for out in node.outputs:
            shape = shape_override.get(out.name)
            if shape is None:
                shape = get_tf_tensor_shape(out)
            dtypes[out.name] = map_tf_dtype(out.dtype)
            output_shapes[out.name] = shape

    # minimal conversion of attributes
    for node in ops:
        attr = {}
        takeit = True
        op_cnt[node.type] += 1
        for a in node.node_def.attr:
            attr_cnt[a] += 1
            value = get_tf_node_attr(node, a)
            if a in ignored_attr:
                pass
            elif a == "T":
                if value and not isinstance(value, list):
                    dtypes[node.name] = map_tf_dtype(value)
            elif a == "shape":
                shape = get_tf_shape_attr(node)
                if shape is not None:
                    attr[a] = shape
            elif a in {"body", "cond", "then_branch", "else_branch", "f"}:
                input_shapes = [inp.get_shape() for inp in node.inputs]
                nattr = get_tf_node_attr(node, a)
                attr[a] = nattr.name
                functions[nattr.name] = input_shapes
            elif a == "DstT":
                attr["to"] = map_tf_dtype(value)
            elif isinstance(value, tensor_pb2.TensorProto):
                if const_node_values and node.name in const_node_values:
                    value.tensor_content = const_node_values[node.name]
                onnx_tensor = tf_to_onnx_tensor(value,
                                                name=port_name(node.name))
                attr[a] = onnx_tensor
            elif isinstance(value, tf.DType):
                attr[a] = map_tf_dtype(value)
            elif isinstance(value, list) and len(value) > 0 and isinstance(
                    value[0], tf.DType):
                attr[a] = [map_tf_dtype(v) for v in value]
            else:
                attr[a] = get_tf_node_attr(node, a)

        node_type = node.type
        input_names = [i.name for i in node.inputs]
        output_names = [i.name for i in node.outputs]

        if node_type == 'PlaceholderWithDefault':
            if ignore_default and node.name in ignore_default:
                node_type = 'Placeholder'
                input_names = []
            elif use_default and node.name in use_default:
                node_type = 'Identity'

        if takeit:
            try:
                onnx_node = helper.make_node(node_type,
                                             input_names,
                                             output_names,
                                             name=node.name,
                                             **attr)
                onnx_nodes.append(onnx_node)
            except Exception as ex:
                logger.error("pass1 convert failed for %s, ex=%s", node, ex)
                raise

    return onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, functions
Example #4
0
def rewrite_random_normal(g, ops):
    pattern1 = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', name='input2', inputs=[
                OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]), "*"
            ]), "*"
        ])

    pattern2 = \
        OpTypePattern('Identity', name='output', inputs=[
            OpTypePattern('Identity', name='input2', inputs=[
                OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"])
            ])
        ])

    pattern_list = [pattern1, pattern2]
    for pattern in pattern_list:
        matcher = GraphMatcher(pattern)
        match_results = list(matcher.match_ops(ops))
        for match in match_results:
            output = match.get_op('output')
            if output.type == 'Add':
                # pattern 1
                mean = output.inputs[1].get_tensor_value()
            else:
                # pattern 2
                mean = 0.0
            dtype = g.get_dtype(output.output[0])
            op_name = utils.make_name("RandomNormal")
            out_name = utils.port_name(op_name)

            rn_op = match.get_op('input1')
            seed = rn_op.get_attr('seed2').i

            if rn_op.inputs[0].type == "Shape":
                shape_node = rn_op.inputs[0]
                new_node = g.make_node("RandomNormalLike",
                                       [shape_node.input[0]],
                                       outputs=[out_name],
                                       name=op_name,
                                       attr={
                                           "mean": mean,
                                           "scale": 1.0,
                                           "dtype": dtype,
                                           "seed": float(seed)
                                       })
            else:
                shape = g.get_shape(output.output[0])
                new_node = g.make_node("RandomNormal", [],
                                       outputs=[out_name],
                                       name=op_name,
                                       attr={
                                           "shape": shape,
                                           "mean": mean,
                                           "scale": 1.0,
                                           "dtype": dtype,
                                           "seed": seed
                                       })

            g.replace_all_inputs(output.output[0], new_node.output[0], ops=ops)
            g.safe_remove_nodes(match.get_nodes())
    return ops
def rewrite_flatten(g, ops):
    pattern_fixed_shape_input = \
        OpTypePattern('Reshape', name='reshape', inputs=[
            OpTypePattern("*", name="input"),
            OpTypePattern('Pack', name="pack", inputs=[
                OpTypePattern('StridedSlice', name="slice", inputs=[
                    "*", "*", "*", "*",
                ]),
                "*",
            ]),
        ])
    pattern_non_fixed_shape_input = \
        OpTypePattern('Reshape', name='reshape', inputs=[
            OpTypePattern("*", name="input"),
            OpTypePattern('Pack', name="pack", inputs=[
                OpTypePattern('StridedSlice', name="slice", inputs=[
                    OpTypePattern('Shape', inputs=[
                        OpTypePattern("*", name="input2")
                    ]),
                    "*", "*", "*",
                ]),
                "*",
            ]),
        ])
    matcher = GraphMatcher(pattern_fixed_shape_input)
    match_results_1 = list(matcher.match_ops(ops))

    matcher = GraphMatcher(pattern_non_fixed_shape_input)
    match_results_2 = list(matcher.match_ops(ops))

    match_results = [(match_results_1, True), (match_results_2, False)]
    for match_results, check_fixed_input_shape in match_results:
        for match in match_results:
            input_node = match.get_op('input')
            reshape_node = match.get_op('reshape')
            pack_node = match.get_op('pack')
            slice_node = match.get_op('slice')
            need_rewrite = pack_node.inputs[1].is_const() and pack_node.inputs[1].get_tensor_value() == -1
            if not need_rewrite:
                continue

            input_shape = g.get_shape(reshape_node.input[0])
            need_rewrite = input_shape is not None
            if not need_rewrite:
                continue

            if check_fixed_input_shape:
                need_rewrite = slice_node.inputs[0].is_const() and \
                               np.array_equal(list(input_shape), list(slice_node.inputs[0].get_tensor_value()))
                if not need_rewrite:
                    continue

            begin = slice_node.inputs[1].get_tensor_value(as_list=False)
            end = slice_node.inputs[2].get_tensor_value(as_list=False)
            strides = slice_node.inputs[3].get_tensor_value(as_list=False)
            need_rewrite = np.array_equal(begin, [0]) and len(end) == 1 and \
                           np.array_equal(strides, [1]) and end[0] - begin[0] == 1
            if not need_rewrite:
                continue

            to_remove = [n for n in match.get_nodes() if n != input_node]
            safe = g.safe_to_remove_nodes(to_remove)

            # Ok if reshape_node is not safe. Will make it safe later.
            if len(to_remove) - len(safe) > 1:
                continue

            op_name = utils.make_name("Flatten")
            out_name = utils.port_name(op_name)
            g.make_node("Flatten", [reshape_node.input[0]], outputs=[out_name], name=op_name)

            last_dim = input_shape[-1]
            sec_last_dim = input_shape[-2]
            new_dim = None
            if last_dim > 0 and sec_last_dim > 0:
                new_dim = last_dim * sec_last_dim
            else:
                new_dim = -1

            g.set_shape(out_name, input_shape[:-2] + [new_dim])
            g.replace_all_inputs(reshape_node.output[0], out_name, ops=ops)
            for n in to_remove:
                g.remove_node(n.name)

    return ops
def rewrite_dropout(g, ops):
    patterns = [
        OpTypePattern(
            'Mul',
            name='outputs',
            inputs=[
                OpTypePattern('RealDiv', name="input2"),
                OpTypePattern(
                    'Floor',
                    inputs=[
                        OpTypePattern(
                            'Add',
                            inputs=[
                                OpTypePattern("*", name="input3"),
                                OpTypePattern(
                                    'RandomUniform|RandomUniformLike'),
                            ])
                    ]),
            ]),
        OpTypePattern(
            "Mul",
            name="outputs",
            inputs=[
                OpTypePattern("Mul", name="input2"),
                OpTypePattern(
                    "Cast",
                    inputs=[
                        OpTypePattern(
                            "GreaterEqual",
                            inputs=[
                                OpTypePattern(
                                    "RandomUniform|RandomUniformLike"),
                                OpTypePattern("*", name="input3")
                            ])
                    ])
            ]),
        # pattern for tf-2.0 tf.nn.dropout()
        OpTypePattern(
            "Mul",
            name="outputs",
            inputs=[
                OpTypePattern(
                    "Cast",
                    inputs=[
                        OpTypePattern(
                            "GreaterEqual",
                            inputs=[
                                OpTypePattern(
                                    "RandomUniform|RandomUniformLike"),
                                OpTypePattern("*", name="input3")
                            ])
                    ]),
                OpTypePattern("Mul", name="input2"),
            ]),
    ]
    for pattern in patterns:
        matcher = GraphMatcher(pattern, allow_reorder=True)
        match_results = list(matcher.match_ops(ops))
        for match in match_results:
            input2 = match.get_op('input2')
            input3 = match.get_op('input3')
            outputs = match.get_op('outputs')

            if not input3.is_scalar():
                logger.warning(
                    "Dropout pattern rooted at %s does not have a "
                    "constant ratio and cannot be replaced.", outputs.name)
                continue
            ratio = input3.get_tensor_value()

            if input2.inputs[0].is_scalar():
                data = input2.inputs[1]
                scaling_constant = input2.inputs[0].get_tensor_value()
            elif input2.inputs[1].is_scalar():
                data = input2.inputs[0]
                scaling_constant = input2.inputs[1].get_tensor_value()
            else:
                logger.warning(
                    "Could not find scaling constant for dropout pattern rooted at %s. "
                    "The pattern will not be replaced with an ONNX dropout node.",
                    outputs.name)
                continue

            #The scaling constant should be 1/(1-ratio), otherwise this isn't truly a dropout node
            if not np.allclose([1], [scaling_constant * (1 - ratio)]):
                logger.warning(
                    "Scaling constant %f for dropout pattern rooted at %s is inconsistent with dropout "
                    "ratio %f. The pattern will not be replaced with an ONNX dropout node.",
                    scaling_constant, outputs.name, ratio)
                continue

            nodes_to_remove = [
                n for n in match.get_nodes() if n.name != input3.name
            ]
            if not g.is_safe_to_remove_nodes(nodes_to_remove,
                                             [outputs.output[0]]):
                logger.warning(
                    "Nodes in dropout pattern rooted at %s cannot be removed because intermediate results "
                    "of some nodes are referenced elsewhere in graph.",
                    outputs.name)
                continue

            op_name = utils.make_name("Dropout")
            out_name = utils.port_name(op_name)
            new_node = g.make_node("Dropout",
                                   inputs=[data.output[0]],
                                   outputs=[out_name],
                                   name=op_name,
                                   attr={"ratio": ratio},
                                   shapes=[g.get_shape(data.output[0])],
                                   dtypes=[g.get_dtype(data.output[0])])
            g.replace_all_inputs(outputs.output[0],
                                 new_node.output[0],
                                 ops=ops)
            for n in nodes_to_remove:
                g.remove_node(n.name)

    return ops