Esempio n. 1
0
 def test_add(self):
     self.assertEqual([{
         "a": 1
     }, {
         "a": 2
     }, {
         "b": 2
     }, {
         "b": 3
     }], (test_combinations.combine(a=[1, 2]) +
          test_combinations.combine(b=[2, 3])))
Esempio n. 2
0
def test_graph_mode_only(test_class_or_method=None):
    """Decorator for ensuring tests run in graph mode.

  Must be applied to subclasses of `parameterized.TestCase` (from
  absl/testing), or a method of such a subclass.

  When applied to a test method, this decorator results in the replacement of
  that method with one new test method, executed in graph mode.

  When applied to a test class, all the methods in the class are affected.

  Args:
    test_class_or_method: the `TestCase` class or method to decorate.

  Returns:
    decorator: A generated TF `test_combinations` decorator, or if
    `test_class_or_method` is not `None`, the generated decorator applied to
    that function.
  Raises:
    SkipTest: Raised when not running in the TF backend.
  """
    if JAX_MODE or NUMPY_MODE:
        raise unittest.SkipTest(
            'Ignoring TF Graph Mode tests in non-TF backends.')

    decorator = test_combinations.generate(
        test_combinations.combine(mode=['graph']),
        test_combinations=[EagerGraphCombination()])

    if test_class_or_method:
        return decorator(test_class_or_method)
    return decorator
Esempio n. 3
0
def test_graph_and_eager_modes(test_class_or_method=None):
    """Decorator for generating graph and eager mode tests from a single test.

  Must be applied to subclasses of `parameterized.TestCase` (from
  absl/testing), or a method of such a subclass.

  When applied to a test method, this decorator results in the replacement of
  that method with a two new test methods, one executed in graph mode and the
  other in eager mode.

  When applied to a test class, all the methods in the class are affected.

  Args:
    test_class_or_method: the `TestCase` class or method to decorate.

  Returns:
    decorator: A generated TF `test_combinations` decorator, or if
    `test_class_or_method` is not `None`, the generated decorator applied to
    that function.
  """
    decorator = test_combinations.generate(
        test_combinations.combine(mode=['graph', 'eager']),
        test_combinations=[EagerGraphCombination()])

    if test_class_or_method:
        return decorator(test_class_or_method)
    return decorator
Esempio n. 4
0
 def test_arguments_sorted(self):
     self.assertEqual([
         OrderedDict([("aa", 1), ("ab", 2)]),
         OrderedDict([("aa", 1), ("ab", 3)]),
         OrderedDict([("aa", 2), ("ab", 2)]),
         OrderedDict([("aa", 2), ("ab", 3)])
     ], test_combinations.combine(ab=[2, 3], aa=[1, 2]))
Esempio n. 5
0
 def test_combine_single_parameter(self):
     self.assertEqual([{
         "a": 1,
         "b": 2
     }, {
         "a": 2,
         "b": 2
     }], test_combinations.combine(a=[1, 2], b=2))
Esempio n. 6
0
def test_all_tf_execution_regimes(test_class_or_method=None):
  """Decorator for generating a collection of tests in various contexts.

  Must be applied to subclasses of `parameterized.TestCase` (from
  `absl/testing`), or a method of such a subclass.

  When applied to a test method, this decorator results in the replacement of
  that method with a collection of new test methods, each executed under a
  different set of context managers that control some aspect of the execution
  model. This decorator generates three test scenario combinations:

    1. Eager mode with `tf.function` decorations enabled
    2. Eager mode with `tf.function` decorations disabled
    3. Graph mode (eveything)

  When applied to a test class, all the methods in the class are affected.

  Args:
    test_class_or_method: the `TestCase` class or method to decorate.

  Returns:
    decorator: A generated TF `test_combinations` decorator, or if
    `test_class_or_method` is not `None`, the generated decorator applied to
    that function.
  """
  decorator = test_combinations.generate(
      (test_combinations.combine(mode='graph',
                                 tf_function='') +
       test_combinations.combine(
           mode='eager', tf_function=['', 'no_tf_function'])),
      test_combinations=[
          EagerGraphCombination(),
          ExecuteFunctionsEagerlyCombination(),
      ])

  if test_class_or_method:
    return decorator(test_class_or_method)
  return decorator
