Пример #1
0
 def get_attn():
     return stax.GlobalSelfAttention(
         n_chan_out=width,
         n_chan_key=width,
         n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))),
         n_heads=int(np.sqrt(width)),
     ) if proj == 'avg' else stax.Identity()
Пример #2
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
             phi, strides, width, is_ntk, proj_into_2d, layer_norm,
             parameterization, use_dropout):
    fc = partial(stax.Dense,
                 W_std=W_std,
                 b_std=b_std,
                 parameterization=parameterization)
    conv = partial(stax.Conv,
                   filter_shape=filter_shape,
                   strides=strides,
                   padding=padding,
                   W_std=W_std,
                   b_std=b_std,
                   parameterization=parameterization)
    affine = conv(width) if is_conv else fc(width)
    rate = np.onp.random.uniform(0.5, 0.9)
    dropout = stax.Dropout(rate, mode='train')
    ave_pool = stax.AvgPool((2, 3), None,
                            'SAME' if padding == 'SAME' else 'CIRCULAR')
    ave_pool_or_identity = ave_pool if use_pooling else stax.Identity()
    dropout_or_identity = dropout if use_dropout else stax.Identity()
    layer_norm_or_identity = (stax.Identity() if layer_norm is None else
                              stax.LayerNorm(axis=layer_norm))
    res_unit = stax.serial(ave_pool_or_identity, phi, dropout_or_identity,
                           affine)
    if is_res:
        block = stax.serial(affine, stax.FanOut(2),
                            stax.parallel(stax.Identity(), res_unit),
                            stax.FanInSum(), layer_norm_or_identity)
    else:
        block = stax.serial(affine, res_unit, layer_norm_or_identity)

    if proj_into_2d == 'FLAT':
        proj_layer = stax.Flatten()
    elif proj_into_2d == 'POOL':
        proj_layer = stax.GlobalAvgPool()
    elif proj_into_2d.startswith('ATTN'):
        n_heads = int(np.sqrt(width))
        n_chan_val = int(np.round(float(width) / n_heads))
        fixed = proj_into_2d == 'ATTN_FIXED'
        proj_layer = stax.serial(
            stax.GlobalSelfAttention(width,
                                     n_chan_key=width,
                                     n_chan_val=n_chan_val,
                                     n_heads=n_heads,
                                     fixed=fixed,
                                     W_key_std=W_std,
                                     W_value_std=W_std,
                                     W_query_std=W_std,
                                     W_out_std=1.0,
                                     b_std=b_std), stax.Flatten())
    else:
        raise ValueError(proj_into_2d)
    readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

    return stax.serial(block, readout)
