예제 #1
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        # Gaussian Error Linear Unit, TanH based approximation:
        # 0.5*x*(1 + tanh([sqrt(2/pi)]*[x + 0.044715x3])
        inp_port = match['pow'].in_port(0).get_source()
        inp = inp_port.node
        log.debug(
            'Found potential TanH-based GeLU pattern after {} with name {}'.
            format(inp.op, inp.name))

        # take the values of the mul ops
        mul_param = match['mul_param']
        mul0_param = match['mul0_param']
        mul1_param = match['mul1_param']
        if mul0_param.value.size == 1 and mul_param.value.size == 1 and mul1_param.value.size == 1:
            mul_param = match['mul_param'].value.item()
            mul0_param = match['mul0_param'].value.item()
            mul1_param = match['mul1_param'].value.item()
            sqrt2pi = sqrt(2.0 / pi)
            # check that the values match the approximation
            if fabs(mul0_param - sqrt2pi) < 1e-06 and fabs(
                    mul_param - 0.044715) < 1e-06 and mul1_param == 0.5:
                log.debug(
                    'Confirmed TanH-based GELU pattern after {} with name {}'.
                    format(inp.op, inp.name))
                gelu = GeLUOP(
                    graph,
                    dict(name=inp.name + '/GELU_',
                         approximation_mode='tanh')).create_node()
                inp_port.connect(gelu.in_port(0))
                match['mul2'].out_port(0).get_connection().set_source(
                    gelu.out_port(0))
예제 #2
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        # Gaussian Error Linear Unit
        # f(x) = 0.5 * x * (1 + erf(x / sqrt(2))
        div = match['div']
        inp_port = div.in_port(0).get_source()
        inp = inp_port.node
        log.debug(
            'Found potential Erf-based GeLU pattern after {} with name {}'.
            format(inp.op, inp.name))

        # take the values of the mul, add and div
        div_param = match['div_param']
        add_param = match['add_param']
        mul_param = match['mul_param']

        if add_param.value.size == 1 and mul_param.value.size == 1 and div_param.value.size == 1:
            mul_param = match['mul_param'].value.item()
            add_param = match['add_param'].value.item()
            div_param = match['div_param'].value.item()

            sqrt2 = sqrt(2.0)
            # check that the values match the approximation
            if fabs(div_param -
                    sqrt2) < 1e-06 and mul_param == 0.5 and add_param == 1.0:
                log.debug(
                    'Confirmed Erf-based GELU pattern after {} with name {}'.
                    format(inp.op, inp.name))
                gelu = GeLUOP(graph,
                              dict(name=inp.name + '/GELU_')).create_node()
                inp_port.connect(gelu.in_port(0))
                match['mul0'].out_port(0).get_connection().set_source(
                    gelu.out_port(0))
예제 #3
0
    def extract(cls, node):
        attrs = get_mxnet_layer_attrs(node.symbol_dict)
        act_type = attrs.str('act_type', 'leaky')
        if act_type == 'prelu':
            prelu_attrs = {
                'channel_shared': 1,
                'filler_type': 'constant',
                'filler_value': 0,
                'min': 0,
                'max': 1,
                'mean': 0,
                'std': 0,
                'sparse': -1,
                'variance_norm': "caffe.FillerParameter.FAN_IN"
            }
            PReLU.update_node_stat(node, prelu_attrs)
        elif act_type == 'elu':
            alpha = attrs.float('slope', 0.25)
            Elu.update_node_stat(node, {'alpha': alpha})
        elif act_type == 'leaky':
            negative_slope = attrs.float('slope', 0.25)
            if negative_slope == 0:
                ReLU.update_node_stat(node)
            else:
                LeakyReLU.update_node_stat(node,
                                           {'negative_slope': negative_slope})
        elif act_type == 'gelu':
            GeLUOP.update_node_stat(node, {'approximation_mode': 'erf'})
        else:
            raise Error(
                "Operation '{}' not supported. Please register it as custom op. "
                + refer_to_faq_msg(86), act_type)

        return LeakyReLUFrontExtractor.enabled
    def replace_gelu(self, graph: Graph, match: dict):
        # Gaussian Error Linear Unit
        # f(x) = 0.5 * x * (1 + erf(x / sqrt(2))
        out_node = match['mul0']
        node_name = out_node.soft_get('name', out_node.id)
        div = match['div']
        inp_node = div.in_port(0).get_source().node
        inp_name = inp_node.soft_get('name', out_node.id)
        log.debug('Found potential Erf-based GeLU pattern after {} with name {}'.format(inp_node.op, inp_name))

        # take the values of the mul, add and div
        div_param = match['div_param']
        add_param = match['add_param']
        mul_param = match['mul_param']

        if add_param.value.size == 1 and mul_param.value.size == 1 and div_param.value.size == 1:
            mul_param = match['mul_param'].value.item()
            add_param = match['add_param'].value.item()
            div_param = match['div_param'].value.item()

            sqrt2 = sqrt(2.0)
            # check that the values match the approximation
            if fabs(div_param - sqrt2) < 1e-06 and mul_param == 0.5 and add_param == 1.0:
                log.debug('Confirmed Erf-based GELU pattern after {} with name {}'.format(inp_node.op, inp_name))
                gelu = GeLUOP(graph, dict(name=inp_name + '/GELU_', approximation='erf')).create_node()
                div.in_port(0).get_connection().set_destination(gelu.in_port(0))
                out_node.out_port(0).get_connection().set_source(gelu.out_port(0))
                rename_nodes([(out_node, node_name + '/TBD'), (gelu, node_name)])