Ejemplo n.º 1
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
             padding, phi, strides, width, is_ntk):
  fc = partial(stax.Dense, W_std=W_std, b_std=b_std)
  conv = partial(
      stax.Conv,
      filter_shape=filter_shape,
      strides=strides,
      padding=padding,
      W_std=W_std,
      b_std=b_std)
  affine = conv(width) if is_conv else fc(width)

  res_unit = stax.serial((stax.AvgPool(
      (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR')
                          if use_pooling else stax.Identity()), phi, affine)

  if is_res:
    block = stax.serial(affine, stax.FanOut(2),
                        stax.parallel(stax.Identity(), res_unit),
                        stax.FanInSum())
  else:
    block = stax.serial(affine, res_unit)

  readout = stax.serial(stax.GlobalAvgPool() if use_pooling else stax.Flatten(),
                        fc(1 if is_ntk else width))

  net = stax.serial(block, readout)
  return net
Ejemplo n.º 2
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
             phi, strides, width, is_ntk, proj_into_2d, layer_norm,
             parameterization, use_dropout):
    fc = partial(stax.Dense,
                 W_std=W_std,
                 b_std=b_std,
                 parameterization=parameterization)
    conv = partial(stax.Conv,
                   filter_shape=filter_shape,
                   strides=strides,
                   padding=padding,
                   W_std=W_std,
                   b_std=b_std,
                   parameterization=parameterization)
    affine = conv(width) if is_conv else fc(width)
    rate = np.onp.random.uniform(0.5, 0.9)
    dropout = stax.Dropout(rate, mode='train')
    ave_pool = stax.AvgPool((2, 3), None,
                            'SAME' if padding == 'SAME' else 'CIRCULAR')
    ave_pool_or_identity = ave_pool if use_pooling else stax.Identity()
    dropout_or_identity = dropout if use_dropout else stax.Identity()
    layer_norm_or_identity = (stax.Identity() if layer_norm is None else
                              stax.LayerNorm(axis=layer_norm))
    res_unit = stax.serial(ave_pool_or_identity, phi, dropout_or_identity,
                           affine)
    if is_res:
        block = stax.serial(affine, stax.FanOut(2),
                            stax.parallel(stax.Identity(), res_unit),
                            stax.FanInSum(), layer_norm_or_identity)
    else:
        block = stax.serial(affine, res_unit, layer_norm_or_identity)

    if proj_into_2d == 'FLAT':
        proj_layer = stax.Flatten()
    elif proj_into_2d == 'POOL':
        proj_layer = stax.GlobalAvgPool()
    elif proj_into_2d.startswith('ATTN'):
        n_heads = int(np.sqrt(width))
        n_chan_val = int(np.round(float(width) / n_heads))
        fixed = proj_into_2d == 'ATTN_FIXED'
        proj_layer = stax.serial(
            stax.GlobalSelfAttention(width,
                                     n_chan_key=width,
                                     n_chan_val=n_chan_val,
                                     n_heads=n_heads,
                                     fixed=fixed,
                                     W_key_std=W_std,
                                     W_value_std=W_std,
                                     W_query_std=W_std,
                                     W_out_std=1.0,
                                     b_std=b_std), stax.Flatten())
    else:
        raise ValueError(proj_into_2d)
    readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

    return stax.serial(block, readout)
Ejemplo n.º 3
0
 def get_attn():
     return stax.GlobalSelfAttention(
         n_chan_out=width,
         n_chan_key=width,
         n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))),
         n_heads=int(np.sqrt(width)),
     ) if proj == 'avg' else stax.Identity()
