def _swap_m_with_i(vecs, m, i): """Swaps `m` and `i` on axis -1. (Helper for pivoted_cholesky.) Given a batch of int64 vectors `vecs`, scalar index `m`, and compatibly shaped per-vector indices `i`, this function swaps elements `m` and `i` in each vector. For the use-case below, these are permutation vectors. Args: vecs: Vectors on which we perform the swap, int64 `Tensor`. m: Scalar int64 `Tensor`, the index into which the `i`th element is going. i: Batch int64 `Tensor`, shaped like vecs.shape[:-1] + [1]; the index into which the `m`th element is going. Returns: vecs: The updated vectors. """ vecs = tf.convert_to_tensor(vecs, dtype=tf.int64, name='vecs') m = tf.convert_to_tensor(m, dtype=tf.int64, name='m') i = tf.convert_to_tensor(i, dtype=tf.int64, name='i') trailing_elts = tf.broadcast_to( tf.range(m + 1, prefer_static.shape(vecs, out_type=tf.int64)[-1]), prefer_static.shape(vecs[..., m + 1:])) trailing_elts = tf.where(tf.equal(trailing_elts, i), tf.gather(vecs, [m], axis=-1), vecs[..., m + 1:]) # TODO(bjp): Could we use tensor_scatter_nd_update? vecs_shape = vecs.shape vecs = tf.concat([ vecs[..., :m], tf.gather(vecs, i, batch_dims=int(prefer_static.rank(vecs)) - 1), trailing_elts ], axis=-1) tensorshape_util.set_shape(vecs, vecs_shape) return vecs
def _log_prob(self, x, **kwargs): batch_ndims = prefer_static.rank_from_shape( self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_ndims = prefer_static.rank_from_shape(self.sample_shape) event_ndims = prefer_static.rank_from_shape( self.distribution.event_shape_tensor, self.distribution.event_shape) ndims = prefer_static.rank(x) # (1) Expand x's dims. d = ndims - batch_ndims - extra_sample_ndims - event_ndims x = tf.reshape(x, shape=tf.pad( tf.shape(x), paddings=[[prefer_static.maximum(0, -d), 0]], constant_values=1)) sample_ndims = prefer_static.maximum(0, d) # (2) Transpose x's dims. sample_dims = prefer_static.range(0, sample_ndims) batch_dims = prefer_static.range(sample_ndims, sample_ndims + batch_ndims) extra_sample_dims = prefer_static.range( sample_ndims + batch_ndims, sample_ndims + batch_ndims + extra_sample_ndims) event_dims = prefer_static.range( sample_ndims + batch_ndims + extra_sample_ndims, ndims) perm = prefer_static.concat( [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0) x = tf.transpose(a=x, perm=perm) # (3) Compute x's log_prob. lp = self.distribution.log_prob(x, **kwargs) # (4) Make the final reduction in x. axis = prefer_static.range(sample_ndims, sample_ndims + extra_sample_ndims) return tf.reduce_sum(lp, axis=axis)
def body(m, pchol, perm, matrix_diag): """Body of a single `tf.while_loop` iteration.""" # Here is roughly a numpy, non-batched version of what's going to happen. # (See also Algorithm 1 of Harbrecht et al.) # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m # 2: maxval = matrix_diag[perm][maxi] # 3: perm[m], perm[maxi] = perm[maxi], perm[m] # 4: row = matrix[perm[m]][perm[m + 1:]] # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2) # 6: pivot = np.sqrt(maxval); row /= pivot # 7: row = np.concatenate([[[pivot]], row], -1) # 8: matrix_diag[perm[m:]] -= row**2 # 9: pchol[m, perm[m:]] = row # Find the maximal position of the (remaining) permuted diagonal. # Steps 1, 2 above. permuted_diag = batch_gather(matrix_diag, perm[..., m:]) maxi = tf.argmax(permuted_diag, axis=-1, output_type=tf.int64)[..., tf.newaxis] maxval = batch_gather(permuted_diag, maxi) maxi = maxi + m maxval = maxval[..., 0] # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above. perm = _swap_m_with_i(perm, m, maxi) # Step 4. row = batch_gather(matrix, perm[..., m:m + 1], axis=-2) row = batch_gather(row, perm[..., m + 1:]) # Step 5. prev_rows = pchol[..., :m, :] prev_rows_perm_m_onward = batch_gather(prev_rows, perm[..., m + 1:]) prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1]) row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col, axis=-2)[..., tf.newaxis, :] # Step 6. pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis] # Step 7. row = tf.concat([pivot, row / pivot], axis=-1) # TODO(b/130899118): Pad grad fails with int64 paddings. # Step 8. paddings = tf.concat([ tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32), [[tf.cast(m, tf.int32), 0]] ], axis=0) diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :] reverse_perm = _invert_permutation(perm) matrix_diag -= batch_gather(diag_update, reverse_perm) # Step 9. row = tf.pad(row, paddings=paddings) # TODO(bjp): Defer the reverse permutation all-at-once at the end? row = batch_gather(row, reverse_perm) pchol_shape = pchol.shape pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]], axis=-2) tensorshape_util.set_shape(pchol, pchol_shape) return m + 1, pchol, perm, matrix_diag
def _maybe_rotate_dims(self, x, rotate_right=False): """Helper which rolls left event_dims left or right event_dims right.""" needs_rotation_const = tf.get_static_value(self._needs_rotation) if needs_rotation_const is not None and not needs_rotation_const: return x ndims = prefer_static.rank(x) n = (ndims - self._rotate_ndims) if rotate_right else self._rotate_ndims perm = prefer_static.concat( [prefer_static.range(n, ndims), prefer_static.range(0, n)], axis=0) return tf.transpose(a=x, perm=perm)
def __init__(self, samples, event_ndims=0, validate_args=False, allow_nan_stats=True, name='Empirical'): """Initialize `Empirical` distributions. Args: samples: Numeric `Tensor` of shape [B1, ..., Bk, S, E1, ..., En]`, `k, n >= 0`. Samples or batches of samples on which the distribution is based. The first `k` dimensions index into a batch of independent distributions. Length of `S` dimension determines number of samples in each multiset. The last `n` dimension represents samples for each distribution. n is specified by argument event_ndims. event_ndims: Python `int32`, default `0`. number of dimensions for each event. When `0` this distribution has scalar samples. When `1` this distribution has vector-like samples. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value `NaN` to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if the rank of `samples` is statically known and less than event_ndims + 1. """ parameters = dict(locals()) with tf.name_scope(name): self._samples = tensor_util.convert_nonref_to_tensor(samples) dtype = dtype_util.common_dtype( [self._samples], dtype_hint=self._samples.dtype) self._event_ndims = event_ndims # Note: this tf.rank call affects the graph, but is ok in `__init__` # because we don't expect shapes (or ranks) to be runtime-variable, nor # ever need to differentiate with respect to them. samples_rank = prefer_static.rank(self.samples) self._samples_axis = samples_rank - self._event_ndims - 1 super(Empirical, self).__init__( dtype=dtype, reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def _inverse(self, y): ndims = prefer_static.rank(y) shifted_y = tf.pad( tf.slice( y, tf.zeros(ndims, dtype=tf.int32), prefer_static.shape(y) - tf.one_hot(ndims + self.axis, ndims, dtype=tf.int32) ), # Remove the last entry of y in the chosen dimension. paddings=tf.one_hot( tf.one_hot(ndims + self.axis, ndims, on_value=0, off_value=-1), 2, dtype=tf.int32 ) # Insert zeros at the beginning of the chosen dimension. ) return y - shifted_y
def _forward(self, x): x = tf.convert_to_tensor(x, name='x') batch_shape = prefer_static.shape(x)[:-1] # Pad zeros on the top row and right column. y = fill_triangular.FillTriangular().forward(x) rank = prefer_static.rank(y) paddings = tf.concat([ tf.zeros(shape=(rank - 2, 2), dtype=tf.int32), tf.constant([[1, 0], [0, 1]], dtype=tf.int32) ], axis=0) y = tf.pad(y, paddings) # Set diagonal to 1s. n = prefer_static.shape(y)[-1] diag = tf.ones(tf.concat([batch_shape, [n]], axis=-1), dtype=x.dtype) y = tf.linalg.set_diag(y, diag) # Normalize each row to have Euclidean (L2) norm 1. y /= tf.norm(y, axis=-1)[..., tf.newaxis] return y
def fill_triangular(x, upper=False, name=None): """Creates a (batch of) triangular matrix from a vector of inputs. Created matrix can be lower- or upper-triangular. (It is more efficient to create the matrix as upper or lower, rather than transpose.) Triangular matrix elements are filled in a clockwise spiral. See example, below. If `x.shape` is `[b1, b2, ..., bB, d]` then the output shape is `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`. Example: ```python fill_triangular([1, 2, 3, 4, 5, 6]) # ==> [[4, 0, 0], # [6, 5, 0], # [3, 2, 1]] fill_triangular([1, 2, 3, 4, 5, 6], upper=True) # ==> [[1, 2, 3], # [0, 5, 6], # [0, 0, 4]] ``` The key trick is to create an upper triangular matrix by concatenating `x` and a tail of itself, then reshaping. Suppose that we are filling the upper triangle of an `n`-by-`n` matrix `M` from a vector `x`. The matrix `M` contains n**2 entries total. The vector `x` contains `n * (n+1) / 2` entries. For concreteness, we'll consider `n = 5` (so `x` has `15` entries and `M` has `25`). We'll concatenate `x` and `x` with the first (`n = 5`) elements removed and reversed: ```python x = np.arange(15) + 1 xc = np.concatenate([x, x[5:][::-1]]) # ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13, # 12, 11, 10, 9, 8, 7, 6]) # (We add one to the arange result to disambiguate the zeros below the # diagonal of our upper-triangular matrix from the first entry in `x`.) # Now, when reshapedlay this out as a matrix: y = np.reshape(xc, [5, 5]) # ==> array([[ 1, 2, 3, 4, 5], # [ 6, 7, 8, 9, 10], # [11, 12, 13, 14, 15], # [15, 14, 13, 12, 11], # [10, 9, 8, 7, 6]]) # Finally, zero the elements below the diagonal: y = np.triu(y, k=0) # ==> array([[ 1, 2, 3, 4, 5], # [ 0, 7, 8, 9, 10], # [ 0, 0, 13, 14, 15], # [ 0, 0, 0, 12, 11], # [ 0, 0, 0, 0, 6]]) ``` From this example we see that the resuting matrix is upper-triangular, and contains all the entries of x, as desired. The rest is details: - If `n` is even, `x` doesn't exactly fill an even number of rows (it fills `n / 2` rows and half of an additional row), but the whole scheme still works. - If we want a lower triangular matrix instead of an upper triangular, we remove the first `n` elements from `x` rather than from the reversed `x`. For additional comparisons, a pure numpy version of this function can be found in `distribution_util_test.py`, function `_fill_triangular`. Args: x: `Tensor` representing lower (or upper) triangular elements. upper: Python `bool` representing whether output matrix should be upper triangular (`True`) or lower triangular (`False`, default). name: Python `str`. The name to give this op. Returns: tril: `Tensor` with lower (or upper) triangular elements filled from `x`. Raises: ValueError: if `x` cannot be mapped to a triangular matrix. """ with tf.name_scope(name or 'fill_triangular'): x = tf.convert_to_tensor(x, name='x') m = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(x.shape, 1)[-1]) if m is not None: # Formula derived by solving for n: m = n(n+1)/2. m = np.int32(m) n = np.sqrt(0.25 + 2. * m) - 0.5 if n != np.floor(n): raise ValueError( 'Input right-most shape ({}) does not ' 'correspond to a triangular matrix.'.format(m)) n = np.int32(n) static_final_shape = tensorshape_util.concatenate( x.shape[:-1], [n, n]) else: m = tf.shape(x)[-1] # For derivation, see above. Casting automatically lops off the 0.5, so we # omit it. We don't validate n is an integer because this has # graph-execution cost; an error will be thrown from the reshape, below. n = tf.cast(tf.sqrt(0.25 + tf.cast(2 * m, dtype=tf.float32)), dtype=tf.int32) static_final_shape = tensorshape_util.concatenate( tensorshape_util.with_rank_at_least(x.shape, 1)[:-1], [None, None]) # Try it out in numpy: # n = 3 # x = np.arange(n * (n + 1) / 2) # m = x.shape[0] # n = np.int32(np.sqrt(.25 + 2 * m) - .5) # x_tail = x[(m - (n**2 - m)):] # np.concatenate([x_tail, x[::-1]], 0).reshape(n, n) # lower # # ==> array([[3, 4, 5], # [5, 4, 3], # [2, 1, 0]]) # np.concatenate([x, x_tail[::-1]], 0).reshape(n, n) # upper # # ==> array([[0, 1, 2], # [3, 4, 5], # [5, 4, 3]]) # # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't # correctly handle `m == n == 1`. Hence, we do nonnegative indexing. # Furthermore observe that: # m - (n**2 - m) # = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2) # = 2 (n**2 / 2 + n / 2) - n**2 # = n**2 + n - n**2 # = n ndims = prefer_static.rank(x) if upper: x_list = [x, tf.reverse(x[..., n:], axis=[ndims - 1])] else: x_list = [x[..., n:], tf.reverse(x, axis=[ndims - 1])] new_shape = (tensorshape_util.as_list(static_final_shape) if tensorshape_util.is_fully_defined(static_final_shape) else tf.concat([tf.shape(x)[:-1], [n, n]], axis=0)) x = tf.reshape(tf.concat(x_list, axis=-1), new_shape) x = tf.linalg.band_part(x, num_lower=(0 if upper else -1), num_upper=(-1 if upper else 0)) tensorshape_util.set_shape(x, static_final_shape) return x
def fill_triangular_inverse(x, upper=False, name=None): """Creates a vector from a (batch of) triangular matrix. The vector is created from the lower-triangular or upper-triangular portion depending on the value of the parameter `upper`. If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`. Example: ```python fill_triangular_inverse( [[4, 0, 0], [6, 5, 0], [3, 2, 1]]) # ==> [1, 2, 3, 4, 5, 6] fill_triangular_inverse( [[1, 2, 3], [0, 5, 6], [0, 0, 4]], upper=True) # ==> [1, 2, 3, 4, 5, 6] ``` Args: x: `Tensor` representing lower (or upper) triangular elements. upper: Python `bool` representing whether output matrix should be upper triangular (`True`) or lower triangular (`False`, default). name: Python `str`. The name to give this op. Returns: flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower (or upper) triangular elements from `x`. """ with tf.name_scope(name or 'fill_triangular_inverse'): x = tf.convert_to_tensor(x, name='x') n = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(x.shape, 2)[-1]) if n is not None: n = np.int32(n) m = np.int32((n * (n + 1)) // 2) static_final_shape = tensorshape_util.concatenate( x.shape[:-2], [m]) else: n = tf.shape(x)[-1] m = (n * (n + 1)) // 2 static_final_shape = tensorshape_util.concatenate( tensorshape_util.with_rank_at_least(x.shape, 2)[:-2], [None]) ndims = prefer_static.rank(x) if upper: initial_elements = x[..., 0, :] triangular_portion = x[..., 1:, :] else: initial_elements = tf.reverse(x[..., -1, :], axis=[ndims - 2]) triangular_portion = x[..., :-1, :] rotated_triangular_portion = tf.reverse(tf.reverse(triangular_portion, axis=[ndims - 1]), axis=[ndims - 2]) consolidated_matrix = triangular_portion + rotated_triangular_portion end_sequence = tf.reshape( consolidated_matrix, tf.concat([tf.shape(x)[:-2], [n * (n - 1)]], axis=0)) y = tf.concat([initial_elements, end_sequence[..., :m - n]], axis=-1) tensorshape_util.set_shape(y, static_final_shape) return y