Ejemplo n.º 1
0
    def test_avg_pool(self):
        X1 = np.ones((4, 2, 3, 2))
        X2 = np.ones((3, 2, 3, 2))

        _, apply_fn, kernel_fn = stax.AvgPool((2, 2), (1, 1),
                                              'SAME',
                                              normalize_edges=False)
        _, apply_fn_norm, kernel_fn_norm = stax.AvgPool((2, 2), (1, 1),
                                                        'SAME',
                                                        normalize_edges=True)
        _, apply_fn_stax = stax.ostax.AvgPool((2, 2), (1, 1), 'SAME')

        out1 = apply_fn((), X1)
        out2 = apply_fn((), X2)

        out1_norm = apply_fn_norm((), X1)
        out2_norm = apply_fn_norm((), X2)

        out1_stax = apply_fn_stax((), X1)
        out2_stax = apply_fn_stax((), X2)

        self.assertAllClose((out1_stax, out2_stax), (out1_norm, out2_norm),
                            True)

        out_unnorm = np.array([[1., 1., 0.5], [0.5, 0.5, 0.25]]).reshape(
            (1, 2, 3, 1))
        out1_unnormalized = np.broadcast_to(out_unnorm, X1.shape)
        out2_unnormalized = np.broadcast_to(out_unnorm, X2.shape)

        self.assertAllClose((out1_unnormalized, out2_unnormalized),
                            (out1, out2), True)

        ker = kernel_fn(X1, X2)
        ker_norm = kernel_fn_norm(X1, X2)

        self.assertAllClose(np.ones_like(ker_norm.nngp), ker_norm.nngp, True)
        self.assertAllClose(np.ones_like(ker_norm.var1), ker_norm.var1, True)
        self.assertAllClose(np.ones_like(ker_norm.var2), ker_norm.var2, True)

        self.assertEqual(ker_norm.nngp.shape, ker.nngp.shape)
        self.assertEqual(ker_norm.var1.shape, ker.var1.shape)
        self.assertEqual(ker_norm.var2.shape, ker.var2.shape)

        ker_unnorm = np.outer(out_unnorm, out_unnorm).reshape((2, 3, 2, 3))
        ker_unnorm = np.transpose(ker_unnorm, axes=(0, 2, 1, 3))
        nngp = np.broadcast_to(ker_unnorm.reshape((1, 1) + ker_unnorm.shape),
                               ker.nngp.shape)
        var1 = np.broadcast_to(np.expand_dims(ker_unnorm, 0), ker.var1.shape)
        var2 = np.broadcast_to(np.expand_dims(ker_unnorm, 0), ker.var2.shape)
        self.assertAllClose((nngp, var1, var2), (ker.nngp, ker.var1, ker.var2),
                            True)
Ejemplo n.º 2
0
def build_le_net(network_width):
    """ Construct the LeNet of width network_width with average pooling using neural tangent's stax."""
    return stax.serial(
        stax.Conv(out_chan=6 * network_width,
                  filter_shape=(3, 3),
                  strides=(1, 1),
                  padding='VALID'), stax.Relu(),
        stax.AvgPool(window_shape=(2, 2), strides=(2, 2)),
        stax.Conv(out_chan=16 * network_width,
                  filter_shape=(3, 3),
                  strides=(1, 1),
                  padding='VALID'), stax.Relu(),
        stax.AvgPool(window_shape=(2, 2), strides=(2, 2)), stax.Flatten(),
        stax.Dense(120 * network_width), stax.Relu(),
        stax.Dense(84 * network_width), stax.Relu(), stax.Dense(10))
