Esempio n. 1
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
             phi, strides, width, is_ntk, proj_into_2d, layer_norm):
    fc = partial(stax.Dense, W_std=W_std, b_std=b_std)
    conv = partial(stax.Conv,
                   filter_shape=filter_shape,
                   strides=strides,
                   padding=padding,
                   W_std=W_std,
                   b_std=b_std)
    affine = conv(width) if is_conv else fc(width)

    res_unit = stax.serial((stax.AvgPool(
        (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR')
                            if use_pooling else stax.Identity()), phi, affine)

    if is_res:
        block = stax.serial(
            affine, stax.FanOut(2), stax.parallel(stax.Identity(), res_unit),
            stax.FanInSum(),
            stax.Identity() if layer_norm is None else stax.LayerNorm(
                axis=layer_norm))
    else:
        block = stax.serial(
            affine, res_unit,
            stax.Identity() if layer_norm is None else stax.LayerNorm(
                axis=layer_norm))

    if proj_into_2d == 'FLAT':
        proj_layer = stax.Flatten()
    elif proj_into_2d == 'POOL':
        proj_layer = stax.GlobalAvgPool()
    elif proj_into_2d.startswith('ATTN'):
        n_heads = int(np.sqrt(width))
        n_chan_val = int(np.round(float(width) / n_heads))
        fixed = proj_into_2d == 'ATTN_FIXED'
        proj_layer = stax.serial(
            stax.GlobalSelfAttention(width,
                                     n_chan_key=width,
                                     n_chan_val=n_chan_val,
                                     n_heads=n_heads,
                                     fixed=fixed,
                                     W_key_std=W_std,
                                     W_value_std=W_std,
                                     W_query_std=W_std,
                                     W_out_std=1.0,
                                     b_std=b_std), stax.Flatten())
    else:
        raise ValueError(proj_into_2d)
    readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

    return stax.serial(block, readout)
Esempio n. 2
0
  def test_hermite(self, same_inputs, degree, get, readout):
    key = random.PRNGKey(1)
    key1, key2, key = random.split(key, 3)

    if degree > 2:
      width = 10000
      n_samples = 5000
      test_utils.skip_test(self)
    else:
      width = 10000
      n_samples = 100

    x1 = np.cos(random.normal(key1, [2, 6, 6, 3]))
    x2 = x1 if same_inputs else np.cos(random.normal(key2, [3, 6, 6, 3]))

    conv_layers = [
        stax.Conv(width, (3, 3), W_std=2., b_std=0.5),
        stax.LayerNorm(),
        stax.Hermite(degree),
        stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten(),
        stax.Dense(1) if get == 'ntk' else stax.Identity()]

    init_fn, apply_fn, kernel_fn = stax.serial(*conv_layers)
    analytic_kernel = kernel_fn(x1, x2, get)
    mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples)
    mc_kernel = mc_kernel_fn(x1, x2, get)
    rot = degree / 2. * 1e-2
    test_utils.assert_close_matrices(self, mc_kernel, analytic_kernel, rot)
    def test_mask_conv(self, same_inputs, get, mask_axis, mask_constant,
                       concat, proj, p, n, transpose):
        if isinstance(concat, int) and concat > n:
            raise absltest.SkipTest('Concatenation axis out of bounds.')

        test_utils.skip_test(self)
        if default_backend() == 'gpu' and n > 3:
            raise absltest.SkipTest('>=4D-CNN is not supported on GPUs.')

        width = 256
        n_samples = 256
        tol = 0.03
        key = random.PRNGKey(1)

        spatial_shape = ((1, 2, 3, 2, 1) if transpose else (15, 8, 9))[:n]
        filter_shape = ((2, 3, 1, 2, 1) if transpose else (7, 2, 3))[:n]
        strides = (2, 1, 3, 2, 3)[:n]
        spatial_spec = 'HWDZX'[:n]
        dimension_numbers = ('N' + spatial_spec + 'C', 'OI' + spatial_spec,
                             'N' + spatial_spec + 'C')

        x1 = np.cos(random.normal(key, (2, ) + spatial_shape + (2, )))
        x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p)

        if same_inputs:
            x2 = None
        else:
            x2 = np.cos(random.normal(key, (4, ) + spatial_shape + (2, )))
            x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p)

        def get_attn():
            return stax.GlobalSelfAttention(
                n_chan_out=width,
                n_chan_key=width,
                n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))),
                n_heads=int(np.sqrt(width)),
            ) if proj == 'avg' else stax.Identity()

        conv = stax.ConvTranspose if transpose else stax.Conv

        nn = stax.serial(
            stax.FanOut(3),
            stax.parallel(
                stax.serial(
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='CIRCULAR',
                         W_std=1.5,
                         b_std=0.2),
                    stax.LayerNorm(axis=(1, -1)),
                    stax.Abs(),
                    stax.DotGeneral(rhs=0.9),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=1.2,
                         b_std=0.1),
                ),
                stax.serial(
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='SAME',
                         W_std=0.1,
                         b_std=0.3),
                    stax.Relu(),
                    stax.Dropout(0.7),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=0.9,
                         b_std=1.),
                ),
                stax.serial(
                    get_attn(),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='CIRCULAR',
                         W_std=1.,
                         b_std=0.1),
                    stax.Erf(),
                    stax.Dropout(0.2),
                    stax.DotGeneral(rhs=0.7),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=1.,
                         b_std=0.1),
                )),
            (stax.FanInSum() if concat is None else stax.FanInConcat(concat)),
            get_attn(),
            {
                'avg': stax.GlobalAvgPool(),
                'sum': stax.GlobalSumPool(),
                'flatten': stax.Flatten(),
            }[proj],
        )

        if get == 'nngp':
            init_fn, apply_fn, kernel_fn = stax.serial(
                nn, stax.Dense(width, 1., 0.))
        elif get == 'ntk':
            init_fn, apply_fn, kernel_fn = stax.serial(nn,
                                                       stax.Dense(1, 1., 0.))
        else:
            raise ValueError(get)

        kernel_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if concat in (0, -n) else -1,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=None if concat in (0, -n) else 0,
        )

        kernel_fn = jit(kernel_fn, static_argnames='get')
        exact = kernel_fn(x1, x2, get, mask_constant=mask_constant)
        empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant)
        test_utils.assert_close_matrices(self, empirical, exact, tol)
