def _call_reshape_input_output(self, fn, x, extra_kwargs=None): """Calls `fn`, appropriately reshaping its input `x` and output.""" # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs` # because it is possible the user provided extra kwargs would itself # have `fn` and/or `x` as a key. with tf.control_dependencies(self._runtime_assertions + self._validate_sample_arg(x)): sample_shape, static_sample_shape = self._sample_shape(x) old_shape = tf.concat([ sample_shape, self.distribution.batch_shape_tensor(), self.event_shape_tensor(), ], axis=0) x_reshape = tf.reshape(x, old_shape) result = fn(x_reshape, ** extra_kwargs) if extra_kwargs else fn(x_reshape) new_shape = tf.concat([ sample_shape, self._batch_shape_unexpanded, ], axis=0) result = tf.reshape(result, new_shape) if (tensorshape_util.rank(static_sample_shape) is not None and tensorshape_util.rank(self.batch_shape) is not None): new_shape = tensorshape_util.concatenate( static_sample_shape, self.batch_shape) tensorshape_util.set_shape(result, new_shape) return result
def _sparse_tensor_dense_matmul(sp_a, b, **kwargs): """Returns (batched) matmul of a SparseTensor with a Tensor. Args: sp_a: `SparseTensor` representing a (batch of) matrices. b: `Tensor` representing a (batch of) matrices, with the same batch shape of `sp_a`. The shape must be compatible with the shape of `sp_a` and kwargs. **kwargs: Keyword arguments to `tf.sparse_tensor_dense_matmul`. Returns: product: A dense (batch of) matrix-shaped Tensor of the same batch shape and dtype as `sp_a` and `b`. If `sp_a` or `b` is adjointed through `kwargs` then the shape is adjusted accordingly. """ batch_shape = _get_shape(sp_a)[:-2] # Reshape the SparseTensor into a rank 3 SparseTensors, with the # batch shape flattened to a single dimension. If the batch rank is 0, then # we add a batch dimension of rank 1. sp_a = tf.sparse.reshape(sp_a, tf.concat([[-1], _get_shape(sp_a)[-2:]], axis=0)) # Reshape b to stack the batch dimension along the rows. b = tf.reshape(b, tf.concat([[-1], _get_shape(b)[-1:]], axis=0)) # Convert the SparseTensor to a matrix in block diagonal form with blocks of # matrices [M, N]. This allow us to use tf.sparse_tensor_dense_matmul which # only accepts rank 2 (Sparse)Tensors. out = tf.sparse.sparse_dense_matmul(_sparse_block_diag(sp_a), b, **kwargs) # Finally retrieve the original batch shape from the resulting rank 2 Tensor. # Note that we avoid inferring the final shape from `sp_a` or `b` because we # might have transposed one or both of them. return tf.reshape( out, tf.concat([batch_shape, [-1], _get_shape(out)[-1:]], axis=0))
def cholesky_concat(chol, cols, name=None): """Concatenates `chol @ chol.T` with additional rows and columns. This operation is conceptually identical to: ```python def cholesky_concat_slow(chol, cols): # cols shaped (n + m) x m = z x m mat = tf.matmul(chol, chol, adjoint_b=True) # batch of n x n # Concat columns. mat = tf.concat([mat, cols[..., :tf.shape(mat)[-2], :]], axis=-1) # n x z # Concat rows. mat = tf.concat([mat, tf.linalg.matrix_transpose(cols)], axis=-2) # z x z return tf.linalg.cholesky(mat) ``` but whereas `cholesky_concat_slow` would cost `O(z**3)` work, `cholesky_concat` only costs `O(z**2 + m**3)` work. The resulting (implicit) matrix must be symmetric and positive definite. Thus, the bottom right `m x m` must be self-adjoint, and we do not require a separate `rows` argument (which can be inferred from `conj(cols.T)`). Args: chol: Cholesky decomposition of `mat = chol @ chol.T`. cols: The new columns whose first `n` rows we would like concatenated to the right of `mat = chol @ chol.T`, and whose conjugate transpose we would like concatenated to the bottom of `concat(mat, cols[:n,:])`. A `Tensor` with final dims `(n+m, m)`. The first `n` rows are the top right rectangle (their conjugate transpose forms the bottom left), and the bottom `m x m` is self-adjoint. name: Optional name for this op. Returns: chol_concat: The Cholesky decomposition of: ``` [ [ mat cols[:n, :] ] [ conj(cols.T) ] ] ``` """ with tf.name_scope(name or 'cholesky_extend'): dtype = dtype_util.common_dtype([chol, cols], dtype_hint=tf.float32) chol = tf.convert_to_tensor(chol, name='chol', dtype=dtype) cols = tf.convert_to_tensor(cols, name='cols', dtype=dtype) n = prefer_static.shape(chol)[-1] mat_nm, mat_mm = cols[..., :n, :], cols[..., n:, :] solved_nm = linear_operator_util.matrix_triangular_solve_with_broadcast( chol, mat_nm) lower_right_mm = tf.linalg.cholesky( mat_mm - tf.matmul(solved_nm, solved_nm, adjoint_a=True)) lower_left_mn = tf.math.conj(tf.linalg.matrix_transpose(solved_nm)) out_batch = prefer_static.shape(solved_nm)[:-2] chol = tf.broadcast_to( chol, tf.concat([out_batch, prefer_static.shape(chol)[-2:]], axis=0)) top_right_zeros_nm = tf.zeros_like(solved_nm) return tf.concat([ tf.concat([chol, top_right_zeros_nm], axis=-1), tf.concat([lower_left_mn, lower_right_mm], axis=-1) ], axis=-2)
def _variance(self): with tf.control_dependencies(self._runtime_assertions): probs = self._marginal_hidden_probs() # probs :: num_steps batch_shape num_states means = self._observation_distribution.mean() # means :: observation_batch_shape[:-1] num_states # observation_event_shape means_shape = tf.concat([ self.batch_shape_tensor(), [self._num_states], self._observation_distribution.event_shape_tensor() ], axis=0) means = tf.broadcast_to(means, means_shape) # means :: batch_shape num_states observation_event_shape observation_event_shape = ( self._observation_distribution.event_shape_tensor()) batch_size = tf.reduce_prod(self.batch_shape_tensor()) flat_probs_shape = [self._num_steps, batch_size, self._num_states] flat_means_shape = [ batch_size, 1, self._num_states, tf.reduce_prod(observation_event_shape) ] flat_probs = tf.reshape(probs, flat_probs_shape) # flat_probs :: num_steps batch_size num_states flat_means = tf.reshape(means, flat_means_shape) # flat_means :: batch_size 1 num_states observation_event_size flat_mean = tf.einsum("ijk,jmkl->jiml", flat_probs, flat_means) # flat_mean :: batch_size num_steps 1 observation_event_size variances = self._observation_distribution.variance() variances = tf.broadcast_to(variances, means_shape) # variances :: batch_shape num_states observation_event_shape flat_variances = tf.reshape(variances, flat_means_shape) # flat_variances :: batch_size 1 num_states observation_event_size # For a mixture of n distributions with mixture probabilities # p[i], and where the individual distributions have means and # variances given by mean[i] and var[i], the variance of # the mixture is given by: # # var = sum i=1..n p[i] * ((mean[i] - mean)**2 + var[i]**2) flat_variance = tf.einsum("ijk,jikl->jil", flat_probs, (flat_means - flat_mean)**2 + flat_variances) # flat_variance :: batch_size num_steps observation_event_size unflat_mean_shape = tf.concat([ self.batch_shape_tensor(), [self._num_steps], observation_event_shape ], axis=0) # returns :: batch_shape num_steps observation_event_shape return tf.reshape(flat_variance, unflat_mean_shape)
def body(m, pchol, perm, matrix_diag): """Body of a single `tf.while_loop` iteration.""" # Here is roughly a numpy, non-batched version of what's going to happen. # (See also Algorithm 1 of Harbrecht et al.) # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m # 2: maxval = matrix_diag[perm][maxi] # 3: perm[m], perm[maxi] = perm[maxi], perm[m] # 4: row = matrix[perm[m]][perm[m + 1:]] # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2) # 6: pivot = np.sqrt(maxval); row /= pivot # 7: row = np.concatenate([[[pivot]], row], -1) # 8: matrix_diag[perm[m:]] -= row**2 # 9: pchol[m, perm[m:]] = row # Find the maximal position of the (remaining) permuted diagonal. # Steps 1, 2 above. permuted_diag = batch_gather(matrix_diag, perm[..., m:]) maxi = tf.argmax(permuted_diag, axis=-1, output_type=tf.int64)[..., tf.newaxis] maxval = batch_gather(permuted_diag, maxi) maxi = maxi + m maxval = maxval[..., 0] # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above. perm = _swap_m_with_i(perm, m, maxi) # Step 4. row = batch_gather(matrix, perm[..., m:m + 1], axis=-2) row = batch_gather(row, perm[..., m + 1:]) # Step 5. prev_rows = pchol[..., :m, :] prev_rows_perm_m_onward = batch_gather(prev_rows, perm[..., m + 1:]) prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1]) row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col, axis=-2)[..., tf.newaxis, :] # Step 6. pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis] # Step 7. row = tf.concat([pivot, row / pivot], axis=-1) # TODO(b/130899118): Pad grad fails with int64 paddings. # Step 8. paddings = tf.concat([ tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32), [[tf.cast(m, tf.int32), 0]] ], axis=0) diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :] reverse_perm = _invert_permutation(perm) matrix_diag -= batch_gather(diag_update, reverse_perm) # Step 9. row = tf.pad(row, paddings=paddings) # TODO(bjp): Defer the reverse permutation all-at-once at the end? row = batch_gather(row, reverse_perm) pchol_shape = pchol.shape pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]], axis=-2) tensorshape_util.set_shape(pchol, pchol_shape) return m + 1, pchol, perm, matrix_diag
def _log_prob(self, value): with tf.control_dependencies(self._runtime_assertions): # The argument `value` is a tensor of sequences of observations. # `observation_batch_shape` is the shape of that tensor with the # sequence part removed. # `observation_batch_shape` is then broadcast to the full batch shape # to give the `batch_shape` that defines the shape of the result. observation_tensor_shape = tf.shape(value) observation_batch_shape = observation_tensor_shape[:-1 - self. _underlying_event_rank] # value :: observation_batch_shape num_steps observation_event_shape batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape, self.batch_shape_tensor()) log_init = tf.broadcast_to( self._log_init, tf.concat([batch_shape, [self._num_states]], axis=0)) # log_init :: batch_shape num_states log_transition = self._log_trans # `observation_event_shape` is the shape of each sequence of observations # emitted by the model. observation_event_shape = observation_tensor_shape[ -1 - self._underlying_event_rank:] working_obs = tf.broadcast_to( value, tf.concat([batch_shape, observation_event_shape], axis=0)) # working_obs :: batch_shape observation_event_shape r = self._underlying_event_rank # Move index into sequence of observations to front so we can apply # tf.foldl working_obs = distribution_util.move_dimension( working_obs, -1 - r, 0)[..., tf.newaxis] # working_obs :: num_steps batch_shape underlying_event_shape observation_probs = ( self._observation_distribution.log_prob(working_obs)) def forward_step(log_prev_step, log_prob_observation): return _log_vector_matrix( log_prev_step, log_transition) + log_prob_observation fwd_prob = tf.foldl(forward_step, observation_probs, initializer=log_init) # fwd_prob :: batch_shape num_states log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1) # log_prob :: batch_shape return log_prob
def _compute_quantiles(): """Helper to build quantiles.""" # Omit {0, 1} since they might lead to Inf/NaN. zero = tf.zeros([], dtype=dist.dtype) edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1] # Expand edges so its broadcast across batch dims. edges = tf.reshape( edges, shape=tf.concat( [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0)) quantiles = dist.quantile(edges) # Cyclically permute left by one. perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0) quantiles = tf.transpose(a=quantiles, perm=perm) return quantiles
def _sample_3d(self, n, seed=None): """Specialized inversion sampler for 3D.""" seed = SeedStream(seed, salt='von_mises_fisher_3d') u_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0) z = tf.random.uniform(u_shape, seed=seed(), dtype=self.dtype) # TODO(bjp): Higher-order odd dim analytic CDFs are available in [1], could # be bisected for bounded sampling runtime (i.e. not rejection sampling). # [1]: Inversion sampler via: https://ieeexplore.ieee.org/document/7347705/ # The inversion is: u = 1 + log(z + (1-z)*exp(-2*kappa)) / kappa # We must protect against both kappa and z being zero. safe_conc = tf.where(self.concentration > 0, self.concentration, tf.ones_like(self.concentration)) safe_z = tf.where(z > 0, z, tf.ones_like(z)) safe_u = 1 + tf.reduce_logsumexp( [tf.math.log(safe_z), tf.math.log1p(-safe_z) - 2 * safe_conc], axis=0) / safe_conc # Limit of the above expression as kappa->0 is 2*z-1 u = tf.where(self.concentration > tf.zeros_like(safe_u), safe_u, 2 * z - 1) # Limit of the expression as z->0 is -1. u = tf.where(tf.equal(z, 0), -tf.ones_like(u), u) if not self._allow_nan_stats: u = tf.debugging.check_numerics(u, 'u in _sample_3d') return u[..., tf.newaxis]
def _sample_n(self, n, seed=None): # See https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution or # https://www.jstor.org/stable/2683801 concentration = tf.convert_to_tensor(self.concentration) loc = tf.convert_to_tensor(self.loc) seed = SeedStream(seed, 'inverse_gaussian') shape = tf.concat( [[n], self._batch_shape_tensor(loc=loc, concentration=concentration)], axis=0) sampled_chi2 = (tf.random.normal(shape, mean=0., stddev=1., seed=seed(), dtype=self.dtype))**2. sampled_uniform = tf.random.uniform(shape, minval=0., maxval=1., seed=seed(), dtype=self.dtype) sampled = (loc + loc**2. * sampled_chi2 / (2. * concentration) - loc / (2. * concentration) * (4. * loc * concentration * sampled_chi2 + (loc * sampled_chi2)**2)**0.5) return tf.where(sampled_uniform <= loc / (loc + sampled), sampled, loc**2 / sampled)
def _call_sample_n(self, sample_shape, seed, name, **kwargs): # We override `_call_sample_n` rather than `_sample_n` so we can ensure that # the result of `self.bijector.forward` is not modified (and thus caching # works). with self._name_and_control_scope(name): sample_shape = tf.convert_to_tensor(sample_shape, dtype=tf.int32, name="sample_shape") sample_shape, n = self._expand_sample_shape_to_vector( sample_shape, "sample_shape") distribution_kwargs, bijector_kwargs = self._kwargs_split_fn( kwargs) # First, generate samples. We will possibly generate extra samples in the # event that we need to reinterpret the samples as part of the # event_shape. x = self._sample_n(n, seed, **distribution_kwargs) # Next, we reshape `x` into its final form. We do this prior to the call # to the bijector to ensure that the bijector caching works. batch_event_shape = tf.shape(x)[1:] final_shape = tf.concat([sample_shape, batch_event_shape], 0) x = tf.reshape(x, final_shape) # Finally, we apply the bijector's forward transformation. For caching to # work, it is imperative that this is the last modification to the # returned result. y = self.bijector.forward(x, **bijector_kwargs) y = self._set_sample_static_shape(y, sample_shape) return y
def _mode(self, samples=None): # Samples count can vary by batch member. Use map_fn to compute mode for # each batch separately. def _get_mode(samples): # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count return tf.argmax(count) if samples is None: samples = tf.convert_to_tensor(self._samples) num_samples = self._compute_num_samples(samples) # Flatten samples for each batch. if self._event_ndims == 0: flattened_samples = tf.reshape(samples, [-1, num_samples]) mode_shape = self._batch_shape_tensor(samples) else: event_size = tf.reduce_prod(self._event_shape_tensor(samples)) mode_shape = tf.concat( [self._batch_shape_tensor(samples), self._event_shape_tensor(samples)], axis=0) flattened_samples = tf.reshape(samples, [-1, num_samples, event_size]) indices = tf.map_fn(_get_mode, flattened_samples, dtype=tf.int64) full_indices = tf.stack( [tf.range(tf.shape(indices)[0]), tf.cast(indices, tf.int32)], axis=1) mode = tf.gather_nd(flattened_samples, full_indices) return tf.reshape(mode, mode_shape)
def _call_and_reshape_output(self, fn, event_shape_list=None, static_event_shape_list=None, extra_kwargs=None): """Calls `fn` and appropriately reshapes its output.""" # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs` # because it is possible the user provided extra kwargs would itself # have `fn`, `event_shape_list`, `static_event_shape_list` and/or # `extra_kwargs` as keys. with tf.control_dependencies(self._runtime_assertions): if event_shape_list is None: event_shape_list = [self._event_shape_tensor()] if static_event_shape_list is None: static_event_shape_list = [self.event_shape] new_shape = tf.concat([self._batch_shape_unexpanded] + event_shape_list, axis=0) result = tf.reshape( fn(**extra_kwargs) if extra_kwargs else fn(), new_shape) if (tensorshape_util.rank(self.batch_shape) is not None and tensorshape_util.rank(self.event_shape) is not None): event_shape = tf.TensorShape([]) for rss in static_event_shape_list: event_shape = tensorshape_util.concatenate( event_shape, rss) static_shape = tensorshape_util.concatenate( self.batch_shape, event_shape) tensorshape_util.set_shape(result, static_shape) return result
def _sample_n(self, n, seed=None): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) shape = tf.concat( [[n], self._batch_shape_tensor(low=low, high=high)], 0) samples = tf.random.uniform(shape=shape, dtype=self.dtype, seed=seed) return low + self._range(low=low, high=high) * samples
def _cdf(self, k): # TODO(b/135263541): Improve numerical precision of categorical.cdf. probs = self.probs_parameter() num_categories = self._num_categories(probs) k, probs = _broadcast_cat_event_and_params( k, probs, base_dtype=dtype_util.base_dtype(self.dtype)) # Since the lowest number in the support is 0, any k < 0 should be zero in # the output. should_be_zero = k < 0 # Will use k as an index in the gather below, so clip it to {0,...,K-1}. k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1) batch_shape = tf.shape(k) # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so # to handle the case where the batch shape is dynamic, flatten the batch # dims (so we know batch_dims=1). k_flat_batch = tf.reshape(k, [-1]) probs_flat_batch = tf.reshape( probs, tf.concat(([-1], [num_categories]), axis=0)) cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1), k_flat_batch[..., tf.newaxis], batch_dims=1) cdf = tf.reshape(cdf_flat, shape=batch_shape) zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype)) return tf.where(should_be_zero, zero, cdf)
def _swap_m_with_i(vecs, m, i): """Swaps `m` and `i` on axis -1. (Helper for pivoted_cholesky.) Given a batch of int64 vectors `vecs`, scalar index `m`, and compatibly shaped per-vector indices `i`, this function swaps elements `m` and `i` in each vector. For the use-case below, these are permutation vectors. Args: vecs: Vectors on which we perform the swap, int64 `Tensor`. m: Scalar int64 `Tensor`, the index into which the `i`th element is going. i: Batch int64 `Tensor`, shaped like vecs.shape[:-1] + [1]; the index into which the `m`th element is going. Returns: vecs: The updated vectors. """ vecs = tf.convert_to_tensor(vecs, dtype=tf.int64, name='vecs') m = tf.convert_to_tensor(m, dtype=tf.int64, name='m') i = tf.convert_to_tensor(i, dtype=tf.int64, name='i') trailing_elts = tf.broadcast_to( tf.range(m + 1, prefer_static.shape(vecs, out_type=tf.int64)[-1]), prefer_static.shape(vecs[..., m + 1:])) trailing_elts = tf.where(tf.equal(trailing_elts, i), tf.gather(vecs, [m], axis=-1), vecs[..., m + 1:]) # TODO(bjp): Could we use tensor_scatter_nd_update? vecs_shape = vecs.shape vecs = tf.concat([ vecs[..., :m], tf.gather(vecs, i, batch_dims=int(prefer_static.rank(vecs)) - 1), trailing_elts ], axis=-1) tensorshape_util.set_shape(vecs, vecs_shape) return vecs
def _make_columnar(self, x): """Ensures non-scalar input has at least one column. Example: If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. If `x = 1` then the output is unchanged. Args: x: `Tensor`. Returns: columnar_x: `Tensor` with at least two dimensions. """ if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 1: x = x[tf.newaxis, :] return x shape = tf.shape(x) maybe_expanded_shape = tf.concat([ shape[:-1], distribution_util.pick_vector(tf.equal(tf.rank(x), 1), [1], np.array([], dtype=np.int32)), shape[-1:], ], 0) return tf.reshape(x, maybe_expanded_shape)
def _marginal_hidden_probs(self): """Compute marginal pdf for each individual observable.""" initial_log_probs = tf.broadcast_to( self._log_init, tf.concat([self.batch_shape_tensor(), [self._num_states]], axis=0)) # initial_log_probs :: batch_shape num_states def _scan_multiple_steps(): """Perform `scan` operation when `num_steps` > 1.""" transition_log_probs = self._log_trans def forward_step(log_probs, _): return _log_vector_matrix(log_probs, transition_log_probs) dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) forward_log_probs = tf.scan(forward_step, dummy_index, initializer=initial_log_probs, name="forward_log_probs") return tf.concat([[initial_log_probs], forward_log_probs], axis=0) forward_log_probs = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps, lambda: initial_log_probs[tf.newaxis, ...]) return tf.exp(forward_log_probs)
def _sample_n(self, n, seed=None): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) peak = tf.convert_to_tensor(self.peak) stream = SeedStream(seed, salt='triangular') shape = tf.concat( [[n], self._batch_shape_tensor(low=low, high=high, peak=peak)], axis=0) samples = tf.random.uniform(shape=shape, dtype=self.dtype, seed=stream()) # We use Inverse CDF sampling here. Because the CDF is a quadratic function, # we must use sqrts here. interval_length = high - low return tf.where( # Note the CDF on the left side of the peak is # (x - low) ** 2 / ((high - low) * (peak - low)). # If we plug in peak for x, we get that the CDF at the peak # is (peak - low) / (high - low). Because of this we decide # which part of the piecewise CDF we should use based on the cdf samples # we drew. samples < (peak - low) / interval_length, # Inverse of (x - low) ** 2 / ((high - low) * (peak - low)). low + tf.sqrt(samples * interval_length * (peak - low)), # Inverse of 1 - (high - x) ** 2 / ((high - low) * (high - peak)) high - tf.sqrt((1. - samples) * interval_length * (high - peak)))
def _sample_n(self, n, seed=None): loc = tf.convert_to_tensor(self.loc) scale = tf.convert_to_tensor(self.scale) shape = tf.concat([[n], self._batch_shape_tensor(loc=loc, scale=scale)], axis=0) sampled = tf.random.normal( shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) return sampled * scale + loc
def _entropy(self, **kwargs): if not self.bijector.is_constant_jacobian: raise NotImplementedError("entropy is not implemented") if not self.bijector._is_injective: # pylint: disable=protected-access raise NotImplementedError("entropy is not implemented when " "bijector is not injective.") distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It # can be shown that: # H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)]. # If is_constant_jacobian then: # E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c) # where c can by anything. entropy = self.distribution.entropy(**distribution_kwargs) if self._is_maybe_event_override: # H[X] = sum_i H[X_i] if X_i are mutually independent. # This means that a reduce_sum is a simple rescaling. entropy = entropy * tf.cast( tf.reduce_prod(self._override_event_shape), dtype=dtype_util.base_dtype(entropy.dtype)) if self._is_maybe_batch_override: new_shape = tf.concat([ prefer_static.ones_like(self._override_batch_shape), self.distribution.batch_shape_tensor() ], 0) entropy = tf.reshape(entropy, new_shape) multiples = tf.concat([ self._override_batch_shape, prefer_static.ones_like(self.distribution.batch_shape_tensor()) ], 0) entropy = tf.tile(entropy, multiples) dummy = prefer_static.zeros(shape=tf.concat( [self.batch_shape_tensor(), self.event_shape_tensor()], 0), dtype=self.dtype) event_ndims = (tensorshape_util.rank(self.event_shape) if tensorshape_util.rank(self.event_shape) is not None else tf.size(self.event_shape_tensor())) ildj = self.bijector.inverse_log_det_jacobian(dummy, event_ndims=event_ndims, **bijector_kwargs) entropy = entropy - tf.cast(ildj, entropy.dtype) tensorshape_util.set_shape(entropy, self.batch_shape) return entropy
def _rotate(self, samples): """Applies a Householder rotation to `samples`.""" event_dim = (tf.compat.dimension_value(self.event_shape[0]) or self._event_shape_tensor()[0]) basis = tf.concat( [[1.], tf.zeros([event_dim - 1], dtype=self.dtype)], axis=0), u = tf.math.l2_normalize(basis - self.mean_direction, axis=-1) return samples - 2 * tf.reduce_sum(samples * u, axis=-1, keepdims=True) * u
def _inverse_event_shape_tensor(self, output_shape): if self.validate_args: # It is not possible for a negative shape so we need only check <= 1. dependencies = [assert_util.assert_greater( output_shape[-1], 1, message="Need last dimension greater than 1.")] else: dependencies = [] with tf.control_dependencies(dependencies): return tf.concat([output_shape[:-1], [output_shape[-1] - 1]], axis=0)
def _sample_n(self, n, seed=None): del seed # unused loc = tf.convert_to_tensor(self.loc) return tf.broadcast_to( loc, tf.concat([[n], self._batch_shape_tensor(loc=loc), self._event_shape_tensor(loc=loc)], axis=0))
def _extract_log_probs(num_states, dist): """Tabulate log probabilities from a batch of distributions.""" states = tf.reshape( tf.range(num_states), tf.concat([[num_states], tf.ones_like(dist.batch_shape_tensor())], axis=0)) return distribution_util.move_dimension(dist.log_prob(states), 0, -1)
def _sample_n(self, n, seed=None): scale = tf.convert_to_tensor(self.scale) shape = tf.concat([[n], tf.shape(scale)], 0) sampled = tf.random.normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) return tf.abs(sampled * scale)
def _uniform_unit_norm(dimension, shape, dtype, seed): """Returns a batch of points chosen uniformly from the unit hypersphere.""" # This works because the Gaussian distribution is spherically symmetric. # raw shape: shape + [dimension] raw = normal.Normal(loc=dtype_util.as_numpy_dtype(dtype)(0), scale=dtype_util.as_numpy_dtype(dtype)(1)).sample( tf.concat([shape, [dimension]], axis=0), seed=seed()) unit_norm = raw / tf.norm(raw, ord=2, axis=-1)[..., tf.newaxis] return unit_norm
def _sample_n(self, n, seed=None, **kwargs): with tf.control_dependencies(self._runtime_assertions): x = self.distribution.sample(sample_shape=n, seed=seed, **kwargs) new_shape = tf.concat([ [n], self._batch_shape_unexpanded, self.event_shape_tensor(), ], axis=0) return tf.reshape(x, new_shape)
def _pad_sample_dims(self, x): with tf.name_scope("pad_sample_dims"): ndims = tensorshape_util.rank(x.shape) if tensorshape_util.rank( x.shape) is not None else tf.rank(x) shape = tf.shape(x) d = ndims - self._event_ndims x = tf.reshape(x, shape=tf.concat([shape[:d], [1], shape[d:]], axis=0)) return x
def _sample_n(self, n, seed=None): loc = tf.convert_to_tensor(self.loc) scale = tf.convert_to_tensor(self.scale) batch_shape = self._batch_shape_tensor(loc=loc, scale=scale) shape = tf.concat([[n], batch_shape], 0) probs = tf.random.uniform(shape=shape, minval=0., maxval=1., dtype=self.dtype, seed=seed) return self._quantile(probs, loc=loc, scale=scale)
def _mode_mean_shape(self): """Shape for the mode/mean Tensors.""" shape = tensorshape_util.concatenate(self.batch_shape, self.event_shape) has_static_shape = tensorshape_util.is_fully_defined(shape) if not has_static_shape: shape = tf.concat([ self.batch_shape_tensor(), self.event_shape_tensor(), ], 0) return shape