示例#1
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        # Gaussian Error Linear Unit
        # f(x) = 0.5 * x * (1 + erf(x / sqrt(2))
        div = match['div']
        inp_port = div.in_port(0).get_source()
        inp = inp_port.node
        log.debug(
            'Found potential Erf-based GeLU pattern after {} with name {}'.
            format(inp.op, inp.name))

        # take the values of the mul, add and div
        div_param = match['div_param']
        add_param = match['add_param']
        mul_param = match['mul_param']

        if add_param.value.size == 1 and mul_param.value.size == 1 and div_param.value.size == 1:
            mul_param = match['mul_param'].value.item()
            add_param = match['add_param'].value.item()
            div_param = match['div_param'].value.item()

            sqrt2 = sqrt(2.0)
            # check that the values match the approximation
            if fabs(div_param -
                    sqrt2) < 1e-06 and mul_param == 0.5 and add_param == 1.0:
                log.debug(
                    'Confirmed Erf-based GELU pattern after {} with name {}'.
                    format(inp.op, inp.name))
                gelu = GeLUOP(graph,
                              dict(name=inp.name + '/GELU_')).create_node()
                inp_port.connect(gelu.in_port(0))
                match['mul0'].out_port(0).get_connection().set_source(
                    gelu.out_port(0))
示例#2
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        # Gaussian Error Linear Unit, TanH based approximation:
        # 0.5*x*(1 + tanh([sqrt(2/pi)]*[x + 0.044715x3])
        inp_port = match['pow'].in_port(0).get_source()
        inp = inp_port.node
        log.debug(
            'Found potential TanH-based GeLU pattern after {} with name {}'.
            format(inp.op, inp.name))

        # take the values of the mul ops
        mul_param = match['mul_param']
        mul0_param = match['mul0_param']
        mul1_param = match['mul1_param']
        if mul0_param.value.size == 1 and mul_param.value.size == 1 and mul1_param.value.size == 1:
            mul_param = match['mul_param'].value.item()
            mul0_param = match['mul0_param'].value.item()
            mul1_param = match['mul1_param'].value.item()
            sqrt2pi = sqrt(2.0 / pi)
            # check that the values match the approximation
            if fabs(mul0_param - sqrt2pi) < 1e-06 and fabs(
                    mul_param - 0.044715) < 1e-06 and mul1_param == 0.5:
                log.debug(
                    'Confirmed TanH-based GELU pattern after {} with name {}'.
                    format(inp.op, inp.name))
                gelu = GeLUOP(
                    graph,
                    dict(name=inp.name + '/GELU_',
                         approximation_mode='tanh')).create_node()
                inp_port.connect(gelu.in_port(0))
                match['mul2'].out_port(0).get_connection().set_source(
                    gelu.out_port(0))
    def replace_gelu(self, graph: Graph, match: dict):
        # Gaussian Error Linear Unit
        # f(x) = 0.5 * x * (1 + erf(x / sqrt(2))
        out_node = match['mul0']
        node_name = out_node.soft_get('name', out_node.id)
        div = match['div']
        inp_node = div.in_port(0).get_source().node
        inp_name = inp_node.soft_get('name', out_node.id)
        log.debug('Found potential Erf-based GeLU pattern after {} with name {}'.format(inp_node.op, inp_name))

        # take the values of the mul, add and div
        div_param = match['div_param']
        add_param = match['add_param']
        mul_param = match['mul_param']

        if add_param.value.size == 1 and mul_param.value.size == 1 and div_param.value.size == 1:
            mul_param = match['mul_param'].value.item()
            add_param = match['add_param'].value.item()
            div_param = match['div_param'].value.item()

            sqrt2 = sqrt(2.0)
            # check that the values match the approximation
            if fabs(div_param - sqrt2) < 1e-06 and mul_param == 0.5 and add_param == 1.0:
                log.debug('Confirmed Erf-based GELU pattern after {} with name {}'.format(inp_node.op, inp_name))
                gelu = GeLUOP(graph, dict(name=inp_name + '/GELU_', approximation='erf')).create_node()
                div.in_port(0).get_connection().set_destination(gelu.in_port(0))
                out_node.out_port(0).get_connection().set_source(gelu.out_port(0))
                rename_nodes([(out_node, node_name + '/TBD'), (gelu, node_name)])