def _validate_observation_data(self): # Check that observation index points and observation counts broadcast. assertions = [] msg = ('Observation index point and observation counts are not ' 'broadcastable.') ndims = self.kernel.feature_ndims if (self.observation_index_points.shape[:-ndims].is_fully_defined() and self.observations.shape.is_fully_defined()): index_point_count = self.observation_index_points.shape[:-ndims] observation_count = self.observations.shape try: tf.broadcast_static_shape(index_point_count, observation_count) except ValueError: # Re-raise with our own more contextual error message. raise ValueError(msg[:-1] + ': {} and {}, respectively.'.format( index_point_count, observation_count)) else: if self._validate_args: # Instead of an assertion of broadcastability, we simply append an op # to dynamically broadcast the two shapes; if this fails, the shapes # must not be broadcastable. broadcast_op = tf.broadcast_dynamic_shape( tf.shape(self.observation_index_points)[:-ndims], tf.shape(self.observations), name='check_that_index_points_and_observation_shapes_broadcast') assertions.append(broadcast_op) return assertions
def _batch_shape(self): scalar_shape = tf.TensorShape([]) return tf.broadcast_static_shape( tf.broadcast_static_shape( scalar_shape if self.amplitude is None else self.amplitude.shape, scalar_shape if self.period is None else self.period.shape), scalar_shape if self.length_scale is None else self.length_scale.shape)
def _finish_prob_for_one_fiber(self, y, x, ildj, event_ndims): """Finish computation of prob on one element of the inverse image.""" x = self._maybe_rotate_dims(x, rotate_right=True) prob = self.distribution.prob(x) if self._is_maybe_event_override: prob = tf.reduce_prod(prob, self._reduce_event_indices) prob *= tf.exp(tf.cast(ildj, prob.dtype)) if self._is_maybe_event_override and isinstance(event_ndims, int): prob.set_shape( tf.broadcast_static_shape( y.shape.with_rank_at_least(1)[:-event_ndims], self.batch_shape)) return prob
def determine_batch_event_shapes(grid, endpoint_affine): """Helper to infer batch_shape and event_shape.""" with tf.name_scope(name="determine_batch_event_shapes"): # grid # shape: [B, k, q] # endpoint_affine # len=k, shape: [B, d, d] batch_shape = grid.shape[:-2] batch_shape_tensor = tf.shape(grid)[:-2] event_shape = None event_shape_tensor = None def _set_event_shape(shape, shape_tensor): if event_shape is None: return shape, shape_tensor return (tf.broadcast_static_shape(event_shape, shape), tf.broadcast_dynamic_shape(event_shape_tensor, shape_tensor)) for aff in endpoint_affine: if aff.shift is not None: batch_shape = tf.broadcast_static_shape(batch_shape, aff.shift.shape[:-1]) batch_shape_tensor = tf.broadcast_dynamic_shape( batch_shape_tensor, tf.shape(aff.shift)[:-1]) event_shape, event_shape_tensor = _set_event_shape( aff.shift.shape[-1:], tf.shape(aff.shift)[-1:]) if aff.scale is not None: batch_shape = tf.broadcast_static_shape(batch_shape, aff.scale.batch_shape) batch_shape_tensor = tf.broadcast_dynamic_shape( batch_shape_tensor, aff.scale.batch_shape_tensor()) event_shape, event_shape_tensor = _set_event_shape( tf.TensorShape([aff.scale.range_dimension]), aff.scale.range_dimension_tensor()[tf.newaxis]) return batch_shape, batch_shape_tensor, event_shape, event_shape_tensor
def batch_shape(self): """Static batch shape of models represented by this component. Returns: batch_shape: A `tf.TensorShape` giving the broadcast batch shape of all model parameters. This should match the batch shape of derived state space models, i.e., `self.make_state_space_model(...).batch_shape`. It may be partially defined or unknown. """ batch_shape = tf.TensorShape([]) for param in self.parameters: batch_shape = tf.broadcast_static_shape( batch_shape, param.prior.batch_shape) return batch_shape
def broadcast_batch_shape(distributions): """Get broadcast batch shape from distributions, statically if possible.""" # Static case batch_shape = distributions[0].batch_shape for distribution in distributions: batch_shape = tf.broadcast_static_shape(batch_shape, distribution.batch_shape) if batch_shape.is_fully_defined(): return batch_shape.as_list() # Fallback on dynamic. batch_shape = distributions[0].batch_shape_tensor() for distribution in distributions: batch_shape = tf.broadcast_dynamic_shape(batch_shape, distribution.batch_shape_tensor()) return tf.convert_to_tensor(batch_shape)
def get_broadcast_shape(*tensors): """Get broadcast shape as a Python list of integers (preferred) or `Tensor`. Args: *tensors: One or more `Tensor` objects (already converted!). Returns: broadcast shape: Python list (if shapes determined statically), otherwise an `int32` `Tensor`. """ # Try static. s_shape = tensors[0].shape for t in tensors[1:]: s_shape = tf.broadcast_static_shape(s_shape, t.shape) if s_shape.is_fully_defined(): return s_shape.as_list() # Fallback on dynamic. d_shape = tf.shape(tensors[0]) for t in tensors[1:]: d_shape = tf.broadcast_dynamic_shape(d_shape, tf.shape(t)) return d_shape
def _reduce_jacobian_det_over_event(self, y, ildj, min_event_ndims, event_ndims): """Reduce jacobian over event_ndims - min_event_ndims.""" # In this case, we need to tile the Jacobian over the event and reduce. y_rank = tf.rank(y) y_shape = tf.shape(y)[y_rank - event_ndims:y_rank - min_event_ndims] ones = tf.ones(y_shape, ildj.dtype) reduced_ildj = tf.reduce_sum( ones * ildj, axis=self._get_event_reduce_dims(min_event_ndims, event_ndims)) # The multiplication by ones can change the inferred static shape so we try # to recover as much as possible. event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) if (event_ndims_ is not None and y.shape.ndims is not None and ildj.shape.ndims is not None): y_shape = y.shape[y.shape.ndims - event_ndims_:y.shape.ndims - min_event_ndims] broadcast_shape = tf.broadcast_static_shape(ildj.shape, y_shape) reduced_ildj.set_shape(broadcast_shape[:broadcast_shape.ndims - (event_ndims_ - min_event_ndims)]) return reduced_ildj
def _log_prob(self, x): if self.input_output_cholesky: x_sqrt = x else: # Complexity: O(nbk**3) x_sqrt = tf.linalg.cholesky(x) batch_shape = self.batch_shape_tensor() event_shape = self.event_shape_tensor() x_ndims = tf.rank(input=x_sqrt) num_singleton_axes_to_prepend = ( tf.maximum(tf.size(input=batch_shape) + 2, x_ndims) - x_ndims) x_with_prepended_singletons_shape = tf.concat([ tf.ones([num_singleton_axes_to_prepend], dtype=tf.int32), tf.shape(input=x_sqrt) ], 0) x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape) ndims = tf.rank(x_sqrt) # sample_ndims = ndims - batch_ndims - event_ndims sample_ndims = ndims - tf.size(input=batch_shape) - 2 sample_shape = tf.shape(input=x_sqrt)[:sample_ndims] # We need to be able to pre-multiply each matrix by its corresponding # batch scale matrix. Since a Distribution Tensor supports multiple # samples per batch, this means we need to reshape the input matrix `x` # so that the first b dimensions are batch dimensions and the last two # are of shape [dimension, dimensions*number_of_samples]. Doing these # gymnastics allows us to do a batch_solve. # # After we're done with sqrt_solve (the batch operation) we need to undo # this reshaping so what we're left with is a Tensor partitionable by # sample, batch, event dimensions. # Complexity: O(nbk**2) since transpose must access every element. scale_sqrt_inv_x_sqrt = x_sqrt perm = tf.concat( [tf.range(sample_ndims, ndims), tf.range(0, sample_ndims)], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) last_dimsize = ( tf.cast(self.dimension, dtype=tf.int32) * tf.reduce_prod( input_tensor=x_with_prepended_singletons_shape[:sample_ndims])) shape = tf.concat([ x_with_prepended_singletons_shape[sample_ndims:-2], [tf.cast(self.dimension, dtype=tf.int32), last_dimsize] ], 0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) # Complexity: O(nbM*k) where M is the complexity of the operator solving a # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so # this step has complexity O(nbk^3). scale_sqrt_inv_x_sqrt = self.scale_operator.solve( scale_sqrt_inv_x_sqrt) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = tf.concat([batch_shape, event_shape, sample_shape], 0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) perm = tf.concat([ tf.range(ndims - sample_ndims, ndims), tf.range(0, ndims - sample_ndims) ], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) # Write V = SS', X = LL'. Then: # tr[inv(V) X] = tr[inv(S)' inv(S) L L'] # = tr[inv(S) L L' inv(S)'] # = tr[(inv(S) L) (inv(S) L)'] # = sum_{ik} (inv(S) L)_{ik}**2 # The second equality follows from the cyclic permutation property. # Complexity: O(nbk**2) trace_scale_inv_x = tf.reduce_sum( input_tensor=tf.square(scale_sqrt_inv_x_sqrt), axis=[-2, -1]) # Complexity: O(nbk) half_log_det_x = tf.reduce_sum(input_tensor=tf.math.log( tf.linalg.diag_part(x_sqrt)), axis=[-1]) # Complexity: O(nbk**2) log_prob = ((self.df - self.dimension - 1.) * half_log_det_x - 0.5 * trace_scale_inv_x - self.log_normalization()) # Set shape hints. # Try to merge what we know from the input x with what we know from the # parameters of this distribution. if x.shape.ndims is not None and self.batch_shape.ndims is not None: log_prob.set_shape( tf.broadcast_static_shape(x.shape[:-2], self.batch_shape)) return log_prob
def __init__(self, initial_distribution, transition_distribution, observation_distribution, num_steps, validate_args=False, allow_nan_stats=True, name="HiddenMarkovModel"): """Initialize hidden Markov model. Args: initial_distribution: A `Categorical`-like instance. Determines probability of first hidden state in Markov chain. The number of categories must match the number of categories of `transition_distribution` as well as both the rightmost batch dimension of `transition_distribution` and the rightmost batch dimension of `observation_distribution`. transition_distribution: A `Categorical`-like instance. The rightmost batch dimension indexes the probability distribution of each hidden state conditioned on the previous hidden state. observation_distribution: A `tfp.distributions.Distribution`-like instance. The rightmost batch dimension indexes the distribution of each observation conditioned on the corresponding hidden state. num_steps: The number of steps taken in Markov chain. A python `int`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `True`. name: Python `str` name prefixed to Ops created by this class. Default value: "HiddenMarkovModel". Raises: ValueError: if `num_steps` is not at least 1. ValueError: if `initial_distribution` does not have scalar `event_shape`. ValueError: if `transition_distribution` does not have scalar `event_shape.` ValueError: if `transition_distribution` and `observation_distribution` are fully defined but don't have matching rightmost dimension. """ parameters = dict(locals()) # pylint: disable=protected-access with tf.compat.v2.name_scope(name) as name: self._runtime_assertions = [] # pylint: enable=protected-access if num_steps < 1: raise ValueError( "num_steps ({}) must be at least 1.".format(num_steps)) self._initial_distribution = initial_distribution self._observation_distribution = observation_distribution self._transition_distribution = transition_distribution if (initial_distribution.event_shape is not None and initial_distribution.event_shape.ndims != 0): raise ValueError( "`initial_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.shape(input=initial_distribution.event_shape_tensor( ))[0], 0, message="`initial_distribution` must have scalar" "`event_dim`s") ] if (transition_distribution.event_shape is not None and transition_distribution.event_shape.ndims != 0): raise ValueError( "`transition_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.shape(input=transition_distribution. event_shape_tensor())[0], 0, message="`transition_distribution` must have scalar" "`event_dim`s") ] if (transition_distribution.batch_shape is not None and transition_distribution.batch_shape.ndims == 0): raise ValueError( "`transition_distribution` can't have scalar batches") elif validate_args: self._runtime_assertions += [ assert_util.assert_greater( tf.size(input=transition_distribution. batch_shape_tensor()), 0, message="`transition_distribution` can't have scalar " "batches") ] if (observation_distribution.batch_shape is not None and observation_distribution.batch_shape.ndims == 0): raise ValueError( "`observation_distribution` can't have scalar batches") elif validate_args: self._runtime_assertions += [ assert_util.assert_greater( tf.size(input=observation_distribution. batch_shape_tensor()), 0, message="`observation_distribution` can't have scalar " "batches") ] # Infer number of hidden states and check consistency # between transitions and observations with tf.control_dependencies(self._runtime_assertions): self._num_states = ( (transition_distribution.batch_shape and transition_distribution.batch_shape[-1]) or transition_distribution.batch_shape_tensor()[-1]) observation_states = ( (observation_distribution.batch_shape and observation_distribution.batch_shape[-1]) or observation_distribution.batch_shape_tensor()[-1]) if (tf.is_tensor(self._num_states) or tf.is_tensor(observation_states)): if validate_args: self._runtime_assertions += [ assert_util.assert_equal( self._num_states, observation_states, message="`transition_distribution` and " "`observation_distribution` must agree on " "last dimension of batch size") ] elif self._num_states != observation_states: raise ValueError("`transition_distribution` and " "`observation_distribution` must agree on " "last dimension of batch size") self._log_init = _extract_log_probs(self._num_states, initial_distribution) self._log_trans = _extract_log_probs(self._num_states, transition_distribution) self._num_steps = num_steps self._num_states = tf.shape(input=self._log_init)[-1] self._underlying_event_rank = tf.size( input=self._observation_distribution.event_shape_tensor()) self.static_event_shape = tf.TensorShape([num_steps]).concatenate( self._observation_distribution.event_shape) with tf.control_dependencies(self._runtime_assertions): self.static_batch_shape = tf.broadcast_static_shape( self._initial_distribution.batch_shape, tf.broadcast_static_shape( self._transition_distribution.batch_shape[:-1], self._observation_distribution.batch_shape[:-1])) # pylint: disable=protected-access super(HiddenMarkovModel, self).__init__( dtype=self._observation_distribution.dtype, reparameterization_type=tf.compat.v1.distributions. NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=(self._initial_distribution._graph_parents + self._transition_distribution._graph_parents + self._observation_distribution._graph_parents), name=name) # pylint: enable=protected-access self._parameters = parameters
def _set_event_shape(shape, shape_tensor): if event_shape is None: return shape, shape_tensor return (tf.broadcast_static_shape(event_shape, shape), tf.broadcast_dynamic_shape(event_shape_tensor, shape_tensor))
def _batch_shape(self): return tf.broadcast_static_shape(self.low.shape, self.high.shape)
def _batch_shape(self): return tf.broadcast_static_shape(self.total_count.shape, self.probs.shape)
def _batch_shape(self): return tf.broadcast_static_shape( self.low.shape, self.high.shape)
def _batch_shape(self): return tf.broadcast_static_shape( tf.broadcast_static_shape(self.df.get_shape(), self.loc.get_shape()), self.scale.get_shape())
def _set_event_shape(shape, shape_tensor): if event_shape is None: return shape, shape_tensor return (tf.broadcast_static_shape(event_shape, shape), tf.broadcast_dynamic_shape(event_shape_tensor, shape_tensor))
def _batch_shape(self): return tf.broadcast_static_shape( self.mean_direction.shape.with_rank_at_least(1)[:-1], self.concentration.shape)
def _batch_shape(self): return tf.broadcast_static_shape(self.mass.shape, self.width.shape)
def _batch_shape(self): return tf.broadcast_static_shape(self.low.get_shape(), self.high.get_shape())
def _batch_shape(self): return tf.broadcast_static_shape(self._shape.get_shape(), self._scale.get_shape())
def _batch_shape(self): return tf.broadcast_static_shape(self.df.shape, self.scale_operator.batch_shape)
def _batch_shape(self): if self._is_empty_fixed_inputs(): return self._base_kernel.batch_shape() return tf.broadcast_static_shape( self._base_kernel.batch_shape, self._fixed_inputs.shape[:-self._base_kernel.feature_ndims])
def _batch_shape(self): return tf.broadcast_static_shape( self.mean_direction.shape.with_rank_at_least(1)[:-1], self.concentration.shape)
def sample_posterior_predictive( model: ModelType, trace: InferenceData, var_names: Optional[Union[str, List[str]]] = None, observed: Optional[Dict[str, Any]] = None, use_auto_batching: bool = True, inplace: bool = True, ) -> InferenceData: """ Draw ``sample_shape`` values from the model for the desired ``var_names``. Parameters ---------- model : types.GeneratorType, pymc4.Model Model to draw samples from trace: ArviZ's InferenceData object The samples drawn from the model's posterior distribution that should be used for sampling from the posterior predictive var_names: Optional[Union[str, List[str]]] The list of variable names that will be included in the returned samples. Strings can be used to specify a single variable. If ``None``, the samples drawn for all observed distributions will be returned in the ``Samples`` dictionary. observed : Optional[Dict[str, Any]] A dictionary that can be used to override the distribution observed values defined in the model. use_auto_batching: bool A bool value that indicates whether ``sample_posterior_predictive`` should automatically batch the draws or not. If you are sure you have manually tuned your model to be fully vectorized, then you can set this to ``False``, and your sampling should be faster than the auto batched counterpart. If you are not sure if your model is vectorized, then auto batching will safely sample from it but with some additional overhead. inplace: If True (default) it will add a posterior_predictive group to the provided ``trace``, instead of returning a new InferenceData object. If a posterior_predictive group is already present in ``trace`` it will be overwritten. Returns ------- Samples: InferenceDataType An ArviZ's InferenceData object with a posterior_predictive group Examples -------- Lets define a simple model to sample from >>> import pymc4 as pm >>> @pm.model ... def model(): ... sd = yield pm.HalfNormal("sd", 5.) ... norm = yield pm.Normal("n", 0, sd, observed=np.random.randn(100)) Now, we may want to draw samples from the model's posterior to then sample from the posterior predictive. >>> trace = pm.inference.sampling.sample(model()) >>> ppc = pm.sample_posterior_predictive(model(), trace).posterior_predictive The samples are returned as a dictionary with the variable names as keys >>> sorted(list(ppc)) ['model/n'] The drawn values are the dictionary's values, and their shape will depend on the supplied ``trace`` >>> ppc["model/n"].shape (10, 1000, 100) """ if var_names is not None and len(var_names) == 0: raise ValueError("Supplied an empty var_names list to sample from") if isinstance(var_names, str): var_names = [var_names] # If we don't have to deal with auto-batching we can simply evaluate_model # passing the trace as values if not use_auto_batching: values = { var_name: tf.convert_to_tensor(value) for var_name, value in trace.posterior.items() } # We need to pass the number of chains and draws as sample_shape for # observed conditionally independent variables sample_shape = (trace.posterior.sizes["chain"], trace.posterior.sizes["draw"]) _, state = evaluate_model_posterior_predictive( model, values=values, observed=observed, sample_shape=sample_shape) all_values = collections.ChainMap(state.all_values, state.deterministics_values) if var_names is None: var_names = list(state.posterior_predictives) output = {k: all_values[k] for k in var_names} return trace_to_arviz(trace=trace, posterior_predictive=output, inplace=inplace) # We cannot assume that the model is vectorized, so we have batch the # pm.evaluate_model_posterior_predictive calls across the trace entries # This brings one big problem: we need to infer the batch dimensions from # the trace. To do this, we will do # 1) A single forward pass with the meta executor to determine the # variable's shapes (we'll call these the core shapes) # 2) Go through the supplied trace to get each variable's batch shapes # (the shapes to the left of the core shapes) # 3) Broadcast the encountered batch shapes between each other as a sanity # check to get the global trace's batch_shape # 4) Broadcast the values in the trace to the global batch_shape to get # each variable's broadcasted value. # 5) As tf.vectorized_map only iterates across the first dimension, we want # to flatten the batch dimensions. To do this, we reshape the broadcasted # values to (-1,) + core_shape. This way, tf.vectorized_map will be able # to vectorize across the entire batch # 6) Collect the samples from, reshape them to batch_shape + core_shape and # return them # Do a single forward pass to infer the distributions core shapes and # default observeds _, state = evaluate_meta_posterior_predictive_model(model, observed=observed) if var_names is None: var_names = list(state.posterior_predictives) else: defined_variables = set(state.all_values) | set( state.deterministics_values) if not set(var_names) <= defined_variables: raise KeyError( "The supplied var_names = {} are not defined in the model.\n" "Defined variables are = {}".format( list(set(var_names) - defined_variables), list(defined_variables))) # Get the global batch_shape batch_shape = tf.TensorShape([]) # Get a copy of trace because we may manipulate the dictionary later in this # function posterior = trace.posterior.copy() # type: ignore posterior_names = list(posterior) for var_name in posterior_names: values = tf.convert_to_tensor(posterior[var_name].values) try: core_shape = state.all_values[var_name].shape except KeyError: if var_name in state.deterministics_values: # Remove the deterministics from the trace del posterior[var_name] continue else: raise TypeError( "Supplied the variable {} in the trace, yet this variable is " "not defined in the model: {!r}".format(var_name, state)) assert_values_compatible_with_distribution_shape( var_name, values, batch_shape=tf.TensorShape([]), event_shape=core_shape) batch_shape = tf.TensorShape( tf.broadcast_static_shape( values.shape[:len(values.shape) - len(core_shape)], # type: ignore batch_shape, )) # Flatten the batch axis flattened_posterior = [] for k, v in posterior.items(): core_shape = tf.TensorShape(state.all_values[k].shape) batched_val = tf.broadcast_to(v.values, batch_shape + core_shape) flattened_posterior.append( tf.reshape(batched_val, shape=[-1] + core_shape.as_list())) posterior_vars = list(posterior) # Setup the function that makes a single draw @tf.function(autograph=False) def single_draw(elems): values = dict(zip(posterior_vars, elems)) _, st = evaluate_model_posterior_predictive(model, values=values, observed=observed) return tuple([ (st.untransformed_values[k] if k in st.untransformed_values else (st.deterministics_values[k] if k in st.deterministics_values else st.transformed_values[k])) for k in var_names ]) # Make draws in parallel across the batch elements with tf.vectorized_map samples = tf.vectorized_map(single_draw, flattened_posterior) # Convert the samples to ndarrays and make a dictionary with the correct # batch_shape + core_shape output = dict() for name, sample in zip(var_names, samples): sample = sample.numpy() output[name] = np.reshape(sample, batch_shape + sample.shape[1:]) return trace_to_arviz(trace=trace, posterior_predictive=output, inplace=inplace)
def _batch_shape(self): return tf.broadcast_static_shape( self.distribution.batch_shape, self.mixture_distribution.logits.shape)[:-1]
def _batch_shape(self): return tf.broadcast_static_shape( self.df.shape, self.scale_operator.batch_shape)
def __init__(self, mean, stddev=None, logstd=None, group_event_ndims=None, check_numerics=False, name=None, default_name=None): # check the arguments if (stddev is None and logstd is None) or \ (stddev is not None and logstd is not None): raise ValueError('One and only one of `stddev`, `logstd` should ' 'be specified.') dtype = get_preferred_tensor_dtype(mean) if not dtype.is_floating: raise TypeError('Normal distribution parameters must be float ' 'numbers.') super(Normal, self).__init__(group_event_ndims=group_event_ndims, check_numerics=check_numerics, name=name, default_name=default_name) with reopen_variable_scope(self.variable_scope): with tf.name_scope('init'): # obtain parameter tensors mean = tf.convert_to_tensor(mean, dtype=dtype) if stddev is not None: stddev = tf.convert_to_tensor(stddev, dtype=dtype) self._stdx = stddev self._stdx_is_log = False else: logstd = tf.convert_to_tensor(logstd, dtype=dtype) self._stdx = logstd self._stdx_is_log = True # check the shape and data types of parameters self._mean = mean try: self._static_batch_shape = tf.broadcast_static_shape( self._mean.get_shape(), self._stdx.get_shape()) except ValueError: raise ValueError( '`mean` and `stddev`/`logstd` should be ' 'broadcastable to match each other (%r vs %r).' % (self._mean.get_shape(), self._stdx.get_shape())) self._dynamic_batch_shape = tf.broadcast_dynamic_shape( tf.shape(self._mean), tf.shape(self._stdx)) # derive the attributes of this Normal distribution if self._stdx_is_log: self._stddev = self._check_numerics( tf.exp(self._stdx, name='stddev'), 'stddev') self._logstd = self._stdx self._var = self._check_numerics( tf.exp(tf.constant(2., dtype=dtype) * self._logstd, name='variance'), 'variance') self._precision = self._check_numerics( tf.exp(tf.constant(-2., dtype=dtype) * self._logstd, name='precision'), 'precision') else: self._stddev = self._stdx self._logstd = self._check_numerics( tf.log(self._stdx, name='logstd'), 'logstd') self._var = tf.square(self._stddev, name='variance') self._precision = self._check_numerics( tf.divide(tf.constant(1., dtype=dtype), self._var, name='precision'), 'precision') self._logvar = tf.multiply(tf.constant(2., dtype=dtype), self._logstd, name='logvar') self._log_prec = tf.negative(self._logvar, name='log_precision')
def _batch_shape(self): return tf.broadcast_static_shape(self.total_count.get_shape(), self.probs.get_shape())
def _broadcast_static_shape(shape_x, shape_y): shape_x = tf.TensorShape(shape_x) shape_y = tf.TensorShape(shape_y) shape_xy = tf.broadcast_static_shape(shape_x, shape_y) return np.array(shape_xy, dtype=np.int32)
def kl_divergence(self, a, b, name=None): """Batched KL divergence `KL(a || b)` for multivariate Normals. With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and covariance `C_a`, `C_b` respectively, ``` KL(a || b) = 0.5 * ( L - k + T + Q ), L := Log[Det(C_b)] - Log[Det(C_a)] T := trace(C_b^{-1} C_a), Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a), ``` This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient methods for solving systems with `C_b` may be available, a dense version of (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B` is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x` and `y`. Args: a: Instance of `MultivariateNormalLinearOperator`. b: Instance of `MultivariateNormalLinearOperator`. name: (optional) name to use for created ops. Default "kl_mvn". Returns: Batchwise `KL(a || b)`. """ def squared_frobenius_norm(x): """Helper to make KL calculation slightly more readable.""" return tf.reduce_sum(tf.square(x), axis=[-2, -1]) def log_abs_determinant(scale): return tf.reduce_sum( tf.log(tf.abs(scale._diag) + tf.constant(1e-8)), reduction_indices=[-1]) with tf.name_scope(name, "kl_mvn", values=[a.loc, b.loc] + a.scale.graph_parents + b.scale.graph_parents): # Calculation is based on: # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians # and, # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm # i.e., # If Ca = AA', Cb = BB', then # tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A'] # = tr[inv(B) A A' inv(B)'] # = tr[(inv(B) A) (inv(B) A)'] # = sum_{ij} (inv(B) A)_{ij}**2 # = ||inv(B) A||_F**2 # where ||.||_F is the Frobenius norm and the second equality follows from # the cyclic permutation property. b_inv_a = b.scale.solve(a.scale.to_dense()) # self.variable_print_append(b_inv_a) kl_div = ( log_abs_determinant(b.scale) - log_abs_determinant(a.scale) + 0.5 * (-tf.cast(a.scale.domain_dimension_tensor(), a.dtype) + squared_frobenius_norm(b_inv_a) + squared_frobenius_norm( b.scale.solve((b.mean() - a.mean())[..., tf.newaxis])))) kl_div.set_shape( tf.broadcast_static_shape(a.batch_shape, b.batch_shape)) # self.variable_print_append(kl_div) return kl_div
def _batch_shape(self): return tf.broadcast_static_shape(self.concentration.get_shape(), self.rate.get_shape())
def modify_distribution(self, dist: ModelType, model_info: Mapping[str, Any], state: SamplingState) -> ModelType: """Remove the observed distribution values but keep their shapes. Modify observed Distribution instances in the following way: 1) The distribution's shape (batch_shape + event_shape) will be checked for consitency with the supplied observed value's shape. 2) If they are inconsistent, an EvaluationError will be raised. 3) If they are consistent the distribution's observed values' shape will be broadcasted with the distribution's shape to construct a new Distribution instance with no observations. 4) This distribution will be yielded instead of the original incoming dist, and it will be used for posterior predictive sampling Parameters ---------- dist: Union[types.GeneratorType, pymc4.coroutine_model.Model] The model_info: Mapping[str, Any] Either ``dist.model_info`` or ``pymc4.coroutine_model.Model.default_model_info`` if ``dist`` is not a ``pymc4.courutine_model.Model`` instance. state: SamplingState The model's evaluation state. Returns ------- model: Union[types.GeneratorType, pymc4.coroutine_model.Model] The original ``dist`` if it was not an observed ``Distribution`` or the ``Distribution`` with the changed ``batch_shape`` and observations set to ``None``. Raises ------ EvaluationError When ``dist`` and its passed observed value don't have a consistent shape """ dist = super().modify_distribution(dist, model_info, state) # We only modify the shape of Distribution instances that have observed # values dist = transform_dist_if_necessary( dist, state, allow_transformed_and_untransformed=False) if not isinstance(dist, Distribution): return dist scoped_name = scopes.variable_name(dist.name) if scoped_name is None: raise EvaluationError( "Attempting to create an anonymous Distribution") observed_value = observed_value_in_evaluation(scoped_name, dist, state) if observed_value is None: return dist # We set the state's observed value to None to explicitly override # any previously given observed and at the same time, have the # scope_name added to the posterior_predictives set in # self.proceed_distribution state.observed_values[scoped_name] = None # We first check the TFP distribution's shape and compare it with the # observed_value's shape assert_values_compatible_with_distribution(scoped_name, observed_value, dist) # Now we get the broadcasted shape between the observed value and the distribution observed_shape = get_observed_tensor_shape(observed_value) dist_shape = dist.batch_shape + dist.event_shape new_dist_shape = tf.broadcast_static_shape(observed_shape, dist_shape) extra_batch_stack = new_dist_shape[:len(new_dist_shape) - len(dist_shape)] # Now we construct and return the same distribution but setting # observed to None and setting a batch_size that matches the result of # broadcasting the observed and distribution shape batch_stack = extra_batch_stack + (dist.batch_stack if dist.batch_stack is not None else ()) if len(batch_stack) > 0: reinterpreted_batch_ndims = dist.reinterpreted_batch_ndims if dist.event_stack: reinterpreted_batch_ndims += len(extra_batch_stack) new_dist = type(dist)( name=dist.name, transform=dist.transform, observed=None, batch_stack=batch_stack, conditionally_independent=dist.conditionally_independent, event_stack=dist.event_stack, reinterpreted_batch_ndims=reinterpreted_batch_ndims, **dist.conditions, ) else: new_dist = type(dist)( name=dist.name, transform=dist.transform, observed=None, batch_stack=None, conditionally_independent=dist.conditionally_independent, event_stack=dist.event_stack, reinterpreted_batch_ndims=dist.reinterpreted_batch_ndims, **dist.conditions, ) return new_dist
def masked_reconstruct(reconstruct, x, mask, validate_shape=True, name=None): """ Replace masked elements of `x` with reconstructed outputs. This method can be used to do missing data imputation on `x`, with the reconstruction outputs for `x`. Args: reconstruct ((tf.Tensor) -> tf.Tensor): Function for reconstructing `x`. x: The tensor to be reconstructed by `func`. mask: `int32` mask, must be broadcastable into the shape of `x`. Indicating whether or not to mask each element of `x`. validate_shape (bool): Whether or not to validate the shape of `mask`? (default :obj:`True`) name (str): Name of this operation in TensorFlow graph. (default "masked_reconstruct") Returns: tf.Tensor: `x` with masked elements replaced by reconstructed outputs. """ with tf.name_scope(name, default_name='masked_reconstruct'): x = tf.convert_to_tensor(x) # type: tf.Tensor mask = tf.convert_to_tensor(mask, dtype=tf.int32) # type: tf.Tensor # broadcast mask against x old_mask = mask try: _ = tf.broadcast_static_shape(x.get_shape(), mask.get_shape()) except ValueError: raise ValueError('Shape of `mask` cannot broadcast ' 'into the shape of `x` ({!r} vs {!r})'.format( old_mask.get_shape(), x.get_shape())) mask = mask * tf.ones_like(x, dtype=mask.dtype) # validate the shape of mask if validate_shape: x_shape = x.get_shape() mask_shape = mask.get_shape() if mask_shape.is_fully_defined() and x_shape.is_fully_defined(): if mask_shape != x_shape: # the only possible situation is that mask has more # dimension than x, and we consider this situation invalid raise ValueError( 'Shape of `mask` cannot broadcast ' 'into the shape of `x` ({!r} vs {!r})'.format( old_mask.get_shape(), x_shape)) else: assert_op = tf.assert_equal( # since already broadcasted by x * ones_like(x), # we only need to compare the rank tf.rank(x), tf.rank(mask), message='Shape of `mask` cannot broadcast into the ' 'shape of `x`') with tf.control_dependencies([assert_op]): mask = tf.identity(mask) # get reconstructed x r_x = reconstruct(x) # get masked outputs return tf.where(tf.cast(mask, dtype=tf.bool), r_x, x)
def _batch_shape(self): return tf.broadcast_static_shape(self.loc.shape, self.scale.shape)
def _batch_shape(self): return tf.broadcast_static_shape(self.amplitude.shape, self.length_scale.shape)
def _batch_shape(self): return tf.broadcast_static_shape( tf.broadcast_static_shape(_maybe_shape_static(self.slope_variance), _maybe_shape_static(self.bias_variance)), _maybe_shape_static(self.exponent))
def _batch_shape(self): return tf.broadcast_static_shape(self.loc.shape, self.concentration.shape)
def __init__(self, mean, log_scale, bin_size, min_val=None, max_val=None, dtype=tf.float32, biased_edges=True, epsilon=1e-7): """ Construct a new :class:`DiscretizedLogistic`. Args: mean: A Tensor, the `mean`. log_scale: A Tensor, the `log(scale)`. bin_size: A scalar, the `bin_size`. min_val: A scalar, the minimum possible value of `x`. max_val: A scalar, the maximum possible value of `x`. dtype: The data type of `x`. biased_edges: Whether or not to use bias density for edge values? See above. epsilon: Small float to avoid dividing by zero or taking logarithm of zero. """ # check the arguments mean = tf.convert_to_tensor(mean) param_dtype = mean.dtype log_scale = tf.convert_to_tensor(log_scale) dtype = tf.as_dtype(dtype) if not is_integer_number(bin_size) and not dtype.is_floating: raise ValueError( '`bin_size` is a float number, but `dtype` is not a float ' 'number type: {}'.format(dtype) ) if min_val is not None: if not is_integer_number(min_val / bin_size): raise ValueError( '`min_val` must be multiples of `bin_size`: ' 'min_val {} vs bin_size {}'.format(min_val, bin_size) ) if max_val is not None: if not is_integer_number(max_val / bin_size): raise ValueError( '`max_val` must be multiples of `bin_size`: ' 'max_val {} vs bin_size {}'.format(max_val, bin_size) ) # infer the batch shape try: batch_static_shape = tf.broadcast_static_shape( mean.get_shape(), log_scale.get_shape()) except ValueError: raise ValueError('The shape of `mean` and `log_scale` cannot ' 'be broadcasted: mean {} vs log_scale {}'. format(mean, log_scale)) with tf.name_scope('DiscretizedLogistic.init'): batch_shape = tf.broadcast_dynamic_shape(tf.shape(mean), tf.shape(log_scale)) # memorize the arguments and call parent constructor bin_size = convert_to_tensor_and_cast(bin_size, param_dtype) if min_val is not None: min_val = convert_to_tensor_and_cast(min_val, param_dtype) if max_val is not None: max_val = convert_to_tensor_and_cast(max_val, param_dtype) self._mean = mean self._log_scale = log_scale self._param_dtype = param_dtype self._bin_size = bin_size self._min_val = min_val self._max_val = max_val self._biased_edges = bool(biased_edges) self._epsilon = epsilon super(DiscretizedLogistic, self).__init__( dtype=dtype, is_continuous=False, is_reparameterized=False, batch_shape=batch_shape, batch_static_shape=batch_static_shape, value_ndims=0 )
def _batch_shape(self): return tf.broadcast_static_shape(self.temperature.shape, self.logits.shape[:-1])
def _batch_shape(self): scalar_shape = tf.TensorShape([]) return tf.broadcast_static_shape( scalar_shape if self.amplitude is None else self.amplitude.shape, scalar_shape if self.length_scale is None else self.length_scale.shape)
def _batch_shape(self): return tf.broadcast_static_shape( self.distribution.batch_shape, self.mixture_distribution.logits.shape)[:-1]
def __init__(self, initial_distribution, transition_distribution, observation_distribution, num_steps, validate_args=False, allow_nan_stats=True, name="HiddenMarkovModel"): """Initialize hidden Markov model. Args: initial_distribution: A `Categorical`-like instance. Determines probability of first hidden state in Markov chain. The number of categories must match the number of categories of `transition_distribution` as well as both the rightmost batch dimension of `transition_distribution` and the rightmost batch dimension of `observation_distribution`. transition_distribution: A `Categorical`-like instance. The rightmost batch dimension indexes the probability distribution of each hidden state conditioned on the previous hidden state. observation_distribution: A `tfp.distributions.Distribution`-like instance. The rightmost batch dimension indexes the distribution of each observation conditioned on the corresponding hidden state. num_steps: The number of steps taken in Markov chain. A python `int`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `True`. name: Python `str` name prefixed to Ops created by this class. Default value: "HiddenMarkovModel". Raises: ValueError: if `num_steps` is not at least 1. ValueError: if `initial_distribution` does not have scalar `event_shape`. ValueError: if `transition_distribution` does not have scalar `event_shape.` ValueError: if `transition_distribution` and `observation_distribution` are fully defined but don't have matching rightmost dimension. """ parameters = dict(locals()) # pylint: disable=protected-access with tf.name_scope(name=name, values=( initial_distribution._graph_parents + transition_distribution._graph_parents + observation_distribution._graph_parents)) as name: self._runtime_assertions = [] # pylint: enable=protected-access if num_steps < 1: raise ValueError("num_steps ({}) must be at least 1.".format(num_steps)) self._initial_distribution = initial_distribution self._observation_distribution = observation_distribution self._transition_distribution = transition_distribution if (initial_distribution.event_shape is not None and initial_distribution.event_shape.ndims != 0): raise ValueError( "`initial_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ tf.assert_equal( tf.shape(initial_distribution.event_shape_tensor())[0], 0, message="`initial_distribution` must have scalar" "`event_dim`s")] if (transition_distribution.event_shape is not None and transition_distribution.event_shape.ndims != 0): raise ValueError( "`transition_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ tf.assert_equal( tf.shape(transition_distribution.event_shape_tensor())[0], 0, message="`transition_distribution` must have scalar" "`event_dim`s")] if (transition_distribution.batch_shape is not None and transition_distribution.batch_shape.ndims == 0): raise ValueError( "`transition_distribution` can't have scalar batches") elif validate_args: self._runtime_assertions += [ tf.assert_greater( tf.size(transition_distribution.batch_shape_tensor()), 0, message="`transition_distribution` can't have scalar " "batches")] if (observation_distribution.batch_shape is not None and observation_distribution.batch_shape.ndims == 0): raise ValueError( "`observation_distribution` can't have scalar batches") elif validate_args: self._runtime_assertions += [ tf.assert_greater( tf.size(observation_distribution.batch_shape_tensor()), 0, message="`observation_distribution` can't have scalar " "batches")] # Infer number of hidden states and check consistency # between transitions and observations with tf.control_dependencies(self._runtime_assertions): self._num_states = ((transition_distribution.batch_shape and transition_distribution.batch_shape[-1]) or transition_distribution.batch_shape_tensor()[-1]) observation_states = ((observation_distribution.batch_shape and observation_distribution.batch_shape[-1]) or observation_distribution.batch_shape_tensor()[-1]) if (tf.contrib.framework.is_tensor(self._num_states) or tf.contrib.framework.is_tensor(observation_states)): if validate_args: self._runtime_assertions += [ tf.assert_equal( self._num_states, observation_states, message="`transition_distribution` and " "`observation_distribution` must agree on " "last dimension of batch size")] elif self._num_states != observation_states: raise ValueError("`transition_distribution` and " "`observation_distribution` must agree on " "last dimension of batch size") self._log_init = _extract_log_probs(self._num_states, initial_distribution) self._log_trans = _extract_log_probs(self._num_states, transition_distribution) self._num_steps = num_steps self._num_states = tf.shape(self._log_init)[-1] self._underlying_event_rank = ( self._observation_distribution.event_shape.ndims) self.static_event_shape = tf.TensorShape( [num_steps]).concatenate(self._observation_distribution.event_shape) with tf.control_dependencies(self._runtime_assertions): self.static_batch_shape = tf.broadcast_static_shape( self._initial_distribution.batch_shape, tf.broadcast_static_shape( self._transition_distribution.batch_shape[:-1], self._observation_distribution.batch_shape[:-1])) # pylint: disable=protected-access super(HiddenMarkovModel, self).__init__( dtype=self._observation_distribution.dtype, reparameterization_type=tf.distributions.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( self._initial_distribution._graph_parents + self._transition_distribution._graph_parents + self._observation_distribution._graph_parents), name=name) # pylint: enable=protected-access self._parameters = parameters
def _batch_shape(self): return tf.broadcast_static_shape( self.amplitude.shape, self.length_scale.shape)
def _batch_shape(self): return tf.broadcast_static_shape(self.loc.get_shape(), self.scale.get_shape())
def _batch_shape(self): return tf.broadcast_static_shape(self.concentration.shape, self.rate.shape)
def _static_broadcast_shape_from_tensors(*tensors): shape = tensors[0].get_shape() for t in tensors[1:]: shape = tf.broadcast_static_shape(shape, t.get_shape()) return shape
def _batch_shape(self): return tf.broadcast_static_shape(self._loc.get_shape()[:-1], self._radius_dist.batch_shape)
def _kl_brute_force(a, b, name=None): """Batched KL divergence `KL(a || b)` for multivariate Normals. With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and covariance `C_a`, `C_b` respectively, ``` KL(a || b) = 0.5 * ( L - k + T + Q ), L := Log[Det(C_b)] - Log[Det(C_a)] T := trace(C_b^{-1} C_a), Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a), ``` This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient methods for solving systems with `C_b` may be available, a dense version of (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B` is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x` and `y`. Args: a: Instance of `MultivariateNormalLinearOperator`. b: Instance of `MultivariateNormalLinearOperator`. name: (optional) name to use for created ops. Default "kl_mvn". Returns: Batchwise `KL(a || b)`. """ def squared_frobenius_norm(x): """Helper to make KL calculation slightly more readable.""" # http://mathworld.wolfram.com/FrobeniusNorm.html # The gradient of KL[p,q] is not defined when p==q. The culprit is # tf.norm, i.e., we cannot use the commented out code. # return tf.square(tf.norm(x, ord="fro", axis=[-2, -1])) return tf.reduce_sum(tf.square(x), axis=[-2, -1]) # TODO(b/35041439): See also b/35040945. Remove this function once LinOp # supports something like: # A.inverse().solve(B).norm(order='fro', axis=[-1, -2]) def is_diagonal(x): """Helper to identify if `LinearOperator` has only a diagonal component.""" return (isinstance(x, tf.linalg.LinearOperatorIdentity) or isinstance(x, tf.linalg.LinearOperatorScaledIdentity) or isinstance(x, tf.linalg.LinearOperatorDiag)) with tf.name_scope( name, "kl_mvn", values=[a.loc, b.loc] + a.scale.graph_parents + b.scale.graph_parents): # Calculation is based on: # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians # and, # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm # i.e., # If Ca = AA', Cb = BB', then # tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A'] # = tr[inv(B) A A' inv(B)'] # = tr[(inv(B) A) (inv(B) A)'] # = sum_{ij} (inv(B) A)_{ij}**2 # = ||inv(B) A||_F**2 # where ||.||_F is the Frobenius norm and the second equality follows from # the cyclic permutation property. if is_diagonal(a.scale) and is_diagonal(b.scale): # Using `stddev` because it handles expansion of Identity cases. b_inv_a = (a.stddev() / b.stddev())[..., tf.newaxis] else: b_inv_a = b.scale.solve(a.scale.to_dense()) kl_div = ( b.scale.log_abs_determinant() - a.scale.log_abs_determinant() + 0.5 * (-tf.cast(a.scale.domain_dimension_tensor(), a.dtype) + squared_frobenius_norm(b_inv_a) + squared_frobenius_norm( b.scale.solve((b.mean() - a.mean())[..., tf.newaxis])))) kl_div.set_shape(tf.broadcast_static_shape(a.batch_shape, b.batch_shape)) return kl_div
def _batch_shape(self): return tf.broadcast_static_shape(self.loc.shape, self.scale.shape)
def _kl_brute_force(a, b, name=None): """Batched KL divergence `KL(a || b)` for multivariate Normals. With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and covariance `C_a`, `C_b` respectively, ``` KL(a || b) = 0.5 * ( L - k + T + Q ), L := Log[Det(C_b)] - Log[Det(C_a)] T := trace(C_b^{-1} C_a), Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a), ``` This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient methods for solving systems with `C_b` may be available, a dense version of (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B` is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x` and `y`. Args: a: Instance of `MultivariateNormalLinearOperator`. b: Instance of `MultivariateNormalLinearOperator`. name: (optional) name to use for created ops. Default "kl_mvn". Returns: Batchwise `KL(a || b)`. """ def squared_frobenius_norm(x): """Helper to make KL calculation slightly more readable.""" # http://mathworld.wolfram.com/FrobeniusNorm.html # The gradient of KL[p,q] is not defined when p==q. The culprit is # tf.norm, i.e., we cannot use the commented out code. # return tf.square(tf.norm(x, ord="fro", axis=[-2, -1])) return tf.reduce_sum(tf.square(x), axis=[-2, -1]) # TODO(b/35041439): See also b/35040945. Remove this function once LinOp # supports something like: # A.inverse().solve(B).norm(order='fro', axis=[-1, -2]) def is_diagonal(x): """Helper to identify if `LinearOperator` has only a diagonal component.""" return (isinstance(x, tf.linalg.LinearOperatorIdentity) or isinstance(x, tf.linalg.LinearOperatorScaledIdentity) or isinstance(x, tf.linalg.LinearOperatorDiag)) with tf.name_scope(name, "kl_mvn", values=[a.loc, b.loc] + a.scale.graph_parents + b.scale.graph_parents): # Calculation is based on: # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians # and, # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm # i.e., # If Ca = AA', Cb = BB', then # tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A'] # = tr[inv(B) A A' inv(B)'] # = tr[(inv(B) A) (inv(B) A)'] # = sum_{ij} (inv(B) A)_{ij}**2 # = ||inv(B) A||_F**2 # where ||.||_F is the Frobenius norm and the second equality follows from # the cyclic permutation property. if is_diagonal(a.scale) and is_diagonal(b.scale): # Using `stddev` because it handles expansion of Identity cases. b_inv_a = (a.stddev() / b.stddev())[..., tf.newaxis] else: b_inv_a = b.scale.solve(a.scale.to_dense()) kl_div = (b.scale.log_abs_determinant() - a.scale.log_abs_determinant() + 0.5 * (-tf.cast(a.scale.domain_dimension_tensor(), a.dtype) + squared_frobenius_norm(b_inv_a) + squared_frobenius_norm( b.scale.solve((b.mean() - a.mean())[..., tf.newaxis])))) kl_div.set_shape( tf.broadcast_static_shape(a.batch_shape, b.batch_shape)) return kl_div