def rewrite_random_uniform(g, ops): pattern = \ OpTypePattern('Add', name='output', inputs=[ OpTypePattern('Mul', inputs=[ OpTypePattern('RandomUniform', name='input1', inputs=["*"]), OpTypePattern('Sub', name='input2', inputs=["*", "*"]), ]), None ]) matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: input2 = match.get_op('input2') output = match.get_op('output') ru_op = match.get_op('input1') # max is on input 0 tmax = input2.inputs[0].get_tensor_value() tmin = input2.inputs[1].get_tensor_value() to_delete = list(set(match.get_nodes())) new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete) g.replace_all_inputs(output.output[0], new_node.output[0], ops=ops) g.safe_remove_nodes(to_delete) return ops
def _parse_input_ta(self, context): graph_inputs = [ v.switch_true_identity_output.id for v in context.loop_properties.all_variables.values() if v.switch_true_identity_output.id ] matcher = GraphMatcher(self.ta_read_input_pattern, allow_reorder=False) match_results = matcher.match_ops(self.g.get_nodes()) match_results = [ r for r in match_results if r.get_op("ta_index").output[0] in graph_inputs ] for match in match_results: ta_input_scatter = match.get_op("ta_input_scatter") # the 3rd input of scatter is the value data_input_id = ta_input_scatter.input[2] ta_read_node = match.get_op("ta_read") # todo: need check ta's index variable is a scalar starting from 1, and increase by 1 each iteration. # then we can be sure this is equivalent to scan input behavior. index_input_id = ta_read_node.input[1] unstacked_ta_consumer = match.get_op("ta_read").output[0] ta = InputTensorArray(data_input_id, index_input_id, unstacked_ta_consumer, self.g) context.loop_properties.add_scan_input(ta)
def test_match_flipped(self): n1 = helper.make_node("Sub", ["i1", "i1"], ["n1:0"], name="n1") n2 = helper.make_node("Add", ["i2", "i2"], ["n2:0"], name="n2") n3 = helper.make_node("Mul", ["n1:0", "n2:0"], ["n3:0"], name="n3") graph_proto = helper.make_graph( nodes=[n1, n2, n3], name="test", inputs=[ helper.make_tensor_value_info("i1", TensorProto.FLOAT, [2, 2]), helper.make_tensor_value_info("i2", TensorProto.FLOAT, [2, 2]) ], outputs=[ helper.make_tensor_value_info("n2:0", TensorProto.FLOAT, [2, 2]) ], initializer=[]) g = GraphUtil.create_graph_from_onnx_graph(graph_proto) pattern = OpTypePattern( 'Mul', inputs=[OpTypePattern('Add'), OpTypePattern('Sub')]) ops = g.get_nodes() matcher = GraphMatcher(pattern, allow_reorder=True) match_results = list(matcher.match_ops(ops)) self.assertEqual(1, len(match_results))
def test_rewrite_subgraph(self): graph_proto = self.sample_net() g = GraphUtil.create_graph_from_onnx_graph(graph_proto) pattern = \ OpTypePattern('Abs', name='output', inputs=[ OpTypePattern('Add', name='input') ]) ops = g.get_nodes() matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: input_node = match.get_op('input') output_node = match.get_op('output') op_name = utils.make_name("ReplacedOp") out_name = utils.port_name(op_name) new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name) g.replace_all_inputs(output_node.output[0], new_node.output[0]) # ops=ops for n in set(match.get_nodes()): g.remove_node(n.name) g.topological_sort(ops) result = onnx_to_graphviz(g) expected = 'digraph { Placeholder__5 [op_type=Placeholder] n1 [op_type=Abs] ' \ 'n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__6 [op_type=Sub] ' \ 'n6 [op_type=Identity] n5_graph_outputs_Identity__4 [op_type=Identity] ' \ 'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__6 n3:0 -> ReplacedOp__6 ' \ 'ReplacedOp__6:0 -> n6 ReplacedOp__6:0 -> n5_graph_outputs_Identity__4 }' self.assertEqual(expected, result)
def rewrite_quantize_and_dequantize(g, ops): pattern_for_qdq_v2 = \ OpTypePattern('QuantizeAndDequantizeV2', name='output', inputs=[ OpTypePattern("*"), OpTypePattern(None), OpTypePattern(None), ]) pattern_for_qdq_v3 = \ OpTypePattern('QuantizeAndDequantizeV3', name='output', inputs=[ OpTypePattern("*"), OpTypePattern(None), OpTypePattern(None), OpTypePattern(None), ]) # Match all the patterns for QDQ ops patterns = [pattern_for_qdq_v3, pattern_for_qdq_v2] match_results = [] for pattern in patterns: matcher = GraphMatcher(pattern) results = list(matcher.match_ops(ops)) match_results.extend(results) return create_qdq_nodes(g, match_results)
def rewrite_random_uniform_fold_const(g, ops): pattern = \ OpTypePattern('Add', name='output', inputs=[ OpTypePattern('Mul', name='mul', inputs=[ OpTypePattern('RandomUniform', name='input1', inputs=["*"]), None, ]), None, ]) matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: output = match.get_op('output') mul = match.get_op('mul') ru_op = match.get_op('input1') tmax_minus_tmin = mul.inputs[1].get_tensor_value() tmin = output.inputs[1].get_tensor_value() tmax = tmin + tmax_minus_tmin to_delete = list(set(match.get_nodes())) new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete) g.replace_all_inputs(output.output[0], new_node.output[0], ops=ops) g.safe_remove_nodes(to_delete) return ops
def rewrite_biasadd_with_conv2d(g, ops): pattern = \ OpTypePattern('BiasAdd', name='biasadd', inputs=[ OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*']), '*']) matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: biasadd = match.get_op('biasadd') conv = match.get_op('conv') #backup the conv and biasadd values conv_type = conv.type conv_input = conv.input conv_attr = conv.attr dtype = g.get_dtype(conv.output[0]) shape = g.get_shape(conv.output[0]) conv_name = biasadd.name conv_output = biasadd.output conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]] # Remove the Conv and BiasAdd node g.remove_node(conv.name) g.remove_node(biasadd.name) g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output, shapes=[shape], dtypes=[dtype], skip_conversion=False) return ops
def rewrite_tfl_qdq(g, ops): pattern0 = \ OpTypePattern('TFL_DEQUANTIZE', name='dequant', inputs=[ OpTypePattern('TFL_QUANTIZE', name='quant'), ]) matcher = GraphMatcher(pattern0, allow_reorder=False) match_results = list(matcher.match_ops(ops)) if match_results: for match in match_results: dequant = match.get_op("dequant") quant = match.get_op("quant") inp_node = quant.inputs[0] for k in ["scale", "quantized_dimension", "zero_point"]: if dequant.get_attr_value(k) != quant.get_attr_value(k): continue needed_relu = None if all(k in quant.attr and len(quant.get_attr_value(k)) == 1 for k in ["min", "max"]): min_val = quant.get_attr_value("min")[0] max_val = quant.get_attr_value("max")[0] if min_val == 0.0 and 5.999 <= max_val <= 6.0: needed_relu = "TFL_RELU6" elif min_val == 0.0: # This may introduce unneeded relu ops but will be correct. # If the --dequantize feature is used a lot in the future we can optimize this. needed_relu = "TFL_RELU" if inp_node.type == needed_relu: # If it's really obviously unneeded, we skip it. needed_relu = None elif "TFL_" + inp_node.get_attr_value( "fused_activation_function", b'').decode() == needed_relu: needed_relu = None if needed_relu is not None: relu_name = inp_node.name + "_relu" relu6 = g.make_node(needed_relu, [quant.input[0]], op_name_scope=relu_name, skip_conversion=False, shapes=quant.output_shapes, dtypes=quant.output_dtypes) g.replace_all_inputs(dequant.output[0], relu6.output[0]) else: g.replace_all_inputs(dequant.output[0], quant.input[0]) g.remove_node(dequant.name) if len(g.find_output_consumers(quant.output[0])) == 0: g.remove_node(quant.name) return ops
def rewrite_conv2d_with_pad(g, ops): pattern = \ OpTypePattern("Conv2D", name="conv", inputs=[ OpTypePattern("Pad", name="pad"), OpTypePattern("*") ]) matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: conv = match.get_op("conv") pad = match.get_op("pad") paddings = pad.inputs[1] if not paddings.is_const(): continue mode = pad.get_attr("mode") if mode: mode = mode.s.decode("utf-8").lower() if mode not in [None, "constant"] or len(pad.input) >= 3: continue # Conv2D already has a pad if conv.get_attr("padding").s.decode("utf-8") == "SAME": continue logger.debug("merge pad [%s] into conv [%s]", pad.name, conv.name) paddings_val = np.array(paddings.get_tensor_value()) # can't pad on batch or channel dimensions data_format = conv.get_attr("data_format").s.decode("utf-8") if data_format == "NHWC": if np.any(paddings_val[0]) or np.any(paddings_val[3]): continue paddings_val = paddings_val[1:3] else: if np.any(paddings_val[0]) or np.any(paddings_val[1]): continue paddings_val = paddings_val[2:4] paddings_val = paddings_val.transpose().flatten() g.replace_input(conv, conv.input[0], pad.input[0], 0) # convert Conv2D conv.type = "Conv2D" func, _ = handler.tf_op.find_effective_op("Conv2D") func(g, conv) conv.skip_conversion = True conv.set_attr("auto_pad", "NOTSET") conv.set_attr("pads", paddings_val) return ops
def rewrite_thresholded_relu(g, ops): if g.opset < 10: return ops pattern = \ OpTypePattern('Mul', name='mul', inputs=[ OpTypePattern('Cast', name='cast', inputs=[ OpTypePattern('Greater', name='greater', inputs=[ OpTypePattern('*', name='greater_input'), OpTypePattern('Const', name='theta') ]) ]), OpTypePattern('*', name='mul_input') ]) matcher = GraphMatcher(pattern, allow_reorder=True) match_results = list(matcher.match_ops(ops)) for match in match_results: greater_node = match.get_op('greater') greater_input_node = match.get_op('greater_input') mul_node = match.get_op("mul") mul_input_node = match.get_op('mul_input') cast_node = match.get_op('cast') greater_input_edge_name = _find_edge_name_between_nodes( greater_input_node, greater_node) mul_input_edge_name = _find_edge_name_between_nodes( mul_input_node, mul_node) if greater_input_edge_name == mul_input_edge_name: theta = match.get_op('theta').get_tensor_value() thresholded_relu = g.make_node( "ThresholdedRelu", inputs=[mul_input_edge_name], attr={"alpha": theta}, shapes=[g.get_shape(mul_node.output[0])], dtypes=[g.get_dtype(mul_node.output[0])]) g.replace_all_inputs(mul_node.output[0], thresholded_relu.output[0], ops=ops) to_delete = [cast_node, mul_node] g.safe_remove_nodes(to_delete) return ops
def rewrite_leakyrelu(g, ops): if g.opset < 6: return ops pattern = \ OpTypePattern('Maximum', name='max', inputs=[ OpTypePattern('Mul', name='mul', inputs=[ OpTypePattern('Const', name='alpha'), OpTypePattern('*', name='mul_input'), ]), OpTypePattern('*', name='max_input'), ]) matcher = GraphMatcher(pattern, allow_reorder=True) match_results = list(matcher.match_ops(ops)) for match in match_results: max_node = match.get_op('max') max_input_node = match.get_op('max_input') mul_node = match.get_op("mul") mul_input_node = match.get_op('mul_input') max_input_edge_name = _find_edge_name_between_nodes( max_input_node, max_node) mul_input_edge_name = _find_edge_name_between_nodes( mul_input_node, mul_node) if max_input_edge_name == mul_input_edge_name: alpha = match.get_op("alpha").get_tensor_value() if alpha >= 1: continue leakyrelu = g.make_node("LeakyRelu", inputs=[max_input_edge_name], attr={"alpha": alpha}, shapes=[g.get_shape(max_node.output[0])], dtypes=[g.get_dtype(max_node.output[0])]) ops.append(leakyrelu) g.replace_all_inputs(max_node.output[0], leakyrelu.output[0], ops=ops) to_delete = [max_node, mul_node] g.safe_remove_nodes(to_delete) return ops
def rewrite_transpose(g, ops): pattern = \ OpTypePattern('Transpose', name='output', inputs=[ OpTypePattern(None), OpTypePattern('Sub', inputs=[ OpTypePattern('Sub', inputs=["*", "*"]), OpTypePattern('Range', inputs=["*", "*", "*"]), ]), ]) matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: output = match.get_op('output') shape = g.get_shape(output.input[0]) dims = range(len(shape) - 1, -1, -1) output.set_attr("perm", dims) g.remove_input(output, output.input[1], 1) to_delete = [n for n in match.get_nodes() if n != output] g.safe_remove_nodes(to_delete) return ops
def _match_cell(self, context, unittype): """match unit cell""" for cell_pattern in get_pattern(unittype): matcher = GraphMatcher(cell_pattern, allow_reorder=True) loop_props = context.loop_properties inputs = loop_props.state_inputs + loop_props.scan_inputs input_ids = [ input_tensor_value_info.id for input_tensor_value_info in inputs ] outputs = loop_props.state_outputs + loop_props.scan_outputs output_ids = [ out_tensor_value_info.id for out_tensor_value_info in outputs ] body_graph_ops, _, _ = LoopRewriterBase.find_subgraph( set(input_ids), set(output_ids), self.g, merge_as_end=True) match_results = list(matcher.match_ops(body_graph_ops)) if len(match_results) == 1: return match_results[0] return None
def rewrite_gemm(g, ops): if g.opset <= 6: return ops # pattern0: alpha*A*B + beta*C pattern0 = \ OpTypePattern('Add|AddV2', name='add', inputs=[ OpTypePattern('Mul', name='mul1', inputs=[ OpTypePattern('Const', name='alpha'), OpTypePattern('MatMul', name='matmul') ]), OpTypePattern('Mul', name='mul2', inputs=[ OpTypePattern('Const', name='beta'), OpTypePattern('*', name='C') ]) ]) # pattern1: alpha*A*B + C pattern1 = \ OpTypePattern('Add|AddV2', name='add', inputs=[ OpTypePattern('Mul', name='mul1', inputs=[ OpTypePattern('MatMul', name='matmul'), OpTypePattern('Const', name='alpha') ]), OpTypePattern('*', name='C'), ]) # pattern2: A*B + beta*C pattern2 = \ OpTypePattern('Add|AddV2', name='add', inputs=[ OpTypePattern('MatMul', name='matmul'), OpTypePattern('Mul', name='mul2', inputs=[ OpTypePattern('Const', name='beta'), OpTypePattern('*', name='C') ]) ]) # pattern3: A*B + C pattern3 = \ OpTypePattern('Add|AddV2', name='add', inputs=[ OpTypePattern('MatMul', name='matmul'), OpTypePattern('*', name='C'), ]) # pattern4: A*B + c pattern4 = \ OpTypePattern('BiasAdd', name='add', inputs=[ OpTypePattern('MatMul', name='matmul'), OpTypePattern('*', name='C'), ]) pattern_list = [pattern0, pattern1, pattern2, pattern3, pattern4] for pattern in pattern_list: matcher = GraphMatcher(pattern, allow_reorder=True) match_results = list(matcher.match_ops(ops)) if match_results: for match in match_results: matmul_node = match.get_op("matmul") if g.get_dtype(matmul_node.input[0]) != onnx_pb.TensorProto.FLOAT: logging.warning(u"For now, onnxruntime only support float32 type for Gemm rewriter") continue attr, is_valid = get_gemm_attr(match) if not is_valid: continue add_node = match.get_op('add') input_c_node = match.get_op("C") a_edge_name = matmul_node.input[0] b_edge_name = matmul_node.input[1] c_edge_name = input_c_node.output[0] a_mul_b_shape = g.get_shape(matmul_node.output[0]) c_shape = g.get_shape(c_edge_name) if c_shape is None: continue if a_mul_b_shape is None: continue if -1 in c_shape + a_mul_b_shape: continue if g.get_rank(a_edge_name) != 2 or g.get_rank(b_edge_name) != 2: continue compatible = True for i in range(1, len(c_shape) + 1): if c_shape[-i] not in [1, a_mul_b_shape[-i]]: compatible = False if not compatible: continue gemm = g.make_node("Gemm", inputs=[a_edge_name, b_edge_name, c_edge_name], attr=attr, shapes=[g.get_shape(add_node.output[0])], dtypes=[g.get_dtype(add_node.output[0])], op_name_scope=matmul_node.name) ops.append(gemm) g.replace_all_inputs(add_node.output[0], gemm.output[0], ops=ops) to_delete = [add_node, matmul_node] g.safe_remove_nodes(to_delete) return ops
def rewrite_eye(g, ops): # schema of eye is eye(num_rows, num_columns=None), if num_columns not specified then it's equal to num_rows # tf.eye is implemented by a sub_graph which contains op "MatrixDiag" or "MatrixSetDiag" while # these two ops are un-supported directly in onnx # but onnx op EyeLike can be used to map the sub_graph # "rewrite_eye" supports tf.eye(non_const) and tf.eye(non_const1, non_const2). # tf.eye(const) and tf.eye(const1, const2) are not supported in this rewriter # ConstantOfShape in opset 9 is used, so if opset less than 9 then do nothing if g.opset < 9: return g.get_nodes() pattern1 = \ OpTypePattern("MatrixDiag", name="output_eye_matrix", inputs=[ OpTypePattern("Fill", inputs=[ OpTypePattern("Const", name="fill_value"), OpTypePattern("ConcatV2", inputs=[ "*", "*", OpTypePattern("Pack", inputs=[ OpTypePattern("Minimum|Cast", name="min_or_cast") ]) ]) ]) ]) pattern2 = \ OpTypePattern("MatrixSetDiag", name="output_eye_matrix", inputs=[ OpTypePattern("Fill"), OpTypePattern("Fill", inputs=[ OpTypePattern("Const", name="fill_value"), OpTypePattern("ConcatV2", inputs=[ "*", "*", OpTypePattern("Pack", inputs=[ OpTypePattern("Minimum|Cast", name="min_or_cast") ]) ]) ]) ]) pattern3 = \ OpTypePattern("MatrixDiag", name="output_eye_matrix", inputs=[ OpTypePattern("Fill", inputs=[ OpTypePattern("ConcatV2", inputs=[ "*", OpTypePattern("ExpandDims", inputs=[ OpTypePattern("Minimum|Cast", name="min_or_cast"), "*" ]), "*", ]), OpTypePattern("Const", name="fill_value"), ]) ]) pattern4 = \ OpTypePattern("MatrixSetDiag", name="output_eye_matrix", inputs=[ OpTypePattern("Fill"), OpTypePattern("Fill", inputs=[ OpTypePattern("ConcatV2", inputs=[ "*", OpTypePattern("ExpandDims", inputs=[ OpTypePattern("Minimum|Cast", name="min_or_cast"), "*" ]), "*", ]), OpTypePattern("Const", name="fill_value"), ]), ]) pattern5 = \ OpTypePattern("MatrixDiagV3", name="output_eye_matrix", inputs=[ OpTypePattern("Fill", inputs=[ OpTypePattern("ConcatV2", inputs=[ "*", OpTypePattern("ExpandDims", inputs=[ OpTypePattern("Minimum|Cast", name="min_or_cast"), "*" ]), "*", ]), OpTypePattern("Const", name="fill_value"), ]), "*", "*", "*", "*", ]) pattern6 = \ OpTypePattern("MatrixSetDiagV3", name="output_eye_matrix", inputs=[ OpTypePattern("Fill"), OpTypePattern("Fill", inputs=[ OpTypePattern("ConcatV2", inputs=[ "*", OpTypePattern("ExpandDims", inputs=[ OpTypePattern("Minimum|Cast", name="min_or_cast"), "*" ]), "*", ]), OpTypePattern("Const", name="fill_value"), ]), "*" ]) pattern7 = \ OpTypePattern("MatrixDiag", name="output_eye_matrix", inputs=[ OpTypePattern("Fill", inputs=[ OpTypePattern("Reshape", inputs=[ OpTypePattern("Minimum|Cast", name="min_or_cast"), "*", ]), OpTypePattern("Const", name="fill_value"), ]) ]) pattern8 = \ OpTypePattern("MatrixSetDiag", name="output_eye_matrix", inputs=[ OpTypePattern("Fill"), OpTypePattern("Fill", inputs=[ OpTypePattern("Reshape", inputs=[ OpTypePattern("Minimum|Cast", name="min_or_cast"), "*", ]), OpTypePattern("Const", name="fill_value"), ]) ]) for pattern in [ pattern1, pattern2, pattern3, pattern4, pattern5, pattern6, pattern7, pattern8 ]: matcher = GraphMatcher(pattern, allow_reorder=True) match_results = list(matcher.match_ops(ops)) for match_result in match_results: if match_result.get_op("fill_value").get_tensor_value() != 1: continue min_or_cast = match_result.get_op("min_or_cast") if min_or_cast.type == "Minimum": min_node = min_or_cast elif min_or_cast.type == "Cast" and min_or_cast.inputs[ 0].type == "Minimum": min_node = min_or_cast.inputs[0] else: continue num_rows = min_node.inputs[0] num_columns = min_node.inputs[1] old_output = match_result.get_op("output_eye_matrix") output_dtypes = [g.get_dtype(old_output.output[0])] output_shapes = [g.get_shape(old_output.output[0])] g.remove_node(old_output.name) # onnx op "EyeLike" need a 2D tensor, so generate it num_rows = GraphBuilder(g).make_unsqueeze( { "axes": [0], "data": num_rows.output[0] }, return_node=True) num_columns = GraphBuilder(g).make_unsqueeze( { "axes": [0], "data": num_columns.output[0] }, return_node=True) matrix_shape = g.make_node( "Concat", [num_rows.output[0], num_columns.output[0]], attr={"axis": 0}) # cast nodes added for "ConstantOfShape" in ONNX only accepts int64 data. matrix_shape_int64 = g.make_node( "Cast", matrix_shape.output, attr={"to": onnx_pb.TensorProto.INT64}) zero_matrix = g.make_node("ConstantOfShape", matrix_shape_int64.output) g.make_node("EyeLike", zero_matrix.output, attr={"dtype": output_dtypes[0]}, name=old_output.name, shapes=output_shapes, dtypes=output_dtypes, outputs=old_output.output) return g.get_nodes()
def rewrite_dropout(g, ops): patterns = [ OpTypePattern( 'Mul', name='outputs', inputs=[ OpTypePattern('RealDiv', name="input2"), OpTypePattern( 'Floor', inputs=[ OpTypePattern( 'Add', inputs=[ OpTypePattern("*", name="input3"), OpTypePattern( 'RandomUniform|RandomUniformLike'), ]) ]), ]), OpTypePattern( "Mul", name="outputs", inputs=[ OpTypePattern("Mul", name="input2"), OpTypePattern( "Cast", inputs=[ OpTypePattern( "GreaterEqual", inputs=[ OpTypePattern( "RandomUniform|RandomUniformLike"), OpTypePattern("*", name="input3") ]) ]) ]), # pattern for tf-2.0 tf.nn.dropout() OpTypePattern( "Mul", name="outputs", inputs=[ OpTypePattern( "Cast", inputs=[ OpTypePattern( "GreaterEqual", inputs=[ OpTypePattern( "RandomUniform|RandomUniformLike"), OpTypePattern("*", name="input3") ]) ]), OpTypePattern("Mul", name="input2"), ]), ] for pattern in patterns: matcher = GraphMatcher(pattern, allow_reorder=True) match_results = list(matcher.match_ops(ops)) for match in match_results: input2 = match.get_op('input2') input3 = match.get_op('input3') outputs = match.get_op('outputs') if not input3.is_scalar(): logger.warning( "Dropout pattern rooted at %s does not have a " "constant ratio and cannot be replaced.", outputs.name) continue ratio = input3.get_tensor_value() if input2.inputs[0].is_scalar(): data = input2.inputs[1] scaling_constant = input2.inputs[0].get_tensor_value() elif input2.inputs[1].is_scalar(): data = input2.inputs[0] scaling_constant = input2.inputs[1].get_tensor_value() else: logger.warning( "Could not find scaling constant for dropout pattern rooted at %s. " "The pattern will not be replaced with an ONNX dropout node.", outputs.name) continue #The scaling constant should be 1/(1-ratio), otherwise this isn't truly a dropout node if not np.allclose([1], [scaling_constant * (1 - ratio)]): logger.warning( "Scaling constant %f for dropout pattern rooted at %s is inconsistent with dropout " "ratio %f. The pattern will not be replaced with an ONNX dropout node.", scaling_constant, outputs.name, ratio) continue nodes_to_remove = [ n for n in match.get_nodes() if n.name != input3.name ] if not g.is_safe_to_remove_nodes(nodes_to_remove, [outputs.output[0]]): logger.warning( "Nodes in dropout pattern rooted at %s cannot be removed because intermediate results " "of some nodes are referenced elsewhere in graph.", outputs.name) continue op_name = utils.make_name("Dropout") out_name = utils.port_name(op_name) new_node = g.make_node("Dropout", inputs=[data.output[0]], outputs=[out_name], name=op_name, attr={"ratio": ratio}, shapes=[g.get_shape(data.output[0])], dtypes=[g.get_dtype(data.output[0])]) g.replace_all_inputs(outputs.output[0], new_node.output[0], ops=ops) for n in nodes_to_remove: g.remove_node(n.name) return ops
def rewrite_tfl_scan_outputs(g, ops): pattern0 = \ OpTypePattern('TFL_CONCATENATION', name='concat', inputs=[ OpTypePattern('TFL_SLICE', name='begin_slice'), OpTypePattern('*', name='middle'), OpTypePattern('TFL_SLICE', name='end_slice') ]) matcher = GraphMatcher(pattern0, allow_reorder=False) match_results = list(matcher.match_ops(ops)) if match_results: for match in match_results: concat = match.get_op("concat") begin_slice = match.get_op("begin_slice") middle = match.get_op("middle") end_slice = match.get_op("end_slice") middle_shape = g.get_shape(middle.output[0]) # Both slices must be slicing the same tensor if begin_slice.input[0] != end_slice.input[0]: continue original_tensor = begin_slice.input[0] if concat.get_attr_int("axis") != 0: continue # The inserted slice must have length 1 (to be a single index) if middle_shape is None or len( middle_shape) == 0 or middle_shape[0] != 1: continue rank = len(middle_shape) scan_output = middle.output[0] if not begin_slice.inputs[1].is_const( ) or not end_slice.inputs[2].is_const(): continue # The first slice must start from the beginning (0) for all dims if not all(v == 0 for v in begin_slice.inputs[1].get_tensor_value()): continue # The second slice must slice to the end (-1) for all dims if not all(v == -1 for v in end_slice.inputs[2].get_tensor_value()): continue # The other slice dims are assembled by concatenation if rank > 1 if rank > 1: begin_concat = begin_slice.inputs[2] end_concat = end_slice.inputs[1] if not begin_concat.type == "TFL_CONCATENATION": continue if not end_concat.type == "TFL_CONCATENATION": continue # Except for dim 0, slice from beginning to end if not all( get_uniform_const_val(inp) == -1 for inp in begin_concat.inputs[1:]): continue if not all( get_uniform_const_val(inp) == 0 for inp in end_concat.inputs[1:]): continue begin_idx = begin_concat.inputs[0] end_idx = end_concat.inputs[0] else: begin_idx = begin_slice.inputs[2] end_idx = end_slice.inputs[1] # For dim 0, slice to i for first part and from i+1 for second if not node_is_one_plus_node(begin_idx, end_idx): continue out1, _ = get_out_and_offset(begin_idx) graph_inps = [n.output[0] for n in g.inputs] # To be a scan output, i must be a graph input if out1 not in graph_inps: continue # The array being sliced must be a graph input if original_tensor not in graph_inps: continue # The input/output index of i idx = graph_inps.index(out1) # The input/output index of the array scan_output_idx = graph_inps.index(original_tensor) # For a scan output, i must be assigned to i+1 with each iteration if not node_is_one_plus_node(g.get_node_by_output(out1), g.get_node_by_output(g.outputs[idx])): continue if len(g.find_output_consumers(concat.output[0])) > 1: continue if g.opset < 10 and len(g.find_output_consumers( concat.output[0])) <= 1: # If opset is < 10, conversion of the subgraph will fail unless we remove the slice nodes # We add a tmp node to replace them. shape = g.get_shape(concat.output[0]) dtype = g.get_dtype(concat.output[0]) tmp_node = g.make_node("TMP_SCAN_OUTPUT", [original_tensor, scan_output], shapes=[shape], dtypes=[dtype]) g.replace_all_inputs(concat.output[0], tmp_node.output[0]) to_remove = [] out = g.outputs[scan_output_idx] node = g.get_node_by_output(out) to_remove.append(node) while len(node.input) > 0 and node != concat: out = node.input[0] node = g.get_node_by_output(out) to_remove.append(node) to_remove += [begin_slice, end_slice, concat] out = original_tensor node = g.get_node_by_output(out) to_remove.append(node) while len(node.input) > 0: out = node.input[0] node = g.get_node_by_output(out) to_remove.append(node) if not g.is_safe_to_remove_nodes(to_remove): continue g.scan_outputs.append((scan_output_idx, scan_output)) return ops
def rewrite_random_normal(g, ops): pattern1 = \ OpTypePattern('Add', name='output', inputs=[ OpTypePattern('Mul', name='input2', inputs=[ OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]), "*" ]), "*" ]) pattern2 = \ OpTypePattern('Identity', name='output', inputs=[ OpTypePattern('Identity', name='input2', inputs=[ OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]) ]) ]) pattern_list = [pattern1, pattern2] for pattern in pattern_list: matcher = GraphMatcher(pattern) match_results = list(matcher.match_ops(ops)) for match in match_results: output = match.get_op('output') if output.type == 'Add': # pattern 1 mean = output.inputs[1].get_tensor_value() else: # pattern 2 mean = 0.0 dtype = g.get_dtype(output.output[0]) op_name = utils.make_name("RandomNormal") out_name = utils.port_name(op_name) rn_op = match.get_op('input1') seed = rn_op.get_attr('seed2').i if rn_op.inputs[0].type == "Shape": shape_node = rn_op.inputs[0] new_node = g.make_node("RandomNormalLike", [shape_node.input[0]], outputs=[out_name], name=op_name, attr={ "mean": mean, "scale": 1.0, "dtype": dtype, "seed": float(seed) }) else: shape = g.get_shape(output.output[0]) new_node = g.make_node("RandomNormal", [], outputs=[out_name], name=op_name, attr={ "shape": shape, "mean": mean, "scale": 1.0, "dtype": dtype, "seed": seed }) g.replace_all_inputs(output.output[0], new_node.output[0], ops=ops) g.safe_remove_nodes(match.get_nodes()) return ops
def rewrite_flatten(g, ops): pattern_fixed_shape_input = \ OpTypePattern('Reshape', name='reshape', inputs=[ OpTypePattern("*", name="input"), OpTypePattern('Pack', name="pack", inputs=[ OpTypePattern('StridedSlice', name="slice", inputs=[ "*", "*", "*", "*", ]), "*", ]), ]) pattern_non_fixed_shape_input = \ OpTypePattern('Reshape', name='reshape', inputs=[ OpTypePattern("*", name="input"), OpTypePattern('Pack', name="pack", inputs=[ OpTypePattern('StridedSlice', name="slice", inputs=[ OpTypePattern('Shape', inputs=[ OpTypePattern("*", name="input2") ]), "*", "*", "*", ]), "*", ]), ]) matcher = GraphMatcher(pattern_fixed_shape_input) match_results_1 = list(matcher.match_ops(ops)) matcher = GraphMatcher(pattern_non_fixed_shape_input) match_results_2 = list(matcher.match_ops(ops)) match_results = [(match_results_1, True), (match_results_2, False)] for match_results, check_fixed_input_shape in match_results: for match in match_results: input_node = match.get_op('input') reshape_node = match.get_op('reshape') pack_node = match.get_op('pack') slice_node = match.get_op('slice') need_rewrite = pack_node.inputs[1].is_const() and pack_node.inputs[1].get_tensor_value() == -1 if not need_rewrite: continue input_shape = g.get_shape(reshape_node.input[0]) need_rewrite = input_shape is not None if not need_rewrite: continue if check_fixed_input_shape: need_rewrite = slice_node.inputs[0].is_const() and \ np.array_equal(list(input_shape), list(slice_node.inputs[0].get_tensor_value())) if not need_rewrite: continue begin = slice_node.inputs[1].get_tensor_value(as_list=False) end = slice_node.inputs[2].get_tensor_value(as_list=False) strides = slice_node.inputs[3].get_tensor_value(as_list=False) need_rewrite = np.array_equal(begin, [0]) and len(end) == 1 and \ np.array_equal(strides, [1]) and end[0] - begin[0] == 1 if not need_rewrite: continue to_remove = [n for n in match.get_nodes() if n != input_node] safe = g.safe_to_remove_nodes(to_remove) # Ok if reshape_node is not safe. Will make it safe later. if len(to_remove) - len(safe) > 1: continue op_name = utils.make_name("Flatten") out_name = utils.port_name(op_name) g.make_node("Flatten", [reshape_node.input[0]], outputs=[out_name], name=op_name) last_dim = input_shape[-1] sec_last_dim = input_shape[-2] new_dim = None if last_dim > 0 and sec_last_dim > 0: new_dim = last_dim * sec_last_dim else: new_dim = -1 g.set_shape(out_name, input_shape[:-2] + [new_dim]) g.replace_all_inputs(reshape_node.output[0], out_name, ops=ops) for n in to_remove: g.remove_node(n.name) return ops