Ejemplo n.º 1
0
  def test_matmul_biasadd_gelu_fusion(self, mode):
    """Test MatMul+BiasAdd+Gelu fusion."""
    self._maybe_skip(mode)
    is_bf16_supported = _pywrap_utils.IsBF16SupportedByOneDNNOnThisCPU()

    m, n, k = (3, 3, 4)  # Matrix dimensions
    for precision in ('float32', 'bfloat16'):
      for approximate in (False, True):
        # Gelu exact (approximate=False) is not supported with bfloat16
        # precision since no support for Erf with bfloat16 data type.
        # TODO(intel-tf): Enable gelu exact with bfloat16, when Erf op is
        # supported with bfloat16.
        if precision == 'bfloat16':
          if not (approximate and is_bf16_supported):
            continue

        # Create MatMul + BiasAdd + Gelu graph
        ops.reset_default_graph()
        with ops.device('/device:CPU:0'):
          x = _input([m, k])
          w = _weight([k, n])
          b = _bias([n])
          if precision == 'bfloat16':
            x = math_ops.cast(x, dtypes.bfloat16)
            w = math_ops.cast(w, dtypes.bfloat16)
            b = math_ops.cast(b, dtypes.bfloat16)
          y = math_ops.matmul(x, w)
          z = nn.bias_add(y, b)
          out = nn.gelu(z, approximate=approximate)

        gelu_type = b'GeluApproximate' if approximate else b'GeluExact'
        epilog_ops = [b'BiasAdd', gelu_type]
        fused_op = ['_MklNativeFusedMatMul', '_MklFusedMatMul']
        graph = self._VerifyValues(out, precision == 'bfloat16', fused_op,
                                   epilog_ops)
Ejemplo n.º 2
0
    def test_matmul_biasadd_gelu_fusion(self, mode):
        """Test MatMul+BiasAdd+Gelu fusion."""
        self.maybe_skip_test(mode)
        data_types = [dtypes.float32]
        if mode == 'cuda':
            data_types.append(dtypes.float16)
        elif mode == 'mkl':
            data_types.append(dtypes.bfloat16)

        is_bf16_supported = _pywrap_utils.IsBF16SupportedByOneDNNOnThisCPU()

        m, n, k = (3, 3, 4)  # Matrix dimensions
        for precision in data_types:
            for approximate in (False, True):
                # Gelu exact (approximate=False) is not supported with bfloat16
                # precision since no support for Erf with bfloat16 data type.
                # TODO(intel-tf): Enable gelu exact with bfloat16, when Erf op is
                # supported with bfloat16.
                if precision == dtypes.bfloat16:
                    if not (approximate and is_bf16_supported):
                        continue

                # TODO(kaixih@nvidia): Enable gelu exact when Erf op is supported with
                # cublaslt.
                if mode == 'cuda' and not approximate:
                    continue

                device = '/device:GPU:0' if mode == 'cuda' else '/device:CPU:0'
                # Create MatMul + BiasAdd + Gelu graph
                ops.reset_default_graph()
                with ops.device(device):
                    x = _input([m, k])
                    w = _weight([k, n])
                    b = _bias([n])
                    x = math_ops.cast(x, precision)
                    w = math_ops.cast(w, precision)
                    b = math_ops.cast(b, precision)
                    y = math_ops.matmul(x, w)
                    z = nn.bias_add(y, b)
                    out = nn.gelu(z, approximate=approximate)

                gelu_type = b'GeluApproximate' if approximate else b'GeluExact'
                epilog_ops = [b'BiasAdd', gelu_type]
                fused_op = [
                    '_MklNativeFusedMatMul', '_MklFusedMatMul', '_FusedMatMul'
                ]
                graph = self._VerifyValues(out, precision != dtypes.float32,
                                           fused_op, epilog_ops)
