def _make_mixture_dist(self, component_logits, locs, scales): """Builds a mixture of quantized logistic distributions. Args: component_logits: 4D `Tensor` of logits for the Categorical distribution over Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix]`. locs: 4D `Tensor` of location parameters for the Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix, num_channels]`. scales: 4D `Tensor` of location parameters for the Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix, num_channels]`. Returns: dist: A quantized logistic mixture `tfp.distribution` over the input data. """ mixture_distribution = categorical.Categorical(logits=component_logits) # Convert distribution parameters for pixel values in # `[self._low, self._high]` for use with `QuantizedDistribution` locs = self._low + 0.5 * (self._high - self._low) * (locs + 1.) scales *= 0.5 * (self._high - self._low) logistic_dist = quantized_distribution.QuantizedDistribution( distribution=transformed_distribution.TransformedDistribution( distribution=logistic.Logistic(loc=locs, scale=scales), bijector=shift.Shift(shift=tf.cast(-0.5, self.dtype))), low=self._low, high=self._high) dist = mixture_same_family.MixtureSameFamily( mixture_distribution=mixture_distribution, components_distribution=independent.Independent( logistic_dist, reinterpreted_batch_ndims=1)) return independent.Independent(dist, reinterpreted_batch_ndims=2)
def _fn(dtype, shape, name, trainable, add_variable_fn): """Creates multivariate `Deterministic` or `Normal` distribution. Args: dtype: Type of parameter's event. shape: Python `list`-like representing the parameter's event shape. name: Python `str` name prepended to any created (or existing) `tf.Variable`s. trainable: Python `bool` indicating all created `tf.Variable`s should be added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. add_variable_fn: `tf.get_variable`-like `callable` used to create (or access existing) `tf.Variable`s. Returns: Multivariate `Deterministic` or `Normal` distribution. """ loc, scale = loc_scale_fn(dtype, shape, name, trainable, add_variable_fn) if scale is None: dist = deterministic_lib.Deterministic(loc=loc) else: dist = normal_lib.Normal(loc=loc, scale=scale) batch_ndims = tf.size(input=dist.batch_shape_tensor()) return independent_lib.Independent( dist, reinterpreted_batch_ndims=batch_ndims)
def _build_posterior_for_one_parameter(param, batch_shape, seed): """Built a transformed-normal variational dist over a parameter's support.""" # Build a trainable Normal distribution. initial_loc = sample_uniform_initial_state(param, init_sample_shape=batch_shape, return_constrained=False, seed=seed) loc = tf.Variable(initial_value=initial_loc, name=param.name + '_loc') scale = tfp_util.TransformedVariable( tf.fill(tf.shape(initial_loc), value=tf.constant(0.02, initial_loc.dtype), name=param.name + '_scale'), softplus_lib.Softplus()) posterior_dist = normal_lib.Normal(loc=loc, scale=scale) # Ensure the `event_shape` of the variational distribution matches the # parameter. if (param.prior.event_shape.ndims is None or param.prior.event_shape.ndims > 0): posterior_dist = independent_lib.Independent( posterior_dist, reinterpreted_batch_ndims=param.prior.event_shape.ndims) # Transform to constrained parameter space. posterior_dist = transformed_distribution_lib.TransformedDistribution( posterior_dist, param.bijector, name='{}_posterior'.format(param.name)) return posterior_dist
def _wrap_as_distributions(structure): return tf.nest.map_structure( lambda x: independent.Independent( # pylint: disable=g-long-lambda deterministic.Deterministic(x), # Particles are a batch dimension. reinterpreted_batch_ndims=tf.rank(x) - 1), structure)
def inner(dtype, shape, name, trainable, add_variable_fn, loc=loc, scale=scale): """Creates multivariate standard `Normal` distribution. Args: dtype: Type of parameter's event. shape: Python `list`-like representing the parameter's event shape. name: Python `str` name prepended to any created (or existing) `tf.Variable`s. trainable: Python `bool` indicating all created `tf.Variable`s should be added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. add_variable_fn: `tf.get_variable`-like `callable` used to create (or access existing) `tf.Variable`s. Returns: Multivariate standard `Normal` distribution. """ del name, trainable, add_variable_fn # unused if loc is None: loc = tf.zeros(shape, dtype) dist = normal_lib.Normal(loc=loc, scale=dtype.as_numpy_dtype(scale)) batch_ndims = tf.size(dist.batch_shape_tensor()) return independent_lib.Independent( dist, reinterpreted_batch_ndims=batch_ndims)
def _fn(dtype, shape, name, trainable, add_variable_fn): loc, scale = loc_scale_fn(dtype, shape, name, trainable, add_variable_fn) if scale is None: dist = deterministic_lib.Deterministic(loc=loc) else: dist = normal_lib.Normal(loc=loc, scale=scale) batch_ndims = tf2.size(dist.batch_shape_tensor()) return independent_lib.Independent(dist, reinterpreted_batch_ndims=batch_ndims)
def _asvi_surrogate_for_sample(dist, build_nested_surrogate, sample_shape=None): """Builds the surrogate for a `tfd.Sample`-wrapped distribution.""" dist_sample_shape = distribution_util.expand_to_vector(dist.sample_shape) nested_surrogate = yield from build_nested_surrogate( dist=dist.distribution, sample_shape=(dist_sample_shape if sample_shape is None else ps.concat( [sample_shape, dist_sample_shape], axis=0))) return independent.Independent( nested_surrogate, reinterpreted_batch_ndims=ps.rank_from_shape(dist_sample_shape))
def build_trainable_location_scale_distribution(initial_loc, initial_scale, event_ndims, distribution_fn=normal.Normal, validate_args=False, name=None): """Builds a variational distribution from a location-scale family. Args: initial_loc: Float `Tensor` initial location. initial_scale: Float `Tensor` initial scale. event_ndims: Integer `Tensor` number of event dimensions in `initial_loc`. distribution_fn: Optional constructor for a `tfd.Distribution` instance in a location-scale family. This should have signature `dist = distribution_fn(loc, scale, validate_args)`. Default value: `tfd.Normal`. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_trainable_location_scale_distribution'). Returns: posterior_dist: A `tfd.Distribution` instance. """ with tf.name_scope(name or 'build_trainable_location_scale_distribution'): dtype = dtype_util.common_dtype([initial_loc, initial_scale], dtype_hint=tf.float32) initial_loc = initial_loc * tf.ones(tf.shape(initial_scale), dtype=dtype) initial_scale = initial_scale * tf.ones_like(initial_loc) loc = tf.Variable(initial_value=initial_loc, name='loc') scale = tfp_util.TransformedVariable(initial_scale, softplus.Softplus(), name='scale') posterior_dist = distribution_fn(loc=loc, scale=scale, validate_args=validate_args) # Ensure the distribution has the desired number of event dimensions. static_event_ndims = tf.get_static_value(event_ndims) if static_event_ndims is None or static_event_ndims > 0: posterior_dist = independent.Independent( posterior_dist, reinterpreted_batch_ndims=event_ndims, validate_args=validate_args) return posterior_dist
def create_fixed_gaussian_prior(dtype, shape, name, trainable, add_variable_fn): prior_loc = add_variable_fn('kernel_prior_loc', shape=shape, dtype=dtype, trainable=False, initializer=prior_kernel_loc_initializer) prior_untr_scale = add_variable_fn('kernel_prior_untransformed_scale', shape=(), dtype=dtype, trainable=False, initializer=prior_kernel_untr_scale_initializer) dist = normal_lib.Normal( loc=prior_loc, scale=tf.nn.softplus(prior_untr_scale)) batch_ndims = tf.size(input=dist.batch_shape_tensor()) return independent_lib.Independent( dist, reinterpreted_batch_ndims=batch_ndims)
def _fn(dtype, shape, name, trainable, add_variable_fn): loc_init = tf.compat.v1.constant_initializer(loc) scale_init = tf.compat.v1.constant_initializer(scale) new_loc = add_variable_fn(name=name + '_loc', shape=shape, initializer=loc_init, regularizer=None, constraint=None, dtype=dtype, trainable=isPosterior) new_scale = add_variable_fn(name=name + '_scale', shape=shape, initializer=scale_init, regularizer=None, constraint=None, dtype=dtype, trainable=isPosterior) dist = normal_lib.Normal(loc=new_loc, scale=new_scale) batch_ndims = tf.size(input=dist.batch_shape_tensor()) return independent_lib.Independent( dist, reinterpreted_batch_ndims=batch_ndims)
def params_and_state_transition_fn(step, params_and_state, perturbation_scale, **kwargs): """Transition function operating on a `ParamsAndState` namedtuple.""" # Extract the state, to pass through to the observation fn. unconstrained_params, state = params_and_state if 'state_history' in kwargs: kwargs['state_history'] = kwargs['state_history'].state # Perturb each (unconstrained) parameter with normally-distributed noise. if not tf.nest.is_nested(perturbation_scale): perturbation_scale = tf.nest.map_structure( lambda x: tf.convert_to_tensor(perturbation_scale, # pylint: disable=g-long-lambda name='perturbation_scale', dtype=x.dtype), unconstrained_params) perturbed_unconstrained_parameter_dists = tf.nest.map_structure( lambda x, p, s: independent.Independent( # pylint: disable=g-long-lambda normal.Normal(loc=x, scale=p), reinterpreted_batch_ndims=prefer_static.rank_from_shape(s)), unconstrained_params, perturbation_scale, parameter_prior.event_shape_tensor()) # For the joint transition, pass the perturbed parameters # into the original transition fn (after pushing them into constrained # space). return joint_distribution_named.JointDistributionNamed( ParametersAndState( unconstrained_parameters=_maybe_build_joint_distribution( perturbed_unconstrained_parameter_dists), state=lambda unconstrained_parameters: ( # pylint: disable=g-long-lambda parameterized_transition_fn( step, state, parameters=parameter_constraining_bijector.forward( unconstrained_parameters), **kwargs))))
def variational_loss(self, observations, observation_index_points=None, kl_weight=1., name='variational_loss'): """Variational loss for the VGP. Given `observations` and `observation_index_points`, compute the negative variational lower bound as specified in [Hensman, 2013][1]. Args: observations: `float` `Tensor` representing collection, or batch of collections, of observations corresponding to `observation_index_points`. Shape has the form `[b1, ..., bB, e]`, which must be brodcastable with the batch and example shapes of `observation_index_points`. The batch shape `[b1, ..., bB]` must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `observation_index_points`, etc.). observation_index_points: `float` `Tensor` representing finite (batch of) vector(s) of points where observations are defined. Shape has the form `[b1, ..., bB, e1, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e1` is the number (size) of index points in each batch (we denote it `e1` to distinguish it from the numer of inducing index points, denoted `e2` below). If set to `None` uses `index_points` as the origin for observations. Default value: None. kl_weight: Amount by which to scale the KL divergence loss between prior and posterior. Default value: 1. name: Python `str` name prefixed to Ops created by this class. Default value: "GaussianProcess". Returns: loss: Scalar tensor representing the negative variational lower bound. Can be directly used in a `tf.Optimizer`. Raises: ValueError: if `mean_fn` is not `None` and is not callable. #### References [1]: Hensman, J., Lawrence, N. "Gaussian Processes for Big Data", 2013 https://arxiv.org/abs/1309.6835 """ with tf.name_scope(name or 'variational_gp_loss'): if observation_index_points is None: observation_index_points = self._index_points observation_index_points = tf.convert_to_tensor( observation_index_points, dtype=self._dtype, name='observation_index_points') observations = tf.convert_to_tensor(observations, dtype=self._dtype, name='observations') kl_weight = tf.convert_to_tensor(kl_weight, dtype=self._dtype, name='kl_weight') # The variational loss is a negative ELBO. The ELBO can be broken down # into three terms: # 1. a likelihood term # 2. a trace term arising from the covariance of the posterior predictive kzx = self.kernel.matrix(self._inducing_index_points, observation_index_points) kzx_linop = tf.linalg.LinearOperatorFullMatrix(kzx) loc = (self._mean_fn(observation_index_points) + kzx_linop.matvec(self._kzz_inv_varloc, adjoint=True)) likelihood = independent.Independent(normal.Normal( loc=loc, scale=tf.sqrt(self._observation_noise_variance + self._jitter), name='NormalLikelihood'), reinterpreted_batch_ndims=1) obs_ll = likelihood.log_prob(observations) chol_kzz_linop = tf.linalg.LinearOperatorLowerTriangular( self._chol_kzz) chol_kzz_inv_kzx = chol_kzz_linop.solve(kzx) kzz_inv_kzx = chol_kzz_linop.solve(chol_kzz_inv_kzx, adjoint=True) kxx_diag = self.kernel.apply(observation_index_points, observation_index_points, example_ndims=1) ktilde_trace_term = ( tf.reduce_sum(kxx_diag, axis=-1) - tf.reduce_sum(chol_kzz_inv_kzx**2, axis=[-2, -1])) # Tr(SB) # where S = A A.T, A = variational_inducing_observations_scale # and B = Kzz^-1 Kzx Kzx.T Kzz^-1 # # Now Tr(SB) = Tr(A A.T Kzz^-1 Kzx Kzx.T Kzz^-1) # = Tr(A.T Kzz^-1 Kzx Kzx.T Kzz^-1 A) # = sum_ij (A.T Kzz^-1 Kzx)_{ij}^2 other_trace_term = tf.reduce_sum( (self._variational_inducing_observations_posterior.scale. matmul(kzz_inv_kzx)**2), axis=[-2, -1]) trace_term = (.5 * (ktilde_trace_term + other_trace_term) / self._observation_noise_variance) kl_term = kl_weight * self.surrogate_posterior_kl_divergence_prior( ) lower_bound = (obs_ll - trace_term - kl_term) return -tf.reduce_mean(lower_bound)
def _batched_isotropic_normal_like(state_part): event_ndims = ps.rank(state_part) - batch_rank return independent.Independent( normal.Normal(ps.zeros_like(state_part, tf.float32), 1.), reinterpreted_batch_ndims=event_ndims)
def _asvi_surrogate_for_independent(dist, build_nested_surrogate): """Builds the surrogate for a `tfd.Independent`-wrapped distribution.""" nested_surrogate = yield from build_nested_surrogate(dist.distribution) return independent.Independent( nested_surrogate, reinterpreted_batch_ndims=dist.reinterpreted_batch_ndims)
def independent_joint_distribution_from_structure(structure_of_distributions, batch_ndims=None, validate_args=False): """Turns a (potentially nested) structure of dists into a single dist. Args: structure_of_distributions: instance of `tfd.Distribution`, or nested structure (tuple, list, dict, etc.) in which all leaves are `tfd.Distribution` instances. batch_ndims: Optional integer `Tensor` number of leftmost batch dimensions shared across all members of the input structure. If this is specified, the returned joint distribution will be an autobatched distribution with the given batch rank, and all other dimensions absorbed into the event. validate_args: Python `bool`. Whether the joint distribution should validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. Returns: distribution: instance of `tfd.Distribution` such that `distribution.sample()` is equivalent to `tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions)`. If `structure_of_distributions` was indeed a structure (as opposed to a single `Distribution` instance), this will be a `JointDistribution` with the corresponding structure. Raises: TypeError: if any leaves of the input structure are not `tfd.Distribution` instances. """ # If input is already a Distribution, just return it. if dist_util.is_distribution_instance(structure_of_distributions): dist = structure_of_distributions if batch_ndims is not None: excess_ndims = ps.rank_from_shape( dist.batch_shape_tensor()) - batch_ndims if tf.get_static_value( excess_ndims) != 0: # Static value may be None. dist = independent.Independent( dist, reinterpreted_batch_ndims=excess_ndims) return dist # If this structure contains other structures (ie, has elements at depth > 1), # recursively turn them into JDs. element_depths = nest.map_structure_with_tuple_paths( lambda path, x: len(path), structure_of_distributions) if max(tf.nest.flatten(element_depths)) > 1: next_level_shallow_structure = nest.get_traverse_shallow_structure( traverse_fn=lambda x: min(tf.nest.flatten(x)) <= 1, structure=element_depths) structure_of_distributions = nest.map_structure_up_to( next_level_shallow_structure, functools.partial(independent_joint_distribution_from_structure, batch_ndims=batch_ndims, validate_args=validate_args), structure_of_distributions) jdnamed = joint_distribution_named.JointDistributionNamed jdsequential = joint_distribution_sequential.JointDistributionSequential # Use an autobatched JD if a specific batch rank was requested. if batch_ndims is not None: jdnamed = functools.partial( joint_distribution_auto_batched.JointDistributionNamedAutoBatched, batch_ndims=batch_ndims, use_vectorized_map=False) jdsequential = functools.partial( joint_distribution_auto_batched. JointDistributionSequentialAutoBatched, batch_ndims=batch_ndims, use_vectorized_map=False) # Otherwise, build a JD from the current structure. if (hasattr(structure_of_distributions, '_asdict') or isinstance( structure_of_distributions, collections.abc.Mapping)): return jdnamed(structure_of_distributions, validate_args=validate_args) return jdsequential(structure_of_distributions, validate_args=validate_args)
def _asvi_surrogate_for_distribution(dist, base_distribution_surrogate_fn, sample_shape=None, variables=None, seed=None): """Recursively creates ASVI surrogates, and creates new variables if needed. Args: dist: a `tfd.Distribution` instance. base_distribution_surrogate_fn: Callable to build a surrogate posterior for a 'base' (non-meta and non-joint) distribution, with signature `surrogate_posterior, variables = base_distribution_fn( dist, sample_shape=None, variables=None, seed=None)`. sample_shape: Optional `Tensor` shape of samples drawn from `dist` by `tfd.Sample` wrappers. If not `None`, the surrogate's event will include independent sample dimensions, i.e., it will have event shape `concat([sample_shape, dist.event_shape], axis=0)`. Default value: `None`. variables: Optional nested structure of `tf.Variable`s returned from a previous call to `_asvi_surrogate_for_distribution`. If `None`, new variables will be created; otherwise, constructs a surrogate posterior backed by the passed-in variables. Default value: `None`. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: surrogate_posterior: Instance of `tfd.Distribution` representing a trainable surrogate posterior distribution, with the same structure and `name` as `dist`. variables: Nested structure of `tf.Variable` trainable parameters for the surrogate posterior. If `dist` is a base distribution, this is a `dict` of `ASVIParameters` instances. If `dist` is a joint distribution, this is a `dist.dtype` structure of such `dict`s. """ # Pass args to any nested surrogates. build_nested_surrogate = functools.partial( _asvi_surrogate_for_distribution, base_distribution_surrogate_fn=base_distribution_surrogate_fn, sample_shape=sample_shape, seed=seed) # Apply any substitutions, while attempting to preserve the original name. dist = _set_name(_as_substituted_distribution(dist), name=_get_name(dist)) # Handle wrapper ("meta") distributions. if isinstance(dist, markov_chain.MarkovChain): return _asvi_surrogate_for_markov_chain( dist=dist, variables=variables, base_distribution_surrogate_fn=base_distribution_surrogate_fn, sample_shape=sample_shape, seed=seed) if isinstance(dist, sample.Sample): dist_sample_shape = distribution_util.expand_to_vector( dist.sample_shape) nested_surrogate, variables = build_nested_surrogate( # pylint: disable=redundant-keyword-arg dist=dist.distribution, variables=variables, sample_shape=(dist_sample_shape if sample_shape is None else ps.concat([sample_shape, dist_sample_shape], axis=0))) surrogate_posterior = independent.Independent( nested_surrogate, reinterpreted_batch_ndims=ps.rank_from_shape(dist_sample_shape), name=_get_name(dist)) # Treat distributions that subclass TransformedDistribution with their own # parameters (e.g., Gumbel, Weibull, MultivariateNormal*, etc) as their # own type of base distribution, rather than as explicit TDs. elif type(dist) == transformed_distribution.TransformedDistribution: # pylint: disable=unidiomatic-typecheck nested_surrogate, variables = build_nested_surrogate( dist.distribution, variables=variables) surrogate_posterior = transformed_distribution.TransformedDistribution( nested_surrogate, bijector=dist.bijector, name=_get_name(dist)) elif isinstance(dist, independent.Independent): nested_surrogate, variables = build_nested_surrogate( dist.distribution, variables=variables) surrogate_posterior = independent.Independent( nested_surrogate, reinterpreted_batch_ndims=dist.reinterpreted_batch_ndims, name=_get_name(dist)) elif hasattr(dist, '_model_coroutine'): surrogate_posterior, variables = _asvi_surrogate_for_joint_distribution( dist, base_distribution_surrogate_fn=base_distribution_surrogate_fn, variables=variables, seed=seed) elif (hasattr(dist, 'distribution') and # Transformed dists not handled above are treated as base distributions. not isinstance(dist, transformed_distribution.TransformedDistribution)): raise ValueError('Meta-distribution `{}` is not yet supported by this ' 'implementation of ASVI. Contact ' '`[email protected]` if you need this ' 'functionality.'.format(type(dist))) else: surrogate_posterior, variables = base_distribution_surrogate_fn( dist=dist, sample_shape=sample_shape, variables=variables, seed=seed) return surrogate_posterior, variables
def posterior_generator(): prior_gen = prior._model_coroutine() # pylint: disable=protected-access dist = next(prior_gen) i = 0 try: while True: original_dist = dist.distribution if isinstance(dist, Root) else dist if isinstance(original_dist, joint_distribution.JointDistribution): # TODO(kateslin): Build inner JD surrogate in # _make_asvi_trainable_variables to avoid rebuilding variables. raise TypeError( 'Argument `prior` cannot be a nested `JointDistribution`.') else: original_dist = _as_trainable_family(original_dist) try: actual_dist = original_dist.distribution except AttributeError: actual_dist = original_dist dist_params = actual_dist.parameters temp_params_dict = {} for param, value in dist_params.items(): if param in (_NON_STATISTICAL_PARAMS + _NON_TRAINABLE_PARAMS) or value is None: temp_params_dict[param] = value else: prior_weight = param_dicts[i][param].prior_weight mean_field_parameter = param_dicts[i][ param].mean_field_parameter if mean_field: temp_params_dict[param] = mean_field_parameter else: temp_params_dict[param] = prior_weight * value + ( 1. - prior_weight) * mean_field_parameter if isinstance(original_dist, sample.Sample): inner_dist = type(actual_dist)(**temp_params_dict) surrogate_dist = independent.Independent( inner_dist, reinterpreted_batch_ndims=ps.rank_from_shape( original_dist.sample_shape)) else: surrogate_dist = type(actual_dist)(**temp_params_dict) if isinstance(original_dist, transformed_distribution.TransformedDistribution): surrogate_dist = transformed_distribution.TransformedDistribution( surrogate_dist, bijector=original_dist.bijector) if isinstance(original_dist, independent.Independent): surrogate_dist = independent.Independent( surrogate_dist, reinterpreted_batch_ndims=original_dist .reinterpreted_batch_ndims) if isinstance(dist, Root): value_out = yield Root(surrogate_dist) else: value_out = yield surrogate_dist dist = prior_gen.send(value_out) i += 1 except StopIteration: pass
def extended_kalman_filter_one_step( state, observation, transition_fn, observation_fn, transition_jacobian_fn, observation_jacobian_fn, name=None): """A single step of the EKF. Args: state: A `Tensor` of shape `concat([[num_timesteps, b1, ..., bN], [state_size]])` with scalar `event_size` and optional batch dimensions `b1, ..., bN`. observation: A `Tensor` of shape `concat([[num_timesteps, b1, ..., bN], [event_size]])` with scalar `event_size` and optional batch dimensions `b1, ..., bN`. transition_fn: a Python `callable` that accepts (batched) vectors of length `state_size`, and returns a `tfd.Distribution` instance, typically a `MultivariateNormal`, representing the state transition and covariance. observation_fn: a Python `callable` that accepts a (batched) vector of length `state_size` and returns a `tfd.Distribution` instance, typically a `MultivariateNormal` representing the observation model and covariance. transition_jacobian_fn: a Python `callable` that accepts a (batched) vector of length `state_size` and returns a (batched) matrix of shape `[state_size, state_size]`, representing the Jacobian of `transition_fn`. observation_jacobian_fn: a Python `callable` that accepts a (batched) vector of length `state_size` and returns a (batched) matrix of size `[state_size, event_size]`, representing the Jacobian of `observation_fn`. name: Python `str` name for ops created by this method. Default value: `None` (i.e., `'extended_kalman_filter_one_step'`). Returns: updated_state: `KalmanFilterState` object containing the updated state estimate. """ with tf.name_scope(name or 'extended_kalman_filter_one_step') as name: # If observations are scalar, we can avoid some matrix ops. observation_size_is_static_and_scalar = (observation.shape[-1] == 1) current_state = state.filtered_mean current_covariance = state.filtered_cov current_jacobian = transition_jacobian_fn(current_state) state_prior = transition_fn(current_state) predicted_cov = (tf.matmul( current_jacobian, tf.matmul(current_covariance, current_jacobian, transpose_b=True)) + state_prior.covariance()) predicted_mean = state_prior.mean() observation_dist = observation_fn(predicted_mean) observation_mean = observation_dist.mean() observation_cov = observation_dist.covariance() predicted_jacobian = observation_jacobian_fn(predicted_mean) tmp_obs_cov = tf.matmul(predicted_jacobian, predicted_cov) residual_covariance = tf.matmul( predicted_jacobian, tmp_obs_cov, transpose_b=True) + observation_cov if observation_size_is_static_and_scalar: gain_transpose = tmp_obs_cov / residual_covariance else: chol_residual_cov = tf.linalg.cholesky(residual_covariance) gain_transpose = tf.linalg.cholesky_solve(chol_residual_cov, tmp_obs_cov) filtered_mean = predicted_mean + tf.matmul( gain_transpose, (observation - observation_mean)[..., tf.newaxis], transpose_a=True)[..., 0] tmp_term = -tf.matmul(predicted_jacobian, gain_transpose, transpose_a=True) tmp_term = tf.linalg.set_diag(tmp_term, tf.linalg.diag_part(tmp_term) + 1.) filtered_cov = ( tf.matmul( tmp_term, tf.matmul(predicted_cov, tmp_term), transpose_a=True) + tf.matmul(gain_transpose, tf.matmul(observation_cov, gain_transpose), transpose_a=True)) if observation_size_is_static_and_scalar: # A plain Normal would have event shape `[]`; wrapping with Independent # ensures `event_shape=[1]` as required. predictive_dist = independent.Independent( normal.Normal(loc=observation_mean, scale=tf.sqrt(residual_covariance[..., 0])), reinterpreted_batch_ndims=1) else: predictive_dist = mvn_tril.MultivariateNormalTriL( loc=observation_mean, scale_tril=chol_residual_cov) log_marginal_likelihood = predictive_dist.log_prob(observation) return linear_gaussian_ssm.KalmanFilterState( filtered_mean=filtered_mean, filtered_cov=filtered_cov, predicted_mean=predicted_mean, predicted_cov=predicted_cov, observation_mean=observation_mean, observation_cov=observation_cov, log_marginal_likelihood=log_marginal_likelihood, timestep=state.timestep + 1)