예제 #1
0
    def replace_pattern(self, graph: Graph, match: dict):
        y = match['maximum'].in_port(0).data.get_value()
        if y is None:
            y = match['maximum'].in_port(1).data.get_value()

        if y is None or y.shape != ():
            log.debug(
                'The value of the "maximum_y_data" is not defined or is not constant'
            )
            return

        normalize_input_node = match['square'].in_port(0).get_source().node
        normalize_node = NormalizeOp(
            graph, {
                'name': normalize_input_node.soft_get('name') + '/Normalize',
                'eps': y,
                'across_spatial': 0,
                'channel_shared': 0
            }).create_node()

        weights_node = Const(
            graph, {
                'value':
                np.ones(shape=int64_array([match['input'].shape[-1]]),
                        dtype=match['input'].data_type)
            }).create_node()

        # the normalize_input_node has 2 consumers so it is necessary to disconnect output port first
        normalize_input_node.out_port(0).disconnect()
        normalize_input_node.out_port(0).get_connection().set_destination(
            normalize_node.in_port(0))
        weights_node.out_port(0).get_connection().set_destination(
            normalize_node.in_port(1))
        match['l2_normalize'].out_port(0).get_connection().set_source(
            normalize_node.out_port(0))
예제 #2
0
    def extract(cls, node):
        proto_layer = node.pb
        param = proto_layer.norm_param

        attrs = collect_attributes(param, enable_flattening_nested_params=True)
        attrs.update(weights_biases(False, node.model_pb))
        # update the attributes of the node
        NormalizeOp.update_node_stat(node, attrs)
        return cls.enabled
예제 #3
0
    def extract(cls, node):
        across_spatial = onnx_attr(node, 'across_spatial', 'i', default=0)
        channel_shared = onnx_attr(node, 'channel_shared', 'i', default=0)
        eps = onnx_attr(node, 'eps', 'f', default=0)

        attrs = {
            'across_spatial': bool(across_spatial),
            'channel_shared': bool(channel_shared),
            'eps': eps,
            'layout': 'NCHW'
        }

        # update the attributes of the node
        NormalizeOp.update_node_stat(node, attrs)
        return cls.enabled
예제 #4
0
 def replace_pattern(self, graph: Graph, match: dict):
     bias_add = match['bias_add']
     merge = match['merge']
     normalize_node = NormalizeOp(
         graph, {
             'name': merge.name + '/Normalize',
             'eps': 1e-6,
             'across_spatial': 0,
             'channel_shared': 0
         }).create_node()
     # the normalize_input_node has 2 consumers so it is necessary to disconnect output port first
     bias_add.out_port(0).connect(normalize_node.in_port(0))
     merge.in_port(0).disconnect()
     normalize_node.out_port(0).get_connection().set_destination(
         merge.in_port(0))
예제 #5
0
    def replace_pattern(self, graph: Graph, match: dict):
        y = match['maximum'].in_port(0).data.get_value()
        if y is None:
            y = match['maximum'].in_port(1).data.get_value()

        if y is None or y.shape != ():
            log.debug(
                'The value of the "maximum_y_data" is not defined or is not constant'
            )
            return

        # rename l2_normalize node since it will be no longer output after the transformation
        output_name = match['l2_normalize'].soft_get('name',
                                                     match['l2_normalize'].id)
        normalizel2_name = output_name + '/normalizel2'
        rename_node(match['l2_normalize'], normalizel2_name)

        normalize_node = NormalizeOp(
            graph, {
                'name': output_name,
                'eps': y,
                'across_spatial': 0,
                'channel_shared': 0
            }).create_node()
        rename_node(normalize_node, output_name)

        weights_node = Const(
            graph, {
                'value':
                np.ones(shape=int64_array([match['input'].shape[-1]]),
                        dtype=match['input'].data_type)
            }).create_node()

        match['square'].in_port(0).get_source().connect(
            normalize_node.in_port(0))

        match['square'].in_port(0).disconnect()
        if match['l2_normalize'].in_port(
                0).get_source().node.id == match['rsqrt'].id:
            match['l2_normalize'].in_port(1).disconnect()
        else:
            match['l2_normalize'].in_port(0).disconnect()

        weights_node.out_port(0).get_connection().set_destination(
            normalize_node.in_port(1))
        match['l2_normalize'].out_port(0).get_connection().set_source(
            normalize_node.out_port(0))
    def extract(cls, node):
        pb = node.parameters
        try:
            collect_until_token(pb, b'<Dim>')
        except Error:
            try:
                pb.seek(0)
                collect_until_token(pb, b'<InputDim>')
            except Error:
                raise Error("Neither <Dim> nor <InputDim> were found")
        in_dim = read_binary_integer32_token(pb)

        try:
            collect_until_token(pb, b'<TargetRms>')
            target_rms = read_binary_float_token(pb)
        except Error:
            # model does not contain TargetRms
            target_rms = 1.0

        try:
            collect_until_token(pb, b'<AddLogStddev>')
            add_log = read_binary_bool_token(pb)
        except Error:
            # model does not contain AddLogStddev
            add_log = False

        if add_log is not False:
            raise Error(
                "AddLogStddev True  in Normalize component is not supported")

        scale = target_rms * np.sqrt(in_dim)

        attrs = {
            'eps': 0.00000001,
            'across_spatial': 0,
            'channel_shared': 1,
            'in_dim': in_dim,
        }
        embed_input(attrs, 1, 'weights', [scale])

        NormalizeOp.update_node_stat(node, attrs)
        return cls.enabled