Exemplo n.º 1
0
  def test_activation(
      self,
      same_inputs,
      model,
      phi_name,
      get,
      abc,
      approximate
  ):
    if abc != [0.3, 1.5, -np.pi/4]:
      test_utils.skip_test(self)

    if approximate and phi_name != 'Gelu':
      raise absltest.SkipTest(
          f'{phi_name} does not have an `approximate parameter.')

    a, b, c = abc
    if phi_name == 'Sin':
      activation = stax.Sin(a=a, b=b, c=c)
    elif phi_name == 'Erf':
      activation = stax.Erf(a=a, b=b, c=c)
    elif phi_name in ['Gelu', 'Sign', 'Cos']:
      if a != 0.3 or b != 0.3 or c != 0.:
        raise absltest.SkipTest('Skip `Gelu/Sign/Cos` test if '
                                ' (a, b, c) != (.3, .3, 0.).')
      activation = stax.Gelu() if phi_name == 'Gelu' else stax.Sign()
    else:
      raise NotImplementedError(f'Activation {phi_name} is not implemented.')
    self._test_activation(activation, same_inputs, model, get)
Exemplo n.º 2
0
  def test_vmap_axes(self, same_inputs):
    n1, n2 = 3, 4
    c1, c2, c3 = 9, 5, 7
    h2, h3, w3 = 6, 8, 2

    def get_x(n, k):
      k1, k2, k3 = random.split(k, 3)
      x1 = random.normal(k1, (n, c1))
      x2 = random.normal(k2, (h2, n, c2))
      x3 = random.normal(k3, (c3, w3, n, h3))
      x = [(x1, x2), x3]
      return x

    x1 = get_x(n1, random.PRNGKey(1))
    x2 = get_x(n2, random.PRNGKey(2)) if not same_inputs else None

    p1 = random.normal(random.PRNGKey(5), (n1, h2, h2))
    p2 = None if same_inputs else random.normal(random.PRNGKey(6), (n2, h2, h2))

    init_fn, apply_fn, _ = stax.serial(
        stax.parallel(
            stax.parallel(
                stax.serial(stax.Dense(4, 2., 0.1),
                            stax.Relu(),
                            stax.Dense(3, 1., 0.15)),  # 1
                stax.serial(stax.Conv(7, (2,), padding='SAME',
                                      dimension_numbers=('HNC', 'OIH', 'NHC')),
                            stax.Erf(),
                            stax.Aggregate(1, 0, -1),
                            stax.GlobalAvgPool(),
                            stax.Dense(3, 0.5, 0.2)),  # 2
            ),
            stax.serial(
                stax.Conv(5, (2, 3), padding='SAME',
                          dimension_numbers=('CWNH', 'IOHW', 'HWCN')),
                stax.Sin(),
            )  # 3
        ),
        stax.parallel(
            stax.FanInSum(),
            stax.Conv(2, (2, 1), dimension_numbers=('HWCN', 'OIHW', 'HNWC'))
        )
    )

    _, params = init_fn(random.PRNGKey(3), tree_map(np.shape, x1))
    implicit = jit(empirical._empirical_implicit_ntk_fn(apply_fn))
    direct = jit(empirical._empirical_direct_ntk_fn(apply_fn))

    implicit_batched = jit(empirical._empirical_implicit_ntk_fn(
        apply_fn, vmap_axes=([(0, 1), 2], [-2, -3], dict(pattern=0))))
    direct_batched = jit(empirical._empirical_direct_ntk_fn(
        apply_fn, vmap_axes=([(-2, -2), -2], [0, 1], dict(pattern=-3))))

    k = direct(x1, x2, params, pattern=(p1, p2))

    self.assertAllClose(k, implicit(x1, x2, params, pattern=(p1, p2)))
    self.assertAllClose(k, direct_batched(x1, x2, params, pattern=(p1, p2)))
    self.assertAllClose(k, implicit_batched(x1, x2, params, pattern=(p1, p2)))
