コード例 #1
0
    def testAutomatic(self, train_shape, test_shape, network, name, kernel_fn,
                      batch_size):
        test_utils.stub_out_pmap(batch, 2)

        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(key, 3)
        key = keys[0]
        self_split = keys[1]
        other_split = keys[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))

        kernel_fn = kernel_fn(key, train_shape[1:], network)

        kernel_batched = batch.batch(kernel_fn, batch_size=batch_size)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

        kernel_batched = batch.batch(kernel_fn,
                                     batch_size=batch_size,
                                     store_on_device=False)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)
コード例 #2
0
    def testNTKAgainstDirect(self, train_shape, test_shape, network, name,
                             kernel_fn):
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=3)
        key = splits[0]
        self_split = splits[1]
        other_split = splits[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))

        implicit, direct, _ = kernel_fn(key,
                                        train_shape[1:],
                                        network,
                                        diagonal_axes=(),
                                        trace_axes=())

        g = implicit(data_self, None)
        g_direct = direct(data_self, None)
        self.assertAllClose(g, g_direct)

        g = implicit(data_other, data_self)
        g_direct = direct(data_other, data_self)
        self.assertAllClose(g, g_direct)
コード例 #3
0
    def _test_analytic_kernel_composition(self, batching_fn):
        # Check Fully-Connected.
        rng = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(rng)
        rng_self = keys[0]
        rng_other = keys[1]
        x_self = np.asarray(normal((8, 10), seed=rng_self))
        x_other = np.asarray(normal((2, 10), seed=rng_other))
        Block = stax.serial(stax.Dense(256), stax.Relu())

        _, _, ker_fn = Block
        ker_fn = batching_fn(ker_fn)

        _, _, composed_ker_fn = stax.serial(Block, Block)

        ker_out = ker_fn(ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            # In the parallel setting, `x1_is_x2` is not computed correctly
            # when x1==x2.
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)

        ker_out = ker_fn(ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)

        # Check convolutional + pooling.
        x_self = np.asarray(normal((8, 10, 10, 3), seed=rng))
        x_other = np.asarray(normal((2, 10, 10, 3), seed=rng))

        Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu())
        Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)
        block_ker_fn = batching_fn(block_ker_fn)
        readout_ker_fn = batching_fn(readout_ker_fn)

        ker_out = readout_ker_fn(block_ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)
        ker_out = readout_ker_fn(block_ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)
コード例 #4
0
    def testPredictOnCPU(self):
        key1 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        key2 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        key3 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        x_train = np.asarray(normal((4, 4, 4, 2), seed=key1))
        x_test = np.asarray(normal((8, 4, 4, 2), seed=key2))

        y_train = np.asarray(stateless_uniform(shape=(4, 2), seed=key3))

        _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(),
                                      stax.Flatten(), stax.Dense(1))

        for store_on_device in [False, True]:
            for device_count in [0, 1]:
                for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
                    for x in [None, 'x_test']:
                        with self.subTest(store_on_device=store_on_device,
                                          device_count=device_count,
                                          get=get,
                                          x=x):
                            kernel_fn_batched = batch.batch(
                                kernel_fn, 2, device_count, store_on_device)
                            predictor = predict.gradient_descent_mse_ensemble(
                                kernel_fn_batched, x_train, y_train)

                            x = x if x is None else x_test
                            predict_none = predictor(None,
                                                     x,
                                                     get,
                                                     compute_cov=True)
                            predict_inf = predictor(np.inf,
                                                    x,
                                                    get,
                                                    compute_cov=True)
                            self.assertAllClose(predict_none, predict_inf)

                            if x is not None:
                                on_cpu = (not store_on_device
                                          or xla_bridge.get_backend().platform
                                          == 'cpu')
                                self.assertEqual(on_cpu,
                                                 utils.is_on_cpu(predict_inf))
                                self.assertEqual(on_cpu,
                                                 utils.is_on_cpu(predict_none))
