Esempio n. 1
0
    def test_layernorm(self, model, width, same_inputs, is_ntk, proj_into_2d,
                       layer_norm):
        is_conv = 'conv' in model
        # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
        if is_conv:
            if xla_bridge.get_backend().platform == 'cpu':
                raise jtu.SkipTest(
                    'Not running CNN models on CPU to save time.')
        elif proj_into_2d != PROJECTIONS[0] or layer_norm != LAYER_NORM[0]:
            raise jtu.SkipTest('FC models do not have these parameters.')

        W_std, b_std = 2.**0.5, 0.5**0.5
        filter_size = FILTER_SIZES[0]
        padding = PADDINGS[0]
        strides = STRIDES[0]
        phi = stax.Relu()
        use_pooling, is_res = False, False
        parameterization = 'ntk'
        use_dropout = False

        self._check_agreement_with_empirical(W_std, b_std, filter_size,
                                             is_conv, is_ntk, is_res,
                                             layer_norm, padding, phi,
                                             proj_into_2d, same_inputs,
                                             strides, use_pooling, width,
                                             parameterization, use_dropout)
Esempio n. 2
0
  def test_exact(self, model, width, strides, padding, phi, same_inputs,
                 filter_size, use_pooling, is_ntk, is_res, proj_into_2d):
    is_conv = 'conv' in model

    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    if is_conv:
      if xla_bridge.get_backend().platform == 'cpu':
        raise jtu.SkipTest('Not running CNN models on CPU to save time.')

      if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or
                                  (padding == 'VALID' and filter_size !=
                                   (1, 1)))):
        raise jtu.SkipTest('Different paths in a residual models need to return'
                           ' outputs of the same shape.')
    elif (filter_size != FILTER_SIZES[0] or padding != PADDINGS[0] or
          strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or
          use_pooling):
      raise jtu.SkipTest('FC models do not have these parameters.')

    if (proj_into_2d.startswith('ATTN') and strides == (2, 1) and
        padding == 'VALID' and xla_bridge.get_backend().platform == 'tpu'):
      #TODO: speed up the vmap alternative impl or fix the current one
      raise jtu.SkipTest('ATTN forward pass on TPU is broken if one of'
                         ' the spatial dimensions is singleton.')

    W_std, b_std = 2.**0.5, 0.5**0.5
    layer_norm = None

    self._check_agreement_with_empirical(W_std, b_std, filter_size, is_conv,
                                         is_ntk, is_res, layer_norm, padding,
                                         phi, proj_into_2d, same_inputs,
                                         strides, use_pooling, width)
Esempio n. 3
0
  def test_dropout(self, model, width, same_inputs, is_ntk, padding, strides,
                   filter_shape, phi, use_pooling, proj_into_2d):
    if xla_bridge.get_backend().platform == 'tpu' and same_inputs:
      raise jtu.SkipTest(
          'Skip TPU test for `same_inputs`. Need to handle '
          'random keys carefully for dropout + empirical kernel.')

    pool_type = 'AVG'
    use_dropout = True
    is_conv = 'conv' in model
    is_res = False
    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    W_std, b_std = 2.**0.5, 0.5**0.5
    layer_norm = None
    parameterization = 'ntk'
    if is_conv:
      if xla_bridge.get_backend().platform == 'cpu':
        raise jtu.SkipTest('Not running CNN models on CPU to save time.')

      if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or
                                  (padding == 'VALID' and filter_shape !=
                                   (1, 1)))):
        raise jtu.SkipTest('Different paths in a residual models need to return'
                           ' outputs of the same shape.')
    elif (filter_shape != FILTER_SHAPES[0] or padding != PADDINGS[0] or
          strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or
          use_pooling):
      raise jtu.SkipTest('FC models do not have these parameters.')

    net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
                   padding, phi, strides, width, is_ntk, proj_into_2d,
                   pool_type, layer_norm, parameterization, use_dropout)
    self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout,
                                         is_ntk, proj_into_2d)
Esempio n. 4
0
  def test_parameterizations(self, model, width, same_inputs, is_ntk,
                             filter_shape, proj_into_2d, parameterization):
    is_conv = 'conv' in model

    W_std, b_std = 2.**0.5, 0.5**0.5
    padding = PADDINGS[0]
    strides = STRIDES[0]
    phi = stax.Relu()
    use_pooling, is_res = False, False
    layer_norm = None
    pool_type = 'AVG'
    use_dropout = False

    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    if is_conv:
      if xla_bridge.get_backend().platform == 'cpu':
        raise jtu.SkipTest('Not running CNN models on CPU to save time.')
    elif proj_into_2d != PROJECTIONS[0]:
      raise jtu.SkipTest('FC models do not have these parameters.')

    net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res,
                   padding, phi, strides, width, is_ntk, proj_into_2d,
                   pool_type, layer_norm, parameterization, use_dropout)
    self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout,
                                         is_ntk, proj_into_2d)