Ejemplo n.º 4
0
  def test_ab_relu_id(self, same_inputs, do_stabilize):
    key = random.PRNGKey(1)
    X0_1 = random.normal(key, (3, 2))
    fc = stax.Dense(5, 1, 0)

    X0_2 = None if same_inputs else random.normal(key, (4, 2))

    # Test that ABRelu(a, a) == a * Identity
    init_fn, apply_id, kernel_fn_id = stax.serial(fc, stax.Identity())
    _, params = init_fn(key, input_shape=X0_1.shape)

    for a in [-5, -1, -0.5, 0, 0.5, 1, 5]:
      with self.subTest(a=a):
        _, apply_ab_relu, kernel_fn_ab_relu = stax.serial(
            fc, stax.ABRelu(a, a, do_stabilize=do_stabilize))

        X1_1_id = a * apply_id(params, X0_1)
        X1_1_ab_relu = apply_ab_relu(params, X0_1)
        self.assertAllClose(X1_1_id, X1_1_ab_relu)

        kernels_id = kernel_fn_id(X0_1 * a, None if X0_2 is None else a * X0_2)
        kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2)
        # Manually correct the value of `is_gaussian` because
        # `ab_relu` (incorrectly) sets `is_gaussian=False` when `a==b`.
        kernels_ab_relu = kernels_ab_relu.replace(is_gaussian=True)
        self.assertAllClose(kernels_id, kernels_ab_relu)
Ejemplo n.º 5
0
  def test_hermite(self, same_inputs, degree, get, readout):
    key = random.PRNGKey(1)
    key1, key2, key = random.split(key, 3)

    if degree > 2:
      width = 10000
      n_samples = 5000
      test_utils.skip_test(self)
    else:
      width = 10000
      n_samples = 100

    x1 = np.cos(random.normal(key1, [2, 6, 6, 3]))
    x2 = x1 if same_inputs else np.cos(random.normal(key2, [3, 6, 6, 3]))

    conv_layers = [
        stax.Conv(width, (3, 3), W_std=2., b_std=0.5),
        stax.LayerNorm(),
        stax.Hermite(degree),
        stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten(),
        stax.Dense(1) if get == 'ntk' else stax.Identity()]

    init_fn, apply_fn, kernel_fn = stax.serial(*conv_layers)
    analytic_kernel = kernel_fn(x1, x2, get)
    mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples)
    mc_kernel = mc_kernel_fn(x1, x2, get)
    rot = degree / 2. * 1e-2
    test_utils.assert_close_matrices(self, mc_kernel, analytic_kernel, rot)
Ejemplo n.º 6
0
  def test_nested_parallel(self, same_inputs, kernel_type):
    platform = default_backend()
    rtol = RTOL if platform != 'tpu' else 0.05

    rng = random.PRNGKey(0)
    (input_key1,
     input_key2,
     input_key3,
     input_key4,
     mask_key,
     mc_key) = random.split(rng, 6)

    x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5))
    x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2))
    x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3))
    x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4))

    m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4)

    x1_1 = test_utils.mask(
        x1_1, mask_constant=-1, mask_axis=(1,), key=m1_key, p=0.5)
    x1_2 = test_utils.mask(
        x1_2, mask_constant=-1, mask_axis=(2, 3,), key=m2_key, p=0.5)
    if not same_inputs:
      x2_3 = test_utils.mask(
          x2_3, mask_constant=-1, mask_axis=(1, 3,), key=m3_key, p=0.5)
      x2_4 = test_utils.mask(
          x2_4, mask_constant=-1, mask_axis=(2,), key=m4_key, p=0.5)

    x1 = (((x1_1, x1_2), x1_3), x1_4)
    x2 = (((x2_1, x2_2), x2_3), x2_4) if not same_inputs else None

    N_in = 2 ** 7

    # We only include dropout on non-TPU backends, because it takes large N to
    # converge on TPU.
    dropout_or_id = stax.Dropout(0.9) if platform != 'tpu' else stax.Identity()

    init_fn, apply_fn, kernel_fn = stax.parallel(
        stax.parallel(
            stax.parallel(stax.Dense(N_in),
                          stax.serial(stax.Conv(N_in + 1, (2, 2)),
                                      stax.Flatten())),
            stax.serial(stax.Conv(N_in + 2, (2, 2)),
                        dropout_or_id,
                        stax.GlobalAvgPool())),
        stax.Conv(N_in + 3, (2,)))

    kernel_fn_empirical = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, mc_key, N_SAMPLES,
        implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
        vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {})
                   if platform == 'tpu' else None)
    )

    test_utils.assert_close_matrices(
        self,
        kernel_fn(x1, x2, get=kernel_type, mask_constant=-1),
        kernel_fn_empirical(x1, x2, get=kernel_type, mask_constant=-1),
        rtol)