コード例 #5
0
    def testAxes(self, diagonal_axes, trace_axes):
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=3)
        key = splits[0]
        self_split = splits[1]
        other_split = splits[2]
        data_self = np.asarray(normal((4, 5, 6, 3), seed=self_split))
        data_other = np.asarray(normal((2, 5, 6, 3), seed=other_split))

        _diagonal_axes = utils.canonicalize_axis(diagonal_axes, data_self)
        _trace_axes = utils.canonicalize_axis(trace_axes, data_self)

        if any(d == c for d in _diagonal_axes for c in _trace_axes):
            raise absltest.SkipTest(
                'diagonal axes must be different from channel axes.')

        implicit, direct, nngp = KERNELS['empirical_logits_3'](
            key, (5, 6, 3),
            CONV,
            diagonal_axes=diagonal_axes,
            trace_axes=trace_axes)

        n_marg = len(_diagonal_axes)
        n_chan = len(_trace_axes)

        g = implicit(data_self, None)
        g_direct = direct(data_self, None)
        g_nngp = nngp(data_self, None)

        self.assertAllClose(g, g_direct)
        self.assertEqual(g_nngp.shape, g.shape)
        self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim)

        if 0 not in _trace_axes and 0 not in _diagonal_axes:
            g = implicit(data_other, data_self)
            g_direct = direct(data_other, data_self)
            g_nngp = nngp(data_other, data_self)

            self.assertAllClose(g, g_direct)
            self.assertEqual(g_nngp.shape, g.shape)
            self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg,
                             g_nngp.ndim)
コード例 #6
0
    def testMaxLearningRate(self, train_shape, network, out_logits,
                            fn_and_kernel):

        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)

        keys = tf_random_split(key)
        key = keys[0]
        split = keys[1]
        if len(train_shape) == 2:
            train_shape = (train_shape[0] * 5, train_shape[1] * 10)
        else:
            train_shape = (16, 8, 8, 3)
        x_train = np.asarray(normal(train_shape, seed=split))

        keys = tf_random_split(key)
        key = keys[0]
        split = keys[1]
        y_train = np.asarray(
            stateless_uniform(shape=(train_shape[0], out_logits),
                              seed=split,
                              minval=0,
                              maxval=1) < 0.5, np.float32)
        # Regress to an MSE loss.
        loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train)**2)
        grad_loss = jit(grad(loss))

        def get_loss(opt_state):
            return loss(get_params(opt_state), x_train)

        steps = 20

        for lr_factor in [0.5, 3.]:
            params, f, ntk = fn_and_kernel(key, train_shape[1:], network,
                                           out_logits)
            g_dd = ntk(x_train, None, 'ntk')

            step_size = predict.max_learning_rate(
                g_dd, y_train_size=y_train.size) * lr_factor
            opt_init, opt_update, get_params = optimizers.sgd(step_size)
            opt_state = opt_init(params)

            init_loss = get_loss(opt_state)

            for i in range(steps):
                params = get_params(opt_state)
                opt_state = opt_update(i, grad_loss(params, x_train),
                                       opt_state)

            trained_loss = get_loss(opt_state)
            loss_ratio = trained_loss / (init_loss + 1e-12)
            if lr_factor == 3.:
                if not math.isnan(loss_ratio):
                    self.assertGreater(loss_ratio, 10.)
            else:
                self.assertLess(loss_ratio, 0.1)
コード例 #7
0
    def testGradientDescentMseEnsembleTrain(self):
        key = stateless_uniform(shape=[2],
                                seed=[1, 1],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        x = np.asarray(normal((8, 4, 6, 3), seed=key))
        _, _, kernel_fn = stax.serial(stax.Conv(1, (2, 2)), stax.Relu(),
                                      stax.Conv(1, (2, 1)))
        y = np.asarray(normal((8, 2, 5, 1), seed=key))
        predictor = predict.gradient_descent_mse_ensemble(kernel_fn, x, y)

        for t in [None, np.array([0., 1., 10.])]:
            with self.subTest(t=t):
                y_none = predictor(t, None, None, compute_cov=True)
                y_x = predictor(t, x, None, compute_cov=True)
                self._assertAllClose(y_none, y_x, 0.04)
コード例 #8
0
    def testSerial(self, train_shape, test_shape, network, name, kernel_fn,
                   batch_size):
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(key, 3)
        key = keys[0]
        self_split = keys[1]
        other_split = keys[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))
        kernel_fn = kernel_fn(key, train_shape[1:], network)
        kernel_batched = batch._serial(kernel_fn, batch_size=batch_size)

        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)
