def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): node = match['op'] name = node.soft_get('name', node.id) # biases normalization bias_node = Add(graph, {'name': name + '/Bias_', 'can_be_scaleshift': False}).create_node() if not graph.graph['cmd_params'].generate_deprecated_IR_V7: node_name = node.name + '/WithoutBiases' bias_node_name = node.name rename_nodes([(node, node_name), (bias_node, bias_node_name)]) node.out_port(0).get_connection().set_source(bias_node.out_port(0)) node.in_port(2).get_connection().set_destination(bias_node.in_port(1)) node.out_port(0).connect(bias_node.in_port(0)) if node.has_valid('alpha') and not math.isclose(node.alpha, 1): bias_node.insert_op_on_input_port(in_port_idx=0, new_op_class=Mul, value=np.array(node.alpha), new_op_attrs={'name': name + '/Alpha_', 'can_be_scaleshift': False}) del node['alpha'] if node.has_valid('beta') and not math.isclose(node.beta, 1): bias_node.insert_op_on_input_port(in_port_idx=1, new_op_class=Mul, value=np.array(node.beta), new_op_attrs={'name': name + '/Beta_', 'can_be_scaleshift': False}) del node['beta'] MatMul.update_node_stat(node, { 'transpose_a': node.has_and_set('transpose_a'), 'transpose_b': node.has_and_set('transpose_b'), })
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='Gemm'): name = node.soft_get('name', node.id) node_output_port = node.out_port(0) if node.has_valid('alpha') and not math.isclose(node.alpha, 1): mul_alpha = create_op_with_const_inputs( graph, Mul, {1: np.array(node.alpha)}, { 'name': name + '/Alpha', 'can_be_scaleshift': False }) node_output_port.get_connection().insert_node(mul_alpha) node_output_port = mul_alpha.out_port(0) del node['alpha'] if node.is_in_port_connected(2): # biases normalization bias_node = Add(graph, { 'name': name + '/Bias_', 'can_be_scaleshift': False }).create_node() without_biases_node_name = name + '/WithoutBiases' rename_nodes([(node, without_biases_node_name), (bias_node, name)]) node_output_port.get_connection().set_source( bias_node.out_port(0)) node.in_port(2).get_connection().set_destination( bias_node.in_port(1)) node_output_port.connect(bias_node.in_port(0)) if node.has_valid('beta') and not math.isclose(node.beta, 1): bias_node.insert_op_on_input_port(in_port_idx=1, new_op_class=Mul, value=np.array( node.beta), new_op_attrs={ 'name': name + '/Beta', 'can_be_scaleshift': False }) del node['beta'] MatMul.update_node_stat( node, { 'transpose_a': node.has_and_set('transpose_a'), 'transpose_b': node.has_and_set('transpose_b'), })