Ejemplo n.º 7
0
def WideResnetBlocknt(channels,
                      strides=(1, 1),
                      channel_mismatch=False,
                      batchnorm='std',
                      parameterization='ntk'):
    """A WideResnet block, with or without BatchNorm."""

    Main = stax_nt.serial(
        _batch_norm_internal(batchnorm), stax_nt.Relu(),
        stax_nt.Conv(channels, (3, 3),
                     strides,
                     padding='SAME',
                     parameterization=parameterization),
        _batch_norm_internal(batchnorm), stax_nt.Relu(),
        stax_nt.Conv(channels, (3, 3),
                     padding='SAME',
                     parameterization=parameterization))

    Shortcut = stax_nt.Identity() if not channel_mismatch else stax_nt.Conv(
        channels, (3, 3),
        strides,
        padding='SAME',
        parameterization=parameterization)
    return stax_nt.serial(stax_nt.FanOut(2), stax_nt.parallel(Main, Shortcut),
                          stax_nt.FanInSum())
Ejemplo n.º 8
0
def _build_network(input_shape, network, out_logits, use_dropout):
    dropout = stax.Dropout(0.9,
                           mode='train') if use_dropout else stax.Identity()
    if len(input_shape) == 1:
        assert network == 'FLAT'
        return stax.serial(stax.Dense(WIDTH, W_std=2.0, b_std=0.5), dropout,
                           stax.Dense(out_logits, W_std=2.0, b_std=0.5))
    elif len(input_shape) == 3:
        if network == POOLING:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05),
                stax.GlobalAvgPool(), dropout,
                stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        elif network == FLAT:
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05),
                stax.Flatten(), dropout,
                stax.Dense(out_logits, W_std=2.0, b_std=0.5))
        elif network == INTERMEDIATE_CONV:
            return stax.Conv(CONVOLUTION_CHANNELS, (2, 2),
                             W_std=2.0,
                             b_std=0.05)
        else:
            raise ValueError(
                'Unexpected network type found: {}'.format(network))
    else:
        raise ValueError('Expected flat or image test input.')
Ejemplo n.º 9
0
    def test_ab_relu_id(self, same_inputs):
        key = random.PRNGKey(1)
        X0_1 = random.normal(key, (5, 7))
        fc = stax.Dense(10, 1, 0)

        X0_2 = None if same_inputs else random.normal(key, (9, 7))

        # Test that ABRelu(a, a) == a * Identity
        init_fn, apply_id, kernel_fn_id = stax.serial(fc, stax.Identity())
        params = init_fn(key, input_shape=(-1, 7))

        for a in [-5, -1, -0.5, 0, 0.5, 1, 5]:
            with self.subTest(a=a):
                _, apply_ab_relu, kernel_fn_ab_relu = stax.serial(
                    fc, stax.ABRelu(a, a))

                X1_1_id = a * apply_id(params, X0_1)
                X1_1_ab_relu = apply_ab_relu(params, X0_1)
                self.assertAllClose(X1_1_id, X1_1_ab_relu, True)

                kernels_id = kernel_fn_id(X0_1 * a,
                                          None if X0_2 is None else a * X0_2)
                kernels_ab_relu = kernel_fn_ab_relu(X0_1, X0_2,
                                                    ('nngp', 'ntk'))
                self.assertAllClose(kernels_id, kernels_ab_relu, True)
def ResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    Main = stax.serial(stax.Relu(),
                       stax.Conv(channels, (3, 3), strides, padding='SAME'),
                       stax.Relu(), stax.Conv(channels, (3, 3),
                                              padding='SAME'))
    Shortcut = stax.Identity() if not channel_mismatch else stax.Conv(
        channels, (3, 3), strides, padding='SAME')
    return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut),
                       stax.FanInSum())