Esempio n. 5
0
  def test_exact(self, model, width, strides, padding, phi, same_inputs,
                 filter_size, use_pooling, is_ntk, is_res, proj_into_2d):
    is_conv = 'conv' in model

    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
    if is_conv:
      if xla_bridge.get_backend().platform == 'cpu':
        raise jtu.SkipTest('Not running CNN models on CPU to save time.')

      if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or
                                  (padding == 'VALID' and filter_size !=
                                   (1, 1)))):
        raise jtu.SkipTest('Different paths in a residual models need to return'
                           ' outputs of the same shape.')
    elif (filter_size != FILTER_SIZES[0] or padding != PADDINGS[0] or
          strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or
          use_pooling):
      raise jtu.SkipTest('FC models do not have these parameters.')

    if (proj_into_2d.startswith('ATTN') and strides == (2, 1) and
        padding == 'VALID' and xla_bridge.get_backend().platform == 'tpu'):
      #TODO(jirihron): speed up the vmap alternative impl or fix the current one
      raise jtu.SkipTest('ATTN forward pass on TPU is broken if one of'
                         ' the spatial dimensions is singleton.')

    W_std, b_std = 2.**0.5, 0.5**0.5

    key = random.PRNGKey(1)
    x1, x2 = _get_inputs(key, is_conv, same_inputs, INPUT_SHAPE)

    init_fn, apply_fn, kernel_fn = _get_net(W_std, b_std, filter_size,
                                            is_conv, use_pooling, is_res,
                                            padding, phi, strides, width,
                                            is_ntk, proj_into_2d)

    def _get_empirical(n_samples, get):
      kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn(
          init_fn, apply_fn, key, n_samples)
      return kernel_fn_empirical(x1, x2, get)

    if proj_into_2d == 'ATTN_PARAM':
      # no analytic kernel available, just test forward/backward pass
      _get_empirical(1, 'ntk' if is_ntk else 'nngp')
    else:
      if is_ntk:
        exact = kernel_fn(x1, x2, 'ntk')
        empirical = np.reshape(_get_empirical(N_SAMPLES, 'ntk'), exact.shape)
      else:
        exact = kernel_fn(x1, x2, 'nngp')
        empirical = _get_empirical(N_SAMPLES, 'nngp')
      utils.assert_close_matrices(self, empirical, exact, RTOL)
Esempio n. 6
0
    def test_exact(self, model, width, strides, padding, phi, same_inputs,
                   filter_size, use_pooling, is_ntk, is_res):
        is_conv = 'conv' in model

        # Check for duplicate / incorrectly-shaped NN configs / wrong backend.
        if is_conv:
            if xla_bridge.get_backend().platform == 'cpu':
                raise jtu.SkipTest(
                    'Not running CNN models on CPU to save time.')

            if use_pooling and not same_inputs:
                raise jtu.SkipTest(
                    'Pooling layers for different inputs or for same '
                    'padding not implemented.')

            if (is_res and is_conv
                    and ((strides is not None and strides !=
                          (1, 1)) or (padding == 'VALID' and filter_size !=
                                      (1, 1)))):
                raise jtu.SkipTest(
                    'Different paths in a residual models need to return'
                    ' outputs of the same shape.')
        elif (filter_size != FILTER_SIZES[0] or padding != PADDINGS[0]
              or strides != STRIDES[0] or use_pooling):
            raise jtu.SkipTest('FC models do not have these parameters.')

        W_std, b_std = 2.**0.5, 0.5**0.5

        key = random.PRNGKey(1)
        x1, x2 = _get_inputs(key, is_conv, same_inputs, INPUT_SHAPE)

        init_fun, apply_fun, ker_fun = _get_net(W_std, b_std, filter_size,
                                                is_conv, use_pooling, is_res,
                                                padding, phi, strides, width,
                                                is_ntk)

        if is_ntk:
            exact = ker_fun(x1, x2).ntk
            ker_fun_empirical = monte_carlo.get_ker_fun_monte_carlo(
                init_fun, apply_fun, False, True)
            empirical = ker_fun_empirical(x1, x2, key, N_SAMPLES).ntk
            empirical = np.reshape(empirical, exact.shape)
        else:
            exact = ker_fun(x1, x2, compute_ntk=False).nngp
            ker_fun_empirical = monte_carlo.get_ker_fun_monte_carlo(
                init_fun, apply_fun, True, False)
            empirical = ker_fun_empirical(x1, x2, key, N_SAMPLES).nngp

        utils.assert_close_matrices(self, empirical, exact, RTOL)