コード例 #9
0
    def testLinearization(self, shape):
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=4)
        key = splits[0]
        s1 = splits[1]
        s2 = splits[2]
        s3 = splits[3]
        w1 = np.asarray(normal(shape, seed=s1))
        w1 = 0.5 * (w1 + w1.T)
        w2 = np.asarray(normal(shape, seed=s2))
        b = np.asarray(normal((shape[-1], ), seed=s3))
        params = (w1, w2, b)

        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=2)
        key = splits[0]
        split = splits[1]
        x0 = np.asarray(normal((shape[-1], ), seed=split))

        f_lin = empirical.linearize(EmpiricalTest.f, x0)

        for _ in range(TAYLOR_RANDOM_SAMPLES):
            for do_alter in [True, False]:
                for do_shift_x in [True, False]:
                    splits = tf_random_split(seed=tf.convert_to_tensor(
                        key, dtype=tf.int32),
                                             num=2)
                    key = splits[0]
                    split = splits[1]
                    x = np.asarray(normal((shape[-1], ), seed=split))
                    self.assertAllClose(
                        EmpiricalTest.f_lin_exact(x0,
                                                  x,
                                                  params,
                                                  do_alter,
                                                  do_shift_x=do_shift_x),
                        f_lin(x, params, do_alter, do_shift_x=do_shift_x))
コード例 #10
0
    def testParallel(self, train_shape, test_shape, network, name, kernel_fn):
        test_utils.stub_out_pmap(batch, 2)
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(key, 3)
        key = keys[0]
        self_split = keys[1]
        other_split = keys[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))

        kernel_fn = kernel_fn(key, train_shape[1:], network, use_dropout=False)
        kernel_batched = batch._parallel(kernel_fn)

        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other, True)
コード例 #11
0
def _get_inputs_and_model(width=1, n_classes=2, use_conv=True):
    key = stateless_uniform(shape=[2],
                            seed=[1, 1],
                            minval=None,
                            maxval=None,
                            dtype=tf.int32)
    keys = tf_random_split(key)
    key = keys[0]
    split = keys[1]
    x1 = np.asarray(normal((8, 4, 3, 2), seed=key))
    x2 = np.asarray(normal((4, 4, 3, 2), seed=split))

    if not use_conv:
        x1 = np.reshape(x1, (x1.shape[0], -1))
        x2 = np.reshape(x2, (x2.shape[0], -1))

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Conv(width, (3, 3)) if use_conv else stax.Dense(width),
        stax.Relu(), stax.Flatten(), stax.Dense(n_classes, 2., 0.5))
    return x1, x2, init_fn, apply_fn, kernel_fn, key
コード例 #12
0
 def _get_inputs(cls, out_logits, test_shape, train_shape):
     key = stateless_uniform(shape=[2],
                             seed=[0, 0],
                             minval=None,
                             maxval=None,
                             dtype=tf.int32)
     keys = tf_random_split(key)
     key = keys[0]
     split = keys[1]
     x_train = np.asarray(normal(train_shape, seed=split))
     keys = tf_random_split(key)
     key = keys[0]
     split = keys[1]
     y_train = np.asarray(
         stateless_uniform(shape=(train_shape[0], out_logits),
                           seed=split,
                           minval=0,
                           maxval=1) < 0.5, np.float32)
     keys = tf_random_split(key)
     key = keys[0]
     split = keys[1]
     x_test = np.asarray(normal(test_shape, seed=split))
     return key, x_test, x_train, y_train