Ejemplo n.º 11
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
             padding, phi, strides, width, is_ntk, proj_into_2d):
  fc = partial(stax.Dense, W_std=W_std, b_std=b_std)
  conv = partial(
      stax.Conv,
      filter_shape=filter_shape,
      strides=strides,
      padding=padding,
      W_std=W_std,
      b_std=b_std)
  affine = conv(width) if is_conv else fc(width)

  res_unit = stax.serial((stax.AvgPool(
      (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR')
                          if use_pooling else stax.Identity()), phi, affine)

  if is_res:
    block = stax.serial(affine, stax.FanOut(2),
                        stax.parallel(stax.Identity(), res_unit),
                        stax.FanInSum())
  else:
    block = stax.serial(affine, res_unit)

  if proj_into_2d == 'FLAT':
    proj_layer = stax.Flatten()
  elif proj_into_2d == 'POOL':
    proj_layer = stax.GlobalAvgPool()
  elif proj_into_2d.startswith('ATTN'):
    n_heads = int(np.sqrt(width))
    n_chan_val = int(np.round(float(width) / n_heads))
    fixed = proj_into_2d == 'ATTN_FIXED'
    proj_layer = stax.serial(
        stax.GlobalSelfAttention(
            width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads,
            fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std,
            W_out_std=1.0, b_std=b_std),
        stax.Flatten())
  else:
    raise ValueError(proj_into_2d)
  readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

  return stax.serial(block, readout)
Ejemplo n.º 12
0
 def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
   main = stax.serial(
       stax.Relu(),
       stax.Conv(
           channels, (3, 3), strides, padding='SAME',
           parameterization='standard'
       ),
       stax.Relu(),
       stax.Conv(channels, (3, 3), padding='SAME',
                 parameterization='standard'),
   )
   shortcut = (
       stax.Identity()
       if not channel_mismatch
       else stax.Conv(
           channels, (3, 3), strides, padding='SAME',
           parameterization='standard'
       )
   )
   return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut),
                      stax.FanInSum())
Ejemplo n.º 13
0
from neural_tangents._src.empirical import _DEFAULT_TESTING_NTK_IMPLEMENTATION
from tests import test_utils

config.parse_flags_with_absl()
config.update('jax_numpy_rank_promotion', 'raise')

test_utils.update_test_tolerance()

prandom.seed(1)


@parameterized.product(
    same_inputs=[False, True],
    readout=[stax.Flatten(),
             stax.GlobalAvgPool(),
             stax.Identity()],
    readin=[stax.Flatten(),
            stax.GlobalAvgPool(),
            stax.Identity()])
class DiagonalTest(test_utils.NeuralTangentsTestCase):
    def _get_kernel_fn(self, same_inputs, readin, readout):
        key = random.PRNGKey(1)
        x1 = random.normal(key, (2, 5, 6, 3))
        x2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))
        layers = [readin]
        filter_shape = (2, 3) if readin[0].__name__ == 'Identity' else ()
        layers += [
            stax.Conv(1, filter_shape, padding='SAME'),
            stax.Relu(),
            stax.Conv(1, filter_shape, padding='SAME'),
            stax.Erf(), readout
Ejemplo n.º 14
0
class ElementwiseTest(test_utils.NeuralTangentsTestCase):

  @parameterized.product(
      phi=[
          stax.Identity(),
          stax.Erf(),
          stax.Sin(),
          stax.Relu(),
      ],
      same_inputs=[False, True, None],
      n=[0, 1, 2],
      diagonal_batch=[True, False],
      diagonal_spatial=[True, False]
  )
  def test_elementwise(
      self,
      same_inputs,
      phi,
      n,
      diagonal_batch,
      diagonal_spatial
  ):
    fn = lambda x: phi[1]((), x)

    name = phi[0].__name__

    def nngp_fn(cov12, var1, var2):
      if 'Identity' in name:
        res = cov12

      elif 'Erf' in name:
        prod = (1 + 2 * var1) * (1 + 2 * var2)
        res = np.arcsin(2 * cov12 / np.sqrt(prod)) * 2 / np.pi

      elif 'Sin' in name:
        sum_ = (var1 + var2)
        s1 = np.exp((-0.5 * sum_ + cov12))
        s2 = np.exp((-0.5 * sum_ - cov12))
        res = (s1 - s2) / 2

      elif 'Relu' in name:
        prod = var1 * var2
        sqrt = np.sqrt(np.maximum(prod - cov12 ** 2, 1e-30))
        angles = np.arctan2(sqrt, cov12)
        dot_sigma = (1 - angles / np.pi) / 2
        res = sqrt / (2 * np.pi) + dot_sigma * cov12

      else:
        raise NotImplementedError(name)

      return res

    _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Elementwise(fn, nngp_fn),
                                  stax.Dense(1), stax.Elementwise(fn, nngp_fn))
    _, _, kernel_fn_manual = stax.serial(stax.Dense(1), phi,
                                         stax.Dense(1), phi)

    key = random.PRNGKey(1)
    shape = (4, 3, 2)[:n] + (1,)
    x1 = random.normal(key, (5,) + shape)
    if same_inputs is None:
      x2 = None
    elif same_inputs is True:
      x2 = x1
    else:
      x2 = random.normal(key, (6,) + shape)

    kwargs = dict(diagonal_batch=diagonal_batch,
                  diagonal_spatial=diagonal_spatial)

    k = kernel_fn(x1, x2, **kwargs)
    k_manual = kernel_fn_manual(x1, x2, **kwargs).replace(is_gaussian=False)
    self.assertAllClose(k_manual, k)