Esempio n. 7
0
  def test_pool(self, width, same_inputs, is_ntk, pool_type,
                padding, filter_shape, strides, normalize_edges):
    is_conv = True
    use_dropout = False
    proj_into_2d = 'POOL'
    # Check for duplicate / incorrectly-shaped NN configs / wrong backend.

    if xla_bridge.get_backend().platform == 'cpu':
      raise jtu.SkipTest('Not running CNN models on CPU to save time.')
    if pool_type == 'SUM' and normalize_edges:
      raise jtu.SkipTest('normalize_edges not applicable to SumPool.')

    net = _get_net_pool(width, is_ntk, pool_type,
                        padding, filter_shape, strides, normalize_edges)
    self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout,
                                         is_ntk, proj_into_2d)
Esempio n. 8
0
    def test_batch_sample_once(self, batch_size, device_count, store_on_device,
                               get):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model()
        kernel_fn = empirical.empirical_kernel_fn(apply_fn)
        sample_once_fn = monte_carlo._sample_once_kernel_fn(kernel_fn,
                                                            init_fn,
                                                            device_count=0)
        batch_sample_once_fn = batch.batch(sample_once_fn, batch_size,
                                           device_count, store_on_device)
        if get is None:
            raise jtu.SkipTest('No default `get` values for this method.')
        else:
            one_sample = sample_once_fn(x1, x2, key, get)
            one_batch_sample = batch_sample_once_fn(x1, x2, key, get)
            self.assertAllClose(one_sample, one_batch_sample, True)
Esempio n. 9
0
    def testTrainedEnsemblePredCov(self, train_shape, test_shape, network,
                                   out_logits):
        if xla_bridge.get_backend().platform == 'gpu' and config.read(
                'jax_enable_x64'):
            raise jtu.SkipTest('Not running GPU x64 to save time.')
        training_steps = 5000
        learning_rate = 1.0
        ensemble_size = 50

        init_fn, apply_fn, ker_fn = stax.serial(
            stax.Dense(1024, W_std=1.2, b_std=0.05), stax.Erf(),
            stax.Dense(out_logits, W_std=1.2, b_std=0.05))

        opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
        opt_update = jit(opt_update)

        key = random.PRNGKey(0)
        key, = random.split(key, 1)

        key, split = random.split(key)
        x_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        y_train = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)
        train = (x_train, y_train)
        key, split = random.split(key)
        x_test = np.cos(random.normal(split, test_shape))

        ensemble_key = random.split(key, ensemble_size)

        loss = jit(lambda params, x, y: 0.5 * np.mean(
            (apply_fn(params, x) - y)**2))
        grad_loss = jit(lambda state, x, y: grad(loss)
                        (get_params(state), x, y))

        def train_network(key):
            _, params = init_fn(key, (-1, ) + train_shape[1:])
            opt_state = opt_init(params)
            for i in range(training_steps):
                opt_state = opt_update(i, grad_loss(opt_state, *train),
                                       opt_state)

            return get_params(opt_state)

        params = vmap(train_network)(ensemble_key)

        ensemble_fx = vmap(apply_fn, (0, None))(params, x_test)
        ensemble_loss = vmap(loss, (0, None, None))(params, x_train, y_train)
        ensemble_loss = np.mean(ensemble_loss)
        self.assertLess(ensemble_loss, 1e-5, True)

        mean_emp = np.mean(ensemble_fx, axis=0)
        mean_subtracted = ensemble_fx - mean_emp
        cov_emp = np.einsum(
            'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / (
                mean_subtracted.shape[0] * mean_subtracted.shape[-1])

        reg = 1e-7
        ntk_predictions = predict.gp_inference(ker_fn,
                                               x_train,
                                               y_train,
                                               x_test,
                                               'ntk',
                                               reg,
                                               compute_cov=True)

        self.assertAllClose(mean_emp, ntk_predictions.mean, True, RTOL, ATOL)
        self.assertAllClose(cov_emp, ntk_predictions.covariance, True, RTOL,
                            ATOL)