Esempio n. 7
0
 def test_combine(self):
     self.assertEqual([{
         "a": 1,
         "b": 2
     }, {
         "a": 1,
         "b": 3
     }, {
         "a": 2,
         "b": 2
     }, {
         "a": 2,
         "b": 3
     }], test_combinations.combine(a=[1, 2], b=[2, 3]))
Esempio n. 8
0
class TestCombinationsTest(test_util.TestCase):

    #
    # These tests check that the generated names are as expected.
    #
    def test_generated_test_case_names(self):
        expected_test_names = [
            'test_snake_case_name_eager_no_tf_function',
            'test_snake_case_name_eager',
            'test_snake_case_name_graph',
            'testCamelCaseName_eager_no_tf_function',
            'testCamelCaseName_eager',
            'testCamelCaseName_graph',
        ]

        for expected_test_name in expected_test_names:
            self.assertIn(expected_test_name, dir(PretendTestCaseClass))

    def test_generated_parameterized_test_case_names(self):
        expected_test_names = [
            'test_snake_case_name_p123_eager_no_tf_function',
            'test_snake_case_name_p123_eager',
            'test_snake_case_name_p123_graph',
            'testCamelCaseNamep123_eager_no_tf_function',
            'testCamelCaseNamep123_eager',
            'testCamelCaseNamep123_graph',
        ]

        for expected_test_name in expected_test_names:
            self.assertIn(expected_test_name,
                          dir(PretendParameterizedTestCaseClass))

    def test_generated_graph_and_eager_test_case_names(self):
        expected_test_names = [
            'test_snake_case_name_eager',
            'test_snake_case_name_graph',
            'testCamelCaseName_eager',
            'testCamelCaseName_graph',
        ]

        for expected_test_name in expected_test_names:
            self.assertIn(expected_test_name,
                          dir(PretendTestCaseClassGraphAndEagerOnly))

    #
    # These tests ensure that the test generators do what they say on the tin.
    #
    @test_combinations.generate(
        test_combinations.combine(mode='graph'),
        test_combinations=[test_util.EagerGraphCombination()])
    def test_graph_mode_combination(self):
        self.assertFalse(context.executing_eagerly())

    @test_combinations.generate(
        test_combinations.combine(mode='eager'),
        test_combinations=[test_util.EagerGraphCombination()])
    def test_eager_mode_combination(self):
        self.assertTrue(context.executing_eagerly())

    @test_combinations.generate(
        test_combinations.combine(tf_function=''),
        test_combinations=[test_util.ExecuteFunctionsEagerlyCombination()])
    def test_tf_function_enabled_mode_combination(self):
        self.assertFalse(tf.config.experimental_functions_run_eagerly())

    @test_combinations.generate(
        test_combinations.combine(tf_function='no_tf_function'),
        test_combinations=[test_util.ExecuteFunctionsEagerlyCombination()])
    def test_tf_function_disabled_mode_combination(self):
        self.assertTrue(tf.config.experimental_functions_run_eagerly())
Esempio n. 9
0
    def test_add(self):
        self.assertEqual([{
            "a": 1
        }, {
            "a": 2
        }, {
            "b": 2
        }, {
            "b": 3
        }], (test_combinations.combine(a=[1, 2]) +
             test_combinations.combine(b=[2, 3])))


@test_combinations.generate(
    test_combinations.combine(a=[1, 0], b=[2, 3], c=[1]))
class CombineTheTestSuite(test_util.TestCase):
    def test_add_things(self, a, b, c):
        self.assertLessEqual(3, a + b + c)
        self.assertLessEqual(a + b + c, 5)

    def test_add_things_one_more(self, a, b, c):
        self.assertLessEqual(3, a + b + c)
        self.assertLessEqual(a + b + c, 5)

    def not_a_test(self, a=0, b=0, c=0):
        del a, b, c
        self.fail()

    def _test_but_private(self, a=0, b=0, c=0):
        del a, b, c