Пример #3
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
             padding, phi, strides, width, is_ntk, proj_into_2d):
  fc = partial(stax.Dense, W_std=W_std, b_std=b_std)
  conv = partial(
      stax.Conv,
      filter_shape=filter_shape,
      strides=strides,
      padding=padding,
      W_std=W_std,
      b_std=b_std)
  affine = conv(width) if is_conv else fc(width)

  res_unit = stax.serial((stax.AvgPool(
      (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR')
                          if use_pooling else stax.Identity()), phi, affine)

  if is_res:
    block = stax.serial(affine, stax.FanOut(2),
                        stax.parallel(stax.Identity(), res_unit),
                        stax.FanInSum())
  else:
    block = stax.serial(affine, res_unit)

  if proj_into_2d == 'FLAT':
    proj_layer = stax.Flatten()
  elif proj_into_2d == 'POOL':
    proj_layer = stax.GlobalAvgPool()
  elif proj_into_2d.startswith('ATTN'):
    n_heads = int(np.sqrt(width))
    n_chan_val = int(np.round(float(width) / n_heads))
    fixed = proj_into_2d == 'ATTN_FIXED'
    proj_layer = stax.serial(
        stax.GlobalSelfAttention(
            width, n_chan_key=width, n_chan_val=n_chan_val, n_heads=n_heads,
            fixed=fixed, W_key_std=W_std, W_value_std=W_std, W_query_std=W_std,
            W_out_std=1.0, b_std=b_std),
        stax.Flatten())
  else:
    raise ValueError(proj_into_2d)
  readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

  return stax.serial(block, readout)
Пример #4
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
             phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm,
             parameterization, s, use_dropout):

  if is_conv:
    # Select a random filter order.
    default_filter_spec = 'HW'
    filter_specs = [''.join(p) for p in itertools.permutations('HWIO')]
    filter_spec = prandom.choice(filter_specs)
    filter_shape = tuple(filter_shape[default_filter_spec.index(c)]
                         for c in filter_spec if c in default_filter_spec)
    strides = tuple(strides[default_filter_spec.index(c)]
                    for c in filter_spec if c in default_filter_spec)

    # Select the activation order.
    default_spec = 'NHWC'
    if default_backend() == 'tpu':
      # Keep batch dimension leading for TPU for batching to work.
      specs = ['N' + ''.join(p) for p in itertools.permutations('CHW')]
    else:
      specs = [''.join(p) for p in itertools.permutations('NCHW')]
    spec = prandom.choice(specs)
    input_shape = tuple(INPUT_SHAPE[default_spec.index(c)] for c in spec)

  else:
    input_shape = (INPUT_SHAPE[0], onp.prod(INPUT_SHAPE[1:]))
    if default_backend() == 'tpu':
      spec = 'NC'
    else:
      spec = prandom.choice(['NC', 'CN'])
      if spec.index('N') == 1:
        input_shape = input_shape[::-1]

    filter_spec = None

  dimension_numbers = (spec, filter_spec, spec)
  batch_axis, channel_axis = spec.index('N'), spec.index('C')

  spec_fc = ''.join(c for c in spec if c in ('N', 'C'))
  batch_axis_fc, channel_axis_fc = spec_fc.index('N'), spec_fc.index('C')

  if not is_conv:
    batch_axis = batch_axis_fc
    channel_axis = channel_axis_fc

  if layer_norm:
    layer_norm = tuple(spec.index(c) for c in layer_norm)

  def fc(out_dim, s):
    return stax.Dense(
        out_dim=out_dim,
        W_std=W_std,
        b_std=b_std,
        parameterization=parameterization,
        s=s,
        batch_axis=batch_axis_fc,
        channel_axis=channel_axis_fc
    )

  def conv(out_chan, s):
    return stax.Conv(
        out_chan=out_chan,
        filter_shape=filter_shape,
        strides=strides,
        padding=padding,
        W_std=W_std,
        b_std=b_std,
        dimension_numbers=dimension_numbers,
        parameterization=parameterization,
        s=s
    )

  affine = conv(width, (s, s)) if is_conv else fc(width, (s, s))
  affine_bottom = conv(width, (1, s)) if is_conv else fc(width, (1, s))

  rate = onp.random.uniform(0.5, 0.9)
  dropout = stax.Dropout(rate, mode='train')

  if pool_type == 'AVG':
    pool_fn = stax.AvgPool
    global_pool_fn = stax.GlobalAvgPool
  elif pool_type == 'SUM':
    pool_fn = stax.SumPool
    global_pool_fn = stax.GlobalSumPool
  else:
    raise ValueError(pool_type)

  if use_pooling:
    pool_or_identity = pool_fn((2, 3),
                               None,
                               'SAME' if padding == 'SAME' else 'CIRCULAR',
                               batch_axis=batch_axis,
                               channel_axis=channel_axis)
  else:
    pool_or_identity = stax.Identity()
  dropout_or_identity = dropout if use_dropout else stax.Identity()
  layer_norm_or_identity = (stax.Identity() if layer_norm is None else
                            stax.LayerNorm(axis=layer_norm,
                                           batch_axis=batch_axis,
                                           channel_axis=channel_axis))
  res_unit = stax.serial(dropout_or_identity, affine, pool_or_identity)
  if is_res:
    block = stax.serial(
        affine_bottom,
        stax.FanOut(2),
        stax.parallel(stax.Identity(),
                      res_unit),
        stax.FanInSum(),
        layer_norm_or_identity,
        phi)
  else:
    block = stax.serial(
        affine_bottom,
        res_unit,
        layer_norm_or_identity,
        phi)

  if proj_into_2d == 'FLAT':
    proj_layer = stax.Flatten(batch_axis, batch_axis_fc)
  elif proj_into_2d == 'POOL':
    proj_layer = global_pool_fn(batch_axis, channel_axis)
  elif proj_into_2d.startswith('ATTN'):
    n_heads = int(np.sqrt(width))
    n_chan_val = int(np.round(float(width) / n_heads))
    proj_layer = stax.serial(
        stax.GlobalSelfAttention(
            n_chan_out=width,
            n_chan_key=width,
            n_chan_val=n_chan_val,
            n_heads=n_heads,
            linear_scaling=True,
            W_key_std=W_std,
            W_value_std=W_std,
            W_query_std=W_std,
            W_out_std=1.0,
            b_std=b_std,
            batch_axis=batch_axis,
            channel_axis=channel_axis),
        stax.Flatten(batch_axis, batch_axis_fc))
  else:
    raise ValueError(proj_into_2d)

  readout = stax.serial(proj_layer,
                        fc(1 if is_ntk else width, (s, 1 if is_ntk else s)))

  device_count = -1 if spec.index('N') == 0 else 0

  net = stax.serial(block, readout)
  return net, input_shape, device_count, channel_axis_fc
Пример #5
0
def main(*args, use_dummy_data: bool = False, **kwargs) -> None:
    # Mask all padding with this value.
    mask_constant = 100.

    if use_dummy_data:
        x_train, y_train, x_test, y_test = _get_dummy_data(mask_constant)
    else:
        # Build data pipelines.
        print('Loading IMDb data.')
        x_train, y_train, x_test, y_test = datasets.get_dataset(
            name='imdb_reviews',
            n_train=FLAGS.n_train,
            n_test=FLAGS.n_test,
            do_flatten_and_normalize=False,
            data_dir=FLAGS.imdb_path,
            input_key='text')

        # Embed words and pad / truncate sentences to a fixed size.
        x_train, x_test = datasets.embed_glove(
            xs=[x_train, x_test],
            glove_path=FLAGS.glove_path,
            max_sentence_length=FLAGS.max_sentence_length,
            mask_constant=mask_constant)

    # Build the infinite network.
    # Not using the finite model, hence width is set to 1 everywhere.
    _, _, kernel_fn = stax.serial(
        stax.Conv(out_chan=1,
                  filter_shape=(9, ),
                  strides=(1, ),
                  padding='VALID'), stax.Relu(),
        stax.GlobalSelfAttention(n_chan_out=1,
                                 n_chan_key=1,
                                 n_chan_val=1,
                                 pos_emb_type='SUM',
                                 W_pos_emb_std=1.,
                                 pos_emb_decay_fn=lambda d: 1 / (1 + d**2),
                                 n_heads=1), stax.Relu(), stax.GlobalAvgPool(),
        stax.Dense(out_dim=1))

    # Optionally, compute the kernel in batches, in parallel.
    kernel_fn = nt.batch(kernel_fn,
                         device_count=-1,
                         batch_size=FLAGS.batch_size)

    start = time.time()
    # Bayesian and infinite-time gradient descent inference with infinite network.
    predict = nt.predict.gradient_descent_mse_ensemble(
        kernel_fn=kernel_fn,
        x_train=x_train,
        y_train=y_train,
        diag_reg=1e-6,
        mask_constant=mask_constant)

    fx_test_nngp, fx_test_ntk = predict(x_test=x_test, get=('nngp', 'ntk'))

    fx_test_nngp.block_until_ready()
    fx_test_ntk.block_until_ready()

    duration = time.time() - start
    print(f'Kernel construction and inference done in {duration} seconds.')

    # Print out accuracy and loss for infinite network predictions.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
    util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
Пример #6
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
             phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm,
             parameterization, use_dropout):

  if is_conv:
    # Select a random dimension order.
    default_spec = 'NHWC'
    if xla_bridge.get_backend().platform == 'tpu':
      # Keep batch dimension leading for TPU for batching to work.
      specs = ['NHWC', 'NHCW', 'NCHW']
    else:
      specs = ['NHWC', 'NHCW', 'NCHW', 'CHWN', 'CHNW', 'CNHW']
    spec = prandom.choice(specs)
    input_shape = tuple(INPUT_SHAPE[default_spec.index(c)] for c in spec)

    if layer_norm:
      layer_norm = tuple(spec.index(c) for c in layer_norm)

  else:
    # Only `NC` dimension order is supported and is enforced by layers.
    spec = None
    input_shape = INPUT_SHAPE
    if layer_norm:
      layer_norm = prandom.choice([(1,), (-1,)])

  dimension_numbers = (spec, 'HWIO', spec)

  fc = partial(
      stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization)

  def conv(out_chan): return stax.GeneralConv(
      dimension_numbers=dimension_numbers,
      out_chan=out_chan,
      filter_shape=filter_shape,
      strides=strides,
      padding=padding,
      W_std=W_std,
      b_std=b_std,
      parameterization=parameterization
  )
  affine = conv(width) if is_conv else fc(width)

  spec = dimension_numbers[-1]

  rate = np.onp.random.uniform(0.5, 0.9)
  dropout = stax.Dropout(rate, mode='train')

  if pool_type == 'AVG':
    pool_fn = stax.AvgPool
    globalPool_fn = stax.GlobalAvgPool
  elif pool_type == 'SUM':
    pool_fn = stax.SumPool
    globalPool_fn = stax.GlobalSumPool

  if use_pooling:
    pool_or_identity = pool_fn((2, 3),
                               None,
                               'SAME' if padding == 'SAME' else 'CIRCULAR',
                               spec=spec)
  else:
    pool_or_identity = stax.Identity()
  dropout_or_identity = dropout if use_dropout else stax.Identity()
  layer_norm_or_identity = (stax.Identity() if layer_norm is None
                            else stax.LayerNorm(axis=layer_norm, spec=spec))
  res_unit = stax.serial(pool_or_identity, phi, dropout_or_identity, affine)
  if is_res:
    block = stax.serial(
        affine,
        stax.FanOut(2),
        stax.parallel(stax.Identity(),
                      res_unit),
        stax.FanInSum(),
        layer_norm_or_identity)
  else:
    block = stax.serial(
        affine,
        res_unit,
        layer_norm_or_identity)

  if proj_into_2d == 'FLAT':
    proj_layer = stax.Flatten(spec=spec)
  elif proj_into_2d == 'POOL':
    proj_layer = globalPool_fn(spec=spec)
  elif proj_into_2d.startswith('ATTN'):
    n_heads = int(np.sqrt(width))
    n_chan_val = int(np.round(float(width) / n_heads))
    fixed = proj_into_2d == 'ATTN_FIXED'
    proj_layer = stax.serial(
        stax.GlobalSelfAttention(
            n_chan_out=width,
            n_chan_key=width,
            n_chan_val=n_chan_val,
            n_heads=n_heads,
            fixed=fixed,
            W_key_std=W_std,
            W_value_std=W_std,
            W_query_std=W_std,
            W_out_std=1.0,
            b_std=b_std,
            spec=spec), stax.Flatten(spec=spec))
  else:
    raise ValueError(proj_into_2d)
  readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

  return stax.serial(block, readout), input_shape