def replace_pattern(self, graph: Graph, match: dict): node = match['minimum'] # Constant propagation case if node.in_node(0).value is not None and node.in_node( 1).value is not None: return neg_1_const = Const( graph, dict(value=np.array(-1), name=node.name + '/negate1_const')) neg_2_const = Const( graph, dict(value=np.array(-1), name=node.name + '/negate2_const')) negate_1 = Mul(graph, dict(name=node.name + '/negate1_')) negate_2 = Mul(graph, dict(name=node.name + '/negate2_')) maximum = Maximum(graph, dict(name=node.name + '/Max_')) negate_output_const = Const( graph, dict(value=np.array(-1), name=node.name + '/negate_out_const_')) negate_output = Mul(graph, dict(scale=-1, name=node.name + '/negate_out_')) negate_output.create_node_with_data(inputs=[ maximum.create_node_with_data([ negate_1.create_node_with_data( [node.in_node(0), neg_1_const.create_node_with_data()]), negate_2.create_node_with_data( [node.in_node(1), neg_2_const.create_node_with_data()]) ]), negate_output_const.create_node_with_data() ], data_nodes=node.out_node()) # Delete minimum vertex node.graph.remove_node(node.id)
def replace_pattern(self, graph: Graph, match: dict): node = match['op'] if (node.data_format != b'NHWC' or len(node.in_nodes()) != 5 or node.in_node(0).value is not None or # input node.in_node(1).value is None or # scale node.in_node(2).value is None or # offset node.in_node(3).value is not None or # mean node.in_node(4).value is not None or # variance node.in_node(1).value.ndim != 1 or node.in_node(2).value.ndim != 1): return scale_mul = Mul(graph, dict(name=node.name + '/scale_mul_')) shift_add = Add(graph, dict(name=node.name + '/shift_add_')) mean_add = Add(graph, dict(name=node.name + '/mean_add_')) variance_mul = Mul(graph, dict(name=node.name + '/variance_mul_')) neg_const = Const( graph, dict(value=np.array(-1), name=node.name + '/mean_negate_')) mean_negate = Mul(graph, dict(name=node.name + '/mean_negate_')) mean_arg = mean_add.create_node_with_data([ node.in_node(0), mean_negate.create_node_with_data( [node.in_node(3), neg_const.create_node_with_data()]) ]) shift_const = Const( graph, dict(value=node.eps, name=node.name + '/variance_denom_shift_const_')) power_const = Const( graph, dict(value=-0.5, name=node.name + '/variance_denom_power_const_')) variance_denom_shift = Add( graph, dict(name=node.name + '/variance_denom_shift_')) variance_denom_power = Pow( graph, dict(name=node.name + '/variance_denom_power_')) variance_arg = variance_mul.create_node_with_data([ mean_arg, variance_denom_power.create_node_with_data([ variance_denom_shift.create_node_with_data( [node.in_node(4), shift_const.create_node_with_data()]), power_const.create_node_with_data() ]) ]) shift_add.create_node_with_data([ scale_mul.create_node_with_data([variance_arg, node.in_node(1)]), node.in_node(2) ], data_nodes=node.out_node()) node.graph.remove_node(node.id)