class MVNPrecisionFactorLinOpTest(test_util.TestCase):
    def _random_constant_spd_linop(
            self,
            event_size,
            batch_shape=(),
            conditioning=1.2,
            dtype=np.float32,
    ):
        """Randomly generate a constant SPD LinearOperator."""
        # The larger conditioning is, the better posed the matrix is.
        # With conditioning = 1, it will be on the edge of singular, and likely
        # numerically singular if event_size is large enough.
        # Conditioning on the small side is best, since then the matrix is not so
        # diagonally dominant, and we therefore test use of transpositions better.
        assert conditioning >= 1

        scale_wishart = tfd.WishartLinearOperator(
            df=dtype(conditioning * event_size),
            scale=tf.linalg.LinearOperatorIdentity(event_size, dtype=dtype),
            input_output_cholesky=False,
        )
        # Make sure to evaluate here. This ensures that the linear operator is a
        # constant rather than a random operator.
        matrix = self.evaluate(
            scale_wishart.sample(batch_shape, seed=test_util.test_seed()))
        return tf.linalg.LinearOperatorFullMatrix(matrix,
                                                  is_positive_definite=True,
                                                  is_self_adjoint=True)

    @test_combinations.generate(
        test_combinations.combine(
            use_loc=[True, False],
            use_precision=[True, False],
            event_size=[3],
            batch_shape=[(), (2, )],
            n_samples=[5000],
            dtype=[np.float32, np.float64],
        ), )
    def test_log_prob_and_sample(
        self,
        use_loc,
        use_precision,
        event_size,
        batch_shape,
        dtype,
        n_samples,
    ):
        cov = self._random_constant_spd_linop(event_size,
                                              batch_shape=batch_shape,
                                              dtype=dtype)
        precision = cov.inverse()
        precision_factor = precision.cholesky()

        # Make sure to evaluate here, else you'll have a random loc vector!
        if use_loc:
            loc = self.evaluate(
                tf.random.normal(batch_shape + (event_size, ),
                                 dtype=dtype,
                                 seed=test_util.test_seed()))
        else:
            loc = None

        mvn_scale = tfd.MultivariateNormalTriL(
            loc=loc, scale_tril=cov.cholesky().to_dense())

        mvn_precision = tfd_e.MultivariateNormalPrecisionFactorLinearOperator(
            loc=loc,
            precision_factor=precision_factor,
            precision=precision if use_precision else None,
        )

        point = tf.random.normal(batch_shape + (event_size, ),
                                 dtype=dtype,
                                 seed=test_util.test_seed())
        mvn_scale_log_prob, mvn_precision_log_prob = self.evaluate(
            [mvn_scale.log_prob(point),
             mvn_precision.log_prob(point)])
        self.assertAllClose(mvn_scale_log_prob,
                            mvn_precision_log_prob,
                            atol=5e-4,
                            rtol=5e-4)

        batch_point = tf.random.normal((2, ) + batch_shape + (event_size, ),
                                       dtype=dtype,
                                       seed=test_util.test_seed())
        mvn_scale_log_prob, mvn_precision_log_prob = self.evaluate([
            mvn_scale.log_prob(batch_point),
            mvn_precision.log_prob(batch_point)
        ])
        self.assertAllClose(mvn_scale_log_prob,
                            mvn_precision_log_prob,
                            atol=5e-4,
                            rtol=5e-4)

        samples = mvn_precision.sample(n_samples, seed=test_util.test_seed())
        arrs = self.evaluate({
            'stddev':
            tf.sqrt(cov.diag_part()),
            'var':
            cov.diag_part(),
            'cov':
            cov.to_dense(),
            'samples':
            samples,
            'sample_var':
            tfp.stats.variance(samples, sample_axis=0),
            'sample_cov':
            tfp.stats.covariance(samples, sample_axis=0),
        })

        self.assertAllMeansClose(
            arrs['samples'],
            loc if loc is not None else np.zeros_like(arrs['cov'][..., 0]),
            axis=0,
            atol=5 * np.max(arrs['stddev']) / np.sqrt(n_samples))
        self.assertAllClose(arrs['sample_var'],
                            arrs['var'],
                            atol=5 * np.sqrt(2) * np.max(arrs['var']) /
                            np.sqrt(n_samples))
        self.assertAllClose(arrs['sample_cov'],
                            arrs['cov'],
                            atol=5 * np.sqrt(2) * np.max(arrs['var']) /
                            np.sqrt(n_samples))

    def test_dynamic_shape(self):
        x = tf.Variable(ps.ones([7, 3]), shape=[7, None])
        self.evaluate(x.initializer)

        # Check that the shape is actually `None`.
        if not tf.executing_eagerly():
            last_shape = x.shape[-1]
            if last_shape is not None:  # This is a `tf.Dimension` in tf1.
                last_shape = last_shape.value
            self.assertIsNone(last_shape)
        dynamic_dist = tfd_e.MultivariateNormalPrecisionFactorLinearOperator(
            precision_factor=tf.linalg.LinearOperatorDiag(tf.ones_like(x)))
        static_dist = tfd_e.MultivariateNormalPrecisionFactorLinearOperator(
            precision_factor=tf.linalg.LinearOperatorDiag(tf.ones([7, 3])))
        in_ = tf.zeros([7, 3])
        self.assertAllClose(self.evaluate(dynamic_dist.log_prob(in_)),
                            static_dist.log_prob(in_))

    @test_combinations.generate(
        test_combinations.combine(
            batch_shape=[(), (2, )],
            dtype=[np.float32, np.float64],
        ), )
    def test_mean_and_mode(self, batch_shape, dtype):
        event_size = 3
        cov = self._random_constant_spd_linop(event_size,
                                              batch_shape=batch_shape,
                                              dtype=dtype)
        precision_factor = cov.inverse().cholesky()

        # Make sure to evaluate here, else you'll have a random loc vector!
        loc = self.evaluate(
            tf.random.normal(batch_shape + (event_size, ),
                             dtype=dtype,
                             seed=test_util.test_seed()))

        mvn_precision = tfd_e.MultivariateNormalPrecisionFactorLinearOperator(
            loc=loc, precision_factor=precision_factor)
        self.assertAllClose(mvn_precision.mean(), loc)
        self.assertAllClose(mvn_precision.mode(), loc)

    @test_combinations.generate(
        test_combinations.combine(
            batch_shape=[(), (2, )],
            use_precision=[True, False],
            dtype=[np.float32, np.float64],
        ), )
    def test_cov_var_stddev(self, batch_shape, use_precision, dtype):
        event_size = 3
        cov = self._random_constant_spd_linop(event_size,
                                              batch_shape=batch_shape,
                                              dtype=dtype)
        precision = cov.inverse()
        precision_factor = precision.cholesky()

        # Make sure to evaluate here, else you'll have a random loc vector!
        loc = self.evaluate(
            tf.random.normal(batch_shape + (event_size, ),
                             dtype=dtype,
                             seed=test_util.test_seed()))

        mvn_precision = tfd_e.MultivariateNormalPrecisionFactorLinearOperator(
            loc=loc,
            precision_factor=precision_factor,
            precision=precision if use_precision else None)
        self.assertAllClose(mvn_precision.covariance(),
                            cov.to_dense(),
                            atol=1e-4)
        self.assertAllClose(mvn_precision.variance(),
                            cov.diag_part(),
                            atol=1e-4)
        self.assertAllClose(mvn_precision.stddev(),
                            tf.sqrt(cov.diag_part()),
                            atol=1e-5)
