예제 #1
0
    def test_gelu(self):
        def gelu(x, approximate=False):
            if approximate:
                return 0.5 * x * (1.0 + np.tanh(
                    np.sqrt(2.0 / np.pi) * (x + 0.044715 * np.power(x, 3))))
            else:
                from scipy.stats import norm  # pylint: disable=g-import-not-at-top
                return x * norm.cdf(x)

        x = backend.placeholder(ndim=2)
        f = backend.function([x], [activations.gelu(x)])
        test_values = np.random.random((2, 5))
        result = f([test_values])[0]
        expected = gelu(test_values)
        self.assertAllClose(result, expected, rtol=1e-05)

        f = backend.function([x], [activations.gelu(x, True)])
        test_values = np.random.random((2, 5))
        result = f([test_values])[0]
        expected = gelu(test_values, True)
        self.assertAllClose(result, expected, rtol=1e-05)
예제 #2
0
def test_gelu():
    """Test using a reference
    """
    def ref_gelu(x):
        return 0.5 * x * (
            1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.pow(x, 3))))

    gelu = np.vectorize(ref_gelu)

    x = K.placeholder(ndim=2)
    f = K.function([x], [activations.gelu(x)])
    test_values = get_standard_values()

    result = f([test_values])[0]
    expected = gelu(test_values)
    assert_allclose(result, expected, rtol=1e-05)