Ejemplo n.º 3
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
             padding, phi, strides, width, is_ntk):
  fc = partial(stax.Dense, W_std=W_std, b_std=b_std)
  conv = partial(
      stax.Conv,
      filter_shape=filter_shape,
      strides=strides,
      padding=padding,
      W_std=W_std,
      b_std=b_std)
  affine = conv(width) if is_conv else fc(width)

  res_unit = stax.serial((stax.AvgPool(
      (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR')
                          if use_pooling else stax.Identity()), phi, affine)

  if is_res:
    block = stax.serial(affine, stax.FanOut(2),
                        stax.parallel(stax.Identity(), res_unit),
                        stax.FanInSum())
  else:
    block = stax.serial(affine, res_unit)

  readout = stax.serial(stax.GlobalAvgPool() if use_pooling else stax.Flatten(),
                        fc(1 if is_ntk else width))

  net = stax.serial(block, readout)
  return net
Ejemplo n.º 4
0
def WideResnet(block_size, k, num_classes):
    return stax.serial(stax.Conv(16, (3, 3), padding='SAME'),
                       WideResnetGroup(block_size, int(16 * k)),
                       WideResnetGroup(block_size, int(32 * k), (2, 2)),
                       WideResnetGroup(block_size, int(64 * k), (2, 2)),
                       stax.AvgPool((8, 8)), stax.Flatten(),
                       stax.Dense(num_classes, 1., 0.))
Ejemplo n.º 5
0
def WideResnetnt(
        block_size,
        k,
        num_classes,
        batchnorm='std'):  #, batch_norm=None,layer_norm=None,freezelast=None):
    """Based off of WideResnet from paper, with or without BatchNorm. 
  (Set config.wrn_block_size=3, config.wrn_widening_f=10 in that case).
  Uses default weight and bias init."""
    parameterization = 'standard'
    layers_lst = [
        stax_nt.Conv(16, (3, 3),
                     padding='SAME',
                     parameterization=parameterization),
        WideResnetGroupnt(block_size,
                          16 * k,
                          parameterization=parameterization,
                          batchnorm=batchnorm),
        WideResnetGroupnt(block_size,
                          32 * k, (2, 2),
                          parameterization=parameterization,
                          batchnorm=batchnorm),
        WideResnetGroupnt(block_size,
                          64 * k, (2, 2),
                          parameterization=parameterization,
                          batchnorm=batchnorm)
    ]
    layers_lst += [_batch_norm_internal(batchnorm), stax_nt.Relu()]
    layers_lst += [
        stax_nt.AvgPool((8, 8)),
        stax_nt.Flatten(),
        stax_nt.Dense(num_classes, parameterization=parameterization)
    ]
    return stax_nt.serial(*layers_lst)
Ejemplo n.º 6
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
             phi, strides, width, is_ntk, proj_into_2d, layer_norm,
             parameterization, use_dropout):
    fc = partial(stax.Dense,
                 W_std=W_std,
                 b_std=b_std,
                 parameterization=parameterization)
    conv = partial(stax.Conv,
                   filter_shape=filter_shape,
                   strides=strides,
                   padding=padding,
                   W_std=W_std,
                   b_std=b_std,
                   parameterization=parameterization)
    affine = conv(width) if is_conv else fc(width)
    rate = np.onp.random.uniform(0.5, 0.9)
    dropout = stax.Dropout(rate, mode='train')
    ave_pool = stax.AvgPool((2, 3), None,
                            'SAME' if padding == 'SAME' else 'CIRCULAR')
    ave_pool_or_identity = ave_pool if use_pooling else stax.Identity()
    dropout_or_identity = dropout if use_dropout else stax.Identity()
    layer_norm_or_identity = (stax.Identity() if layer_norm is None else
                              stax.LayerNorm(axis=layer_norm))
    res_unit = stax.serial(ave_pool_or_identity, phi, dropout_or_identity,
                           affine)
    if is_res:
        block = stax.serial(affine, stax.FanOut(2),
                            stax.parallel(stax.Identity(), res_unit),
                            stax.FanInSum(), layer_norm_or_identity)
    else:
        block = stax.serial(affine, res_unit, layer_norm_or_identity)

    if proj_into_2d == 'FLAT':
        proj_layer = stax.Flatten()
    elif proj_into_2d == 'POOL':
        proj_layer = stax.GlobalAvgPool()
    elif proj_into_2d.startswith('ATTN'):
        n_heads = int(np.sqrt(width))
        n_chan_val = int(np.round(float(width) / n_heads))
        fixed = proj_into_2d == 'ATTN_FIXED'
        proj_layer = stax.serial(
            stax.GlobalSelfAttention(width,
                                     n_chan_key=width,
                                     n_chan_val=n_chan_val,
                                     n_heads=n_heads,
                                     fixed=fixed,
                                     W_key_std=W_std,
                                     W_value_std=W_std,
                                     W_query_std=W_std,
                                     W_out_std=1.0,
                                     b_std=b_std), stax.Flatten())
    else:
        raise ValueError(proj_into_2d)
    readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

    return stax.serial(block, readout)
Ejemplo n.º 7
0
def _MyrtleNetwork(width, depth, W_std=jnp.sqrt(2.0), b_std=0.0):
    layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]}
    activation_fn = stax.Relu()
    layers = []
    conv = functools.partial(stax.Conv,
                             W_std=W_std,
                             b_std=b_std,
                             padding="SAME")

    layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][0]
    layers += [stax.AvgPool((2, 2), strides=(2, 2))]
    layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][1]
    layers += [stax.AvgPool((2, 2), strides=(2, 2))]
    layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][2]
    layers += [stax.AvgPool((2, 2), strides=(2, 2))] * 3

    layers += [stax.Flatten(), stax.Dense(10, W_std, b_std)]

    return stax.serial(*layers)