Esempio n. 10
0
  def test_fan_in_conv(self,
                       same_inputs,
                       axis,
                       n_branches,
                       get,
                       branch_in,
                       readout):
    if xla_bridge.get_backend().platform == 'cpu':
      raise jtu.SkipTest('Not running CNNs on CPU to save time.')

    if axis in (None, 0, 1, 2) and branch_in == 'dense_after_branch_in':
      raise jtu.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` '
                         'require `is_gaussian`.')

    if axis == 3 and branch_in == 'dense_before_branch_in':
      raise jtu.SkipTest('`FanInConcat` on feature axis requires a dense layer '
                         'after concatenation.')

    key = random.PRNGKey(1)
    X0_1 = random.normal(key, (2, 5, 6, 3))
    X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))

    if xla_bridge.get_backend().platform == 'tpu':
      width = 2048
      n_samples = 1024
      tol = 0.02
    else:
      width = 1024
      n_samples = 512
      tol = 0.01

    conv = stax.Conv(out_chan=width,
                     filter_shape=(3, 3),
                     padding='SAME',
                     W_std=1.25,
                     b_std=0.1)

    input_layers = [conv,
                    stax.FanOut(n_branches)]

    branches = []
    for b in range(n_branches):
      branch_layers = [FanInTest._get_phi(b)]
      for i in range(b):
        branch_layers += [
            stax.Conv(
                out_chan=width,
                filter_shape=(i + 1, 4 - i),
                padding='SAME',
                W_std=1.25 + i,
                b_std=0.1 + i),
            FanInTest._get_phi(i)]

      if branch_in == 'dense_before_branch_in':
        branch_layers += [conv]
      branches += [stax.serial(*branch_layers)]

    output_layers = [
        stax.FanInSum() if axis is None else stax.FanInConcat(axis),
        stax.Relu(),
        stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten()
    ]
    if branch_in == 'dense_after_branch_in':
      output_layers.insert(1, conv)

    nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                       output_layers))

    init_fn, apply_fn, kernel_fn = stax.serial(
        nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5))

    kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn(
        init_fn,
        apply_fn,
        key,
        n_samples,
        device_count=0 if axis in (0, -4) else -1)

    exact = kernel_fn(X0_1, X0_2, get=get)
    empirical = kernel_fn_mc(X0_1, X0_2, get=get)
    empirical = empirical.reshape(exact.shape)
    utils.assert_close_matrices(self, empirical, exact, tol)
Esempio n. 11
0
  def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in):
    if axis in (None, 0) and branch_in == 'dense_after_branch_in':
      raise jtu.SkipTest('`FanInSum` and `FanInConcat(0)` '
                         'require `is_gaussian`.')

    if axis == 1 and branch_in == 'dense_before_branch_in':
      raise jtu.SkipTest('`FanInConcat` on feature axis requires a dense layer'
                         'after concatenation.')

    key = random.PRNGKey(1)
    X0_1 = random.normal(key, (10, 20))
    X0_2 = None if same_inputs else random.normal(key, (8, 20))

    if xla_bridge.get_backend().platform == 'tpu':
      width = 2048
      n_samples = 1024
      tol = 0.02
    else:
      width = 1024
      n_samples = 256
      tol = 0.01

    dense = stax.Dense(width, 1.25, 0.1)
    input_layers = [dense,
                    stax.FanOut(n_branches)]

    branches = []
    for b in range(n_branches):
      branch_layers = [FanInTest._get_phi(b)]
      for i in range(b):
        branch_layers += [
            stax.Dense(width, 1. + 2 * i, 0.5 + i),
            FanInTest._get_phi(i)]

      if branch_in == 'dense_before_branch_in':
        branch_layers += [dense]
      branches += [stax.serial(*branch_layers)]

    output_layers = [
        stax.FanInSum() if axis is None else stax.FanInConcat(axis),
        stax.Relu()
    ]
    if branch_in == 'dense_after_branch_in':
      output_layers.insert(1, dense)

    nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                       output_layers))

    if get == 'nngp':
      init_fn, apply_fn, kernel_fn = nn
    elif get == 'ntk':
      init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1.25, 0.5))
    else:
      raise ValueError(get)

    kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn(
        init_fn, apply_fn, key, n_samples, device_count=0)

    exact = kernel_fn(X0_1, X0_2, get=get)
    empirical = kernel_fn_mc(X0_1, X0_2, get=get)
    empirical = empirical.reshape(exact.shape)
    utils.assert_close_matrices(self, empirical, exact, tol)