Exemplo n.º 3
0
class ElementwiseNumericalTest(test_utils.NeuralTangentsTestCase):

  @parameterized.product(
      model=[
          'fc',
          'conv-pool',
          'conv-flatten'
      ],
      phi=[
          stax.Erf(),
          stax.Gelu(),
          stax.Sin(),
      ],
      same_inputs=[False, True],
      get=['nngp', 'ntk']
  )
  def test_elementwise_numerical(self, same_inputs, model, phi, get):
    if 'conv' in model:
      test_utils.skip_test(self)

    key, split = random.split(random.PRNGKey(1))

    output_dim = 1
    b_std = 0.01
    W_std = 1.0
    rtol = 2e-3
    deg = 25
    if get == 'ntk':
      rtol *= 2
    if default_backend() == 'tpu':
      rtol *= 2

    if model == 'fc':
      X0_1 = random.normal(key, (3, 7))
      X0_2 = None if same_inputs else random.normal(split, (5, 7))
      affine = stax.Dense(1024, W_std, b_std)
      readout = stax.Dense(output_dim)
      depth = 1
    else:
      X0_1 = random.normal(key, (2, 8, 8, 3))
      X0_2 = None if same_inputs else random.normal(split, (3, 8, 8, 3))
      affine = stax.Conv(1024, (3, 2), W_std=W_std, b_std=b_std, padding='SAME')
      readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else
                            stax.Flatten(),
                            stax.Dense(output_dim))
      depth = 2

    _, _, kernel_fn = stax.serial(*[affine, phi] * depth, readout)
    analytic_kernel = kernel_fn(X0_1, X0_2, get)

    fn = lambda x: phi[1]((), x)
    _, _, kernel_fn = stax.serial(
        *[affine, stax.ElementwiseNumerical(fn, deg=deg)] * depth, readout)
    numerical_activation_kernel = kernel_fn(X0_1, X0_2, get)

    test_utils.assert_close_matrices(self, analytic_kernel,
                                     numerical_activation_kernel, rtol)
Exemplo n.º 4
0
class ElementwiseTest(test_utils.NeuralTangentsTestCase):

  @parameterized.product(
      phi=[
          stax.Identity(),
          stax.Erf(),
          stax.Sin(),
          stax.Relu(),
      ],
      same_inputs=[False, True, None],
      n=[0, 1, 2],
      diagonal_batch=[True, False],
      diagonal_spatial=[True, False]
  )
  def test_elementwise(
      self,
      same_inputs,
      phi,
      n,
      diagonal_batch,
      diagonal_spatial
  ):
    fn = lambda x: phi[1]((), x)

    name = phi[0].__name__

    def nngp_fn(cov12, var1, var2):
      if 'Identity' in name:
        res = cov12

      elif 'Erf' in name:
        prod = (1 + 2 * var1) * (1 + 2 * var2)
        res = np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi

      elif 'Sin' in name:
        sum_ = (var1 + var2)
        s1 = np.exp((-0.5 * sum_ + cov12))
        s2 = np.exp((-0.5 * sum_ - cov12))
        res = (s1 - s2) / 2

      elif 'Relu' in name:
        prod = var1 * var2
        sqrt = np.sqrt(np.maximum(prod - cov12 ** 2, 1e-30))
        angles = np.arctan2(sqrt, cov12)
        dot_sigma = (1 - angles / np.pi) / 2
        res = sqrt / (2 * np.pi) + dot_sigma * cov12

      else:
        raise NotImplementedError(name)

      return res

    _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Elementwise(fn, nngp_fn),
                                  stax.Dense(1), stax.Elementwise(fn, nngp_fn))
    _, _, kernel_fn_manual = stax.serial(stax.Dense(1), phi,
                                         stax.Dense(1), phi)

    key = random.PRNGKey(1)
    shape = (4, 3, 2)[:n] + (1,)
    x1 = random.normal(key, (5,) + shape)
    if same_inputs is None:
      x2 = None
    elif same_inputs is True:
      x2 = x1
    else:
      x2 = random.normal(key, (6,) + shape)

    kwargs = dict(diagonal_batch=diagonal_batch,
                  diagonal_spatial=diagonal_spatial)

    k = kernel_fn(x1, x2, **kwargs)
    k_manual = kernel_fn_manual(x1, x2, **kwargs).replace(is_gaussian=False)
    self.assertAllClose(k_manual, k)