Exemplo n.º 1
0
    def extract(cls, node):
        attrs = get_mxnet_layer_attrs(node.symbol_dict)
        act_type = attrs.str('act_type', 'leaky')
        if act_type == 'prelu':
            prelu_attrs = {
                'channel_shared': 1,
                'filler_type': 'constant',
                'filler_value': 0,
                'min': 0,
                'max': 1,
                'mean': 0,
                'std': 0,
                'sparse': -1,
                'variance_norm': "caffe.FillerParameter.FAN_IN"
            }
            PReLU.update_node_stat(node, prelu_attrs)
        elif act_type == 'elu':
            alpha = attrs.float('slope', 0.25)
            Elu.update_node_stat(node, {'alpha': alpha})
        elif act_type == 'leaky':
            negative_slope = attrs.float('slope', 0.25)
            if negative_slope == 0:
                ReLU.update_node_stat(node)
            else:
                LeakyReLU.update_node_stat(node,
                                           {'negative_slope': negative_slope})
        else:
            raise Error(
                "Operation '{}' not supported. Please register it as custom op. "
                + refer_to_faq_msg(86), act_type)

        return LeakyReLUFrontExtractor.enabled
    def extract(cls, node):
        proto_layer = node.pb
        pb_model = node.model_pb
        param = proto_layer.prelu_param

        update_attrs = {
            'channel_shared': int(param.channel_shared)
        }

        variance_norm_caffe_map = {
            0: 'caffe.FillerParameter.FAN_IN',
            1: 'caffe.FillerParameter.FAN_OUT',
            2: 'caffe.FillerParameter.AVERAGE'
        }

        if hasattr(param, 'filler'):
            update_attrs.update({
                'filler_type': param.filler.type,
                'filler_value': int(param.filler.value),
                'min': int(param.filler.min),
                'max': int(param.filler.max),
                'mean': int(param.filler.mean),
                'std': int(param.filler.std),
                'sparse': param.filler.sparse,
                'variance_norm': variance_norm_caffe_map[param.filler.variance_norm]
            })

        mapping_rule = merge_attrs(param, update_attrs)
        mapping_rule.update(weights_biases(False, pb_model))
        mapping_rule.update(layout_attrs())

        # update the attributes of the node
        PReLU.update_node_stat(node, mapping_rule)
        return cls.enabled
Exemplo n.º 3
0
    def replace_pattern(graph: Graph, match: dict):
        relu = match['leakyrelu']
        relu_name = relu.soft_get('name', relu.id)
        if not relu.has_valid('negative_slope'):
            return

        rename_node(relu, relu_name + '/to_delete')
        # Create PReLU op and reconnect input/output from LeakyReLU to PReLU
        prelu = PReLU(graph, dict(name=relu_name)).create_node()
        rename_node(prelu, relu_name)

        const = Const(graph, dict(name=relu_name + "/weights", value=np.array([relu.negative_slope]))).create_node()

        relu.in_port(0).get_connection().set_destination(prelu.in_port(0))
        const.out_port(0).connect(prelu.in_port(1))
        relu.out_port(0).get_connection().set_source(prelu.out_port(0))
Exemplo n.º 4
0
 def replace_sub_graph(self, graph: Graph, match: dict):
     consumers = [n for n in match if n not in ['mul', 'op', 'add'] and not check_node_usages_out_of_match(match, n)]
     if consumers:
         log.warning('PReLU pattern was detected. Non pattern consumers of nodes: "{}" were found. Won\'t replace'
                     ''.format(', '.join([match[n].id for n in consumers])))
         return
     gamma = match['mul'].in_node(0) if match['mul'].in_node(1).id == match['neg_1'].id else match['mul'].in_node(1)
     prelu_node = PReLU(graph, {'name': '{}/PReLU'.format(match['add'].id)}).create_node([match['op'], gamma])
     match['add'].replace_node(prelu_node)
     log.debug('PReLU pattern starting from "{}" was collapsed to "{}"'.format(match['op'].id, prelu_node.id))