Ejemplo n.º 8
0
def WideResnet(block_size, k, num_classes, W_std=1., b_std=0.):
    return stax.serial(
        stax.Conv(16, (3, 3), W_std=W_std, b_std=b_std, padding='SAME'),
        WideResnetGroup(block_size, int(16 * k), W_std=W_std, b_std=b_std),
        WideResnetGroup(block_size,
                        int(32 * k), (2, 2),
                        W_std=W_std,
                        b_std=b_std),
        WideResnetGroup(block_size,
                        int(64 * k), (2, 2),
                        W_std=W_std,
                        b_std=b_std), stax.AvgPool((7, 7)), stax.Flatten(),
        stax.Dense(num_classes, W_std=W_std, b_std=b_std))
Ejemplo n.º 9
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
             padding, phi, strides, width, is_ntk, proj_into_2d):
  fc = partial(stax.Dense, W_std=W_std, b_std=b_std)
  conv = partial(
      stax.Conv,
      filter_shape=filter_shape,
      strides=strides,
      padding=padding,
      W_std=W_std,
      b_std=b_std)
  affine = conv(width) if is_conv else fc(width)

  res_unit = stax.serial((stax.AvgPool(
      (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR')
                          if use_pooling else stax.Identity()), phi, affine)

  if is_res:
    block = stax.serial(affine, stax.FanOut(2),
                        stax.parallel(stax.Identity(), res_unit),
                        stax.FanInSum())
  else:
    block = stax.serial(affine, res_unit)

  if proj_into_2d == 'FLAT':
    proj_layer = stax.Flatten()
  elif proj_into_2d == 'POOL':
    proj_layer = stax.GlobalAvgPool()
  elif proj_into_2d.startswith('ATTN'):
    n_heads = int(np.sqrt(width))
    n_chan_val = int(np.round(float(width) / n_heads))
    fixed = proj_into_2d == 'ATTN_FIXED'
    proj_layer = stax.serial(
        stax.GlobalSelfAttention(
            width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads,
            fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std,
            W_out_std=1.0, b_std=b_std),
        stax.Flatten())
  else:
    raise ValueError(proj_into_2d)
  readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

  return stax.serial(block, readout)
Ejemplo n.º 10
0
    def test_input_req(self, same_inputs):
        test_utils.skip_test(self)

        key = random.PRNGKey(1)
        x1 = random.normal(key, (2, 7, 8, 4, 3))
        x2 = None if same_inputs else random.normal(key, (4, 7, 8, 4, 3))

        _, _, wrong_conv_fn = stax.serial(
            stax.Conv(out_chan=1,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NDHWC', 'HDWIO', 'NCDWH')),
            stax.Relu(),
            stax.Conv(out_chan=1,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NHDWC', 'HWDIO', 'NCWHD')))
        with self.assertRaises(ValueError):
            wrong_conv_fn(x1, x2)

        init_fn, apply_fn, correct_conv_fn = stax.serial(
            stax.Conv(out_chan=1024,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NHWDC', 'DHWIO', 'NCWDH')),
            stax.Relu(),
            stax.Conv(out_chan=1024,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NCHDW', 'WHDIO', 'NCDWH')),
            stax.Flatten(), stax.Dense(1024))

        correct_conv_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn=init_fn,
            apply_fn=apply_fn,
            key=key,
            n_samples=400,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=0)
        K = correct_conv_fn(x1, x2, get='nngp')
        K_mc = correct_conv_fn_mc(x1, x2, get='nngp')
        self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05)

        _, _, wrong_conv_fn = stax.serial(
            stax.Conv(out_chan=1,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NDHWC', 'HDWIO', 'NCDWH')),
            stax.GlobalAvgPool(channel_axis=2))
        with self.assertRaises(ValueError):
            wrong_conv_fn(x1, x2)

        init_fn, apply_fn, correct_conv_fn = stax.serial(
            stax.Conv(out_chan=1024,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NHDWC', 'DHWIO', 'NDWCH')),
            stax.Relu(), stax.AvgPool((2, 1, 3), batch_axis=0,
                                      channel_axis=-2),
            stax.Conv(out_chan=1024,
                      filter_shape=(1, 2, 3),
                      dimension_numbers=('NDHCW', 'IHWDO', 'NDCHW')),
            stax.Relu(), stax.GlobalAvgPool(channel_axis=2), stax.Dense(1024))

        correct_conv_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn=init_fn,
            apply_fn=apply_fn,
            key=key,
            n_samples=300,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=0)
        K = correct_conv_fn(x1, x2, get='nngp')
        K_mc = correct_conv_fn_mc(x1, x2, get='nngp')
        self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05)

        _, _, wrong_conv_fn = stax.serial(
            stax.Flatten(),
            stax.Dense(1),
            stax.Erf(),
            stax.Conv(out_chan=1,
                      filter_shape=(1, 2),
                      dimension_numbers=('CN', 'IO', 'NC')),
        )
        with self.assertRaises(ValueError):
            wrong_conv_fn(x1, x2)

        init_fn, apply_fn, correct_conv_fn = stax.serial(
            stax.Flatten(), stax.Conv(out_chan=1024, filter_shape=()),
            stax.Relu(), stax.Dense(1))

        correct_conv_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn=init_fn,
            apply_fn=apply_fn,
            key=key,
            n_samples=200,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=0)
        K = correct_conv_fn(x1, x2, get='ntk')
        K_mc = correct_conv_fn_mc(x1, x2, get='ntk')
        self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05)
Ejemplo n.º 11
0
X, _, Xtest, _ = prep_data('CIFAR10', False, noise_index)

n = X.shape[0]
ntest = Xtest.shape[0]
W_std = 1.0
b_std = 0.0
# Number of rows generated at each job
m = onp.int(200)

if model_name == 'Myrtle':
    init_fn, apply_fn, kernel_fn = stax.serial(stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\
     stax.Relu(),\
     stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\
     stax.Relu(),\
     stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\
     stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\
     stax.Relu(),\
     stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\
     stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\
     stax.Relu(),\
     stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\
     stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\
     stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\
     stax.Flatten(),\
     stax.Dense(10, W_std, b_std))
else:
    raise Exception('Invalid Input Error')

apply_fn = jit(apply_fn)
kernel_fn = jit(kernel_fn, static_argnums=(2, ))