def determine_batch_event_shapes(grid, endpoint_affine): """Helper to infer batch_shape and event_shape.""" with tf.name_scope("determine_batch_event_shapes"): # grid # shape: [B, k, q] # endpoint_affine # len=k, shape: [B, d, d] batch_shape = grid.shape[:-2] batch_shape_tensor = tf.shape(grid)[:-2] event_shape = None event_shape_tensor = None def _set_event_shape(shape, shape_tensor): if event_shape is None: return shape, shape_tensor return (tf.broadcast_static_shape(event_shape, shape), tf.broadcast_dynamic_shape(event_shape_tensor, shape_tensor)) for aff in endpoint_affine: if aff.shift is not None: batch_shape = tf.broadcast_static_shape( batch_shape, aff.shift.shape[:-1]) batch_shape_tensor = tf.broadcast_dynamic_shape( batch_shape_tensor, tf.shape(aff.shift)[:-1]) event_shape, event_shape_tensor = _set_event_shape( aff.shift.shape[-1:], tf.shape(aff.shift)[-1:]) if aff.scale is not None: batch_shape = tf.broadcast_static_shape( batch_shape, aff.scale.batch_shape) batch_shape_tensor = tf.broadcast_dynamic_shape( batch_shape_tensor, aff.scale.batch_shape_tensor()) event_shape, event_shape_tensor = _set_event_shape( tf.TensorShape([aff.scale.range_dimension]), aff.scale.range_dimension_tensor()[tf.newaxis]) return batch_shape, batch_shape_tensor, event_shape, event_shape_tensor
def _finish_prob_for_one_fiber(self, y, x, ildj, event_ndims, **distribution_kwargs): """Finish computation of prob on one element of the inverse image.""" x = self._maybe_rotate_dims(x, rotate_right=True) prob = self.distribution.prob(x, **distribution_kwargs) if self._is_maybe_event_override: prob = tf.reduce_prod(prob, axis=self._reduce_event_indices) prob = prob * tf.exp(tf.cast(ildj, prob.dtype)) if self._is_maybe_event_override and isinstance(event_ndims, int): tensorshape_util.set_shape( prob, tf.broadcast_static_shape( tensorshape_util.with_rank_at_least(y.shape, 1)[:-event_ndims], self.batch_shape)) return prob
def broadcast_shape(x_shape, y_shape): """Computes the shape of a broadcast. When both arguments are statically-known, the broadcasted shape will be computed statically and returned as a `TensorShape`. Otherwise, a rank-1 `Tensor` will be returned. Arguments: x_shape: A `TensorShape` or rank-1 integer `Tensor`. The input `Tensor` is broadcast against this shape. y_shape: A `TensorShape` or rank-1 integer `Tensor`. The input `Tensor` is broadcast against this shape. Returns: shape: A `TensorShape` or rank-1 integer `Tensor` representing the broadcasted shape. """ x_shape_static = tf.get_static_value(x_shape) y_shape_static = tf.get_static_value(y_shape) if (x_shape_static is None) or (y_shape_static is None): return tf.broadcast_dynamic_shape(x_shape, y_shape) return tf.broadcast_static_shape( tf.TensorShape(x_shape_static), tf.TensorShape(y_shape_static))
def _batch_shape(self): param = self._logits if self._logits is not None else self._probs return tf.broadcast_static_shape(self.temperature.shape, param.shape[:-1])
def _batch_shape(self): return tf.broadcast_static_shape(self.loc.shape, self.scale.shape)
def _batch_shape(self): return tf.broadcast_static_shape(self.concentration1.shape, self.concentration0.shape)
def _batch_shape(self): return tf.broadcast_static_shape( tensorshape_util.with_rank_at_least(self.mean_direction.shape, 1)[:-1], self.concentration.shape)
def _batch_shape(self): x = self._probs if self._logits is None else self._logits return tf.broadcast_static_shape(self._total_count.shape, x.shape)
def _kl_brute_force(a, b, name=None): """Batched KL divergence `KL(a || b)` for multivariate Normals. With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and covariance `C_a`, `C_b` respectively, ``` KL(a || b) = 0.5 * ( L - k + T + Q ), L := Log[Det(C_b)] - Log[Det(C_a)] T := trace(C_b^{-1} C_a), Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a), ``` This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient methods for solving systems with `C_b` may be available, a dense version of (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B` is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x` and `y`. Args: a: Instance of `MultivariateNormalLinearOperator`. b: Instance of `MultivariateNormalLinearOperator`. name: (optional) name to use for created ops. Default "kl_mvn". Returns: Batchwise `KL(a || b)`. """ def squared_frobenius_norm(x): """Helper to make KL calculation slightly more readable.""" # http://mathworld.wolfram.com/FrobeniusNorm.html # The gradient of KL[p,q] is not defined when p==q. The culprit is # tf.norm, i.e., we cannot use the commented out code. # return tf.square(tf.norm(x, ord="fro", axis=[-2, -1])) return tf.reduce_sum(tf.square(x), axis=[-2, -1]) # TODO(b/35041439): See also b/35040945. Remove this function once LinOp # supports something like: # A.inverse().solve(B).norm(order='fro', axis=[-1, -2]) def is_diagonal(x): """Helper to identify if `LinearOperator` has only a diagonal component.""" return (isinstance(x, tf.linalg.LinearOperatorIdentity) or isinstance(x, tf.linalg.LinearOperatorScaledIdentity) or isinstance(x, tf.linalg.LinearOperatorDiag)) with tf.name_scope(name or 'kl_mvn'): # Calculation is based on: # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians # and, # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm # i.e., # If Ca = AA', Cb = BB', then # tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A'] # = tr[inv(B) A A' inv(B)'] # = tr[(inv(B) A) (inv(B) A)'] # = sum_{ij} (inv(B) A)_{ij}**2 # = ||inv(B) A||_F**2 # where ||.||_F is the Frobenius norm and the second equality follows from # the cyclic permutation property. if is_diagonal(a.scale) and is_diagonal(b.scale): # Using `stddev` because it handles expansion of Identity cases. b_inv_a = (a.stddev() / b.stddev())[..., tf.newaxis] else: b_inv_a = b.scale.solve(a.scale.to_dense()) kl_div = (b.scale.log_abs_determinant() - a.scale.log_abs_determinant() + 0.5 * (-tf.cast(a.scale.domain_dimension_tensor(), a.dtype) + squared_frobenius_norm(b_inv_a) + squared_frobenius_norm( b.scale.solve((b.mean() - a.mean())[..., tf.newaxis])))) tensorshape_util.set_shape( kl_div, tf.broadcast_static_shape(a.batch_shape, b.batch_shape)) return kl_div
def _batch_shape(self): return tf.broadcast_static_shape( self.loc.shape, tf.broadcast_static_shape(self.atol.shape, self.rtol.shape))[:-1]
def _log_prob(self, x): if self.input_output_cholesky: x_sqrt = x else: # Complexity: O(nbk**3) x_sqrt = tf.linalg.cholesky(x) batch_shape = self.batch_shape_tensor() event_shape = self.event_shape_tensor() x_ndims = tf.rank(x_sqrt) num_singleton_axes_to_prepend = ( tf.maximum(tf.size(batch_shape) + 2, x_ndims) - x_ndims) x_with_prepended_singletons_shape = tf.concat([ tf.ones([num_singleton_axes_to_prepend], dtype=tf.int32), tf.shape(x_sqrt) ], 0) x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape) ndims = tf.rank(x_sqrt) # sample_ndims = ndims - batch_ndims - event_ndims sample_ndims = ndims - tf.size(batch_shape) - 2 sample_shape = tf.shape(x_sqrt)[:sample_ndims] # We need to be able to pre-multiply each matrix by its corresponding # batch scale matrix. Since a Distribution Tensor supports multiple # samples per batch, this means we need to reshape the input matrix `x` # so that the first b dimensions are batch dimensions and the last two # are of shape [dimension, dimensions*number_of_samples]. Doing these # gymnastics allows us to do a batch_solve. # # After we're done with sqrt_solve (the batch operation) we need to undo # this reshaping so what we're left with is a Tensor partitionable by # sample, batch, event dimensions. # Complexity: O(nbk**2) since transpose must access every element. scale_sqrt_inv_x_sqrt = x_sqrt perm = tf.concat( [tf.range(sample_ndims, ndims), tf.range(0, sample_ndims)], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) last_dim_size = ( tf.cast(self.dimension, dtype=tf.int32) * tf.reduce_prod(x_with_prepended_singletons_shape[:sample_ndims])) shape = tf.concat([ x_with_prepended_singletons_shape[sample_ndims:-2], [tf.cast(self.dimension, dtype=tf.int32), last_dim_size] ], axis=0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) # Complexity: O(nbM*k) where M is the complexity of the operator solving a # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so # this step has complexity O(nbk^3). scale_sqrt_inv_x_sqrt = self.scale_operator.solve( scale_sqrt_inv_x_sqrt) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = tf.concat( [tf.shape(scale_sqrt_inv_x_sqrt)[:-2], event_shape, sample_shape], axis=0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) perm = tf.concat([ tf.range(ndims - sample_ndims, ndims), tf.range(0, ndims - sample_ndims) ], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) # Write V = SS', X = LL'. Then: # tr[inv(V) X] = tr[inv(S)' inv(S) L L'] # = tr[inv(S) L L' inv(S)'] # = tr[(inv(S) L) (inv(S) L)'] # = sum_{ik} (inv(S) L)_{ik}**2 # The second equality follows from the cyclic permutation property. # Complexity: O(nbk**2) trace_scale_inv_x = tf.reduce_sum(tf.square(scale_sqrt_inv_x_sqrt), axis=[-2, -1]) # Complexity: O(nbk) half_log_det_x = tf.reduce_sum(tf.math.log( tf.linalg.diag_part(x_sqrt)), axis=[-1]) # Complexity: O(nbk**2) log_prob = ((self.df - self.dimension - 1.) * half_log_det_x - 0.5 * trace_scale_inv_x - self.log_normalization()) # Set shape hints. # Try to merge what we know from the input x with what we know from the # parameters of this distribution. if tensorshape_util.rank( x.shape) is not None and tensorshape_util.rank( self.batch_shape) is not None: tensorshape_util.set_shape( log_prob, tf.broadcast_static_shape(x.shape[:-2], self.batch_shape)) return log_prob
def _batch_shape(self): return tf.broadcast_static_shape(self.df.shape, self.scale_operator.batch_shape)
def _set_event_shape(shape, shape_tensor): if event_shape is None: return shape, shape_tensor return (tf.broadcast_static_shape(event_shape, shape), tf.broadcast_dynamic_shape(event_shape_tensor, shape_tensor))
def __init__(self, initial_distribution, transition_distribution, observation_distribution, num_steps, validate_args=False, allow_nan_stats=True, name="HiddenMarkovModel"): """Initialize hidden Markov model. Args: initial_distribution: A `Categorical`-like instance. Determines probability of first hidden state in Markov chain. The number of categories must match the number of categories of `transition_distribution` as well as both the rightmost batch dimension of `transition_distribution` and the rightmost batch dimension of `observation_distribution`. transition_distribution: A `Categorical`-like instance. The rightmost batch dimension indexes the probability distribution of each hidden state conditioned on the previous hidden state. observation_distribution: A `tfp.distributions.Distribution`-like instance. The rightmost batch dimension indexes the distribution of each observation conditioned on the corresponding hidden state. num_steps: The number of steps taken in Markov chain. A python `int`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `True`. name: Python `str` name prefixed to Ops created by this class. Default value: "HiddenMarkovModel". Raises: ValueError: if `num_steps` is not at least 1. ValueError: if `initial_distribution` does not have scalar `event_shape`. ValueError: if `transition_distribution` does not have scalar `event_shape.` ValueError: if `transition_distribution` and `observation_distribution` are fully defined but don't have matching rightmost dimension. """ parameters = dict(locals()) # pylint: disable=protected-access with tf.name_scope(name) as name: self._runtime_assertions = [] # pylint: enable=protected-access num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps") if validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.rank(num_steps), 0, message="`num_steps` must be a scalar") ] self._runtime_assertions += [ assert_util.assert_greater_equal( num_steps, 1, message="`num_steps` must be at least 1.") ] self._initial_distribution = initial_distribution self._observation_distribution = observation_distribution self._transition_distribution = transition_distribution if (initial_distribution.event_shape is not None and tensorshape_util.rank( initial_distribution.event_shape) != 0): raise ValueError( "`initial_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.shape(initial_distribution.event_shape_tensor())[0], 0, message="`initial_distribution` must have scalar" "`event_dim`s") ] if (transition_distribution.event_shape is not None and tensorshape_util.rank( transition_distribution.event_shape) != 0): raise ValueError( "`transition_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.shape( transition_distribution.event_shape_tensor())[0], 0, message="`transition_distribution` must have scalar" "`event_dim`s") ] if (transition_distribution.batch_shape is not None and tensorshape_util.rank( transition_distribution.batch_shape) == 0): raise ValueError( "`transition_distribution` can't have scalar batches") elif validate_args: self._runtime_assertions += [ assert_util.assert_greater( tf.size(transition_distribution.batch_shape_tensor()), 0, message="`transition_distribution` can't have scalar " "batches") ] if (observation_distribution.batch_shape is not None and tensorshape_util.rank( observation_distribution.batch_shape) == 0): raise ValueError( "`observation_distribution` can't have scalar batches") elif validate_args: self._runtime_assertions += [ assert_util.assert_greater( tf.size(observation_distribution.batch_shape_tensor()), 0, message="`observation_distribution` can't have scalar " "batches") ] # Infer number of hidden states and check consistency # between transitions and observations with tf.control_dependencies(self._runtime_assertions): self._num_states = ( (transition_distribution.batch_shape and transition_distribution.batch_shape[-1]) or transition_distribution.batch_shape_tensor()[-1]) observation_states = ( (observation_distribution.batch_shape and observation_distribution.batch_shape[-1]) or observation_distribution.batch_shape_tensor()[-1]) if (tf.is_tensor(self._num_states) or tf.is_tensor(observation_states)): if validate_args: self._runtime_assertions += [ assert_util.assert_equal( self._num_states, observation_states, message="`transition_distribution` and " "`observation_distribution` must agree on " "last dimension of batch size") ] elif self._num_states != observation_states: raise ValueError("`transition_distribution` and " "`observation_distribution` must agree on " "last dimension of batch size") self._log_init = _extract_log_probs(self._num_states, initial_distribution) self._log_trans = _extract_log_probs(self._num_states, transition_distribution) self._num_steps = num_steps self._num_states = tf.shape(self._log_init)[-1] self._underlying_event_rank = tf.size( self._observation_distribution.event_shape_tensor()) num_steps_ = tf.get_static_value(num_steps) if num_steps_ is not None: self.static_event_shape = tf.TensorShape([ num_steps_ ]).concatenate(self._observation_distribution.event_shape) else: self.static_event_shape = None with tf.control_dependencies(self._runtime_assertions): self.static_batch_shape = tf.broadcast_static_shape( self._initial_distribution.batch_shape, tf.broadcast_static_shape( self._transition_distribution.batch_shape[:-1], self._observation_distribution.batch_shape[:-1])) # pylint: disable=protected-access super(HiddenMarkovModel, self).__init__( dtype=self._observation_distribution.dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name) # pylint: enable=protected-access self._parameters = parameters
def _batch_shape(self): dist, mixture_dist = self.poisson_and_mixture_distributions() return tf.broadcast_static_shape( dist.batch_shape, mixture_dist.logits.shape)[:-1]
def _batch_shape(self): return tensorshape_util.with_rank_at_least( tf.broadcast_static_shape( tf.TensorShape(self._total_count.shape).concatenate([1]), tf.TensorShape(self._concentration.shape)), 1)[:-1]
def _batch_shape(self): return tf.broadcast_static_shape(self.low.shape, self.high.shape)
def _batch_shape(self): return tf.broadcast_static_shape( (self._probs if self._logits is None else self._logits).shape[:-1], self.total_count.shape)