Esempio n. 4
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
             phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm,
             parameterization, s, use_dropout):

  if is_conv:
    # Select a random filter order.
    default_filter_spec = 'HW'
    filter_specs = [''.join(p) for p in itertools.permutations('HWIO')]
    filter_spec = prandom.choice(filter_specs)
    filter_shape = tuple(filter_shape[default_filter_spec.index(c)]
                         for c in filter_spec if c in default_filter_spec)
    strides = tuple(strides[default_filter_spec.index(c)]
                    for c in filter_spec if c in default_filter_spec)

    # Select the activation order.
    default_spec = 'NHWC'
    if default_backend() == 'tpu':
      # Keep batch dimension leading for TPU for batching to work.
      specs = ['N' + ''.join(p) for p in itertools.permutations('CHW')]
    else:
      specs = [''.join(p) for p in itertools.permutations('NCHW')]
    spec = prandom.choice(specs)
    input_shape = tuple(INPUT_SHAPE[default_spec.index(c)] for c in spec)

  else:
    input_shape = (INPUT_SHAPE[0], onp.prod(INPUT_SHAPE[1:]))
    if default_backend() == 'tpu':
      spec = 'NC'
    else:
      spec = prandom.choice(['NC', 'CN'])
      if spec.index('N') == 1:
        input_shape = input_shape[::-1]

    filter_spec = None

  dimension_numbers = (spec, filter_spec, spec)
  batch_axis, channel_axis = spec.index('N'), spec.index('C')

  spec_fc = ''.join(c for c in spec if c in ('N', 'C'))
  batch_axis_fc, channel_axis_fc = spec_fc.index('N'), spec_fc.index('C')

  if not is_conv:
    batch_axis = batch_axis_fc
    channel_axis = channel_axis_fc

  if layer_norm:
    layer_norm = tuple(spec.index(c) for c in layer_norm)

  def fc(out_dim, s):
    return stax.Dense(
        out_dim=out_dim,
        W_std=W_std,
        b_std=b_std,
        parameterization=parameterization,
        s=s,
        batch_axis=batch_axis_fc,
        channel_axis=channel_axis_fc
    )

  def conv(out_chan, s):
    return stax.Conv(
        out_chan=out_chan,
        filter_shape=filter_shape,
        strides=strides,
        padding=padding,
        W_std=W_std,
        b_std=b_std,
        dimension_numbers=dimension_numbers,
        parameterization=parameterization,
        s=s
    )

  affine = conv(width, (s, s)) if is_conv else fc(width, (s, s))
  affine_bottom = conv(width, (1, s)) if is_conv else fc(width, (1, s))

  rate = onp.random.uniform(0.5, 0.9)
  dropout = stax.Dropout(rate, mode='train')

  if pool_type == 'AVG':
    pool_fn = stax.AvgPool
    global_pool_fn = stax.GlobalAvgPool
  elif pool_type == 'SUM':
    pool_fn = stax.SumPool
    global_pool_fn = stax.GlobalSumPool
  else:
    raise ValueError(pool_type)

  if use_pooling:
    pool_or_identity = pool_fn((2, 3),
                               None,
                               'SAME' if padding == 'SAME' else 'CIRCULAR',
                               batch_axis=batch_axis,
                               channel_axis=channel_axis)
  else:
    pool_or_identity = stax.Identity()
  dropout_or_identity = dropout if use_dropout else stax.Identity()
  layer_norm_or_identity = (stax.Identity() if layer_norm is None else
                            stax.LayerNorm(axis=layer_norm,
                                           batch_axis=batch_axis,
                                           channel_axis=channel_axis))
  res_unit = stax.serial(dropout_or_identity, affine, pool_or_identity)
  if is_res:
    block = stax.serial(
        affine_bottom,
        stax.FanOut(2),
        stax.parallel(stax.Identity(),
                      res_unit),
        stax.FanInSum(),
        layer_norm_or_identity,
        phi)
  else:
    block = stax.serial(
        affine_bottom,
        res_unit,
        layer_norm_or_identity,
        phi)

  if proj_into_2d == 'FLAT':
    proj_layer = stax.Flatten(batch_axis, batch_axis_fc)
  elif proj_into_2d == 'POOL':
    proj_layer = global_pool_fn(batch_axis, channel_axis)
  elif proj_into_2d.startswith('ATTN'):
    n_heads = int(np.sqrt(width))
    n_chan_val = int(np.round(float(width) / n_heads))
    proj_layer = stax.serial(
        stax.GlobalSelfAttention(
            n_chan_out=width,
            n_chan_key=width,
            n_chan_val=n_chan_val,
            n_heads=n_heads,
            linear_scaling=True,
            W_key_std=W_std,
            W_value_std=W_std,
            W_query_std=W_std,
            W_out_std=1.0,
            b_std=b_std,
            batch_axis=batch_axis,
            channel_axis=channel_axis),
        stax.Flatten(batch_axis, batch_axis_fc))
  else:
    raise ValueError(proj_into_2d)

  readout = stax.serial(proj_layer,
                        fc(1 if is_ntk else width, (s, 1 if is_ntk else s)))

  device_count = -1 if spec.index('N') == 0 else 0

  net = stax.serial(block, readout)
  return net, input_shape, device_count, channel_axis_fc