class ComparingMethodsTest(test_util.TestCase):
    """Compare various KF EnKF versions.

  If the model is linear and Gaussian the EnKF sample mean/cov and marginal
  likelihood converges to that of a KF in the large ensemble limit.
  This class tests that they are the same. It does that by implementing a
  one-step KF. It also does some simple checks on the KF, to make sure we didn't
  just replicate misunderstanding in the EnKF.

  This class also checks that various flavors of the EnKF are the same.
  """
    def _random_spd_matrix(self, n, noise_level, seed, dtype):
        """Random SPD matrix with inflated diagonal."""
        wigner_mat = (tf.random.normal(shape=[n, n], seed=seed, dtype=dtype) /
                      tf.sqrt(tf.cast(n, dtype)))
        eye = tf.linalg.eye(n, dtype=dtype)
        return noise_level**2 * (tf.linalg.matmul(
            wigner_mat, wigner_mat, adjoint_b=True) + 0.5 * eye)

    def _get_linear_model_params(
        self,
        noise_level,
        n_states,
        n_observations,
        seed_stream,
        dtype,
    ):
        """Get parameters defining a linear state space model (for KF & EnKF)."""
        def _normal(shape):
            return tf.random.normal(shape, seed=seed_stream(), dtype=dtype)

        def _uniform(shape):
            return tf.random.uniform(
                # Setting minval > 0 helps test with rtol.
                shape,
                minval=1.0,
                maxval=2.0,
                seed=seed_stream(),
                dtype=dtype)

        return LinearModelParams(
            dtype=dtype,
            n_states=n_states,
            n_observations=n_observations,
            prior_mean=_uniform([n_states]),
            prior_cov=self._random_spd_matrix(n_states,
                                              1.0,
                                              seed_stream(),
                                              dtype=dtype),
            transition_mat=_normal([n_states, n_states]),
            observation_mat=_normal([n_observations, n_states]),
            transition_cov=self._random_spd_matrix(n_states,
                                                   noise_level,
                                                   seed_stream(),
                                                   dtype=dtype),
            observation_noise_cov=self._random_spd_matrix(n_observations,
                                                          noise_level,
                                                          seed_stream(),
                                                          dtype=dtype),
        )

    def _kalman_filter_solve(self, observation, linear_model_params):
        """Solve one assimilation step using a KF."""
        # See http://screen/tnjSAEuo5nPKmYt for equations.
        # pylint: disable=unnecessary-lambda
        p = linear_model_params  # Simple & Sweet

        # With A, B matrices and x a vector, we define the operations...
        a_x = lambda a, x: tf.linalg.matvec(a, x)  # Ax
        a_b = lambda a, b: tf.linalg.matmul(a, b)  # AB
        a_bt = lambda a, b: tf.linalg.matmul(a, b, adjoint_b=True)  # ABᵀ
        a_b_at = lambda c, d: a_b(c, a_bt(d, c))  # ABAᵀ

        predictive_mean = a_x(p.transition_mat, p.prior_mean)
        predictive_cov = a_b_at(p.transition_mat,
                                p.prior_cov) + p.transition_cov

        kalman_gain = a_b(
            a_bt(predictive_cov, p.observation_mat),
            tf.linalg.inv(
                a_b_at(p.observation_mat, predictive_cov) +
                p.observation_noise_cov))
        updated_mean = (
            predictive_mean +
            a_x(kalman_gain,
                observation - a_x(p.observation_mat, predictive_mean)))
        updated_cov = a_b(
            tf.linalg.eye(p.n_states, dtype=p.dtype) -
            a_b(kalman_gain, p.observation_mat), predictive_cov)

        # p(Y | X_{predictive})
        marginal_dist = tfd.MultivariateNormalTriL(
            loc=a_x(p.observation_mat, predictive_mean),
            scale_tril=tf.linalg.cholesky(
                a_b_at(p.observation_mat, predictive_cov) +
                p.observation_noise_cov),
        )

        return dict(
            predictive_mean=predictive_mean,
            predictive_cov=predictive_cov,
            predictive_stddev=tf.sqrt(tf.linalg.diag_part(predictive_cov)),
            updated_mean=updated_mean,
            updated_cov=updated_cov,
            updated_stddev=tf.sqrt(tf.linalg.diag_part(updated_cov)),
            log_marginal_likelihood=marginal_dist.log_prob(observation),
        )
        # pylint: enable=unnecessary-lambda

    def _get_enkf_params(
        self,
        n_ensemble,
        linear_model_params,
        prior_dist,
        seed_stream,
        dtype,
    ):
        """Get parameters specific to EnKF reconstructions."""
        particles = prior_dist.sample(n_ensemble, seed=seed_stream())
        state = tfs.EnsembleKalmanFilterState(step=0,
                                              particles=particles,
                                              extra={})

        def observation_fn(_, particles, extra):
            observation_particles_dist = tfd.MultivariateNormalTriL(
                loc=tf.linalg.matvec(linear_model_params.observation_mat,
                                     particles),
                scale_tril=tf.linalg.cholesky(
                    linear_model_params.observation_noise_cov))
            return observation_particles_dist, extra

        def transition_fn(_, particles, extra):
            new_particles_dist = tfd.MultivariateNormalTriL(
                loc=tf.linalg.matvec(linear_model_params.transition_mat,
                                     particles),
                scale_tril=tf.linalg.cholesky(
                    linear_model_params.transition_cov))
            return new_particles_dist, extra

        return EnKFParams(
            state=state,
            n_ensemble=n_ensemble,
            observation_fn=observation_fn,
            transition_fn=transition_fn,
        )

    def _enkf_solve(self, observation, enkf_params, predict_kwargs,
                    update_kwargs, log_marginal_likelihood_kwargs,
                    seed_stream):
        """Solve one data assimilation step using an EnKF."""
        predicted_state = tfs.ensemble_kalman_filter_predict(
            enkf_params.state,
            enkf_params.transition_fn,
            seed=seed_stream(),
            **predict_kwargs)
        updated_state = tfs.ensemble_kalman_filter_update(
            predicted_state,
            observation,
            enkf_params.observation_fn,
            seed=seed_stream(),
            **update_kwargs)
        log_marginal_likelihood = tfs.ensemble_kalman_filter_log_marginal_likelihood(
            predicted_state,
            observation,
            enkf_params.observation_fn,
            seed=seed_stream(),
            **log_marginal_likelihood_kwargs)

        return dict(
            predictive_mean=tf.reduce_mean(predicted_state.particles, axis=0),
            predictive_cov=tfp.stats.covariance(predicted_state.particles),
            predictive_stddev=tfp.stats.stddev(predicted_state.particles),
            updated_mean=tf.reduce_mean(updated_state.particles, axis=0),
            updated_cov=tfp.stats.covariance(updated_state.particles),
            updated_stddev=tfp.stats.stddev(updated_state.particles),
            log_marginal_likelihood=log_marginal_likelihood,
        )

    @test_combinations.generate(
        test_combinations.combine(
            noise_level=[0.001, 0.1, 1.0],
            n_states=[2, 5],
            n_observations=[2, 5],
            perturbed_observations=[False, True],
        ))
    def test_kf_vs_enkf(
        self,
        noise_level,
        n_states,
        n_observations,
        perturbed_observations,
    ):
        """Check that the KF and EnKF solutions are the same."""
        # Tests pass with n_ensemble = 1e7. The KF vs. EnKF tolerance is
        # proportional to 1 / sqrt(n_ensemble), so this shows good agreement.
        n_ensemble = int(1e4) if NUMPY_MODE else int(1e6)

        salt = str(noise_level) + str(n_states) + str(n_observations)
        seed_stream = test_util.test_seed_stream(salt)
        dtype = tf.float64
        predict_kwargs = {}
        update_kwargs = {}
        log_marginal_likelihood_kwargs = {
            'perturbed_observations': perturbed_observations,
        }

        linear_model_params = self._get_linear_model_params(
            noise_level=noise_level,
            n_states=n_states,
            n_observations=n_observations,
            seed_stream=seed_stream,
            dtype=dtype)

        # Ensure that our observation comes from a state that ~ prior.
        prior_dist = tfd.MultivariateNormalTriL(
            loc=linear_model_params.prior_mean,
            scale_tril=tf.linalg.cholesky(linear_model_params.prior_cov))
        true_state = prior_dist.sample(seed=seed_stream())
        observation = tf.linalg.matvec(linear_model_params.observation_mat,
                                       true_state)

        kf_soln = self._kalman_filter_solve(observation, linear_model_params)

        enkf_params = self._get_enkf_params(n_ensemble, linear_model_params,
                                            prior_dist, seed_stream, dtype)
        enkf_soln = self._enkf_solve(observation, enkf_params, predict_kwargs,
                                     update_kwargs,
                                     log_marginal_likelihood_kwargs,
                                     seed_stream)

        # In the low noise limit, the spectral norm of the posterior covariance is
        # bounded by reconstruction_tol**2.
        # http://screen/96UV8kiXMvp8QSM
        reconstruction_tol = noise_level / tf.reduce_min(
            tf.linalg.svd(linear_model_params.observation_mat,
                          compute_uv=False))

        # Evaluate at the same time, so both use the same randomness!
        # Do not use anything that was not evaluated here!
        true_state, reconstruction_tol, kf_soln, enkf_soln = self.evaluate(
            [true_state, reconstruction_tol, kf_soln, enkf_soln])

        max_updated_scale = self.evaluate(
            tf.sqrt(
                tf.reduce_max(
                    tf.linalg.svd(kf_soln['updated_cov'], compute_uv=False))))

        if noise_level < 0.2 and n_states == n_observations:
            # Check that the theoretical error bound is obeyed.
            # We use max_updated_scale below to check reconstruction error, but
            # without this check here, it's possible that max_updated_scale is large
            # due to some error in the kalman filter...which would invalidate checks
            # below.
            slop = 2. + 5 * noise_level
            self.assertLess(max_updated_scale, slop * reconstruction_tol)

        # The KF should reconstruct the correct value up to 5 stddevs.
        # The relevant stddev is that of a χ² random variable.
        reconstruction_error = np.linalg.norm(kf_soln['updated_mean'] -
                                              true_state,
                                              axis=-1)
        self.assertLess(reconstruction_error,
                        5 * np.sqrt(2 * n_states) * max_updated_scale)

        # We know the EnKF converges at rate 1 / Sqrt(n_ensemble). The factor in
        # front is set empirically.
        tol_scale = 1 / np.sqrt(n_ensemble)  # 1 / Sqrt(1e6) = 0.001
        self.assertAllCloseNested(kf_soln,
                                  enkf_soln,
                                  atol=20 * tol_scale,
                                  rtol=50 * tol_scale)

    @parameterized.named_parameters(
        dict(
            testcase_name='low_rank_ensemble',
            kwargs_1=dict(
                predict={},
                update={
                    'low_rank_ensemble': False,
                },
                log_marginal_likelihood={
                    'low_rank_ensemble': False,
                    'perturbed_observations': False
                },
            ),
            kwargs_2=dict(
                predict={},
                update={
                    'low_rank_ensemble': True,
                },
                log_marginal_likelihood={
                    'low_rank_ensemble': True,
                    'perturbed_observations': False
                },
            ),
        ),
        dict(
            testcase_name='low_rank_ensemble_1d_obs',
            # n_observations = 1 invokes a special code path.
            n_observations=1,
            kwargs_1=dict(
                predict={},
                update={
                    'low_rank_ensemble': False,
                },
                log_marginal_likelihood={
                    'low_rank_ensemble': False,
                    'perturbed_observations': False
                },
            ),
            kwargs_2=dict(
                predict={},
                update={
                    'low_rank_ensemble': True,
                },
                log_marginal_likelihood={
                    'low_rank_ensemble': True,
                    'perturbed_observations': False
                },
            ),
        ),
    )
    def test_cases_where_different_kwargs_give_same_enkf_result(
        self,
        kwargs_1,
        kwargs_2,
        n_states=5,
        n_observations=5,
        n_ensemble=10,
    ):
        """Check that two sets of kwargs give same result."""
        # In most cases, `test_kf_vs_enkf` is more complete, since it tests
        # correctness. However, `test_kf_vs_enkf` requires a huge ensemble.
        # This test is useful when you cannot use a huge ensemble and/or you want to
        # compare to a method already checked for correctness by `test_kf_vs_enkf`.
        salt = str(n_ensemble) + str(n_states) + str(n_observations)
        seed_stream = test_util.test_seed_stream(salt)
        dtype = tf.float64

        linear_model_params = self._get_linear_model_params(
            noise_level=0.1,
            n_states=n_states,
            n_observations=n_observations,
            seed_stream=seed_stream,
            dtype=dtype)

        # Ensure that our observation comes from a state that ~ prior.
        prior_dist = tfd.MultivariateNormalTriL(
            loc=linear_model_params.prior_mean,
            scale_tril=tf.linalg.cholesky(linear_model_params.prior_cov))
        true_state = prior_dist.sample(seed=seed_stream())
        observation = tf.linalg.matvec(linear_model_params.observation_mat,
                                       true_state)

        enkf_params = self._get_enkf_params(n_ensemble, linear_model_params,
                                            prior_dist, seed_stream, dtype)

        # Use the exact same seeds for each.
        enkf_soln_1 = self._enkf_solve(observation, enkf_params,
                                       kwargs_1['predict'], kwargs_1['update'],
                                       kwargs_1['log_marginal_likelihood'],
                                       test_util.test_seed_stream(salt))
        enkf_soln_2 = self._enkf_solve(observation, enkf_params,
                                       kwargs_2['predict'], kwargs_2['update'],
                                       kwargs_2['log_marginal_likelihood'],
                                       test_util.test_seed_stream(salt))

        # Evaluate at the same time, so both use the same randomness!
        # Do not use anything that was not evaluated here!
        enkf_soln_1, enkf_soln_2 = self.evaluate([enkf_soln_1, enkf_soln_2])

        # We used the same seed, so solutions should be identical up to tolerance of
        # different solver methods.
        self.assertAllCloseNested(enkf_soln_1, enkf_soln_2)