def extract(cls, node): attrs = get_mxnet_layer_attrs(node.symbol_dict) act_type = attrs.str('act_type', 'leaky') if act_type == 'prelu': prelu_attrs = { 'channel_shared': 1, 'filler_type': 'constant', 'filler_value': 0, 'min': 0, 'max': 1, 'mean': 0, 'std': 0, 'sparse': -1, 'variance_norm': "caffe.FillerParameter.FAN_IN" } PReLU.update_node_stat(node, prelu_attrs) elif act_type == 'elu': alpha = attrs.float('slope', 0.25) Elu.update_node_stat(node, {'alpha': alpha}) elif act_type == 'leaky': negative_slope = attrs.float('slope', 0.25) if negative_slope == 0: ReLU.update_node_stat(node) else: LeakyReLU.update_node_stat(node, {'negative_slope': negative_slope}) else: raise Error( "Operation '{}' not supported. Please register it as custom op. " + refer_to_faq_msg(86), act_type) return LeakyReLUFrontExtractor.enabled
def extract(cls, node): proto_layer = node.pb pb_model = node.model_pb param = proto_layer.prelu_param update_attrs = { 'channel_shared': int(param.channel_shared) } variance_norm_caffe_map = { 0: 'caffe.FillerParameter.FAN_IN', 1: 'caffe.FillerParameter.FAN_OUT', 2: 'caffe.FillerParameter.AVERAGE' } if hasattr(param, 'filler'): update_attrs.update({ 'filler_type': param.filler.type, 'filler_value': int(param.filler.value), 'min': int(param.filler.min), 'max': int(param.filler.max), 'mean': int(param.filler.mean), 'std': int(param.filler.std), 'sparse': param.filler.sparse, 'variance_norm': variance_norm_caffe_map[param.filler.variance_norm] }) mapping_rule = merge_attrs(param, update_attrs) mapping_rule.update(weights_biases(False, pb_model)) mapping_rule.update(layout_attrs()) # update the attributes of the node PReLU.update_node_stat(node, mapping_rule) return cls.enabled
def replace_pattern(graph: Graph, match: dict): relu = match['leakyrelu'] relu_name = relu.soft_get('name', relu.id) if not relu.has_valid('negative_slope'): return rename_node(relu, relu_name + '/to_delete') # Create PReLU op and reconnect input/output from LeakyReLU to PReLU prelu = PReLU(graph, dict(name=relu_name)).create_node() rename_node(prelu, relu_name) const = Const(graph, dict(name=relu_name + "/weights", value=np.array([relu.negative_slope]))).create_node() relu.in_port(0).get_connection().set_destination(prelu.in_port(0)) const.out_port(0).connect(prelu.in_port(1)) relu.out_port(0).get_connection().set_source(prelu.out_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): consumers = [n for n in match if n not in ['mul', 'op', 'add'] and not check_node_usages_out_of_match(match, n)] if consumers: log.warning('PReLU pattern was detected. Non pattern consumers of nodes: "{}" were found. Won\'t replace' ''.format(', '.join([match[n].id for n in consumers]))) return gamma = match['mul'].in_node(0) if match['mul'].in_node(1).id == match['neg_1'].id else match['mul'].in_node(1) prelu_node = PReLU(graph, {'name': '{}/PReLU'.format(match['add'].id)}).create_node([match['op'], gamma]) match['add'].replace_node(prelu_node) log.debug('PReLU pattern starting from "{}" was collapsed to "{}"'.format(match['op'].id, prelu_node.id))