Esempio n. 5
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
             phi, strides, width, is_ntk, proj_into_2d, layer_norm,
             parameterization, use_dropout, dimension_numbers):
    fc = partial(stax.Dense,
                 W_std=W_std,
                 b_std=b_std,
                 parameterization=parameterization)

    def conv(out_chan):
        return stax.GeneralConv(dimension_numbers=dimension_numbers,
                                out_chan=out_chan,
                                filter_shape=filter_shape,
                                strides=strides,
                                padding=padding,
                                W_std=W_std,
                                b_std=b_std,
                                parameterization=parameterization)

    affine = conv(width) if is_conv else fc(width)

    spec = dimension_numbers[-1]

    rate = np.onp.random.uniform(0.5, 0.9)
    dropout = stax.Dropout(rate, mode='train')

    dropout_or_identity = dropout if use_dropout else stax.Identity()
    avg_pool_or_identity = stax.AvgPool(
        window_shape=(2, 3),
        strides=None,
        padding='SAME' if padding == 'SAME' else 'CIRCULAR',
        spec=spec) if use_pooling else stax.Identity()
    layer_norm_or_identity = (stax.Identity() if layer_norm is None else
                              stax.LayerNorm(axis=layer_norm, spec=spec))

    res_unit = stax.serial(avg_pool_or_identity, phi, dropout_or_identity,
                           affine)

    if is_res:
        block = stax.serial(affine, stax.FanOut(2),
                            stax.parallel(stax.Identity(), res_unit),
                            stax.FanInSum(), layer_norm_or_identity)
    else:
        block = stax.serial(affine, res_unit, layer_norm_or_identity)

    if proj_into_2d == 'FLAT':
        proj_layer = stax.Flatten(spec=spec)
    elif proj_into_2d == 'POOL':
        proj_layer = stax.GlobalAvgPool(spec=spec)
    elif proj_into_2d.startswith('ATTN'):
        n_heads = int(np.sqrt(width))
        n_chan_val = int(np.round(float(width) / n_heads))
        fixed = proj_into_2d == 'ATTN_FIXED'
        proj_layer = stax.serial(
            stax.GlobalSelfAttention(width,
                                     n_chan_key=width,
                                     n_chan_val=n_chan_val,
                                     n_heads=n_heads,
                                     fixed=fixed,
                                     W_key_std=W_std,
                                     W_value_std=W_std,
                                     W_query_std=W_std,
                                     W_out_std=1.0,
                                     b_std=b_std,
                                     spec=spec), stax.Flatten(spec=spec))
    else:
        raise ValueError(proj_into_2d)
    readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

    return stax.serial(block, readout)
