def calculate_reshape(original_shape, new_shape, validate=False, name=None): """Calculates the reshaped dimensions (replacing up to one -1 in reshape).""" batch_shape_static = tensorshape_util.constant_value_as_shape(new_shape) if tensorshape_util.is_fully_defined(batch_shape_static): return np.int32(batch_shape_static), batch_shape_static, [] with tf.name_scope(name or 'calculate_reshape'): original_size = tf.reduce_prod(original_shape) implicit_dim = tf.equal(new_shape, -1) size_implicit_dim = (original_size // tf.maximum(1, -tf.reduce_prod(new_shape))) expanded_new_shape = tf.where( # Assumes exactly one `-1`. implicit_dim, size_implicit_dim, new_shape) validations = [] if not validate else [ # pylint: disable=g-long-ternary assert_util.assert_rank( original_shape, 1, message='Original shape must be a vector.'), assert_util.assert_rank( new_shape, 1, message='New shape must be a vector.'), assert_util.assert_less_equal( tf.math.count_nonzero(implicit_dim, dtype=tf.int32), 1, message='At most one dimension can be unknown.'), assert_util.assert_positive( expanded_new_shape, message='Shape elements must be >=-1.'), assert_util.assert_equal(tf.reduce_prod(expanded_new_shape), original_size, message='Shape sizes do not match.'), ] return expanded_new_shape, batch_shape_static, validations
def _variance(self): with tf.control_dependencies(self._runtime_assertions): probs = self._marginal_hidden_probs() # probs :: num_steps batch_shape num_states means = self._observation_distribution.mean() # means :: observation_batch_shape[:-1] num_states # observation_event_shape means_shape = tf.concat([ self.batch_shape_tensor(), [self._num_states], self._observation_distribution.event_shape_tensor() ], axis=0) means = tf.broadcast_to(means, means_shape) # means :: batch_shape num_states observation_event_shape observation_event_shape = ( self._observation_distribution.event_shape_tensor()) batch_size = tf.reduce_prod(self.batch_shape_tensor()) flat_probs_shape = [self._num_steps, batch_size, self._num_states] flat_means_shape = [ batch_size, 1, self._num_states, tf.reduce_prod(observation_event_shape) ] flat_probs = tf.reshape(probs, flat_probs_shape) # flat_probs :: num_steps batch_size num_states flat_means = tf.reshape(means, flat_means_shape) # flat_means :: batch_size 1 num_states observation_event_size flat_mean = tf.einsum("ijk,jmkl->jiml", flat_probs, flat_means) # flat_mean :: batch_size num_steps 1 observation_event_size variances = self._observation_distribution.variance() variances = tf.broadcast_to(variances, means_shape) # variances :: batch_shape num_states observation_event_shape flat_variances = tf.reshape(variances, flat_means_shape) # flat_variances :: batch_size 1 num_states observation_event_size # For a mixture of n distributions with mixture probabilities # p[i], and where the individual distributions have means and # variances given by mean[i] and var[i], the variance of # the mixture is given by: # # var = sum i=1..n p[i] * ((mean[i] - mean)**2 + var[i]**2) flat_variance = tf.einsum("ijk,jikl->jil", flat_probs, (flat_means - flat_mean)**2 + flat_variances) # flat_variance :: batch_size num_steps observation_event_size unflat_mean_shape = tf.concat([ self.batch_shape_tensor(), [self._num_steps], observation_event_shape ], axis=0) # returns :: batch_shape num_steps observation_event_shape return tf.reshape(flat_variance, unflat_mean_shape)
def _mode(self, samples=None): # Samples count can vary by batch member. Use map_fn to compute mode for # each batch separately. def _get_mode(samples): # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count return tf.argmax(count) if samples is None: samples = tf.convert_to_tensor(self._samples) num_samples = self._compute_num_samples(samples) # Flatten samples for each batch. if self._event_ndims == 0: flattened_samples = tf.reshape(samples, [-1, num_samples]) mode_shape = self._batch_shape_tensor(samples) else: event_size = tf.reduce_prod(self._event_shape_tensor(samples)) mode_shape = tf.concat([ self._batch_shape_tensor(samples), self._event_shape_tensor(samples) ], axis=0) flattened_samples = tf.reshape(samples, [-1, num_samples, event_size]) indices = tf.map_fn(_get_mode, flattened_samples, dtype=tf.int64) full_indices = tf.stack( [tf.range(tf.shape(indices)[0]), tf.cast(indices, tf.int32)], axis=1) mode = tf.gather_nd(flattened_samples, full_indices) return tf.reshape(mode, mode_shape)
def _event_shape_tensor(self): event_sizes = tf.nest.map_structure(tensorshape_util.num_elements, self._distribution.event_shape) if any(s is None for s in tf.nest.flatten(event_sizes)): event_sizes = tf.nest.map_structure( lambda static_size, shape_tensor: # pylint: disable=g-long-lambda (tf.reduce_prod(shape_tensor) if static_size is None else static_size), event_sizes, self._distribution.event_shape_tensor()) return tf.reduce_sum(tf.nest.flatten(event_sizes))[tf.newaxis]
def _mean(self): with tf.control_dependencies(self._runtime_assertions): probs = self._marginal_hidden_probs() # probs :: num_steps batch_shape num_states means = self._observation_distribution.mean() # means :: observation_batch_shape[:-1] num_states # observation_event_shape means_shape = tf.concat([ self.batch_shape_tensor(), [self._num_states], self._observation_distribution.event_shape_tensor() ], axis=0) means = tf.broadcast_to(means, means_shape) # means :: batch_shape num_states observation_event_shape observation_event_shape = ( self._observation_distribution.event_shape_tensor()) batch_size = tf.reduce_prod(self.batch_shape_tensor()) flat_probs_shape = [self._num_steps, batch_size, self._num_states] flat_means_shape = [ batch_size, self._num_states, tf.reduce_prod(observation_event_shape) ] flat_probs = tf.reshape(probs, flat_probs_shape) # flat_probs :: num_steps batch_size num_states flat_means = tf.reshape(means, flat_means_shape) # flat_means :: batch_size num_states observation_event_size flat_mean = tf.einsum("ijk,jkl->jil", flat_probs, flat_means) # flat_mean :: batch_size num_steps observation_event_size unflat_mean_shape = tf.concat([ self.batch_shape_tensor(), [self._num_steps], observation_event_shape ], axis=0) # returns :: batch_shape num_steps observation_event_shape return tf.reshape(flat_mean, unflat_mean_shape)
def _finish_prob_for_one_fiber(self, y, x, ildj, event_ndims, **distribution_kwargs): """Finish computation of prob on one element of the inverse image.""" x = self._maybe_rotate_dims(x, rotate_right=True) prob = self.distribution.prob(x, **distribution_kwargs) if self._is_maybe_event_override: prob = tf.reduce_prod(prob, axis=self._reduce_event_indices) prob = prob * tf.exp(tf.cast(ildj, prob.dtype)) if self._is_maybe_event_override and isinstance(event_ndims, int): tensorshape_util.set_shape( prob, tf.broadcast_static_shape( tensorshape_util.with_rank_at_least(y.shape, 1)[:-event_ndims], self.batch_shape)) return prob
def _entropy(self, **kwargs): if not self.bijector.is_constant_jacobian: raise NotImplementedError("entropy is not implemented") if not self.bijector._is_injective: # pylint: disable=protected-access raise NotImplementedError("entropy is not implemented when " "bijector is not injective.") distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It # can be shown that: # H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)]. # If is_constant_jacobian then: # E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c) # where c can by anything. entropy = self.distribution.entropy(**distribution_kwargs) if self._is_maybe_event_override: # H[X] = sum_i H[X_i] if X_i are mutually independent. # This means that a reduce_sum is a simple rescaling. entropy = entropy * tf.cast( tf.reduce_prod(self._override_event_shape), dtype=dtype_util.base_dtype(entropy.dtype)) if self._is_maybe_batch_override: new_shape = tf.concat([ prefer_static.ones_like(self._override_batch_shape), self.distribution.batch_shape_tensor() ], 0) entropy = tf.reshape(entropy, new_shape) multiples = tf.concat([ self._override_batch_shape, prefer_static.ones_like(self.distribution.batch_shape_tensor()) ], 0) entropy = tf.tile(entropy, multiples) dummy = prefer_static.zeros(shape=tf.concat( [self.batch_shape_tensor(), self.event_shape_tensor()], 0), dtype=self.dtype) event_ndims = (tensorshape_util.rank(self.event_shape) if tensorshape_util.rank(self.event_shape) is not None else tf.size(self.event_shape_tensor())) ildj = self.bijector.inverse_log_det_jacobian(dummy, event_ndims=event_ndims, **bijector_kwargs) entropy = entropy - tf.cast(ildj, entropy.dtype) tensorshape_util.set_shape(entropy, self.batch_shape) return entropy
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. distributions = self.poisson_and_mixture_distributions() dist, mixture_dist = distributions batch_size = tensorshape_util.num_elements(self.batch_shape) if batch_size is None: batch_size = tf.reduce_prod( self._batch_shape_tensor(distributions=distributions)) # We need to 'sample extra' from the mixture distribution if it doesn't # already specify a probs vector for each batch coordinate. # We only support this kind of reduced broadcasting, i.e., there is exactly # one probs vector for all batch dims or one for each. stream = SeedStream(seed, salt='PoissonLogNormalQuadratureCompound') ids = mixture_dist.sample(sample_shape=concat_vectors( [n], distribution_util.pick_vector(mixture_dist.is_scalar_batch(), [batch_size], np.int32([]))), seed=stream()) # We need to flatten batch dims in case mixture_dist has its own # batch dims. ids = tf.reshape(ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `quadrature_size` for `batch_size` number of times. offset = tf.range(start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids = ids + offset rate = tf.gather(tf.reshape(dist.rate, shape=[-1]), ids) rate = tf.reshape( rate, shape=concat_vectors( [n], self._batch_shape_tensor(distributions=distributions))) return tf.random.poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)
def _entropy(self): samples = tf.convert_to_tensor(self.samples) num_samples = self._compute_num_samples(samples) entropy_shape = self._batch_shape_tensor(samples) # Flatten samples for each batch. if self._event_ndims == 0: samples = tf.reshape(samples, [-1, num_samples]) else: event_size = tf.reduce_prod(self.event_shape_tensor()) samples = tf.reshape(samples, [-1, num_samples, event_size]) # Use map_fn to compute entropy for each batch separately. def _get_entropy(samples): # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count prob = tf.cast(count / num_samples, dtype=self.dtype) entropy = tf.reduce_sum(-prob * tf.math.log(prob)) return entropy entropy = tf.map_fn(_get_entropy, samples, dtype=self.dtype) return tf.reshape(entropy, entropy_shape)
def _sample_n(self, n, seed=None): stream = SeedStream(seed, salt="VectorDiffeomixture") x = self.distribution.sample(sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=stream()) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = tensorshape_util.num_elements(self.batch_shape) if batch_size is None: batch_size = tf.reduce_prod(self.batch_shape_tensor()) mix_batch_size = tensorshape_util.num_elements( self.mixture_distribution.batch_shape) if mix_batch_size is None: mix_batch_size = tf.reduce_prod( self.mixture_distribution.batch_shape_tensor()) ids = self.mixture_distribution.sample(sample_shape=concat_vectors( [n], distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]), [batch_size // mix_batch_size])), seed=stream()) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = tf.reshape(ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `components * quadrature_size` for `batch_size` number of times. stride = tensorshape_util.num_elements( tensorshape_util.with_rank_at_least(self.grid.shape, 2)[-2:]) if stride is None: stride = tf.reduce_prod(tf.shape(self.grid)[-2:]) offset = tf.range(start=0, limit=batch_size * stride, delta=stride, dtype=ids.dtype) weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset) # At this point, weight flattened all batch dims into one. # We also need to append a singleton to broadcast with event dims. if tensorshape_util.is_fully_defined(self.batch_shape): new_shape = [-1] + tensorshape_util.as_list(self.batch_shape) + [1] else: new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]), axis=0) weight = tf.reshape(weight, shape=new_shape) if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError( "Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def _sample_n(self, n, seed=None): if self._use_static_graph: with tf.control_dependencies(self._assertions): # This sampling approach is almost the same as the approach used by # `MixtureSameFamily`. The differences are due to having a list of # `Distribution` objects rather than a single object, and maintaining # random seed management that is consistent with the non-static code # path. samples = [] cat_samples = self.cat.sample(n, seed=seed) stream = SeedStream(seed, salt="Mixture") for c in range(self.num_components): samples.append(self.components[c].sample(n, seed=stream())) stack_axis = -1 - tensorshape_util.rank(self._static_event_shape) x = tf.stack(samples, axis=stack_axis) # [n, B, k, E] npdt = dtype_util.as_numpy_dtype(x.dtype) mask = tf.one_hot( indices=cat_samples, # [n, B] depth=self._num_components, # == k on_value=npdt(1), off_value=npdt(0)) # [n, B, k] mask = distribution_util.pad_mixture_dimensions( mask, self, self._cat, tensorshape_util.rank(self._static_event_shape)) # [n, B, k, [1]*e] return tf.reduce_sum(x * mask, axis=stack_axis) # [n, B, E] with tf.control_dependencies(self._assertions): n = tf.convert_to_tensor(n, name="n") static_n = tf.get_static_value(n) n = int(static_n) if static_n is not None else n cat_samples = self.cat.sample(n, seed=seed) static_samples_shape = cat_samples.shape if tensorshape_util.is_fully_defined(static_samples_shape): samples_shape = tensorshape_util.as_list(static_samples_shape) samples_size = tensorshape_util.num_elements(static_samples_shape) else: samples_shape = tf.shape(cat_samples) samples_size = tf.size(cat_samples) static_batch_shape = self.batch_shape if tensorshape_util.is_fully_defined(static_batch_shape): batch_shape = tensorshape_util.as_list(static_batch_shape) batch_size = tensorshape_util.num_elements(static_batch_shape) else: batch_shape = self.batch_shape_tensor() batch_size = tf.reduce_prod(batch_shape) static_event_shape = self.event_shape if tensorshape_util.is_fully_defined(static_event_shape): event_shape = np.array( tensorshape_util.as_list(static_event_shape), dtype=np.int32) else: event_shape = self.event_shape_tensor() # Get indices into the raw cat sampling tensor. We will # need these to stitch sample values back out after sampling # within the component partitions. samples_raw_indices = tf.reshape(tf.range(0, samples_size), samples_shape) # Partition the raw indices so that we can use # dynamic_stitch later to reconstruct the samples from the # known partitions. partitioned_samples_indices = tf.dynamic_partition( data=samples_raw_indices, partitions=cat_samples, num_partitions=self.num_components) # Copy the batch indices n times, as we will need to know # these to pull out the appropriate rows within the # component partitions. batch_raw_indices = tf.reshape( tf.tile(tf.range(0, batch_size), [n]), samples_shape) # Explanation of the dynamic partitioning below: # batch indices are i.e., [0, 1, 0, 1, 0, 1] # Suppose partitions are: # [1 1 0 0 1 1] # After partitioning, batch indices are cut as: # [batch_indices[x] for x in 2, 3] # [batch_indices[x] for x in 0, 1, 4, 5] # i.e. # [1 1] and [0 0 0 0] # Now we sample n=2 from part 0 and n=4 from part 1. # For part 0 we want samples from batch entries 1, 1 (samples 0, 1), # and for part 1 we want samples from batch entries 0, 0, 0, 0 # (samples 0, 1, 2, 3). partitioned_batch_indices = tf.dynamic_partition( data=batch_raw_indices, partitions=cat_samples, num_partitions=self.num_components) samples_class = [None for _ in range(self.num_components)] stream = SeedStream(seed, salt="Mixture") for c in range(self.num_components): n_class = tf.size(partitioned_samples_indices[c]) samples_class_c = self.components[c].sample( n_class, seed=stream()) # Pull out the correct batch entries from each index. # To do this, we may have to flatten the batch shape. # For sample s, batch element b of component c, we get the # partitioned batch indices from # partitioned_batch_indices[c]; and shift each element by # the sample index. The final lookup can be thought of as # a matrix gather along locations (s, b) in # samples_class_c where the n_class rows correspond to # samples within this component and the batch_size columns # correspond to batch elements within the component. # # Thus the lookup index is # lookup[c, i] = batch_size * s[i] + b[c, i] # for i = 0 ... n_class[c] - 1. lookup_partitioned_batch_indices = ( batch_size * tf.range(n_class) + partitioned_batch_indices[c]) samples_class_c = tf.reshape( samples_class_c, tf.concat([[n_class * batch_size], event_shape], 0)) samples_class_c = tf.gather( samples_class_c, lookup_partitioned_batch_indices, name="samples_class_c_gather") samples_class[c] = samples_class_c # Stitch back together the samples across the components. lhs_flat_ret = tf.dynamic_stitch( indices=partitioned_samples_indices, data=samples_class) # Reshape back to proper sample, batch, and event shape. ret = tf.reshape( lhs_flat_ret, tf.concat( [samples_shape, self.event_shape_tensor()], 0)) tensorshape_util.set_shape( ret, tensorshape_util.concatenate(static_samples_shape, self.event_shape)) return ret
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None): """The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_reconstruct'). Returns: x: The original input to `tf.linalg.lu`, i.e., `x` as in, `lu_reconstruct(*tf.linalg.lu(x))`. #### Examples ```python import numpy as np from tensorflow_probability.python.internal.backend import jax as tf import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax x = [[[3., 4], [1, 2]], [[7., 8], [3, 4]]] x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x)) tf.assert_near(x, x_reconstructed) # ==> True ``` """ with tf.name_scope(name or 'lu_reconstruct'): lower_upper = tf.convert_to_tensor(lower_upper, dtype_hint=tf.float32, name='lower_upper') perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm') assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) if assertions: with tf.control_dependencies(assertions): lower_upper = tf.identity(lower_upper) perm = tf.identity(perm) shape = tf.shape(lower_upper) lower = tf.linalg.set_diag( tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0), tf.ones(shape[:-1], dtype=lower_upper.dtype)) upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1) x = tf.matmul(lower, upper) if (tensorshape_util.rank(lower_upper.shape) is None or tensorshape_util.rank(lower_upper.shape) != 2): # We either don't know the batch rank or there are >0 batch dims. batch_size = tf.reduce_prod(shape[:-2]) d = shape[-1] x = tf.reshape(x, [batch_size, d, d]) perm = tf.reshape(perm, [batch_size, d]) perm = tf.map_fn(tf.math.invert_permutation, perm) batch_indices = tf.broadcast_to( tf.range(batch_size)[:, tf.newaxis], [batch_size, d]) x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1)) x = tf.reshape(x, shape) else: x = tf.gather(x, tf.math.invert_permutation(perm)) x.set_shape(lower_upper.shape) return x
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None): """Solves systems of linear eqns `A X = RHS`, given LU factorizations. Note: this function does not verify the implied matrix is actually invertible nor is this condition checked even when `validate_args=True`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. rhs: Matrix-shaped float `Tensor` representing targets for which to solve; `A X = RHS`. To handle vector cases, use: `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Note: this function does not verify the implied matrix is actually invertible, even when `validate_args=True`. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_solve'). Returns: x: The `X` in `A @ X = RHS`. #### Examples ```python import numpy as np from tensorflow_probability.python.internal.backend import jax as tf import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax x = [[[1., 2], [3, 4]], [[7, 8], [3, 4]]] inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2)) tf.assert_near(tf.matrix_inverse(x), inv_x) # ==> True ``` """ with tf.name_scope(name or 'lu_solve'): lower_upper = tf.convert_to_tensor(lower_upper, dtype_hint=tf.float32, name='lower_upper') perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm') rhs = tf.convert_to_tensor(rhs, dtype_hint=lower_upper.dtype, name='rhs') assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args) if assertions: with tf.control_dependencies(assertions): lower_upper = tf.identity(lower_upper) perm = tf.identity(perm) rhs = tf.identity(rhs) if (tensorshape_util.rank(rhs.shape) == 2 and tensorshape_util.rank(perm.shape) == 1): # Both rhs and perm have scalar batch_shape. permuted_rhs = tf.gather(rhs, perm, axis=-2) else: # Either rhs or perm have non-scalar batch_shape or we can't determine # this information statically. rhs_shape = tf.shape(rhs) broadcast_batch_shape = tf.broadcast_dynamic_shape( rhs_shape[:-2], tf.shape(perm)[:-1]) d, m = rhs_shape[-2], rhs_shape[-1] rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]], axis=0) # Tile out rhs. broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape) broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m]) # Tile out perm and add batch indices. broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1]) broadcast_perm = tf.reshape(broadcast_perm, [-1, d]) broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape) broadcast_batch_indices = tf.broadcast_to( tf.range(broadcast_batch_size)[:, tf.newaxis], [broadcast_batch_size, d]) broadcast_perm = tf.stack( [broadcast_batch_indices, broadcast_perm], axis=-1) permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm) permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape) lower = tf.linalg.set_diag( tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0), tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype)) return linear_operator_util.matrix_triangular_solve_with_broadcast( lower_upper, # Only upper is accessed. linear_operator_util.matrix_triangular_solve_with_broadcast( lower, permuted_rhs), lower=False)
def _sample_n(self, n, seed=None): sample_and_batch_shape = tf.concat([[n], self.batch_shape_tensor()], 0) flat_batch_and_sample_shape = tf.stack( [tf.reduce_prod(self.batch_shape_tensor()), n]) # In order to be reparameterizable we sample on the truncated_normal of # unit variance and mean and scale (but with the standardized # truncation bounds). @tf.custom_gradient def _std_samples_with_gradients(lower, upper): """Standard truncated Normal with gradient support for low, high.""" # Note: Unlike the convention in tf_probability, # parameterized_truncated_normal returns a tensor with the final dimension # being the sample dimension. std_samples = random_ops.parameterized_truncated_normal( shape=flat_batch_and_sample_shape, means=0.0, stddevs=1.0, minvals=lower, maxvals=upper, dtype=self.dtype, seed=seed) def grad(dy): """Computes a derivative for the min and max parameters. This function implements the derivative wrt the truncation bounds, which get blocked by the sampler. We use a custom expression for numerical stability instead of automatic differentiation on CDF for implicit gradients. Args: dy: output gradients Returns: The standard normal samples and the gradients wrt the upper bound and lower bound. """ # std_samples has an extra dimension (the sample dimension), expand # lower and upper so they broadcast along this dimension. # See note above regarding parameterized_truncated_normal, the sample # dimension is the final dimension. lower_broadcast = lower[..., tf.newaxis] upper_broadcast = upper[..., tf.newaxis] cdf_samples = ((special_math.ndtr(std_samples) - special_math.ndtr(lower_broadcast)) / (special_math.ndtr(upper_broadcast) - special_math.ndtr(lower_broadcast))) # tiny, eps are tolerance parameters to ensure we stay away from giving # a zero arg to the log CDF expression. tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps) du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) + tf.math.log(cdf_samples)) dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) + tf.math.log1p(-cdf_samples)) # Reduce the gradient across the samples grad_u = tf.reduce_sum(dy * du, axis=-1) grad_l = tf.reduce_sum(dy * dl, axis=-1) return [grad_l, grad_u] return std_samples, grad std_samples = _std_samples_with_gradients( tf.reshape(self._standardized_low, [-1]), tf.reshape(self._standardized_high, [-1])) # The returned shape is [flat_batch x n] std_samples = tf.transpose(a=std_samples, perm=[1, 0]) std_samples = tf.reshape(std_samples, sample_and_batch_shape) samples = (std_samples * tf.expand_dims(self._scale, axis=0) + tf.expand_dims(self._loc, axis=0)) return samples
def _log_prob(self, x): if self.input_output_cholesky: x_sqrt = x else: # Complexity: O(nbk**3) x_sqrt = tf.linalg.cholesky(x) batch_shape = self.batch_shape_tensor() event_shape = self.event_shape_tensor() x_ndims = tf.rank(x_sqrt) num_singleton_axes_to_prepend = ( tf.maximum(tf.size(batch_shape) + 2, x_ndims) - x_ndims) x_with_prepended_singletons_shape = tf.concat([ tf.ones([num_singleton_axes_to_prepend], dtype=tf.int32), tf.shape(x_sqrt) ], 0) x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape) ndims = tf.rank(x_sqrt) # sample_ndims = ndims - batch_ndims - event_ndims sample_ndims = ndims - tf.size(batch_shape) - 2 sample_shape = tf.shape(x_sqrt)[:sample_ndims] # We need to be able to pre-multiply each matrix by its corresponding # batch scale matrix. Since a Distribution Tensor supports multiple # samples per batch, this means we need to reshape the input matrix `x` # so that the first b dimensions are batch dimensions and the last two # are of shape [dimension, dimensions*number_of_samples]. Doing these # gymnastics allows us to do a batch_solve. # # After we're done with sqrt_solve (the batch operation) we need to undo # this reshaping so what we're left with is a Tensor partitionable by # sample, batch, event dimensions. # Complexity: O(nbk**2) since transpose must access every element. scale_sqrt_inv_x_sqrt = x_sqrt perm = tf.concat( [tf.range(sample_ndims, ndims), tf.range(0, sample_ndims)], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) last_dim_size = ( tf.cast(self.dimension, dtype=tf.int32) * tf.reduce_prod(x_with_prepended_singletons_shape[:sample_ndims])) shape = tf.concat([ x_with_prepended_singletons_shape[sample_ndims:-2], [tf.cast(self.dimension, dtype=tf.int32), last_dim_size] ], axis=0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) # Complexity: O(nbM*k) where M is the complexity of the operator solving a # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so # this step has complexity O(nbk^3). scale_sqrt_inv_x_sqrt = self.scale_operator.solve( scale_sqrt_inv_x_sqrt) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = tf.concat( [tf.shape(scale_sqrt_inv_x_sqrt)[:-2], event_shape, sample_shape], axis=0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) perm = tf.concat([ tf.range(ndims - sample_ndims, ndims), tf.range(0, ndims - sample_ndims) ], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) # Write V = SS', X = LL'. Then: # tr[inv(V) X] = tr[inv(S)' inv(S) L L'] # = tr[inv(S) L L' inv(S)'] # = tr[(inv(S) L) (inv(S) L)'] # = sum_{ik} (inv(S) L)_{ik}**2 # The second equality follows from the cyclic permutation property. # Complexity: O(nbk**2) trace_scale_inv_x = tf.reduce_sum(tf.square(scale_sqrt_inv_x_sqrt), axis=[-2, -1]) # Complexity: O(nbk) half_log_det_x = tf.reduce_sum(tf.math.log( tf.linalg.diag_part(x_sqrt)), axis=[-1]) # Complexity: O(nbk**2) log_prob = ((self.df - self.dimension - 1.) * half_log_det_x - 0.5 * trace_scale_inv_x - self.log_normalization()) # Set shape hints. # Try to merge what we know from the input x with what we know from the # parameters of this distribution. if tensorshape_util.rank( x.shape) is not None and tensorshape_util.rank( self.batch_shape) is not None: tensorshape_util.set_shape( log_prob, tf.broadcast_static_shape(x.shape[:-2], self.batch_shape)) return log_prob
def __init__(self, mixture_distribution, components_distribution, reparameterize=False, validate_args=False, allow_nan_stats=True, name="MixtureSameFamily"): """Construct a `MixtureSameFamily` distribution. Args: mixture_distribution: `tfp.distributions.Categorical`-like instance. Manages the probability of selecting components. The number of categories must match the rightmost batch dimension of the `components_distribution`. Must have either scalar `batch_shape` or `batch_shape` matching `components_distribution.batch_shape[:-1]`. components_distribution: `tfp.distributions.Distribution`-like instance. Right-most batch dimension indexes components. reparameterize: Python `bool`, default `False`. Whether to reparameterize samples of the distribution using implicit reparameterization gradients [(Figurnov et al., 2018)][1]. The gradients for the mixture logits are equivalent to the ones described by [(Graves, 2016)][2]. The gradients for the components parameters are also computed using implicit reparameterization (as opposed to ancestral sampling), meaning that all components are updated every step. Only works when: (1) components_distribution is fully reparameterized; (2) components_distribution is either a scalar distribution or fully factorized (tfd.Independent applied to a scalar distribution); (3) batch shape has a known rank. Experimental, may be slow and produce infs/NaNs. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: `if not dtype_util.is_integer(mixture_distribution.dtype)`. ValueError: if mixture_distribution does not have scalar `event_shape`. ValueError: if `mixture_distribution.batch_shape` and `components_distribution.batch_shape[:-1]` are both fully defined and the former is neither scalar nor equal to the latter. ValueError: if `mixture_distribution` categories does not equal `components_distribution` rightmost batch shape. #### References [1]: Michael Figurnov, Shakir Mohamed and Andriy Mnih. Implicit reparameterization gradients. In _Neural Information Processing Systems_, 2018. https://arxiv.org/abs/1805.08498 [2]: Alex Graves. Stochastic Backpropagation through Mixture Density Distributions. _arXiv_, 2016. https://arxiv.org/abs/1607.05690 """ parameters = dict(locals()) with tf.name_scope(name) as name: self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution self._runtime_assertions = [] s = components_distribution.event_shape_tensor() self._event_ndims = tf.compat.dimension_value(s.shape[0]) if self._event_ndims is None: self._event_ndims = tf.size(s) self._event_size = tf.reduce_prod(s) if not dtype_util.is_integer(mixture_distribution.dtype): raise ValueError( "`mixture_distribution.dtype` ({}) is not over integers". format(dtype_util.name(mixture_distribution.dtype))) if (tensorshape_util.rank(mixture_distribution.event_shape) is not None and tensorshape_util.rank( mixture_distribution.event_shape) != 0): raise ValueError( "`mixture_distribution` must have scalar `event_dim`s") elif validate_args: self._runtime_assertions += [ assert_util.assert_equal( tf.size(mixture_distribution.event_shape_tensor()), 0, message= "`mixture_distribution` must have scalar `event_dim`s" ), ] mdbs = mixture_distribution.batch_shape cdbs = tensorshape_util.with_rank_at_least( components_distribution.batch_shape, 1)[:-1] if tensorshape_util.is_fully_defined( mdbs) and tensorshape_util.is_fully_defined(cdbs): if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs: raise ValueError( "`mixture_distribution.batch_shape` (`{}`) is not " "compatible with `components_distribution.batch_shape` " "(`{}`)".format(tensorshape_util.as_list(mdbs), tensorshape_util.as_list(cdbs))) elif validate_args: mdbs = mixture_distribution.batch_shape_tensor() cdbs = components_distribution.batch_shape_tensor()[:-1] self._runtime_assertions += [ assert_util.assert_equal( distribution_utils.pick_vector( mixture_distribution.is_scalar_batch(), cdbs, mdbs), cdbs, message= ("`mixture_distribution.batch_shape` is not " "compatible with `components_distribution.batch_shape`" )) ] mixture_dist_param = (mixture_distribution.probs if mixture_distribution.logits is None else mixture_distribution.logits) km = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(mixture_dist_param.shape, 1)[-1]) kc = tf.compat.dimension_value( tensorshape_util.with_rank_at_least( components_distribution.batch_shape, 1)[-1]) if km is not None and kc is not None and km != kc: raise ValueError( "`mixture_distribution components` ({}) does not " "equal `components_distribution.batch_shape[-1]` " "({})".format(km, kc)) elif validate_args: km = tf.shape(mixture_dist_param)[-1] kc = components_distribution.batch_shape_tensor()[-1] self._runtime_assertions += [ assert_util.assert_equal( km, kc, message=( "`mixture_distribution components` does not equal " "`components_distribution.batch_shape[-1:]`")), ] elif km is None: km = tf.shape(mixture_dist_param)[-1] self._num_components = km self._reparameterize = reparameterize if reparameterize: # Note: tfd.Independent passes through the reparameterization type hence # we do not need separate logic for Independent. if (self._components_distribution.reparameterization_type != reparameterization.FULLY_REPARAMETERIZED): raise ValueError("Cannot reparameterize a mixture of " "non-reparameterized components.") reparameterization_type = reparameterization.FULLY_REPARAMETERIZED else: reparameterization_type = reparameterization.NOT_REPARAMETERIZED super(MixtureSameFamily, self).__init__( dtype=self._components_distribution.dtype, reparameterization_type=reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def _sample_n(self, n, seed=None): with tf.control_dependencies(self._runtime_assertions): strm = SeedStream(seed, salt="HiddenMarkovModel") num_states = self._num_states batch_shape = self.batch_shape_tensor() batch_size = tf.reduce_prod(batch_shape) # The batch sizes of the underlying initial distributions and # transition distributions might not match the batch size of # the HMM distribution. # As a result we need to ask for more samples from the # underlying distributions and then reshape the results into # the correct batch size for the HMM. init_repeat = ( tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod( self._initial_distribution.batch_shape_tensor())) init_state = self._initial_distribution.sample(n * init_repeat, seed=strm()) init_state = tf.reshape(init_state, [n, batch_size]) # init_state :: n batch_size transition_repeat = ( tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod( self._transition_distribution.batch_shape_tensor()[:-1])) def generate_step(state, _): """Take a single step in Markov chain.""" gen = self._transition_distribution.sample(n * transition_repeat, seed=strm()) # gen :: (n * transition_repeat) transition_batch new_states = tf.reshape(gen, [n, batch_size, num_states]) # new_states :: n batch_size num_states old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32) # old_states :: n batch_size num_states return tf.reduce_sum(old_states_one_hot * new_states, axis=-1) def _scan_multiple_steps(): """Take multiple steps with tf.scan.""" dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) if seed is not None: # Force parallel_iterations to 1 to ensure reproducibility # b/139210489 hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state, parallel_iterations=1) else: # Invoke default parallel_iterations behavior hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state) # TODO(b/115618503): add/use prepend_initializer to tf.scan return tf.concat([[init_state], hidden_states], axis=0) hidden_states = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps, lambda: init_state[tf.newaxis, ...]) hidden_one_hot = tf.one_hot( hidden_states, num_states, dtype=self._observation_distribution.dtype) # hidden_one_hot :: num_steps n batch_size num_states # The observation distribution batch size might not match # the required batch size so as with the initial and # transition distributions we generate more samples and # reshape. observation_repeat = (batch_size // tf.reduce_prod( self._observation_distribution.batch_shape_tensor()[:-1])) possible_observations = self._observation_distribution.sample( [self._num_steps, observation_repeat * n], seed=strm()) inner_shape = self._observation_distribution.event_shape # possible_observations :: num_steps (observation_repeat * n) # observation_batch[:-1] num_states inner_shape possible_observations = tf.reshape( possible_observations, tf.concat([[self._num_steps, n], batch_shape, [num_states], inner_shape], axis=0)) # possible_observations :: steps n batch_size num_states inner_shape hidden_one_hot = tf.reshape( hidden_one_hot, tf.concat([[self._num_steps, n], batch_shape, [num_states], tf.ones_like(inner_shape)], axis=0)) # hidden_one_hot :: steps n batch_size num_states "inner_shape" observations = tf.reduce_sum(hidden_one_hot * possible_observations, axis=-1 - tf.size(inner_shape)) # observations :: steps n batch_size inner_shape observations = distribution_util.move_dimension( observations, 0, 1 + tf.size(batch_shape)) # returned :: n batch_shape steps inner_shape return observations