예제 #1
0
    def replace_pattern(self, graph: Graph, match: dict):
        y = match['maximum'].in_port(0).data.get_value()
        if y is None:
            y = match['maximum'].in_port(1).data.get_value()

        if y is None or y.shape != ():
            log.debug(
                'The value of the "maximum_y_data" is not defined or is not constant'
            )
            return

        normalize_input_node = match['square'].in_port(0).get_source().node
        normalize_node = NormalizeOp(
            graph, {
                'name': normalize_input_node.soft_get('name') + '/Normalize',
                'eps': y,
                'across_spatial': 0,
                'channel_shared': 0
            }).create_node()

        weights_node = Const(
            graph, {
                'value':
                np.ones(shape=int64_array([match['input'].shape[-1]]),
                        dtype=match['input'].data_type)
            }).create_node()

        # the normalize_input_node has 2 consumers so it is necessary to disconnect output port first
        normalize_input_node.out_port(0).disconnect()
        normalize_input_node.out_port(0).get_connection().set_destination(
            normalize_node.in_port(0))
        weights_node.out_port(0).get_connection().set_destination(
            normalize_node.in_port(1))
        match['l2_normalize'].out_port(0).get_connection().set_source(
            normalize_node.out_port(0))
예제 #2
0
 def replace_pattern(self, graph: Graph, match: dict):
     bias_add = match['bias_add']
     merge = match['merge']
     normalize_node = NormalizeOp(
         graph, {
             'name': merge.name + '/Normalize',
             'eps': 1e-6,
             'across_spatial': 0,
             'channel_shared': 0
         }).create_node()
     # the normalize_input_node has 2 consumers so it is necessary to disconnect output port first
     bias_add.out_port(0).connect(normalize_node.in_port(0))
     merge.in_port(0).disconnect()
     normalize_node.out_port(0).get_connection().set_destination(
         merge.in_port(0))
예제 #3
0
    def replace_pattern(self, graph: Graph, match: dict):
        y = match['maximum'].in_port(0).data.get_value()
        if y is None:
            y = match['maximum'].in_port(1).data.get_value()

        if y is None or y.shape != ():
            log.debug(
                'The value of the "maximum_y_data" is not defined or is not constant'
            )
            return

        # rename l2_normalize node since it will be no longer output after the transformation
        output_name = match['l2_normalize'].soft_get('name',
                                                     match['l2_normalize'].id)
        normalizel2_name = output_name + '/normalizel2'
        rename_node(match['l2_normalize'], normalizel2_name)

        normalize_node = NormalizeOp(
            graph, {
                'name': output_name,
                'eps': y,
                'across_spatial': 0,
                'channel_shared': 0
            }).create_node()
        rename_node(normalize_node, output_name)

        weights_node = Const(
            graph, {
                'value':
                np.ones(shape=int64_array([match['input'].shape[-1]]),
                        dtype=match['input'].data_type)
            }).create_node()

        match['square'].in_port(0).get_source().connect(
            normalize_node.in_port(0))

        match['square'].in_port(0).disconnect()
        if match['l2_normalize'].in_port(
                0).get_source().node.id == match['rsqrt'].id:
            match['l2_normalize'].in_port(1).disconnect()
        else:
            match['l2_normalize'].in_port(0).disconnect()

        weights_node.out_port(0).get_connection().set_destination(
            normalize_node.in_port(1))
        match['l2_normalize'].out_port(0).get_connection().set_source(
            normalize_node.out_port(0))