Ejemplo n.º 15
0
@parameterized.named_parameters(
    test_utils.cases_from_list(
        {
            'testcase_name':
            ' [{}_out={}_in={}]'.format(
                'same_inputs' if same_inputs else 'different_inputs',
                readout[0].__name__, readin[0].__name__),
            'same_inputs':
            same_inputs,
            'readout':
            readout,
            'readin':
            readin
        } for same_inputs in [False, True] for readout in
        [stax.Flatten(), stax.GlobalAvgPool(),
         stax.Identity()]
        for readin in [stax.Flatten(),
                       stax.GlobalAvgPool(),
                       stax.Identity()]))
class DiagonalTest(test_utils.NeuralTangentsTestCase):
    def _get_kernel_fn(self, same_inputs, readin, readout):
        key = random.PRNGKey(1)
        x1 = random.normal(key, (2, 5, 6, 3))
        x2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))
        layers = [readin]
        filter_shape = (2, 3) if readin[0].__name__ == 'Identity' else ()
        layers += [
            stax.Conv(1, filter_shape, padding='SAME'),
            stax.Relu(),
            stax.Conv(1, filter_shape, padding='SAME'),
            stax.Erf(), readout
Ejemplo n.º 16
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
             phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm,
             parameterization, use_dropout):

  if is_conv:
    # Select a random dimension order.
    default_spec = 'NHWC'
    if xla_bridge.get_backend().platform == 'tpu':
      # Keep batch dimension leading for TPU for batching to work.
      specs = ['NHWC', 'NHCW', 'NCHW']
    else:
      specs = ['NHWC', 'NHCW', 'NCHW', 'CHWN', 'CHNW', 'CNHW']
    spec = prandom.choice(specs)
    input_shape = tuple(INPUT_SHAPE[default_spec.index(c)] for c in spec)

    if layer_norm:
      layer_norm = tuple(spec.index(c) for c in layer_norm)

  else:
    # Only `NC` dimension order is supported and is enforced by layers.
    spec = None
    input_shape = INPUT_SHAPE
    if layer_norm:
      layer_norm = prandom.choice([(1,), (-1,)])

  dimension_numbers = (spec, 'HWIO', spec)

  fc = partial(
      stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization)

  def conv(out_chan): return stax.GeneralConv(
      dimension_numbers=dimension_numbers,
      out_chan=out_chan,
      filter_shape=filter_shape,
      strides=strides,
      padding=padding,
      W_std=W_std,
      b_std=b_std,
      parameterization=parameterization
  )
  affine = conv(width) if is_conv else fc(width)

  spec = dimension_numbers[-1]

  rate = np.onp.random.uniform(0.5, 0.9)
  dropout = stax.Dropout(rate, mode='train')

  if pool_type == 'AVG':
    pool_fn = stax.AvgPool
    globalPool_fn = stax.GlobalAvgPool
  elif pool_type == 'SUM':
    pool_fn = stax.SumPool
    globalPool_fn = stax.GlobalSumPool

  if use_pooling:
    pool_or_identity = pool_fn((2, 3),
                               None,
                               'SAME' if padding == 'SAME' else 'CIRCULAR',
                               spec=spec)
  else:
    pool_or_identity = stax.Identity()
  dropout_or_identity = dropout if use_dropout else stax.Identity()
  layer_norm_or_identity = (stax.Identity() if layer_norm is None
                            else stax.LayerNorm(axis=layer_norm, spec=spec))
  res_unit = stax.serial(pool_or_identity, phi, dropout_or_identity, affine)
  if is_res:
    block = stax.serial(
        affine,
        stax.FanOut(2),
        stax.parallel(stax.Identity(),
                      res_unit),
        stax.FanInSum(),
        layer_norm_or_identity)
  else:
    block = stax.serial(
        affine,
        res_unit,
        layer_norm_or_identity)

  if proj_into_2d == 'FLAT':
    proj_layer = stax.Flatten(spec=spec)
  elif proj_into_2d == 'POOL':
    proj_layer = globalPool_fn(spec=spec)
  elif proj_into_2d.startswith('ATTN'):
    n_heads = int(np.sqrt(width))
    n_chan_val = int(np.round(float(width) / n_heads))
    fixed = proj_into_2d == 'ATTN_FIXED'
    proj_layer = stax.serial(
        stax.GlobalSelfAttention(
            n_chan_out=width,
            n_chan_key=width,
            n_chan_val=n_chan_val,
            n_heads=n_heads,
            fixed=fixed,
            W_key_std=W_std,
            W_value_std=W_std,
            W_query_std=W_std,
            W_out_std=1.0,
            b_std=b_std,
            spec=spec), stax.Flatten(spec=spec))
  else:
    raise ValueError(proj_into_2d)
  readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

  return stax.serial(block, readout), input_shape
