def test_rewrite_subgraph(self): model_proto = self.sample_net() nodes = model_proto.node g = tf2onnx.graph.Graph(nodes, output_shapes={}, dtypes={}) pattern = \ OpTypePattern('Abs', name='output', inputs=[ OpTypePattern('Add', name='input') ]) ops = g.get_nodes() matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: input_node = match.get_op('input') output_node = match.get_op('output') op_name = tf2onnx.utils.make_name("ReplacedOp") out_name = tf2onnx.utils.port_name(op_name) new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name) ops = g.replace_subgraph(ops, match, [], [output_node], [], [new_node]) g.topological_sort(ops) result = onnx_to_graphviz(g) expected = 'digraph { n1 [op_type=Abs] n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__2 [op_type=Sub] ' \ 'n6 [op_type=Identity] input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__2 ' \ 'n3:0 -> ReplacedOp__2 ReplacedOp__2:0 -> n6 }' self.assertEqual(expected, result)
def test_match_flipped(self): n1 = helper.make_node("Sub", ["i1", "i1"], ["n1:0"], name="n1") n2 = helper.make_node("Add", ["i2", "i2"], ["n2:0"], name="n2") n3 = helper.make_node("Mul", ["n1:0", "n2:0"], ["n3:0"], name="n3") graph_proto = helper.make_graph( nodes=[n1, n2, n3], name="test", inputs=[ helper.make_tensor_value_info("i1", TensorProto.FLOAT, [2, 2]), helper.make_tensor_value_info("i2", TensorProto.FLOAT, [2, 2]) ], outputs=[ helper.make_tensor_value_info("n2:0", TensorProto.FLOAT, [2, 2]) ], initializer=[]) g = GraphUtil.create_graph_from_onnx_graph(graph_proto) pattern = OpTypePattern( 'Mul', inputs=[OpTypePattern('Add'), OpTypePattern('Sub')]) ops = g.get_nodes() matcher = GraphMatcher(pattern, allow_reorder=True) match_results = list(matcher.match_ops(ops)) self.assertEqual(1, len(match_results))
def rewrite_thresholded_relu(g, ops): if g.opset < 10: return ops pattern = \ OpTypePattern('Mul', name='mul', inputs=[ OpTypePattern('Cast', name='cast', inputs=[ OpTypePattern('Greater', name='greater', inputs=[ OpTypePattern('*', name='greater_input'), OpTypePattern('Const', name='theta') ]) ]), OpTypePattern('*', name='mul_input') ]) matcher = GraphMatcher(pattern, allow_reorder=True) match_results = list(matcher.match_ops(ops)) for match in match_results: greater_node = match.get_op('greater') greater_input_node = match.get_op('greater_input') mul_node = match.get_op("mul") mul_input_node = match.get_op('mul_input') cast_node = match.get_op('cast') greater_input_edge_name = _find_edge_name_between_nodes(greater_input_node, greater_node) mul_input_edge_name = _find_edge_name_between_nodes(mul_input_node, mul_node) if greater_input_edge_name == mul_input_edge_name: theta = match.get_op('theta').get_tensor_value() thresholded_relu = g.make_node("ThresholdedRelu", inputs=[mul_input_edge_name], attr={"alpha": theta}, shapes=[g.get_shape(mul_node.output[0])], dtypes=[g.get_dtype(mul_node.output[0])]) g.replace_all_inputs(mul_node.output[0], thresholded_relu.output[0], ops=ops) to_delete = [cast_node, mul_node] g.safe_remove_nodes(to_delete) return ops
def run(self, unit_type): """ main procedures: 1 use cell op pattern to find cell >> the found cell is the start pointer of the procedures below 2 find needed info from tensorflow graph: 1 rnn scope name 2 input_x 3 weight 4 sequence node 5 initializer 6 state output & hidden output 3 process found info according to ONNX requirement remember: op pattern and scope name are useful they are used to get needed info from tensorflow graph raw found info need to be formatted according to ONNX requirement """ # allow_reorder must be true. because LSTMCell and BasicLSTMCell's call function # are defining the calculation with different orders. Then we can share the same # pattern. cell_pattern = get_pattern(unit_type) matcher = GraphMatcher(cell_pattern, allow_reorder=True) match_results = list(matcher.match_ops(self.g.get_nodes())) if match_results: for match in match_results: self.run_single_match(match) self.g.delete_unused_nodes(self.g.outputs) self.print_step("finish handling") return self.g.get_nodes()
def rewrite_biasadd_with_conv2d(g, ops): pattern = \ OpTypePattern('BiasAdd', name='biasadd', inputs=[ OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*']), '*']) matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: biasadd = match.get_op('biasadd') conv = match.get_op('conv') #backup the conv and biasadd values conv_type = conv.type conv_input = conv.input conv_attr = conv.attr dtype = g.get_dtype(conv.output[0]) shape = g.get_shape(conv.output[0]) conv_name = biasadd.name conv_output = biasadd.output conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]] # Remove the Conv and BiasAdd node g.remove_node(conv.name) g.remove_node(biasadd.name) g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output, shapes=[shape], dtypes=[dtype], skip_conversion=False) return ops
def _parse_input_ta(self, context): graph_inputs = [ v.switch_true_identity_output.id for v in context.loop_properties.all_variables.values() if v.switch_true_identity_output.id ] matcher = GraphMatcher(self.ta_read_input_pattern, allow_reorder=False) match_results = matcher.match_ops(self.g.get_nodes()) match_results = [ r for r in match_results if r.get_op("ta_index").output[0] in graph_inputs ] for match in match_results: ta_input_scatter = match.get_op("ta_input_scatter") # the 3rd input of scatter is the value data_input_id = ta_input_scatter.input[2] ta_read_node = match.get_op("ta_read") # todo: need check ta's index variable is a scalar starting from 1, and increase by 1 each iteration. # then we can be sure this is equivalent to scan input behavior. index_input_id = ta_read_node.input[1] unstacked_ta_consumer = match.get_op("ta_read").output[0] ta = InputTensorArray(data_input_id, index_input_id, unstacked_ta_consumer, self.g) context.loop_properties.add_scan_input(ta)
def rewrite_random_uniform(g, ops): pattern = \ OpTypePattern('Add', name='output', inputs=[ OpTypePattern('Mul', inputs=[ OpTypePattern('RandomUniform', name='input1', inputs=["*"]), OpTypePattern('Sub', name='input2', inputs=["Const|ConstV2", "Const|ConstV2"]), ]), None ]) matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: input2 = match.get_op('input2') output = match.get_op('output') ru_op = match.get_op('input1') # max is on input 0 tmax = input2.inputs[0].get_tensor_value() tmin = input2.inputs[1].get_tensor_value() to_delete = list(set(match.get_nodes())) new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete) g.replace_all_inputs(output.output[0], new_node.output[0], ops=ops) g.safe_remove_nodes(to_delete) return ops
def rewrite_random_uniform_fold_const(g, ops): pattern = \ OpTypePattern('Add', name='output', inputs=[ OpTypePattern('Mul', name='mul', inputs=[ OpTypePattern('RandomUniform', name='input1', inputs=["*"]), "Const|ConstV2", ]), "Const|ConstV2", ]) matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: output = match.get_op('output') mul = match.get_op('mul') ru_op = match.get_op('input1') tmax_minus_tmin = mul.inputs[1].get_tensor_value() tmin = output.inputs[1].get_tensor_value() tmax = tmin + tmax_minus_tmin to_delete = list(set(match.get_nodes())) new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete) g.replace_all_inputs(output.output[0], new_node.output[0], ops=ops) g.safe_remove_nodes(to_delete) return ops
def test_rewrite_subgraph(self): graph_proto = self.sample_net() g = GraphUtil.create_graph_from_onnx_graph(graph_proto) pattern = \ OpTypePattern('Abs', name='output', inputs=[ OpTypePattern('Add', name='input') ]) ops = g.get_nodes() matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: input_node = match.get_op('input') output_node = match.get_op('output') op_name = utils.make_name("ReplacedOp") out_name = utils.port_name(op_name) new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name) g.replace_all_inputs(output_node.output[0], new_node.output[0]) # ops=ops for n in set(match.get_nodes()): g.remove_node(n.name) g.topological_sort(ops) result = onnx_to_graphviz(g) expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] ' \ 'n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__5 [op_type=Sub] ' \ 'n6 [op_type=Identity] n5_graph_outputs_Identity__3 [op_type=Identity] ' \ 'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__5 n3:0 -> ReplacedOp__5 ' \ 'ReplacedOp__5:0 -> n6 ReplacedOp__5:0 -> n5_graph_outputs_Identity__3 }' self.assertEqual(expected, result)
def rewrite_random_normal(g, ops): pattern = \ OpTypePattern('Add', name='output', inputs=[ OpTypePattern('Mul', name='input2', inputs=[ OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]), "*" ]), "*" ]) matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: output = match.get_op('output') mean = output.inputs[1].get_tensor_value() dtype = g.get_dtype(output.output[0]) op_name = utils.make_name("RandomNormal") out_name = port_name(op_name) rn_op = match.get_op('input1') if rn_op.inputs[0].type == "Shape": shape_node = rn_op.inputs[0] new_node = g.make_node("RandomNormalLike", [shape_node.input[0]], outputs=[out_name], name=op_name, attr={"mean": mean, "scale": 1.0, "dtype": dtype}) else: shape = g.get_shape(output.output[0]) new_node = g.make_node("RandomNormal", [], outputs=[out_name], name=op_name, attr={"shape": shape, "mean": mean, "scale": 1.0, "dtype": dtype}) g.replace_all_inputs(ops, output.output[0], new_node.output[0]) for n in set(match.get_nodes()): g.remove_node(n.name) return ops
def rewrite_quantize_and_dequantize(g, ops): pattern_for_qdq_v2 = \ OpTypePattern('QuantizeAndDequantizeV2', name='output', inputs=[ OpTypePattern("*"), OpTypePattern(None), OpTypePattern(None), ]) pattern_for_qdq_v3 = \ OpTypePattern('QuantizeAndDequantizeV3', name='output', inputs=[ OpTypePattern("*"), OpTypePattern(None), OpTypePattern(None), OpTypePattern(None), ]) # Match all the patterns for QDQ ops patterns = [pattern_for_qdq_v3, pattern_for_qdq_v2] match_results = [] for pattern in patterns: matcher = GraphMatcher(pattern) results = list(matcher.match_ops(ops)) match_results.extend(results) return create_qdq_nodes(g, match_results)
def rewrite_ragged_variant_shape(g, ops): pattern1 = \ OpTypePattern('Shape', name='shape', inputs=[ OpTypePattern('RaggedTensorToVariant', name='raggedtovariant') ]) pattern_list = [pattern1] for pattern in pattern_list: matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: shape = match.get_op('shape') raggedtovariant = match.get_op('raggedtovariant') if raggedtovariant.get_attr_value("batched_input") != 1: continue if raggedtovariant.get_attr_value("RAGGED_RANK") != 1: continue # Shape of batched variant from ragged is same as number of splits minus 1 g.replace_inputs(shape, [raggedtovariant.input[0]]) np_dtype = utils.map_onnx_to_numpy_type( g.get_dtype(shape.output[0])) const_one = g.make_const(utils.make_name("const_one"), np.array(1, np_dtype)).output[0] g.insert_new_node_on_output("Sub", shape.output[0], inputs=[shape.output[0], const_one]) return ops
def rewrite_leakyrelu(g, ops): if g.opset < 6: return ops pattern = \ OpTypePattern('Maximum', name='max', inputs=[ OpTypePattern('Mul', name='mul', inputs=[ OpTypePattern('Const', name='alpha'), OpTypePattern('*', name='mul_input'), ]), OpTypePattern('*', name='max_input'), ]) matcher = GraphMatcher(pattern, allow_reorder=True) match_results = list(matcher.match_ops(ops)) for match in match_results: max_node = match.get_op('max') mul_node = match.get_op("mul") max_input_edge_name = match.get_tensor('max_input') mul_input_edge_name = match.get_tensor('mul_input') if max_input_edge_name == mul_input_edge_name: alpha = match.get_op("alpha").get_tensor_value() if alpha >= 1: continue leakyrelu = g.make_node("LeakyRelu", inputs=[max_input_edge_name], attr={"alpha": alpha}, shapes=[g.get_shape(max_node.output[0])], dtypes=[g.get_dtype(max_node.output[0])]) ops.append(leakyrelu) g.replace_all_inputs(max_node.output[0], leakyrelu.output[0], ops=ops) to_delete = [max_node, mul_node] g.safe_remove_nodes(to_delete) return ops
def rewrite_dropout(g, ops): pattern = \ OpTypePattern('Mul', name='outputs', inputs=[ OpTypePattern('RealDiv', name="input2"), OpTypePattern('Floor', inputs=[ OpTypePattern('Add', inputs=[ OpTypePattern(None, name="input3"), OpTypePattern('RandomUniform|RandomUniformLike'), ]) ]), ]) matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: inputs2 = match.get_op('input2') outputs = match.get_op('outputs') op_name = utils.make_name("Dropout") out_name = port_name(op_name) new_node = g.make_node("Dropout", [inputs2.input[0]], outputs=[out_name], name=op_name, attr={"ratio": 1.0}, shapes=[g.get_shape(inputs2.input[0])], dtypes=[g.get_dtype(inputs2.input[0])]) g.replace_all_inputs(ops, outputs.output[0], new_node.output[0]) g.safe_remove_nodes(match.get_nodes()) # remove dropout if its ratio is 1.0 for node in g.get_nodes(): if node.type == "Dropout" and node.get_attr("ratio").f == 1.0: g.replace_all_inputs(g.get_nodes(), node.output[0], node.input[0]) g.remove_node(node.name) return ops
def _parse_input_ta(self, context): matcher = GraphMatcher(self.rnn_input_pattern, allow_reorder=True) match_results = list(matcher.match_ops(self.g.get_nodes())) match_results = [ r for r in match_results if r.get_op("ta_input_scatter").name.startswith(context.rnn_scope) ] for match in match_results: ta_input_scatter = match.get_op("ta_input_scatter") # the 3rd input of scatter is the value input_ta = TensorArrayProp() # dynamic_rnn specific approach. input_ta.data_input_id = ta_input_scatter.input[2] ta_read_node = match.get_op("ta_read") input_ta.index_input_id = ta_read_node.input[1] input_ta.output_id = match.get_op("ta_read").output[0] input_shape = self.g.get_shape(input_ta.data_input_id) output_shape = self.g.get_shape(input_ta.output_id) if output_shape is None and input_shape is not None: self.g.set_shape(input_ta.output_id, input_shape[1:]) context.input_tas.append(input_ta) log.debug( "input ta %s - data input (%s) shape: %s, output (%s) shape: %s", ta_read_node.name, input_ta.data_input_id, self.g.get_shape(input_ta.data_input_id), input_ta.output_id, self.g.get_shape(input_ta.output_id))
def rewrite_random_normal(g, ops): pattern1 = \ OpTypePattern('Add', name='output', inputs=[ OpTypePattern('Mul', name='input2', inputs=[ OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]), "Const|ConstV2" ]), "Const|ConstV2" ]) pattern2 = \ OpTypePattern('Identity', name='output', inputs=[ OpTypePattern('Identity', name='input2', inputs=[ OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]) ]) ]) pattern_list = [pattern1, pattern2] for pattern in pattern_list: matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: output = match.get_op('output') if output.type == 'Add': # pattern 1 mean = output.inputs[1].get_tensor_value() else: # pattern 2 mean = 0.0 input2 = match.get_op('input2') if input2.type == 'Mul': scale = input2.inputs[1].get_tensor_value() else: scale = 1.0 dtype = g.get_dtype(output.output[0]) op_name = utils.make_name("RandomNormal") out_name = utils.port_name(op_name) rn_op = match.get_op('input1') seed = float(rn_op.get_attr('seed2').i) attr = {"mean": mean, "scale": scale, "dtype": dtype, "seed": seed} if rn_op.inputs[0].type == "Shape": shape_node = rn_op.inputs[0] new_node = g.make_node("RandomNormalLike", [shape_node.input[0]], outputs=[out_name], name=op_name, attr=attr) else: shape = g.get_shape(output.output[0]) if shape is None or -1 in shape: continue attr['shape'] = shape new_node = g.make_node("RandomNormal", [], outputs=[out_name], name=op_name, attr=attr) g.replace_all_inputs(output.output[0], new_node.output[0], ops=ops) g.safe_remove_nodes(match.get_nodes()) return ops
def rewrite_depthwise_conv_dilations(g, ops): pattern1 = \ OpTypePattern("BatchToSpaceND", name="batch_to_space", inputs=[ OpTypePattern("DepthwiseConv2dNative", name="depthwise_conv", inputs=[ OpTypePattern("SpaceToBatchND", name="space_to_batch", inputs=[ OpTypePattern("*"), OpTypePattern("Const|ConstV2"), OpTypePattern("Const|ConstV2"), ]), OpTypePattern("*"), ]), OpTypePattern("Const|ConstV2"), OpTypePattern("Const|ConstV2"), ]) for pattern in [pattern1]: matcher = GraphMatcher(pattern, allow_reorder=False) match_results = list(matcher.match_ops(ops)) for match_result in match_results: space_to_batch = match_result.get_op("space_to_batch") depthwise_conv = match_result.get_op("depthwise_conv") batch_to_space = match_result.get_op("batch_to_space") block_shape1 = space_to_batch.inputs[1].get_tensor_value( as_list=True) paddings = space_to_batch.inputs[2].get_tensor_value( as_list=False).flatten().tolist() block_shape2 = batch_to_space.inputs[1].get_tensor_value( as_list=True) crops = batch_to_space.inputs[2].get_tensor_value(as_list=True) if block_shape1 != block_shape2: continue if depthwise_conv.get_attr_value("dilations", [1, 1, 1, 1]) != [1, 1, 1, 1]: continue if depthwise_conv.get_attr_value("strides", [1, 1, 1, 1]) != [1, 1, 1, 1]: continue if depthwise_conv.get_attr_value("data_format", b"NHWC") != b"NHWC": continue if depthwise_conv.get_attr_value("padding") != b"VALID": continue if crops != [[0, 0], [0, 0]]: continue inp = space_to_batch.input[0] kernel = depthwise_conv.input[1] g.replace_inputs(depthwise_conv, [inp, kernel]) depthwise_conv.set_attr("dilations", [1] + block_shape1 + [1]) depthwise_conv.set_attr("explicit_paddings", [0, 0] + paddings + [0, 0]) depthwise_conv.set_attr("padding", "EXPLICIT") g.copy_shape(batch_to_space.output[0], depthwise_conv.output[0]) g.replace_all_inputs(batch_to_space.output[0], depthwise_conv.output[0]) return g.get_nodes()
def rewrite_test(g, ops): pattern = \ OpTypePattern('Add', name='op', inputs=["*", "*"]) ops = g.get_nodes() matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: op = match.get_op('op') op.type = "Mul" return ops
def find_sequence_length_node(self, context): # get any state variable state_variable = list(context.state_variables.values())[0] next_iter_input_node = self.g.get_node_by_output(state_variable.next_iteration_input.id) if not is_select_op(next_iter_input_node): log.debug("no sequence length node is given") return None matcher = GraphMatcher(seq_len_pattern) match_result = matcher.match_op(next_iter_input_node) if not match_result: raise RuntimeError("failed to find sequence length.") return match_result.get_op("seq_len_node")
def rewrite_conv2d_with_pad(g, ops): pattern = \ OpTypePattern("Conv2D", name="conv", inputs=[ OpTypePattern("Pad", name="pad"), OpTypePattern("*") ]) matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: conv = match.get_op("conv") pad = match.get_op("pad") paddings = pad.inputs[1] if not paddings.is_const(): continue mode = pad.get_attr("mode") if mode: mode = mode.s.decode("utf-8").lower() if mode not in [None, "constant"] or len(pad.input) >= 3: continue # Conv2D already has a pad if conv.get_attr("padding").s.decode("utf-8") == "SAME": continue logger.debug("merge pad [%s] into conv [%s]", pad.name, conv.name) paddings_val = np.array(paddings.get_tensor_value()) # can't pad on batch or channel dimensions data_format = conv.get_attr("data_format").s.decode("utf-8") if data_format == "NHWC": if np.any(paddings_val[0]) or np.any(paddings_val[3]): continue paddings_val = paddings_val[1:3] else: if np.any(paddings_val[0]) or np.any(paddings_val[1]): continue paddings_val = paddings_val[2:4] paddings_val = paddings_val.transpose().flatten() g.replace_input(conv, conv.input[0], pad.input[0]) # convert Conv2D conv.type = "Conv2D" func, _ = handler.tf_op.find_effective_op("Conv2D") func(g, conv) conv.skip_conversion = True conv.set_attr("auto_pad", "NOTSET") conv.set_attr("pads", paddings_val) return ops
def rewrite_tfl_qdq(g, ops): pattern0 = \ OpTypePattern('TFL_DEQUANTIZE', name='dequant', inputs=[ OpTypePattern('TFL_QUANTIZE', name='quant'), ]) matcher = GraphMatcher(pattern0, allow_reorder=False) match_results = list(matcher.match_ops(ops)) if match_results: for match in match_results: dequant = match.get_op("dequant") quant = match.get_op("quant") inp_node = quant.inputs[0] for k in ["scale", "quantized_dimension", "zero_point"]: if dequant.get_attr_value(k) != quant.get_attr_value(k): continue needed_relu = None if all(k in quant.attr and len(quant.get_attr_value(k)) == 1 for k in ["min", "max"]): min_val = quant.get_attr_value("min")[0] max_val = quant.get_attr_value("max")[0] if min_val == 0.0 and 5.999 <= max_val <= 6.0: needed_relu = "TFL_RELU6" elif min_val == 0.0: # This may introduce unneeded relu ops but will be correct. # If the --dequantize feature is used a lot in the future we can optimize this. needed_relu = "TFL_RELU" if inp_node.type == needed_relu: # If it's really obviously unneeded, we skip it. needed_relu = None elif "TFL_" + inp_node.get_attr_value("fused_activation_function", b'').decode() == needed_relu: needed_relu = None if needed_relu is not None: relu_name = inp_node.name + "_relu" relu6 = g.make_node(needed_relu, [quant.input[0]], op_name_scope=relu_name, skip_conversion=False, shapes=quant.output_shapes, dtypes=quant.output_dtypes) g.replace_all_inputs(dequant.output[0], relu6.output[0]) else: g.replace_all_inputs(dequant.output[0], quant.input[0]) g.remove_node(dequant.name) if len(g.find_output_consumers(quant.output[0])) == 0: g.remove_node(quant.name) return ops
def rewrite_tfl_rfft(g, ops): pattern0 = \ OpTypePattern('TFL_COMPLEX_ABS', name='complex_abs', inputs=[ OpTypePattern('TFL_RESHAPE', name='reshape', inputs=[ OpTypePattern('TFL_RFFT2D', name='rfft2d', inputs=[ OpTypePattern('*'), OpTypePattern('Const|ConstV2', name='length'), ]), OpTypePattern('Const|ConstV2', name='shape'), ], allow_reorder=True), ]) matcher = GraphMatcher(pattern0, allow_reorder=False) match_results = list(matcher.match_ops(ops)) if match_results: for match in match_results: length = match.get_op("length").get_tensor_value(as_list=True) rfft2d = match.get_op("rfft2d") complex_abs = match.get_op("complex_abs") reshape = match.get_op("reshape") shape = match.get_op("shape").get_tensor_value(as_list=True) output_shape = g.get_shape(rfft2d.output[0]) if output_shape is None or output_shape != shape[:-1] + [ 1, shape[-1] ]: continue if length[0] != 1: continue rfft2d.type = "RFFT" g.copy_shape(complex_abs.input[0], rfft2d.output[0]) # Skip the Reshape g.replace_input(complex_abs, complex_abs.input[0], rfft2d.output[0], 0) new_length = g.make_const(utils.make_name("rfft_length"), np.array([length[1]], np.int64)) g.replace_input(rfft2d, rfft2d.input[1], new_length.output[0], 1) g.replace_all_inputs(complex_abs.output[0], reshape.output[0]) # Move reshape below complex abs g.replace_input(reshape, reshape.input[0], complex_abs.output[0], 0) return ops
def _match_cell(self, context, unittype): """match unit cell""" for cell_pattern in get_pattern(unittype): matcher = GraphMatcher(cell_pattern, allow_reorder=True) loop_props = context.loop_properties inputs = loop_props.state_inputs + loop_props.scan_inputs input_ids = [input_tensor_value_info.id for input_tensor_value_info in inputs] outputs = loop_props.state_outputs + loop_props.scan_outputs output_ids = [out_tensor_value_info.id for out_tensor_value_info in outputs] body_graph_ops, _, _ = LoopRewriterBase.find_subgraph( set(input_ids), set(output_ids), self.g, merge_as_end=True ) match_results = list(matcher.match_ops(body_graph_ops)) if len(match_results) == 1: return match_results[0] return None
def rewrite_transpose(g, ops): pattern = \ OpTypePattern('Transpose', name='output', inputs=[ OpTypePattern(None), OpTypePattern('Sub', inputs=[ OpTypePattern('Sub', inputs=["*", "*"]), OpTypePattern('Range', inputs=["*", "*", "*"]), ]), ]) matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: output = match.get_op('output') shape = g.get_shape(output.input[0]) dims = [i for i in range(len(shape) - 1, -1, -1)] output.set_attr("perm", dims) g.remove_input(output, output.input[1]) to_delete = [n for n in match.get_nodes() if n != output] g.safe_remove_nodes(to_delete) return ops
def rewrite_biasadd_with_conv2d(g, ops): pattern1 = \ OpTypePattern('BiasAdd', name='biasadd', inputs=[ OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*']), '*']) pattern2 = \ OpTypePattern('BiasAdd', name='biasadd', inputs=[ OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=[ '*', '*', '*']), '*'], allow_reorder=True) for pattern in [pattern1, pattern2]: matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: biasadd = match.get_op('biasadd') conv = match.get_op('conv') # Backup the conv and biasadd values conv_type = conv.type conv_input = conv.input conv_attr = conv.attr dtype = g.get_dtype(conv.output[0]) shape = g.get_shape(conv.output[0]) conv_name = biasadd.name conv_output = biasadd.output if pattern == pattern2: conv_inputs = [conv_input[0], conv_input[1], conv_input[2], biasadd.input[1]] else: conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]] if len(g.find_output_consumers(conv.output[0])) > 1: continue # Remove the Conv and BiasAdd node g.remove_node(conv.name) g.remove_node(biasadd.name) g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output, shapes=[shape], dtypes=[dtype], skip_conversion=False) return ops
def rewrite_random_uniform(g, ops): pattern = \ OpTypePattern('Add', name='output', inputs=[ OpTypePattern('Mul', inputs=[ OpTypePattern('RandomUniform', name='input1', inputs=["*"]), OpTypePattern('Sub', name='input2', inputs=["*", "*"]), ]), None ]) matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: input2 = match.get_op('input2') output = match.get_op('output') ru_op = match.get_op('input1') # max is on input 0 tmax = input2.inputs[0].get_tensor_value()[0] tmin = input2.inputs[1].get_tensor_value()[0] new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output) ops = g.replace_subgraph(ops, match, [], [output], [], [new_node]) return ops
def rewrite_random_uniform_fold_const(g, ops): pattern = \ OpTypePattern('Add', name='output', inputs=[ OpTypePattern('Mul', name='mul', inputs=[ OpTypePattern('RandomUniform', name='input1', inputs=["*"]), None, ]), None, ]) matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: output = match.get_op('output') mul = match.get_op('mul') ru_op = match.get_op('input1') tmax_minus_tmin = mul.inputs[1].get_tensor_value()[0] tmin = output.inputs[1].get_tensor_value()[0] tmax = tmin + tmax_minus_tmin new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output) ops = g.replace_subgraph(ops, match, [], [output], [], [new_node]) return ops
def rewrite_tfl_select_zero_mul(g, ops): pattern0 = \ OpTypePattern('TFL_SELECT_V2', name='select', inputs=[ OpTypePattern('TFL_EQUAL', name='equal', inputs=[ OpTypePattern('Const|ConstV2', name='const_eq'), OpTypePattern('*', name='term_eq'), ], allow_reorder=True), OpTypePattern('Const|ConstV2', name='const_select'), OpTypePattern('TFL_MUL', name='mul', inputs=[ OpTypePattern('*', name='term_mul1'), OpTypePattern('*', name='term_mul2'), ]), ]) matcher = GraphMatcher(pattern0, allow_reorder=False) match_results = list(matcher.match_ops(ops)) if match_results: for match in match_results: select = match.get_op("select") term_eq = match.get_op("term_eq") const_select = match.get_op("const_select") const_eq = match.get_op("const_eq") term_mul1 = match.get_op("term_mul1") term_mul2 = match.get_op("term_mul2") if const_select.get_tensor_value(as_list=True) != 0: continue if const_eq.get_tensor_value(as_list=True) != 0: continue if term_mul1.name != term_eq.name: term_mul1, term_mul2 = term_mul2, term_mul1 if term_mul1.name != term_eq.name: continue # Tell downstream conversion to avoid Mul/Add optimization select.set_attr("handles_nan", True) return ops
def rewrite_dropout(g, ops): patterns = [ OpTypePattern( 'Mul', name='outputs', inputs=[ OpTypePattern('RealDiv', name="input2"), OpTypePattern( 'Floor', inputs=[ OpTypePattern( 'Add', inputs=[ OpTypePattern("*", name="input3"), OpTypePattern( 'RandomUniform|RandomUniformLike'), ]) ]), ]), OpTypePattern( "Mul", name="outputs", inputs=[ OpTypePattern("Mul", name="input2"), OpTypePattern( "Cast", inputs=[ OpTypePattern( "GreaterEqual", inputs=[ OpTypePattern( "RandomUniform|RandomUniformLike"), OpTypePattern("*", name="input3") ]) ]) ]), # pattern for tf-2.0 tf.nn.dropout() OpTypePattern( "Mul", name="outputs", inputs=[ OpTypePattern( "Cast", inputs=[ OpTypePattern( "GreaterEqual", inputs=[ OpTypePattern( "RandomUniform|RandomUniformLike"), OpTypePattern("*", name="input3") ]) ]), OpTypePattern("Mul", name="input2"), ]), ] for pattern in patterns: matcher = GraphMatcher(pattern, allow_reorder=True) match_results = list(matcher.match_ops(ops)) for match in match_results: input2 = match.get_op('input2') input3 = match.get_op('input3') outputs = match.get_op('outputs') if not input3.is_scalar(): logger.warning( "Dropout pattern rooted at %s does not have a " "constant ratio and cannot be replaced.", outputs.name) continue ratio = input3.get_tensor_value() if input2.inputs[0].is_scalar(): data = input2.inputs[1] scaling_constant = input2.inputs[0].get_tensor_value() elif input2.inputs[1].is_scalar(): data = input2.inputs[0] scaling_constant = input2.inputs[1].get_tensor_value() else: logger.warning( "Could not find scaling constant for dropout pattern rooted at %s. " "The pattern will not be replaced with an ONNX dropout node.", outputs.name) continue #The scaling constant should be 1/(1-ratio), otherwise this isn't truly a dropout node if not np.allclose([1], [scaling_constant * (1 - ratio)]): logger.warning( "Scaling constant %f for dropout pattern rooted at %s is inconsistent with dropout " "ratio %f. The pattern will not be replaced with an ONNX dropout node.", scaling_constant, outputs.name, ratio) continue nodes_to_remove = [ n for n in match.get_nodes() if n.name != input3.name ] if not g.is_safe_to_remove_nodes(nodes_to_remove, [outputs.output[0]]): logger.warning( "Nodes in dropout pattern rooted at %s cannot be removed because intermediate results " "of some nodes are referenced elsewhere in graph.", outputs.name) continue op_name = utils.make_name("Dropout") out_name = utils.port_name(op_name) new_node = g.make_node("Dropout", inputs=[data.output[0]], outputs=[out_name], name=op_name, attr={"ratio": ratio}, shapes=[g.get_shape(data.output[0])], dtypes=[g.get_dtype(data.output[0])]) g.replace_all_inputs(outputs.output[0], new_node.output[0], ops=ops) for n in nodes_to_remove: g.remove_node(n.name) return ops
def rewrite_flatten(g, ops): pattern_fixed_shape_input = \ OpTypePattern('Reshape', name='reshape', inputs=[ OpTypePattern("*", name="input"), OpTypePattern('Pack', name="pack", inputs=[ OpTypePattern('StridedSlice', name="slice", inputs=[ "*", "*", "*", "*", ]), "*", ]), ]) pattern_non_fixed_shape_input = \ OpTypePattern('Reshape', name='reshape', inputs=[ OpTypePattern("*", name="input"), OpTypePattern('Pack', name="pack", inputs=[ OpTypePattern('StridedSlice', name="slice", inputs=[ OpTypePattern('Shape', inputs=[ OpTypePattern("*", name="input2") ]), "*", "*", "*", ]), "*", ]), ]) matcher = GraphMatcher(pattern_fixed_shape_input) match_results_1 = list(matcher.match_ops(ops)) matcher = GraphMatcher(pattern_non_fixed_shape_input) match_results_2 = list(matcher.match_ops(ops)) match_results = [(match_results_1, True), (match_results_2, False)] for match_results, check_fixed_input_shape in match_results: for match in match_results: input_node = match.get_op('input') reshape_node = match.get_op('reshape') pack_node = match.get_op('pack') slice_node = match.get_op('slice') need_rewrite = pack_node.inputs[1].is_const( ) and pack_node.inputs[1].get_tensor_value() == -1 if not need_rewrite: continue input_shape = g.get_shape(reshape_node.input[0]) need_rewrite = input_shape is not None if not need_rewrite: continue if check_fixed_input_shape: need_rewrite = slice_node.inputs[0].is_const() and \ np.array_equal(list(input_shape), list(slice_node.inputs[0].get_tensor_value())) if not need_rewrite: continue begin = slice_node.inputs[1].get_tensor_value(as_list=False) end = slice_node.inputs[2].get_tensor_value(as_list=False) strides = slice_node.inputs[3].get_tensor_value(as_list=False) need_rewrite = np.array_equal(begin, [0]) and len(end) == 1 and \ np.array_equal(strides, [1]) and end[0] - begin[0] == 1 if not need_rewrite: continue op_name = utils.make_name("Flatten") out_name = port_name(op_name) new_node = g.make_node("Flatten", [reshape_node.input[0]], outputs=[out_name], name=op_name) last_dim = input_shape[-1] sec_last_dim = input_shape[-2] new_dim = None if last_dim > 0 and sec_last_dim > 0: new_dim = last_dim * sec_last_dim else: new_dim = -1 g.set_shape(out_name, input_shape[:-2] + [new_dim]) g.replace_all_inputs(ops, reshape_node.output[0], out_name) to_delete = [n for n in match.get_nodes() if n != input_node] g.safe_remove_nodes(to_delete) return ops