def _swap_m_with_i(vecs, m, i): """Swaps `m` and `i` on axis -1. (Helper for pivoted_cholesky.) Given a batch of int64 vectors `vecs`, scalar index `m`, and compatibly shaped per-vector indices `i`, this function swaps elements `m` and `i` in each vector. For the use-case below, these are permutation vectors. Args: vecs: Vectors on which we perform the swap, int64 `Tensor`. m: Scalar int64 `Tensor`, the index into which the `i`th element is going. i: Batch int64 `Tensor`, shaped like vecs.shape[:-1] + [1]; the index into which the `m`th element is going. Returns: vecs: The updated vectors. """ vecs = tf.convert_to_tensor(vecs, dtype=tf.int64, name='vecs') m = tf.convert_to_tensor(m, dtype=tf.int64, name='m') i = tf.convert_to_tensor(i, dtype=tf.int64, name='i') trailing_elts = tf.broadcast_to( tf.range(m + 1, prefer_static.shape(vecs, out_type=tf.int64)[-1]), prefer_static.shape(vecs[..., m + 1:])) trailing_elts = tf.where(tf.equal(trailing_elts, i), tf.gather(vecs, [m], axis=-1), vecs[..., m + 1:]) # TODO(bjp): Could we use tensor_scatter_nd_update? vecs_shape = vecs.shape vecs = tf.concat([ vecs[..., :m], tf.gather(vecs, i, batch_dims=int(prefer_static.rank(vecs)) - 1), trailing_elts ], axis=-1) tensorshape_util.set_shape(vecs, vecs_shape) return vecs
def _cdf(self, k): # TODO(b/135263541): Improve numerical precision of categorical.cdf. probs = self.probs_parameter() num_categories = self._num_categories(probs) k, probs = _broadcast_cat_event_and_params( k, probs, base_dtype=dtype_util.base_dtype(self.dtype)) # Since the lowest number in the support is 0, any k < 0 should be zero in # the output. should_be_zero = k < 0 # Will use k as an index in the gather below, so clip it to {0,...,K-1}. k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1) batch_shape = tf.shape(k) # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so # to handle the case where the batch shape is dynamic, flatten the batch # dims (so we know batch_dims=1). k_flat_batch = tf.reshape(k, [-1]) probs_flat_batch = tf.reshape( probs, tf.concat(([-1], [num_categories]), axis=0)) cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1), k_flat_batch[..., tf.newaxis], batch_dims=1) cdf = tf.reshape(cdf_flat, shape=batch_shape) zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype)) return tf.where(should_be_zero, zero, cdf)
def _log_prob(self, x): x = tf.convert_to_tensor(x, name='x') right_indices = tf.minimum( tf.size(self.outcomes) - 1, tf.reshape( tf.searchsorted(self.outcomes, values=tf.reshape(x, shape=[-1]), side='right'), dist_util.prefer_static_shape(x))) use_right_indices = self._is_equal_or_close( x, tf.gather(self.outcomes, indices=right_indices)) left_indices = tf.maximum(0, right_indices - 1) use_left_indices = self._is_equal_or_close( x, tf.gather(self.outcomes, indices=left_indices)) log_probs = self._categorical.log_prob( tf.where(use_left_indices, left_indices, right_indices)) return tf.where(tf.logical_not(use_left_indices | use_right_indices), dtype_util.as_numpy_dtype(log_probs.dtype)(-np.inf), log_probs)
def _sample_n(self, n, seed=None): samples = tf.convert_to_tensor(self._samples) indices = tf.random.uniform([n], maxval=self._compute_num_samples(samples), dtype=tf.int32, seed=seed) draws = tf.gather(samples, indices, axis=self._samples_axis) axes = tf.concat( [[self._samples_axis], tf.range(self._samples_axis, dtype=tf.int32), tf.range(self._event_ndims, dtype=tf.int32) + self._samples_axis + 1], axis=0) draws = tf.transpose(a=draws, perm=axes) return draws
def _forward(self, x): map_values = tf.convert_to_tensor(self.map_values) if self.validate_args: with tf.control_dependencies([ assert_util.assert_equal( (0 <= x) & (x < tf.size(map_values)), True, message='indices out of bound') ]): x = tf.identity(x) # If we want batch dims in self.map_values, we can (after broadcasting), # use: # tf.gather(self.map_values, x, batch_dims=-1, axis=-1) return tf.gather(map_values, indices=x)
def _cdf(self, x): x = tf.convert_to_tensor(x, name='x') flat_x = tf.reshape(x, shape=[-1]) upper_bound = tf.searchsorted(self.outcomes, values=flat_x, side='right') values_at_ub = tf.gather( self.outcomes, indices=tf.minimum( upper_bound, dist_util.prefer_static_shape(self.outcomes)[-1] - 1)) should_use_upper_bound = self._is_equal_or_close(flat_x, values_at_ub) indices = tf.where(should_use_upper_bound, upper_bound, upper_bound - 1) return self._categorical.cdf( tf.reshape(indices, shape=dist_util.prefer_static_shape(x)))
def _prob(self, x): if self.validate_args: is_vector_check = assert_util.assert_rank_at_least(x, 1) right_vec_space_check = assert_util.assert_equal( self.event_shape_tensor(), tf.gather(tf.shape(x), tf.rank(x) - 1), message= "Argument 'x' not defined in the same space R^k as this distribution" ) with tf.control_dependencies([is_vector_check]): with tf.control_dependencies([right_vec_space_check]): x = tf.identity(x) loc = tf.convert_to_tensor(self.loc) return tf.cast(tf.reduce_all(tf.abs(x - loc) <= self._slack(loc), axis=-1), dtype=self.dtype)
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. distributions = self.poisson_and_mixture_distributions() dist, mixture_dist = distributions batch_size = tensorshape_util.num_elements(self.batch_shape) if batch_size is None: batch_size = tf.reduce_prod( self._batch_shape_tensor(distributions=distributions)) # We need to 'sample extra' from the mixture distribution if it doesn't # already specify a probs vector for each batch coordinate. # We only support this kind of reduced broadcasting, i.e., there is exactly # one probs vector for all batch dims or one for each. stream = SeedStream(seed, salt='PoissonLogNormalQuadratureCompound') ids = mixture_dist.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( mixture_dist.is_scalar_batch(), [batch_size], np.int32([]))), seed=stream()) # We need to flatten batch dims in case mixture_dist has its own # batch dims. ids = tf.reshape( ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `quadrature_size` for `batch_size` number of times. offset = tf.range( start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids = ids + offset rate = tf.gather(tf.reshape(dist.rate, shape=[-1]), ids) rate = tf.reshape( rate, shape=concat_vectors([n], self._batch_shape_tensor( distributions=distributions))) return tf.random.poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)
def _sample_n(self, n, seed=None): seed = SeedStream(seed, salt='vom_mises_fisher') # The sampling strategy relies on the fact that vMF variates are symmetric # about the mean direction. Accordingly, if we have a sampling strategy for # the away-from-mean angle, then we can uniformly sample the remaining # dimensions on the S^{dim-2} sphere for , and rotate these samples from a # (1, 0, 0, ..., 0)-mode distribution into the target orientation. # # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a # von-Mises distributed `x` value in [-1, 1], then uniformly select what # amounts to a "up" or "down" additional degree of freedom after unit # normalizing, followed by a final rotation to the desired mean direction # from a basis of (1, 0). # # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the # unit sphere over which the distribution is uniform, in particular the # circle where x = \hat{x} intersects the unit sphere. We pick a point on # that circle, then rotate to the desired mean direction from a basis of # (1, 0, 0). event_dim = (tf.compat.dimension_value(self.event_shape[0]) or self._event_shape_tensor()[0]) sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0) dim = tf.cast(event_dim - 1, self.dtype) if event_dim == 3: samples_dim0 = self._sample_3d(n, seed=seed) else: # Wood'94 provides a rejection algorithm to sample the x coordinate. # Wood'94 definition of b: # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim # https://stats.stackexchange.com/questions/156729 suggests: b = dim / (2 * self.concentration + tf.sqrt(4 * self.concentration**2 + dim**2)) # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE # https://github.com/nicola-decao/s-vae-tf/ x = (1 - b) / (1 + b) c = self.concentration * x + dim * tf.math.log1p(-x**2) beta = beta_lib.Beta(dim / 2, dim / 2) def cond_fn(w, should_continue): del w return tf.reduce_any(should_continue) def body_fn(w, should_continue): z = beta.sample(sample_shape=sample_batch_shape, seed=seed()) # set_shape needed here because of b/139013403 z.set_shape(w.shape) w = tf.where(should_continue, (1 - (1 + b) * z) / (1 - (1 - b) * z), w) w = tf.debugging.check_numerics(w, 'w') unif = tf.random.uniform(sample_batch_shape, seed=seed(), dtype=self.dtype) # set_shape needed here because of b/139013403 unif.set_shape(w.shape) should_continue = tf.logical_and( should_continue, self.concentration * w + dim * tf.math.log1p(-x * w) - c < tf.math.log(unif)) return w, should_continue w = tf.zeros(sample_batch_shape, dtype=self.dtype) should_continue = tf.ones(sample_batch_shape, dtype=tf.bool) samples_dim0 = tf.while_loop(cond=cond_fn, body=body_fn, loop_vars=(w, should_continue))[0] samples_dim0 = samples_dim0[..., tf.newaxis] if not self._allow_nan_stats: # Verify samples are w/in -1, 1, with useful error output tensors (top # value rather than all values). with tf.control_dependencies([ assert_util.assert_less_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(1.01), data=[ tf.math.top_k(tf.reshape(samples_dim0, [-1]))[0] ]), assert_util.assert_greater_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(-1.01), data=[ -tf.math.top_k(tf.reshape(-samples_dim0, [-1]))[0] ]) ]): samples_dim0 = tf.identity(samples_dim0) samples_otherdims_shape = tf.concat( [sample_batch_shape, [event_dim - 1]], axis=0) unit_otherdims = tf.math.l2_normalize(tf.random.normal( samples_otherdims_shape, seed=seed(), dtype=self.dtype), axis=-1) samples = tf.concat( [ samples_dim0, # we must avoid sqrt(1 - (>1)**2) tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims ], axis=-1) samples = tf.math.l2_normalize(samples, axis=-1) if not self._allow_nan_stats: samples = tf.debugging.check_numerics(samples, 'samples') # Runtime assert that samples are unit length. if not self._allow_nan_stats: worst, idx = tf.math.top_k( tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1])) with tf.control_dependencies([ assert_util.assert_near( dtype_util.as_numpy_dtype(self.dtype)(0), worst, data=[ worst, idx, tf.gather(tf.reshape(samples, [-1, event_dim]), idx) ], atol=1e-4, summarize=100) ]): samples = tf.identity(samples) # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0). # Now, we move the mode to `self.mean_direction` using a rotation matrix. if not self._allow_nan_stats: # Assert that the basis vector rotates to the mean direction, as expected. basis = tf.cast( tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0), self.dtype) with tf.control_dependencies([ assert_util.assert_less( tf.linalg.norm(self._rotate(basis) - self.mean_direction, axis=-1), dtype_util.as_numpy_dtype(self.dtype)(1e-5)) ]): return self._rotate(samples) return self._rotate(samples)
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None): """The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_reconstruct'). Returns: x: The original input to `tf.linalg.lu`, i.e., `x` as in, `lu_reconstruct(*tf.linalg.lu(x))`. #### Examples ```python import numpy as np from tensorflow_probability.python.internal.backend import numpy as tf import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.numpy x = [[[3., 4], [1, 2]], [[7., 8], [3, 4]]] x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x)) tf.assert_near(x, x_reconstructed) # ==> True ``` """ with tf.name_scope(name or 'lu_reconstruct'): lower_upper = tf.convert_to_tensor(lower_upper, dtype_hint=tf.float32, name='lower_upper') perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm') assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) if assertions: with tf.control_dependencies(assertions): lower_upper = tf.identity(lower_upper) perm = tf.identity(perm) shape = tf.shape(lower_upper) lower = tf.linalg.set_diag( tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0), tf.ones(shape[:-1], dtype=lower_upper.dtype)) upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1) x = tf.matmul(lower, upper) if (tensorshape_util.rank(lower_upper.shape) is None or tensorshape_util.rank(lower_upper.shape) != 2): # We either don't know the batch rank or there are >0 batch dims. batch_size = tf.reduce_prod(shape[:-2]) d = shape[-1] x = tf.reshape(x, [batch_size, d, d]) perm = tf.reshape(perm, [batch_size, d]) perm = tf.map_fn(tf.math.invert_permutation, perm) batch_indices = tf.broadcast_to( tf.range(batch_size)[:, tf.newaxis], [batch_size, d]) x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1)) x = tf.reshape(x, shape) else: x = tf.gather(x, tf.math.invert_permutation(perm)) x.set_shape(lower_upper.shape) return x
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None): """Solves systems of linear eqns `A X = RHS`, given LU factorizations. Note: this function does not verify the implied matrix is actually invertible nor is this condition checked even when `validate_args=True`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. rhs: Matrix-shaped float `Tensor` representing targets for which to solve; `A X = RHS`. To handle vector cases, use: `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Note: this function does not verify the implied matrix is actually invertible, even when `validate_args=True`. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_solve'). Returns: x: The `X` in `A @ X = RHS`. #### Examples ```python import numpy as np from tensorflow_probability.python.internal.backend import numpy as tf import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.numpy x = [[[1., 2], [3, 4]], [[7, 8], [3, 4]]] inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2)) tf.assert_near(tf.matrix_inverse(x), inv_x) # ==> True ``` """ with tf.name_scope(name or 'lu_solve'): lower_upper = tf.convert_to_tensor(lower_upper, dtype_hint=tf.float32, name='lower_upper') perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm') rhs = tf.convert_to_tensor(rhs, dtype_hint=lower_upper.dtype, name='rhs') assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args) if assertions: with tf.control_dependencies(assertions): lower_upper = tf.identity(lower_upper) perm = tf.identity(perm) rhs = tf.identity(rhs) if (tensorshape_util.rank(rhs.shape) == 2 and tensorshape_util.rank(perm.shape) == 1): # Both rhs and perm have scalar batch_shape. permuted_rhs = tf.gather(rhs, perm, axis=-2) else: # Either rhs or perm have non-scalar batch_shape or we can't determine # this information statically. rhs_shape = tf.shape(rhs) broadcast_batch_shape = tf.broadcast_dynamic_shape( rhs_shape[:-2], tf.shape(perm)[:-1]) d, m = rhs_shape[-2], rhs_shape[-1] rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]], axis=0) # Tile out rhs. broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape) broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m]) # Tile out perm and add batch indices. broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1]) broadcast_perm = tf.reshape(broadcast_perm, [-1, d]) broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape) broadcast_batch_indices = tf.broadcast_to( tf.range(broadcast_batch_size)[:, tf.newaxis], [broadcast_batch_size, d]) broadcast_perm = tf.stack( [broadcast_batch_indices, broadcast_perm], axis=-1) permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm) permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape) lower = tf.linalg.set_diag( tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0), tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype)) return linear_operator_util.matrix_triangular_solve_with_broadcast( lower_upper, # Only upper is accessed. linear_operator_util.matrix_triangular_solve_with_broadcast( lower, permuted_rhs), lower=False)
def batch_gather(params, indices, axis=-1): return tf.gather(params, indices, axis=axis, batch_dims=batch_dims)
def _sample_n(self, n, seed=None): if self._use_static_graph: with tf.control_dependencies(self._assertions): # This sampling approach is almost the same as the approach used by # `MixtureSameFamily`. The differences are due to having a list of # `Distribution` objects rather than a single object, and maintaining # random seed management that is consistent with the non-static code # path. samples = [] cat_samples = self.cat.sample(n, seed=seed) stream = SeedStream(seed, salt="Mixture") for c in range(self.num_components): samples.append(self.components[c].sample(n, seed=stream())) stack_axis = -1 - tensorshape_util.rank( self._static_event_shape) x = tf.stack(samples, axis=stack_axis) # [n, B, k, E] npdt = dtype_util.as_numpy_dtype(x.dtype) mask = tf.one_hot( indices=cat_samples, # [n, B] depth=self._num_components, # == k on_value=npdt(1), off_value=npdt(0)) # [n, B, k] mask = distribution_util.pad_mixture_dimensions( mask, self, self._cat, tensorshape_util.rank( self._static_event_shape)) # [n, B, k, [1]*e] return tf.reduce_sum(x * mask, axis=stack_axis) # [n, B, E] with tf.control_dependencies(self._assertions): n = tf.convert_to_tensor(n, name="n") static_n = tf.get_static_value(n) n = int(static_n) if static_n is not None else n cat_samples = self.cat.sample(n, seed=seed) static_samples_shape = cat_samples.shape if tensorshape_util.is_fully_defined(static_samples_shape): samples_shape = tensorshape_util.as_list(static_samples_shape) samples_size = tensorshape_util.num_elements( static_samples_shape) else: samples_shape = tf.shape(cat_samples) samples_size = tf.size(cat_samples) static_batch_shape = self.batch_shape if tensorshape_util.is_fully_defined(static_batch_shape): batch_shape = tensorshape_util.as_list(static_batch_shape) batch_size = tensorshape_util.num_elements(static_batch_shape) else: batch_shape = self.batch_shape_tensor() batch_size = tf.reduce_prod(batch_shape) static_event_shape = self.event_shape if tensorshape_util.is_fully_defined(static_event_shape): event_shape = np.array( tensorshape_util.as_list(static_event_shape), dtype=np.int32) else: event_shape = self.event_shape_tensor() # Get indices into the raw cat sampling tensor. We will # need these to stitch sample values back out after sampling # within the component partitions. samples_raw_indices = tf.reshape(tf.range(0, samples_size), samples_shape) # Partition the raw indices so that we can use # dynamic_stitch later to reconstruct the samples from the # known partitions. partitioned_samples_indices = tf.dynamic_partition( data=samples_raw_indices, partitions=cat_samples, num_partitions=self.num_components) # Copy the batch indices n times, as we will need to know # these to pull out the appropriate rows within the # component partitions. batch_raw_indices = tf.reshape( tf.tile(tf.range(0, batch_size), [n]), samples_shape) # Explanation of the dynamic partitioning below: # batch indices are i.e., [0, 1, 0, 1, 0, 1] # Suppose partitions are: # [1 1 0 0 1 1] # After partitioning, batch indices are cut as: # [batch_indices[x] for x in 2, 3] # [batch_indices[x] for x in 0, 1, 4, 5] # i.e. # [1 1] and [0 0 0 0] # Now we sample n=2 from part 0 and n=4 from part 1. # For part 0 we want samples from batch entries 1, 1 (samples 0, 1), # and for part 1 we want samples from batch entries 0, 0, 0, 0 # (samples 0, 1, 2, 3). partitioned_batch_indices = tf.dynamic_partition( data=batch_raw_indices, partitions=cat_samples, num_partitions=self.num_components) samples_class = [None for _ in range(self.num_components)] stream = SeedStream(seed, salt="Mixture") for c in range(self.num_components): n_class = tf.size(partitioned_samples_indices[c]) samples_class_c = self.components[c].sample(n_class, seed=stream()) # Pull out the correct batch entries from each index. # To do this, we may have to flatten the batch shape. # For sample s, batch element b of component c, we get the # partitioned batch indices from # partitioned_batch_indices[c]; and shift each element by # the sample index. The final lookup can be thought of as # a matrix gather along locations (s, b) in # samples_class_c where the n_class rows correspond to # samples within this component and the batch_size columns # correspond to batch elements within the component. # # Thus the lookup index is # lookup[c, i] = batch_size * s[i] + b[c, i] # for i = 0 ... n_class[c] - 1. lookup_partitioned_batch_indices = ( batch_size * tf.range(n_class) + partitioned_batch_indices[c]) samples_class_c = tf.reshape( samples_class_c, tf.concat([[n_class * batch_size], event_shape], 0)) samples_class_c = tf.gather(samples_class_c, lookup_partitioned_batch_indices, name="samples_class_c_gather") samples_class[c] = samples_class_c # Stitch back together the samples across the components. lhs_flat_ret = tf.dynamic_stitch( indices=partitioned_samples_indices, data=samples_class) # Reshape back to proper sample, batch, and event shape. ret = tf.reshape( lhs_flat_ret, tf.concat( [samples_shape, self.event_shape_tensor()], 0)) tensorshape_util.set_shape( ret, tensorshape_util.concatenate(static_samples_shape, self.event_shape)) return ret
def _sample_n(self, n, seed=None, **distribution_kwargs): return tf.gather(self.outcomes, indices=self._categorical.sample( sample_shape=[n], seed=seed, **distribution_kwargs))
def _forward(self, x): y = tf.gather(x, self.permutation, axis=self.axis) tensorshape_util.set_shape(y, x.shape) return y
def _forward_event_shape_tensor(self, input_shape): perm = self._make_perm(tf.size(input_shape), self.perm) return tf.gather(input_shape, perm)
def _inverse(self, y): x = tf.gather(y, tf.math.invert_permutation(self.permutation), axis=self.axis) tensorshape_util.set_shape(x, y.shape) return x
def _sample_n(self, n, seed=None): stream = SeedStream(seed, salt="VectorDiffeomixture") x = self.distribution.sample(sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=stream()) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = tensorshape_util.num_elements(self.batch_shape) if batch_size is None: batch_size = tf.reduce_prod(self.batch_shape_tensor()) mix_batch_size = tensorshape_util.num_elements( self.mixture_distribution.batch_shape) if mix_batch_size is None: mix_batch_size = tf.reduce_prod( self.mixture_distribution.batch_shape_tensor()) ids = self.mixture_distribution.sample(sample_shape=concat_vectors( [n], distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]), [batch_size // mix_batch_size])), seed=stream()) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = tf.reshape(ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `components * quadrature_size` for `batch_size` number of times. stride = tensorshape_util.num_elements( tensorshape_util.with_rank_at_least(self.grid.shape, 2)[-2:]) if stride is None: stride = tf.reduce_prod(tf.shape(self.grid)[-2:]) offset = tf.range(start=0, limit=batch_size * stride, delta=stride, dtype=ids.dtype) weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset) # At this point, weight flattened all batch dims into one. # We also need to append a singleton to broadcast with event dims. if tensorshape_util.is_fully_defined(self.batch_shape): new_shape = [-1] + tensorshape_util.as_list(self.batch_shape) + [1] else: new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]), axis=0) weight = tf.reshape(weight, shape=new_shape) if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError( "Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def _inverse_event_shape_tensor(self, output_shape): perm = self._make_perm(tf.size(output_shape), tf.argsort(self.perm)) return tf.gather(output_shape, perm)
def _mode(self): return tf.gather(self.outcomes, indices=self._categorical.mode())