Ejemplo n.º 3
0
def gelu(x, approximate=False):
    """Applies the Gaussian error linear unit (GELU) activation function.

  Gaussian error linear unit (GELU) computes
  `x * P(X <= x)`, where `P(X) ~ N(0, 1)`.
  The (GELU) nonlinearity weights inputs by their value, rather than gates
  inputs by their sign as in ReLU.

  For example:

  >>> x = tf.constant([-3.0, -1.0, 0.0, 1.0, 3.0], dtype=tf.float32)
  >>> y = tf.keras.activations.gelu(x)
  >>> y.numpy()
  array([-0.00404951, -0.15865529,  0.        ,  0.8413447 ,  2.9959507 ],
      dtype=float32)
  >>> y = tf.keras.activations.gelu(x, approximate=True)
  >>> y.numpy()
  array([-0.00363752, -0.15880796,  0.        ,  0.841192  ,  2.9963627 ],
      dtype=float32)

  Args:
      x: Input tensor.
      approximate: A `bool`, whether to enable approximation.

  Returns:
      The gaussian error linear activation:
      `0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))`
      if `approximate` is `True` or
      `x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2)))`,
      where `P(X) ~ N(0, 1)`,
      if `approximate` is `False`.

  Reference:
    - [Gaussian Error Linear Units (GELUs)](https://arxiv.org/abs/1606.08415)
  """
    return nn.gelu(x, approximate)
Ejemplo n.º 4
0
    def test_matmul_biasadd_gelu_fusion(self, mode):
        """Test MatMul+BiasAdd+Gelu fusion."""
        self._maybe_skip(mode)
        is_bf16_supported = _pywrap_utils.IsBF16SupportedByOneDNNOnThisCPU()
        run_options = config_pb2.RunOptions(output_partition_graphs=True)
        metadata = config_pb2.RunMetadata()

        m, n, k = (3, 3, 4)  # Matrix dimensions
        for precision in ('float32', 'bfloat16'):
            for approximate in (False, True):
                # Gelu exact (approximate=False) is not supported with bfloat16
                # precision since no support for Erf with bfloat16 data type.
                # TODO(intel-tf): Enable gelu exact with bfloat16, when Erf op is
                # supported with bfloat16.
                if precision == 'bfloat16':
                    if not (approximate and is_bf16_supported):
                        continue

                # Create MatMul + BiasAdd + Gelu graph
                ops.reset_default_graph()
                with ops.device('/device:CPU:0'):
                    x = _input([m, k])
                    w = _weight([k, n])
                    b = _bias([n])
                    if precision == 'bfloat16':
                        x = math_ops.cast(x, dtypes.bfloat16)
                        w = math_ops.cast(w, dtypes.bfloat16)
                        b = math_ops.cast(b, dtypes.bfloat16)
                    y = math_ops.matmul(x, w)
                    z = nn.bias_add(y, b)
                    out = nn.gelu(z, approximate=approximate)

                # Compute reference value.
                config = _get_config(remapping_on=False)
                with session.Session(config=config) as sess:
                    sess.run(variables.global_variables_initializer())
                    output_val_ref = sess.run(out,
                                              options=run_options,
                                              run_metadata=metadata)
                # Compute output with fusion.
                config = _get_config(remapping_on=True)
                with session.Session(config=config) as sess:
                    sess.run(variables.global_variables_initializer())
                    output_val = sess.run(out,
                                          options=run_options,
                                          run_metadata=metadata)
                    graph = metadata.partition_graphs[0]

                # Graph should contain fused op.
                found_fused_op = False
                gelu_type = b'GeluApproximate' if approximate else b'GeluExact'
                for node in graph.node:
                    if node.op in ('_MklNativeFusedMatMul', '_MklFusedMatMul'):
                        fused_ops = node.attr['fused_ops'].list.s
                        found_fused_op = len(fused_ops) == 2 and \
                            fused_ops[0] == b'BiasAdd' and fused_ops[1] == gelu_type
                        break
                self.assertTrue(found_fused_op)

                # Computed output value should be close to reference value.
                tol = 1e-5 if precision == 'float32' else 1e-2
                self.assertAllClose(output_val_ref,
                                    output_val,
                                    atol=tol,
                                    rtol=tol)