Ejemplo n.º 17
0
    step=1)
sigma_w = st.slider("Sigma w for Residual Case ", 0.1, 3.0, 1.5, step=0.1)
sigma_b = st.slider("Sigma b for Residual Case", 0.0, 0.1, 0.05, step=0.01)

activation_fn = st.selectbox("Activation Function for Residual Case",
                             ("Erf", "ReLU", "None"))

activation_fn = activation_fn_dict[activation_fn]

sequence = ((activation_fn, stax.Dense(n_hidden, W_std=sigma_w, b_std=sigma_b))
            if activation_fn else
            (stax.Dense(n_hidden, W_std=sigma_w, b_std=sigma_b), ))

ResBlock = stax.serial(
    stax.FanOut(2),
    stax.parallel(stax.serial(*(sequence * depth)), stax.Identity()),
    stax.FanInSum(),
)

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(n_hidden, W_std=sigma_w, b_std=sigma_b),
    ResBlock,
    ResBlock,
    activation_fn,
    stax.Dense(1, W_std=sigma_w, b_std=sigma_b),
)

apply_fn = jit(apply_fn)
kernel_fn = jit(kernel_fn, static_argnums=(2, ))

opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
Ejemplo n.º 18
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
             phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm,
             parameterization, s, use_dropout):

  if is_conv:
    # Select a random filter order.
    default_filter_spec = 'HW'
    filter_specs = [''.join(p) for p in itertools.permutations('HWIO')]
    filter_spec = prandom.choice(filter_specs)
    filter_shape = tuple(filter_shape[default_filter_spec.index(c)]
                         for c in filter_spec if c in default_filter_spec)
    strides = tuple(strides[default_filter_spec.index(c)]
                    for c in filter_spec if c in default_filter_spec)

    # Select the activation order.
    default_spec = 'NHWC'
    if default_backend() == 'tpu':
      # Keep batch dimension leading for TPU for batching to work.
      specs = ['N' + ''.join(p) for p in itertools.permutations('CHW')]
    else:
      specs = [''.join(p) for p in itertools.permutations('NCHW')]
    spec = prandom.choice(specs)
    input_shape = tuple(INPUT_SHAPE[default_spec.index(c)] for c in spec)

  else:
    input_shape = (INPUT_SHAPE[0], onp.prod(INPUT_SHAPE[1:]))
    if default_backend() == 'tpu':
      spec = 'NC'
    else:
      spec = prandom.choice(['NC', 'CN'])
      if spec.index('N') == 1:
        input_shape = input_shape[::-1]

    filter_spec = None

  dimension_numbers = (spec, filter_spec, spec)
  batch_axis, channel_axis = spec.index('N'), spec.index('C')

  spec_fc = ''.join(c for c in spec if c in ('N', 'C'))
  batch_axis_fc, channel_axis_fc = spec_fc.index('N'), spec_fc.index('C')

  if not is_conv:
    batch_axis = batch_axis_fc
    channel_axis = channel_axis_fc

  if layer_norm:
    layer_norm = tuple(spec.index(c) for c in layer_norm)

  def fc(out_dim, s):
    return stax.Dense(
        out_dim=out_dim,
        W_std=W_std,
        b_std=b_std,
        parameterization=parameterization,
        s=s,
        batch_axis=batch_axis_fc,
        channel_axis=channel_axis_fc
    )

  def conv(out_chan, s):
    return stax.Conv(
        out_chan=out_chan,
        filter_shape=filter_shape,
        strides=strides,
        padding=padding,
        W_std=W_std,
        b_std=b_std,
        dimension_numbers=dimension_numbers,
        parameterization=parameterization,
        s=s
    )

  affine = conv(width, (s, s)) if is_conv else fc(width, (s, s))
  affine_bottom = conv(width, (1, s)) if is_conv else fc(width, (1, s))

  rate = onp.random.uniform(0.5, 0.9)
  dropout = stax.Dropout(rate, mode='train')

  if pool_type == 'AVG':
    pool_fn = stax.AvgPool
    global_pool_fn = stax.GlobalAvgPool
  elif pool_type == 'SUM':
    pool_fn = stax.SumPool
    global_pool_fn = stax.GlobalSumPool
  else:
    raise ValueError(pool_type)

  if use_pooling:
    pool_or_identity = pool_fn((2, 3),
                               None,
                               'SAME' if padding == 'SAME' else 'CIRCULAR',
                               batch_axis=batch_axis,
                               channel_axis=channel_axis)
  else:
    pool_or_identity = stax.Identity()
  dropout_or_identity = dropout if use_dropout else stax.Identity()
  layer_norm_or_identity = (stax.Identity() if layer_norm is None else
                            stax.LayerNorm(axis=layer_norm,
                                           batch_axis=batch_axis,
                                           channel_axis=channel_axis))
  res_unit = stax.serial(dropout_or_identity, affine, pool_or_identity)
  if is_res:
    block = stax.serial(
        affine_bottom,
        stax.FanOut(2),
        stax.parallel(stax.Identity(),
                      res_unit),
        stax.FanInSum(),
        layer_norm_or_identity,
        phi)
  else:
    block = stax.serial(
        affine_bottom,
        res_unit,
        layer_norm_or_identity,
        phi)

  if proj_into_2d == 'FLAT':
    proj_layer = stax.Flatten(batch_axis, batch_axis_fc)
  elif proj_into_2d == 'POOL':
    proj_layer = global_pool_fn(batch_axis, channel_axis)
  elif proj_into_2d.startswith('ATTN'):
    n_heads = int(np.sqrt(width))
    n_chan_val = int(np.round(float(width) / n_heads))
    proj_layer = stax.serial(
        stax.GlobalSelfAttention(
            n_chan_out=width,
            n_chan_key=width,
            n_chan_val=n_chan_val,
            n_heads=n_heads,
            linear_scaling=True,
            W_key_std=W_std,
            W_value_std=W_std,
            W_query_std=W_std,
            W_out_std=1.0,
            b_std=b_std,
            batch_axis=batch_axis,
            channel_axis=channel_axis),
        stax.Flatten(batch_axis, batch_axis_fc))
  else:
    raise ValueError(proj_into_2d)

  readout = stax.serial(proj_layer,
                        fc(1 if is_ntk else width, (s, 1 if is_ntk else s)))

  device_count = -1 if spec.index('N') == 0 else 0

  net = stax.serial(block, readout)
  return net, input_shape, device_count, channel_axis_fc