예제 #1
0
    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
        node = match['op']
        name = node.soft_get('name', node.id)

        # biases normalization
        bias_node = Add(graph, {'name': name + '/Bias_', 'can_be_scaleshift': False}).create_node()
        if not graph.graph['cmd_params'].generate_deprecated_IR_V7:
            node_name = node.name + '/WithoutBiases'
            bias_node_name = node.name
            rename_nodes([(node, node_name), (bias_node, bias_node_name)])
        node.out_port(0).get_connection().set_source(bias_node.out_port(0))
        node.in_port(2).get_connection().set_destination(bias_node.in_port(1))
        node.out_port(0).connect(bias_node.in_port(0))

        if node.has_valid('alpha') and not math.isclose(node.alpha, 1):
            bias_node.insert_op_on_input_port(in_port_idx=0, new_op_class=Mul, value=np.array(node.alpha),
                                              new_op_attrs={'name': name + '/Alpha_', 'can_be_scaleshift': False})
            del node['alpha']

        if node.has_valid('beta') and not math.isclose(node.beta, 1):
            bias_node.insert_op_on_input_port(in_port_idx=1, new_op_class=Mul, value=np.array(node.beta),
                                              new_op_attrs={'name': name + '/Beta_', 'can_be_scaleshift': False})
            del node['beta']

        MatMul.update_node_stat(node, {
            'transpose_a': node.has_and_set('transpose_a'),
            'transpose_b': node.has_and_set('transpose_b'),
        })
예제 #2
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='Gemm'):
            name = node.soft_get('name', node.id)
            node_output_port = node.out_port(0)
            if node.has_valid('alpha') and not math.isclose(node.alpha, 1):
                mul_alpha = create_op_with_const_inputs(
                    graph, Mul, {1: np.array(node.alpha)}, {
                        'name': name + '/Alpha',
                        'can_be_scaleshift': False
                    })
                node_output_port.get_connection().insert_node(mul_alpha)
                node_output_port = mul_alpha.out_port(0)
                del node['alpha']

            if node.is_in_port_connected(2):
                # biases normalization
                bias_node = Add(graph, {
                    'name': name + '/Bias_',
                    'can_be_scaleshift': False
                }).create_node()
                without_biases_node_name = name + '/WithoutBiases'
                rename_nodes([(node, without_biases_node_name),
                              (bias_node, name)])
                node_output_port.get_connection().set_source(
                    bias_node.out_port(0))
                node.in_port(2).get_connection().set_destination(
                    bias_node.in_port(1))
                node_output_port.connect(bias_node.in_port(0))
                if node.has_valid('beta') and not math.isclose(node.beta, 1):
                    bias_node.insert_op_on_input_port(in_port_idx=1,
                                                      new_op_class=Mul,
                                                      value=np.array(
                                                          node.beta),
                                                      new_op_attrs={
                                                          'name':
                                                          name + '/Beta',
                                                          'can_be_scaleshift':
                                                          False
                                                      })
                    del node['beta']

            MatMul.update_node_stat(
                node, {
                    'transpose_a': node.has_and_set('transpose_a'),
                    'transpose_b': node.has_and_set('transpose_b'),
                })