Esempio n. 6
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
             phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm,
             parameterization, use_dropout):

  if is_conv:
    # Select a random dimension order.
    default_spec = 'NHWC'
    if xla_bridge.get_backend().platform == 'tpu':
      # Keep batch dimension leading for TPU for batching to work.
      specs = ['NHWC', 'NHCW', 'NCHW']
    else:
      specs = ['NHWC', 'NHCW', 'NCHW', 'CHWN', 'CHNW', 'CNHW']
    spec = prandom.choice(specs)
    input_shape = tuple(INPUT_SHAPE[default_spec.index(c)] for c in spec)

    if layer_norm:
      layer_norm = tuple(spec.index(c) for c in layer_norm)

  else:
    # Only `NC` dimension order is supported and is enforced by layers.
    spec = None
    input_shape = INPUT_SHAPE
    if layer_norm:
      layer_norm = prandom.choice([(1,), (-1,)])

  dimension_numbers = (spec, 'HWIO', spec)

  fc = partial(
      stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization)

  def conv(out_chan): return stax.GeneralConv(
      dimension_numbers=dimension_numbers,
      out_chan=out_chan,
      filter_shape=filter_shape,
      strides=strides,
      padding=padding,
      W_std=W_std,
      b_std=b_std,
      parameterization=parameterization
  )
  affine = conv(width) if is_conv else fc(width)

  spec = dimension_numbers[-1]

  rate = np.onp.random.uniform(0.5, 0.9)
  dropout = stax.Dropout(rate, mode='train')

  if pool_type == 'AVG':
    pool_fn = stax.AvgPool
    globalPool_fn = stax.GlobalAvgPool
  elif pool_type == 'SUM':
    pool_fn = stax.SumPool
    globalPool_fn = stax.GlobalSumPool

  if use_pooling:
    pool_or_identity = pool_fn((2, 3),
                               None,
                               'SAME' if padding == 'SAME' else 'CIRCULAR',
                               spec=spec)
  else:
    pool_or_identity = stax.Identity()
  dropout_or_identity = dropout if use_dropout else stax.Identity()
  layer_norm_or_identity = (stax.Identity() if layer_norm is None
                            else stax.LayerNorm(axis=layer_norm, spec=spec))
  res_unit = stax.serial(pool_or_identity, phi, dropout_or_identity, affine)
  if is_res:
    block = stax.serial(
        affine,
        stax.FanOut(2),
        stax.parallel(stax.Identity(),
                      res_unit),
        stax.FanInSum(),
        layer_norm_or_identity)
  else:
    block = stax.serial(
        affine,
        res_unit,
        layer_norm_or_identity)

  if proj_into_2d == 'FLAT':
    proj_layer = stax.Flatten(spec=spec)
  elif proj_into_2d == 'POOL':
    proj_layer = globalPool_fn(spec=spec)
  elif proj_into_2d.startswith('ATTN'):
    n_heads = int(np.sqrt(width))
    n_chan_val = int(np.round(float(width) / n_heads))
    fixed = proj_into_2d == 'ATTN_FIXED'
    proj_layer = stax.serial(
        stax.GlobalSelfAttention(
            n_chan_out=width,
            n_chan_key=width,
            n_chan_val=n_chan_val,
            n_heads=n_heads,
            fixed=fixed,
            W_key_std=W_std,
            W_value_std=W_std,
            W_query_std=W_std,
            W_out_std=1.0,
            b_std=b_std,
            spec=spec), stax.Flatten(spec=spec))
  else:
    raise ValueError(proj_into_2d)
  readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

  return stax.serial(block, readout), input_shape