コード例 #13
0
    def testPredictND(self):
        n_chan = 6
        key = stateless_uniform(shape=[2],
                                seed=[1, 1],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        im_shape = (5, 4, 3)
        n_train = 2
        n_test = 2
        x_train = np.asarray(normal((n_train, ) + im_shape, seed=key))
        y_train = stateless_uniform(shape=(n_train, 3, 2, n_chan), seed=key)
        init_fn, apply_fn, _ = stax.Conv(n_chan, (3, 2), (1, 2))
        _, params = init_fn(key, x_train.shape)
        fx_train_0 = apply_fn(params, x_train)

        for trace_axes in [(), (-1, ), (-2, ), (-3, ), (0, 1), (2, 3), (2, ),
                           (1, 3), (0, -1), (0, 0, -3), (0, 1, 2, 3),
                           (0, 1, -1, 2)]:
            for ts in [None, np.arange(6).reshape((2, 3))]:
                for x in [None, 'x_test']:
                    with self.subTest(trace_axes=trace_axes, ts=ts, x=x):
                        t_shape = ts.shape if ts is not None else ()
                        y_test_shape = t_shape + (n_test, ) + y_train.shape[1:]
                        y_train_shape = t_shape + y_train.shape
                        x = x if x is None else np.asarray(
                            normal((n_test, ) + im_shape, seed=key))
                        fx_test_0 = None if x is None else apply_fn(params, x)

                        kernel_fn = empirical.empirical_kernel_fn(
                            apply_fn, trace_axes=trace_axes)

                        # TODO(romann): investigate the SIGTERM error on CPU.
                        # kernel_fn = jit(kernel_fn, static_argnums=(2,))
                        ntk_train_train = kernel_fn(x_train, None, 'ntk',
                                                    params)
                        if x is not None:
                            ntk_test_train = kernel_fn(x, x_train, 'ntk',
                                                       params)

                        loss = lambda x, y: 0.5 * np.mean(x - y)**2
                        predict_fn_mse = predict.gradient_descent_mse(
                            ntk_train_train, y_train, trace_axes=trace_axes)

                        predict_fn_mse_ensemble = predict.gradient_descent_mse_ensemble(
                            kernel_fn,
                            x_train,
                            y_train,
                            trace_axes=trace_axes,
                            params=params)

                        if x is None:
                            p_train_mse = predict_fn_mse(ts, fx_train_0)
                        else:
                            p_train_mse, p_test_mse = predict_fn_mse(
                                ts, fx_train_0, fx_test_0, ntk_test_train)
                            self.assertAllClose(y_test_shape, p_test_mse.shape)
                        self.assertAllClose(y_train_shape, p_train_mse.shape)

                        p_nngp_mse_ens, p_ntk_mse_ens = predict_fn_mse_ensemble(
                            ts, x, ('nngp', 'ntk'), compute_cov=True)
                        ref_shape = y_train_shape if x is None else y_test_shape
                        self.assertAllClose(ref_shape,
                                            p_ntk_mse_ens.mean.shape)
                        self.assertAllClose(ref_shape,
                                            p_nngp_mse_ens.mean.shape)

                        if ts is not None:
                            predict_fn = predict.gradient_descent(
                                loss,
                                ntk_train_train,
                                y_train,
                                trace_axes=trace_axes)

                            if x is None:
                                p_train = predict_fn(ts, fx_train_0)
                            else:
                                p_train, p_test = predict_fn(
                                    ts, fx_train_0, fx_test_0, ntk_test_train)
                                self.assertAllClose(y_test_shape, p_test.shape)
                            self.assertAllClose(y_train_shape, p_train.shape)
コード例 #14
0
    def testGpInference(self):
        reg = 1e-5
        key = stateless_uniform(shape=[2],
                                seed=[1, 1],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        x_train = np.asarray(normal((4, 2), seed=key))
        init_fn, apply_fn, kernel_fn_analytic = stax.serial(
            stax.Dense(32, 2., 0.5), stax.Relu(), stax.Dense(10, 2., 0.5))
        y_train = np.asarray(normal((4, 10), seed=key))
        for kernel_fn_is_analytic in [True, False]:
            if kernel_fn_is_analytic:
                kernel_fn = kernel_fn_analytic
            else:
                _, params = init_fn(key, x_train.shape)
                kernel_fn_empirical = empirical.empirical_kernel_fn(apply_fn)

                def kernel_fn(x1, x2, get):
                    return kernel_fn_empirical(x1, x2, get, params)

            for get in [
                    None, 'nngp', 'ntk', ('nngp', ), ('ntk', ),
                ('nngp', 'ntk'), ('ntk', 'nngp')
            ]:
                k_dd = kernel_fn(x_train, None, get)

                gp_inference = predict.gp_inference(k_dd,
                                                    y_train,
                                                    diag_reg=reg)
                gd_ensemble = predict.gradient_descent_mse_ensemble(
                    kernel_fn, x_train, y_train, diag_reg=reg)
                for x_test in [None, 'x_test']:
                    x_test = None if x_test is None else np.asarray(
                        normal((8, 2), seed=key))
                    k_td = None if x_test is None else kernel_fn(
                        x_test, x_train, get)

                    for compute_cov in [True, False]:
                        with self.subTest(
                                kernel_fn_is_analytic=kernel_fn_is_analytic,
                                get=get,
                                x_test=x_test if x_test is None else 'x_test',
                                compute_cov=compute_cov):
                            if compute_cov:
                                nngp_tt = (True if x_test is None else
                                           kernel_fn(x_test, None, 'nngp'))
                            else:
                                nngp_tt = None

                            out_ens = gd_ensemble(None, x_test, get,
                                                  compute_cov)
                            out_ens_inf = gd_ensemble(np.inf, x_test, get,
                                                      compute_cov)
                            self._assertAllClose(out_ens_inf, out_ens, 0.08)

                            if (get is not None and 'nngp' not in get
                                    and compute_cov and k_td is not None):
                                with self.assertRaises(ValueError):
                                    out_gp_inf = gp_inference(
                                        get=get,
                                        k_test_train=k_td,
                                        nngp_test_test=nngp_tt)
                            else:
                                out_gp_inf = gp_inference(
                                    get=get,
                                    k_test_train=k_td,
                                    nngp_test_test=nngp_tt)
                                self.assertAllClose(out_ens, out_gp_inf)
コード例 #15
0
    def testTaylorExpansion(self, shape):
        def f_2_exact(x0, x, params, do_alter, do_shift_x=True):
            w1, w2, b = params
            f_lin = EmpiricalTest.f_lin_exact(x0, x, params, do_alter,
                                              do_shift_x)
            if do_shift_x:
                x0 = x0 * 2 + 1.
                x = x * 2 + 1.
            if do_alter:
                b *= 2.
                w1 += 5.
                w2 /= 0.9
            dx = x - x0
            return f_lin + 0.5 * np.dot(np.dot(dx.T, w1), dx)

        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=4)
        key = splits[0]
        s1 = splits[1]
        s2 = splits[2]
        s3 = splits[3]
        w1 = np.asarray(normal(shape, seed=s1))
        w1 = 0.5 * (w1 + w1.T)
        w2 = np.asarray(normal(shape, seed=s2))
        b = np.asarray(normal((shape[-1], ), seed=s3))
        params = (w1, w2, b)

        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=2)
        key = splits[0]
        split = splits[1]
        x0 = np.asarray(normal((shape[-1], ), seed=split))

        f_lin = empirical.taylor_expand(EmpiricalTest.f, x0, 1)
        f_2 = empirical.taylor_expand(EmpiricalTest.f, x0, 2)

        for _ in range(TAYLOR_RANDOM_SAMPLES):
            for do_alter in [True, False]:
                for do_shift_x in [True, False]:
                    splits = tf_random_split(seed=tf.convert_to_tensor(
                        key, dtype=tf.int32),
                                             num=2)
                    key = splits[0]
                    split = splits[1]
                    x = np.asarray(normal((shape[-1], ), seed=split))
                    self.assertAllClose(
                        EmpiricalTest.f_lin_exact(x0,
                                                  x,
                                                  params,
                                                  do_alter,
                                                  do_shift_x=do_shift_x),
                        f_lin(x, params, do_alter, do_shift_x=do_shift_x))
                    self.assertAllClose(
                        f_2_exact(x0,
                                  x,
                                  params,
                                  do_alter,
                                  do_shift_x=do_shift_x),
                        f_2(x, params, do_alter, do_shift_x=do_shift_x))