def _check_shapes(self): """Static check that shapes are compatible.""" # Broadcast shape also checks that u and v are compatible. uv_shape = array_ops.broadcast_static_shape( self.u.get_shape(), self.v.get_shape()) batch_shape = array_ops.broadcast_static_shape( self.base_operator.batch_shape, uv_shape[:-2]) self.base_operator.domain_dimension.assert_is_compatible_with( uv_shape[-2]) if self._diag_update is not None: uv_shape[-1].assert_is_compatible_with(self._diag_update.get_shape()[-1]) array_ops.broadcast_static_shape( batch_shape, self._diag_update.get_shape()[:-1])
def _reduce_jacobian_det_over_event( self, y, ildj, min_event_ndims, event_ndims): """Reduce jacobian over event_ndims - min_event_ndims.""" if not self.is_constant_jacobian: return math_ops.reduce_sum( ildj, self._get_event_reduce_dims(min_event_ndims, event_ndims)) # In this case, we need to tile the jacobian over the event and reduce. y_rank = array_ops.rank(y) y_shape = array_ops.shape(y)[ y_rank - event_ndims : y_rank - min_event_ndims] ones = array_ops.ones(y_shape, ildj.dtype) reduced_ildj = math_ops.reduce_sum( ones * ildj, axis=self._get_event_reduce_dims(min_event_ndims, event_ndims)) # The multiplication by ones can change the inferred static shape so we try # to recover as much as possible. if (isinstance(event_ndims, int) and y.get_shape().ndims and ildj.get_shape().ndims): y_shape = y.get_shape() y_shape = y_shape[y_shape.ndims - event_ndims : y_shape.ndims - min_event_ndims] ildj_shape = ildj.get_shape() broadcast_shape = array_ops.broadcast_static_shape( ildj_shape, y_shape) reduced_ildj.set_shape( broadcast_shape[: broadcast_shape.ndims - ( event_ndims - min_event_ndims)]) return reduced_ildj
def benchmarkBatchMatMulBroadcast(self): for (a_shape, b_shape) in self.shape_pairs: with compat.forward_compatibility_horizon(2019, 4, 26): with ops.Graph().as_default(), \ session.Session(config=benchmark.benchmark_config()) as sess, \ ops.device("/cpu:0"): matrix_a = variables.Variable( GetRandomNormalInput(a_shape, np.float32)) matrix_b = variables.Variable( GetRandomNormalInput(b_shape, np.float32)) variables.global_variables_initializer().run() # Use batch matmul op's internal broadcasting. self.run_op_benchmark( sess, math_ops.matmul(matrix_a, matrix_b), min_iters=50, name="batch_matmul_cpu_{}_{}".format(a_shape, b_shape)) # Manually broadcast the input matrices using the broadcast_to op. broadcasted_batch_shape = array_ops.broadcast_static_shape( matrix_a.shape[:-2], matrix_b.shape[:-2]) broadcasted_a_shape = broadcasted_batch_shape.concatenate( matrix_a.shape[-2:]) broadcasted_b_shape = broadcasted_batch_shape.concatenate( matrix_b.shape[-2:]) self.run_op_benchmark( sess, math_ops.matmul( array_ops.broadcast_to(matrix_a, broadcasted_a_shape), array_ops.broadcast_to(matrix_b, broadcasted_b_shape)), min_iters=50, name="batch_matmul_manual_broadcast_cpu_{}_{}".format( a_shape, b_shape))
def _possibly_broadcast_batch_shape(self, x): """Return 'x', possibly after broadcasting the leading dimensions.""" # If we have no batch shape, our batch shape broadcasts with everything! if self._batch_shape_arg is None: return x # Static attempt: # If we determine that no broadcast is necessary, pass x through # If we need a broadcast, add to an array of zeros. # # special_shape is the shape that, when broadcast with x's shape, will give # the correct broadcast_shape. Note that # We have already verified the second to last dimension of self.shape # matches x's shape in assert_compatible_matrix_dimensions. # Also, the final dimension of 'x' can have any shape. # Therefore, the final two dimensions of special_shape are 1's. special_shape = self.batch_shape.concatenate([1, 1]) bshape = array_ops.broadcast_static_shape(x.get_shape(), special_shape) if special_shape.is_fully_defined(): # bshape.is_fully_defined iff special_shape.is_fully_defined. if bshape == x.get_shape(): return x # Use the built in broadcasting of addition. zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) return x + zeros # Dynamic broadcast: # Always add to an array of zeros, rather than using a "cond", since a # cond would require copying data from GPU --> CPU. special_shape = array_ops.concat((self.batch_shape_dynamic(), [1, 1]), 0) zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) return x + zeros
def _reduce_jacobian_det_over_event( self, y, ildj, min_event_ndims, event_ndims): """Reduce jacobian over event_ndims - min_event_ndims.""" # In this case, we need to tile the Jacobian over the event and reduce. y_rank = array_ops.rank(y) y_shape = array_ops.shape(y)[ y_rank - event_ndims : y_rank - min_event_ndims] ones = array_ops.ones(y_shape, ildj.dtype) reduced_ildj = math_ops.reduce_sum( ones * ildj, axis=self._get_event_reduce_dims(min_event_ndims, event_ndims)) # The multiplication by ones can change the inferred static shape so we try # to recover as much as possible. event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) if (event_ndims_ is not None and y.shape.ndims is not None and ildj.shape.ndims is not None): y_shape = y.shape[y.shape.ndims - event_ndims_ : y.shape.ndims - min_event_ndims] broadcast_shape = array_ops.broadcast_static_shape(ildj.shape, y_shape) reduced_ildj.set_shape( broadcast_shape[: broadcast_shape.ndims - ( event_ndims_ - min_event_ndims)]) return reduced_ildj
def _broadcast_shape(shape1, shape2): """Convenience function which statically broadcasts shape when possible.""" if (tensor_util.constant_value(shape1) is not None and tensor_util.constant_value(shape2) is not None): return array_ops.broadcast_static_shape( tensor_shape.TensorShape(tensor_util.constant_value(shape1)), tensor_shape.TensorShape(tensor_util.constant_value(shape2))) return array_ops.broadcast_dynamic_shape(shape1, shape2)
def _static_check_for_broadcastable_batch_shape(operators): """ValueError if operators determined to have non-broadcastable shapes.""" if len(operators) < 2: return # This will fail if they cannot be broadcast together. batch_shape = operators[0].batch_shape for op in operators[1:]: batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
def _finish_log_prob_for_one_fiber(self, y, x, ildj): """Finish computation of log_prob on one element of the inverse image.""" x = self._maybe_rotate_dims(x, rotate_right=True) log_prob = self.distribution.log_prob(x) if self._is_maybe_event_override: log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices) log_prob += math_ops.cast(ildj, log_prob.dtype) if self._is_maybe_event_override: log_prob.set_shape(array_ops.broadcast_static_shape( y.get_shape().with_rank_at_least(1)[:-1], self.batch_shape)) return log_prob
def _finish_prob_for_one_fiber(self, y, x, ildj, event_ndims): """Finish computation of prob on one element of the inverse image.""" x = self._maybe_rotate_dims(x, rotate_right=True) prob = self.distribution.prob(x) if self._is_maybe_event_override: prob = math_ops.reduce_prod(prob, self._reduce_event_indices) prob *= math_ops.exp(math_ops.cast(ildj, prob.dtype)) if self._is_maybe_event_override and isinstance(event_ndims, int): prob.set_shape(array_ops.broadcast_static_shape( y.get_shape().with_rank_at_least(1)[:-event_ndims], self.batch_shape)) return prob
def _prob(self, y): x, ildj = self.bijector.inverse_and_inverse_log_det_jacobian(y) x = self._maybe_rotate_dims(x, rotate_right=True) prob = self.distribution.prob(x) if self._is_maybe_event_override: prob = math_ops.reduce_prod(prob, self._reduce_event_indices) prob *= math_ops.exp(ildj) if self._is_maybe_event_override: prob.set_shape(array_ops.broadcast_static_shape( y.get_shape().with_rank_at_least(1)[:-1], self.batch_shape)) return prob
def _log_prob(self, y): x = self.bijector.inverse(y) ildj = self.bijector.inverse_log_det_jacobian(y) x = self._maybe_rotate_dims(x, rotate_right=True) log_prob = self.distribution.log_prob(x) if self._is_maybe_event_override: log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices) log_prob = ildj + log_prob if self._is_maybe_event_override: log_prob.set_shape(array_ops.broadcast_static_shape( y.get_shape().with_rank_at_least(1)[:-1], self.batch_shape)) return log_prob
def _log_prob(self, y): # For caching to work, it is imperative that the bijector is the first to # modify the input. x = self.bijector.inverse(y) ildj = self.bijector.inverse_log_det_jacobian(y) x = self._maybe_rotate_dims(x, rotate_right=True) log_prob = self.distribution.log_prob(x) if self._is_maybe_event_override: log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices) log_prob = ildj + log_prob if self._is_maybe_event_override: log_prob.set_shape(array_ops.broadcast_static_shape( y.get_shape().with_rank_at_least(1)[:-1], self.batch_shape)) return log_prob
def determine_batch_event_shapes(grid, endpoint_affine): """Helper to infer batch_shape and event_shape.""" with ops.name_scope(name="determine_batch_event_shapes"): # grid # shape: [B, k, q] # endpoint_affine # len=k, shape: [B, d, d] batch_shape = grid.shape[:-2] batch_shape_tensor = array_ops.shape(grid)[:-2] event_shape = None event_shape_tensor = None def _set_event_shape(shape, shape_tensor): if event_shape is None: return shape, shape_tensor return (array_ops.broadcast_static_shape(event_shape, shape), array_ops.broadcast_dynamic_shape( event_shape_tensor, shape_tensor)) for aff in endpoint_affine: if aff.shift is not None: batch_shape = array_ops.broadcast_static_shape( batch_shape, aff.shift.shape[:-1]) batch_shape_tensor = array_ops.broadcast_dynamic_shape( batch_shape_tensor, array_ops.shape(aff.shift)[:-1]) event_shape, event_shape_tensor = _set_event_shape( aff.shift.shape[-1:], array_ops.shape(aff.shift)[-1:]) if aff.scale is not None: batch_shape = array_ops.broadcast_static_shape( batch_shape, aff.scale.batch_shape) batch_shape_tensor = array_ops.broadcast_dynamic_shape( batch_shape_tensor, aff.scale.batch_shape_tensor()) event_shape, event_shape_tensor = _set_event_shape( tensor_shape.TensorShape([aff.scale.range_dimension]), aff.scale.range_dimension_tensor()[array_ops.newaxis]) return batch_shape, batch_shape_tensor, event_shape, event_shape_tensor
def prefer_static_broadcast_shape( shape1, shape2, name="prefer_static_broadcast_shape"): """Convenience function which statically broadcasts shape when possible. Args: shape1: `1-D` integer `Tensor`. Already converted to tensor! shape2: `1-D` integer `Tensor`. Already converted to tensor! name: A string name to prepend to created ops. Returns: The broadcast shape, either as `TensorShape` (if broadcast can be done statically), or as a `Tensor`. """ with ops.name_scope(name, values=[shape1, shape2]): if (tensor_util.constant_value(shape1) is not None and tensor_util.constant_value(shape2) is not None): return array_ops.broadcast_static_shape( tensor_shape.TensorShape(tensor_util.constant_value(shape1)), tensor_shape.TensorShape(tensor_util.constant_value(shape2))) return array_ops.broadcast_dynamic_shape(shape1, shape2)
def prefer_static_broadcast_shape( shape1, shape2, name="prefer_static_broadcast_shape"): """Convenience function which statically broadcasts shape when possible. Args: shape1: `1-D` integer `Tensor`. Already converted to tensor! shape2: `1-D` integer `Tensor`. Already converted to tensor! name: A string name to prepend to created ops. Returns: The broadcast shape, either as `TensorShape` (if broadcast can be done statically), or as a `Tensor`. """ with ops.name_scope(name, values=[shape1, shape2]): def make_shape_tensor(x): return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32) def get_tensor_shape(s): if isinstance(s, tensor_shape.TensorShape): return s s_ = tensor_util.constant_value(make_shape_tensor(s)) if s_ is not None: return tensor_shape.TensorShape(s_) return None def get_shape_tensor(s): if not isinstance(s, tensor_shape.TensorShape): return make_shape_tensor(s) if s.is_fully_defined(): return make_shape_tensor(s.as_list()) raise ValueError("Cannot broadcast from partially " "defined `TensorShape`.") shape1_ = get_tensor_shape(shape1) shape2_ = get_tensor_shape(shape2) if shape1_ is not None and shape2_ is not None: return array_ops.broadcast_static_shape(shape1_, shape2_) shape1_ = get_shape_tensor(shape1) shape2_ = get_shape_tensor(shape2) return array_ops.broadcast_dynamic_shape(shape1_, shape2_)
def get_broadcast_shape(*tensors): """Get broadcast shape as a Python list of integers (preferred) or `Tensor`. Args: *tensors: One or more `Tensor` objects (already converted!). Returns: broadcast shape: Python list (if shapes determined statically), otherwise an `int32` `Tensor`. """ # Try static. s_shape = tensors[0].shape for t in tensors[1:]: s_shape = array_ops.broadcast_static_shape(s_shape, t.shape) if s_shape.is_fully_defined(): return s_shape.as_list() # Fallback on dynamic. d_shape = array_ops.shape(tensors[0]) for t in tensors[1:]: d_shape = array_ops.broadcast_dynamic_shape(d_shape, array_ops.shape(t)) return d_shape
def _batch_shape(self): return array_ops.broadcast_static_shape( self.distribution.batch_shape, self.mixture_distribution.logits.shape)[:-1]
def _batch_shape(self): return array_ops.broadcast_static_shape(self.low.get_shape(), self.high.get_shape())
def _get_batch_shape(self): return array_ops.broadcast_static_shape( array_ops.broadcast_static_shape( self.df.get_shape(), self.mu.get_shape()), self.sigma.get_shape())
def _get_batch_shape(self): return array_ops.broadcast_static_shape( array_ops.broadcast_static_shape(self.df.get_shape(), self.mu.get_shape()), self.sigma.get_shape())
def _batch_shape(self): return array_ops.broadcast_static_shape( self.total_count.get_shape(), self.probs.get_shape())
def broadcast_matrix_batch_dims(batch_matrices, name=None): """Broadcast leading dimensions of zero or more [batch] matrices. Example broadcasting one batch dim of two simple matrices. ```python x = [[1, 2], [3, 4]] # Shape [2, 2], no batch dims y = [[[1]]] # Shape [1, 1, 1], 1 batch dim of shape [1] x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) x_bc ==> [[[1, 2], [3, 4]]] # Shape [1, 2, 2], 1 batch dim of shape [1]. y_bc ==> same as y ``` Example broadcasting many batch dims ```python x = tf.random.normal(shape=(2, 3, 1, 4, 4)) y = tf.random.normal(shape=(1, 3, 2, 5, 5)) x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) x_bc.shape ==> (2, 3, 2, 4, 4) y_bc.shape ==> (2, 3, 2, 5, 5) ``` Args: batch_matrices: Iterable of `Tensor`s, each having two or more dimensions. name: A string name to prepend to created ops. Returns: bcast_matrices: List of `Tensor`s, with `bcast_matrices[i]` containing the values from `batch_matrices[i]`, with possibly broadcast batch dims. Raises: ValueError: If any input `Tensor` is statically determined to have less than two dimensions. """ with ops.name_scope(name or "broadcast_matrix_batch_dims", values=batch_matrices): check_ops.assert_proper_iterable(batch_matrices) batch_matrices = list(batch_matrices) for i, mat in enumerate(batch_matrices): batch_matrices[i] = ops.convert_to_tensor(mat) assert_is_batch_matrix(batch_matrices[i]) if len(batch_matrices) < 2: return batch_matrices # Try static broadcasting. # bcast_batch_shape is the broadcast batch shape of ALL matrices. # E.g. if batch_matrices = [x, y], with # x.shape = [2, j, k] (batch shape = [2]) # y.shape = [3, 1, l, m] (batch shape = [3, 1]) # ==> bcast_batch_shape = [3, 2] bcast_batch_shape = batch_matrices[0].shape[:-2] for mat in batch_matrices[1:]: bcast_batch_shape = array_ops.broadcast_static_shape( bcast_batch_shape, mat.shape[:-2]) if bcast_batch_shape.is_fully_defined(): for i, mat in enumerate(batch_matrices): if mat.shape[:-2] != bcast_batch_shape: bcast_shape = array_ops.concat([ bcast_batch_shape.as_list(), array_ops.shape(mat)[-2:] ], axis=0) batch_matrices[i] = array_ops.broadcast_to( mat, bcast_shape) return batch_matrices # Since static didn't work, do dynamic, which always copies data. bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2] for mat in batch_matrices[1:]: bcast_batch_shape = array_ops.broadcast_dynamic_shape( bcast_batch_shape, array_ops.shape(mat)[:-2]) for i, mat in enumerate(batch_matrices): batch_matrices[i] = array_ops.broadcast_to( mat, array_ops.concat( [bcast_batch_shape, array_ops.shape(mat)[-2:]], axis=0)) return batch_matrices
def _batch_shape(self): return array_ops.broadcast_static_shape( self.concentration.get_shape(), self.rate.get_shape())
def _kl_brute_force(a, b, name=None): """Batched KL divergence `KL(a || b)` for multivariate Normals. With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and covariance `C_a`, `C_b` respectively, ``` KL(a || b) = 0.5 * ( L - k + T + Q ), L := Log[Det(C_b)] - Log[Det(C_a)] T := trace(C_b^{-1} C_a), Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a), ``` This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient methods for solving systems with `C_b` may be available, a dense version of (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B` is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x` and `y`. Args: a: Instance of `MultivariateNormalLinearOperator`. b: Instance of `MultivariateNormalLinearOperator`. name: (optional) name to use for created ops. Default "kl_mvn". Returns: Batchwise `KL(a || b)`. """ def squared_frobenius_norm(x): """Helper to make KL calculation slightly more readable.""" # http://mathworld.wolfram.com/FrobeniusNorm.html # The gradient of KL[p,q] is not defined when p==q. The culprit is # linalg_ops.norm, i.e., we cannot use the commented out code. # return math_ops.square(linalg_ops.norm(x, ord="fro", axis=[-2, -1])) return math_ops.reduce_sum(math_ops.square(x), axis=[-2, -1]) # TODO(b/35041439): See also b/35040945. Remove this function once LinOp # supports something like: # A.inverse().solve(B).norm(order='fro', axis=[-1, -2]) def is_diagonal(x): """Helper to identify if `LinearOperator` has only a diagonal component.""" return (isinstance(x, linalg.LinearOperatorIdentity) or isinstance(x, linalg.LinearOperatorScaledIdentity) or isinstance(x, linalg.LinearOperatorDiag)) with ops.name_scope(name, "kl_mvn", values=[a.loc, b.loc] + a.scale.graph_parents + b.scale.graph_parents): # Calculation is based on: # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians # and, # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm # i.e., # If Ca = AA', Cb = BB', then # tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A'] # = tr[inv(B) A A' inv(B)'] # = tr[(inv(B) A) (inv(B) A)'] # = sum_{ij} (inv(B) A)_{ij}**2 # = ||inv(B) A||_F**2 # where ||.||_F is the Frobenius norm and the second equality follows from # the cyclic permutation property. if is_diagonal(a.scale) and is_diagonal(b.scale): # Using `stddev` because it handles expansion of Identity cases. b_inv_a = (a.stddev() / b.stddev())[..., array_ops.newaxis] else: b_inv_a = b.scale.solve(a.scale.to_dense()) kl_div = (b.scale.log_abs_determinant() - a.scale.log_abs_determinant() + 0.5 * ( - math_ops.cast(a.scale.domain_dimension_tensor(), a.dtype) + squared_frobenius_norm(b_inv_a) + squared_frobenius_norm(b.scale.solve( (b.mean() - a.mean())[..., array_ops.newaxis])))) kl_div.set_shape(array_ops.broadcast_static_shape( a.batch_shape, b.batch_shape)) return kl_div
def _batch_shape(self): return array_ops.broadcast_static_shape( self.low.get_shape(), self.high.get_shape())
def _get_batch_shape(self): return array_ops.broadcast_static_shape( self.n.get_shape(), self.p.get_shape())
def _batch_shape(self): return array_ops.broadcast_static_shape( self.total_count.get_shape(), self.probs.get_shape())
def _set_event_shape(shape, shape_tensor): if event_shape is None: return shape, shape_tensor return (array_ops.broadcast_static_shape(event_shape, shape), array_ops.broadcast_dynamic_shape(event_shape_tensor, shape_tensor))
def _get_batch_shape(self): return array_ops.broadcast_static_shape( self.alpha.get_shape(), self.beta.get_shape())
def _shape(self): # If d_shape = [5, 3], we return [5, 3, 3]. v_shape = array_ops.broadcast_static_shape(self.row.shape, self.col.shape) return v_shape.concatenate(v_shape[-1:])
def broadcast_matrix_batch_dims(batch_matrices, name=None): """Broadcast leading dimensions of zero or more [batch] matrices. Example broadcasting one batch dim of two simple matrices. ```python x = [[1, 2], [3, 4]] # Shape [2, 2], no batch dims y = [[[1]]] # Shape [1, 1, 1], 1 batch dim of shape [1] x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) x_bc ==> [[[1, 2], [3, 4]]] # Shape [1, 2, 2], 1 batch dim of shape [1]. y_bc ==> same as y ``` Example broadcasting many batch dims ```python x = tf.random_normal(shape=(2, 3, 1, 4, 4)) y = tf.random_normal(shape=(1, 3, 2, 5, 5)) x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) x_bc.shape ==> (2, 3, 2, 4, 4) y_bc.shape ==> (2, 3, 2, 5, 5) ``` Args: batch_matrices: Iterable of `Tensor`s, each having two or more dimensions. name: A string name to prepend to created ops. Returns: bcast_matrices: List of `Tensor`s, with `bcast_matricies[i]` containing the values from `batch_matrices[i]`, with possibly broadcast batch dims. Raises: ValueError: If any input `Tensor` is statically determined to have less than two dimensions. """ with ops.name_scope( name or "broadcast_matrix_batch_dims", values=batch_matrices): check_ops.assert_proper_iterable(batch_matrices) batch_matrices = list(batch_matrices) for i, mat in enumerate(batch_matrices): batch_matrices[i] = ops.convert_to_tensor(mat) assert_is_batch_matrix(batch_matrices[i]) if len(batch_matrices) < 2: return batch_matrices # Try static broadcasting. # bcast_batch_shape is the broadcast batch shape of ALL matrices. # E.g. if batch_matrices = [x, y], with # x.shape = [2, j, k] (batch shape = [2]) # y.shape = [3, 1, l, m] (batch shape = [3, 1]) # ==> bcast_batch_shape = [3, 2] bcast_batch_shape = batch_matrices[0].get_shape()[:-2] for mat in batch_matrices[1:]: bcast_batch_shape = array_ops.broadcast_static_shape( bcast_batch_shape, mat.get_shape()[:-2]) if bcast_batch_shape.is_fully_defined(): # The [1, 1] at the end will broadcast with anything. bcast_shape = bcast_batch_shape.concatenate([1, 1]) for i, mat in enumerate(batch_matrices): if mat.get_shape()[:-2] != bcast_batch_shape: batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape) return batch_matrices # Since static didn't work, do dynamic, which always copies data. bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2] for mat in batch_matrices[1:]: bcast_batch_shape = array_ops.broadcast_dynamic_shape( bcast_batch_shape, array_ops.shape(mat)[:-2]) bcast_shape = array_ops.concat([bcast_batch_shape, [1, 1]], axis=0) for i, mat in enumerate(batch_matrices): batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape) return batch_matrices
def _set_event_shape(shape, shape_tensor): if event_shape is None: return shape, shape_tensor return (array_ops.broadcast_static_shape(event_shape, shape), array_ops.broadcast_dynamic_shape( event_shape_tensor, shape_tensor))
def _batch_shape(self): return array_ops.broadcast_static_shape( self.distribution.batch_shape, self.mixture_distribution.logits.shape)[:-1]
def _shape(self): batch_shape = array_ops.broadcast_static_shape( self.base_operator.batch_shape, self.u.get_shape()[:-2]) return batch_shape.concatenate(self.base_operator.shape[-2:])
def _get_batch_shape(self): return array_ops.broadcast_static_shape(self.n.get_shape(), self.p.get_shape())
def _get_batch_shape(self): return array_ops.broadcast_static_shape(self._sigma.get_shape(), self._xi.get_shape())
def _shape(self): batch_shape = array_ops.broadcast_static_shape( self.base_operator.batch_shape, self.u.get_shape()[:-2]) return batch_shape.concatenate(self.base_operator.shape[-2:])
def _batch_shape(self): return array_ops.broadcast_static_shape(self.concentration.get_shape(), self.rate.get_shape())
def _kl_brute_force(a, b, name=None): """Batched KL divergence `KL(a || b)` for multivariate Normals. With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and covariance `C_a`, `C_b` respectively, ``` KL(a || b) = 0.5 * ( L - k + T + Q ), L := Log[Det(C_b)] - Log[Det(C_a)] T := trace(C_b^{-1} C_a), Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a), ``` This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient methods for solving systems with `C_b` may be available, a dense version of (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B` is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x` and `y`. Args: a: Instance of `MultivariateNormalLinearOperator`. b: Instance of `MultivariateNormalLinearOperator`. name: (optional) name to use for created ops. Default "kl_mvn". Returns: Batchwise `KL(a || b)`. """ def squared_frobenius_norm(x): """Helper to make KL calculation slightly more readable.""" # http://mathworld.wolfram.com/FrobeniusNorm.html return math_ops.square(linalg_ops.norm(x, ord="fro", axis=[-2, -1])) # TODO(b/35041439): See also b/35040945. Remove this function once LinOp # supports something like: # A.inverse().solve(B).norm(order='fro', axis=[-1, -2]) def is_diagonal(x): """Helper to identify if `LinearOperator` has only a diagonal component.""" return (isinstance(x, linalg.LinearOperatorIdentity) or isinstance(x, linalg.LinearOperatorScaledIdentity) or isinstance(x, linalg.LinearOperatorDiag)) with ops.name_scope(name, "kl_mvn", values=[a.loc, b.loc] + a.scale.graph_parents + b.scale.graph_parents): # Calculation is based on: # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians # and, # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm # i.e., # If Ca = AA', Cb = BB', then # tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A'] # = tr[inv(B) A A' inv(B)'] # = tr[(inv(B) A) (inv(B) A)'] # = sum_{ij} (inv(B) A)_{ij}**2 # = ||inv(B) A||_F**2 # where ||.||_F is the Frobenius norm and the second equality follows from # the cyclic permutation property. if is_diagonal(a.scale) and is_diagonal(b.scale): # Using `stddev` because it handles expansion of Identity cases. b_inv_a = (a.stddev() / b.stddev())[..., array_ops.newaxis] else: b_inv_a = b.scale.solve(a.scale.to_dense()) kl_div = (b.scale.log_abs_determinant() - a.scale.log_abs_determinant() + 0.5 * (-math_ops.cast(a.scale.domain_dimension_tensor(), a.dtype) + squared_frobenius_norm(b_inv_a) + squared_frobenius_norm( b.scale.solve( (b.mean() - a.mean())[..., array_ops.newaxis])))) kl_div.set_shape( array_ops.broadcast_static_shape(a.batch_shape, b.batch_shape)) return kl_div
def _get_batch_shape(self): return array_ops.broadcast_static_shape( self.alpha.get_shape(), self.beta.get_shape())
def _batch_shape(self): return array_ops.broadcast_static_shape( self.loc.get_shape(), self.scale.get_shape())
def _batch_shape(self): return array_ops.broadcast_static_shape(self.loc.get_shape(), self.scale.get_shape())