def _log_prob(self, x, **kwargs): batch_ndims = prefer_static.rank_from_shape( self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_ndims = prefer_static.rank_from_shape(self.sample_shape) event_ndims = prefer_static.rank_from_shape( self.distribution.event_shape_tensor, self.distribution.event_shape) ndims = prefer_static.rank(x) # (1) Expand x's dims. d = ndims - batch_ndims - extra_sample_ndims - event_ndims x = tf.reshape(x, shape=tf.pad( tf.shape(x), paddings=[[prefer_static.maximum(0, -d), 0]], constant_values=1)) sample_ndims = prefer_static.maximum(0, d) # (2) Transpose x's dims. sample_dims = prefer_static.range(0, sample_ndims) batch_dims = prefer_static.range(sample_ndims, sample_ndims + batch_ndims) extra_sample_dims = prefer_static.range( sample_ndims + batch_ndims, sample_ndims + batch_ndims + extra_sample_ndims) event_dims = prefer_static.range( sample_ndims + batch_ndims + extra_sample_ndims, ndims) perm = prefer_static.concat( [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0) x = tf.transpose(a=x, perm=perm) # (3) Compute x's log_prob. lp = self.distribution.log_prob(x, **kwargs) # (4) Make the final reduction in x. axis = prefer_static.range(sample_ndims, sample_ndims + extra_sample_ndims) return tf.reduce_sum(lp, axis=axis)
def _maybe_rotate_dims(self, x, rotate_right=False): """Helper which rolls left event_dims left or right event_dims right.""" needs_rotation_const = tf.get_static_value(self._needs_rotation) if needs_rotation_const is not None and not needs_rotation_const: return x ndims = prefer_static.rank(x) n = (ndims - self._rotate_ndims) if rotate_right else self._rotate_ndims perm = prefer_static.concat( [prefer_static.range(n, ndims), prefer_static.range(0, n)], axis=0) return tf.transpose(a=x, perm=perm)
def _sample_n(self, n, seed=None): samples = tf.convert_to_tensor(self._samples) indices = tf.random.uniform([n], maxval=self._compute_num_samples(samples), dtype=tf.int32, seed=seed) draws = tf.gather(samples, indices, axis=self._samples_axis) axes = tf.concat( [[self._samples_axis], tf.range(self._samples_axis, dtype=tf.int32), tf.range(self._event_ndims, dtype=tf.int32) + self._samples_axis + 1], axis=0) draws = tf.transpose(a=draws, perm=axes) return draws
def _sample_n(self, n, seed=None): logits = self._logits_parameter_no_checks() logits_2d = tf.reshape(logits, [-1, self._num_categories(logits)]) sample_dtype = tf.int64 if dtype_util.size( self.dtype) > 4 else tf.int32 draws = tf.random.categorical(logits_2d, n, dtype=sample_dtype, seed=seed) draws = tf.cast(draws, self.dtype) return tf.reshape(tf.transpose(draws), shape=tf.concat( [[n], self._batch_shape_tensor(logits)], axis=0))
def _sample_n(self, n, seed=None): logits = self._logits_parameter_no_checks() sample_shape = prefer_static.concat( [[n], prefer_static.shape(logits)], 0) event_size = self._event_size(logits) if tensorshape_util.rank(logits.shape) == 2: logits_2d = logits else: logits_2d = tf.reshape(logits, [-1, event_size]) samples = tf.random.categorical(logits_2d, n, seed=seed) samples = tf.transpose(a=samples) samples = tf.one_hot(samples, event_size, dtype=self.dtype) ret = tf.reshape(samples, sample_shape) return ret
def _compute_quantiles(): """Helper to build quantiles.""" # Omit {0, 1} since they might lead to Inf/NaN. zero = tf.zeros([], dtype=dist.dtype) edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1] # Expand edges so its broadcast across batch dims. edges = tf.reshape( edges, shape=tf.concat( [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0)) quantiles = dist.quantile(edges) # Cyclically permute left by one. perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0) quantiles = tf.transpose(a=quantiles, perm=perm) return quantiles
def _sample_n(self, n, seed, **kwargs): fake_sample_ndims = prefer_static.rank_from_shape(self.sample_shape) event_ndims = prefer_static.rank_from_shape( self.distribution.event_shape_tensor, self.distribution.event_shape) batch_ndims = prefer_static.rank_from_shape( self.distribution.batch_shape_tensor, self.distribution.batch_shape) perm = prefer_static.concat([ [0], prefer_static.range(1 + fake_sample_ndims, 1 + fake_sample_ndims + batch_ndims), prefer_static.range(1, 1 + fake_sample_ndims), prefer_static.range( 1 + fake_sample_ndims + batch_ndims, 1 + fake_sample_ndims + batch_ndims + event_ndims), ], axis=0) x = self.distribution.sample(prefer_static.concat( [[n], self.sample_shape], axis=0), seed=seed, **kwargs) return tf.transpose(a=x, perm=perm)
def _transpose(self, x, perm): perm = self._make_perm(tf.rank(x), perm) return tf.transpose(a=x, perm=perm)
def _sample_n(self, n, seed=None): sample_and_batch_shape = tf.concat([[n], self.batch_shape_tensor()], 0) flat_batch_and_sample_shape = tf.stack( [tf.reduce_prod(self.batch_shape_tensor()), n]) # In order to be reparameterizable we sample on the truncated_normal of # unit variance and mean and scale (but with the standardized # truncation bounds). @tf.custom_gradient def _std_samples_with_gradients(lower, upper): """Standard truncated Normal with gradient support for low, high.""" # Note: Unlike the convention in tf_probability, # parameterized_truncated_normal returns a tensor with the final dimension # being the sample dimension. std_samples = random_ops.parameterized_truncated_normal( shape=flat_batch_and_sample_shape, means=0.0, stddevs=1.0, minvals=lower, maxvals=upper, dtype=self.dtype, seed=seed) def grad(dy): """Computes a derivative for the min and max parameters. This function implements the derivative wrt the truncation bounds, which get blocked by the sampler. We use a custom expression for numerical stability instead of automatic differentiation on CDF for implicit gradients. Args: dy: output gradients Returns: The standard normal samples and the gradients wrt the upper bound and lower bound. """ # std_samples has an extra dimension (the sample dimension), expand # lower and upper so they broadcast along this dimension. # See note above regarding parameterized_truncated_normal, the sample # dimension is the final dimension. lower_broadcast = lower[..., tf.newaxis] upper_broadcast = upper[..., tf.newaxis] cdf_samples = ((special_math.ndtr(std_samples) - special_math.ndtr(lower_broadcast)) / (special_math.ndtr(upper_broadcast) - special_math.ndtr(lower_broadcast))) # tiny, eps are tolerance parameters to ensure we stay away from giving # a zero arg to the log CDF expression. tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps) du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) + tf.math.log(cdf_samples)) dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) + tf.math.log1p(-cdf_samples)) # Reduce the gradient across the samples grad_u = tf.reduce_sum(dy * du, axis=-1) grad_l = tf.reduce_sum(dy * dl, axis=-1) return [grad_l, grad_u] return std_samples, grad std_samples = _std_samples_with_gradients( tf.reshape(self._standardized_low, [-1]), tf.reshape(self._standardized_high, [-1])) # The returned shape is [flat_batch x n] std_samples = tf.transpose(a=std_samples, perm=[1, 0]) std_samples = tf.reshape(std_samples, sample_and_batch_shape) samples = (std_samples * tf.expand_dims(self._scale, axis=0) + tf.expand_dims(self._loc, axis=0)) return samples
def _log_prob(self, x): if self.input_output_cholesky: x_sqrt = x else: # Complexity: O(nbk**3) x_sqrt = tf.linalg.cholesky(x) batch_shape = self.batch_shape_tensor() event_shape = self.event_shape_tensor() x_ndims = tf.rank(x_sqrt) num_singleton_axes_to_prepend = ( tf.maximum(tf.size(batch_shape) + 2, x_ndims) - x_ndims) x_with_prepended_singletons_shape = tf.concat([ tf.ones([num_singleton_axes_to_prepend], dtype=tf.int32), tf.shape(x_sqrt) ], 0) x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape) ndims = tf.rank(x_sqrt) # sample_ndims = ndims - batch_ndims - event_ndims sample_ndims = ndims - tf.size(batch_shape) - 2 sample_shape = tf.shape(x_sqrt)[:sample_ndims] # We need to be able to pre-multiply each matrix by its corresponding # batch scale matrix. Since a Distribution Tensor supports multiple # samples per batch, this means we need to reshape the input matrix `x` # so that the first b dimensions are batch dimensions and the last two # are of shape [dimension, dimensions*number_of_samples]. Doing these # gymnastics allows us to do a batch_solve. # # After we're done with sqrt_solve (the batch operation) we need to undo # this reshaping so what we're left with is a Tensor partitionable by # sample, batch, event dimensions. # Complexity: O(nbk**2) since transpose must access every element. scale_sqrt_inv_x_sqrt = x_sqrt perm = tf.concat( [tf.range(sample_ndims, ndims), tf.range(0, sample_ndims)], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) last_dim_size = ( tf.cast(self.dimension, dtype=tf.int32) * tf.reduce_prod(x_with_prepended_singletons_shape[:sample_ndims])) shape = tf.concat([ x_with_prepended_singletons_shape[sample_ndims:-2], [tf.cast(self.dimension, dtype=tf.int32), last_dim_size] ], axis=0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) # Complexity: O(nbM*k) where M is the complexity of the operator solving a # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so # this step has complexity O(nbk^3). scale_sqrt_inv_x_sqrt = self.scale_operator.solve( scale_sqrt_inv_x_sqrt) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = tf.concat( [tf.shape(scale_sqrt_inv_x_sqrt)[:-2], event_shape, sample_shape], axis=0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) perm = tf.concat([ tf.range(ndims - sample_ndims, ndims), tf.range(0, ndims - sample_ndims) ], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) # Write V = SS', X = LL'. Then: # tr[inv(V) X] = tr[inv(S)' inv(S) L L'] # = tr[inv(S) L L' inv(S)'] # = tr[(inv(S) L) (inv(S) L)'] # = sum_{ik} (inv(S) L)_{ik}**2 # The second equality follows from the cyclic permutation property. # Complexity: O(nbk**2) trace_scale_inv_x = tf.reduce_sum(tf.square(scale_sqrt_inv_x_sqrt), axis=[-2, -1]) # Complexity: O(nbk) half_log_det_x = tf.reduce_sum(tf.math.log( tf.linalg.diag_part(x_sqrt)), axis=[-1]) # Complexity: O(nbk**2) log_prob = ((self.df - self.dimension - 1.) * half_log_det_x - 0.5 * trace_scale_inv_x - self.log_normalization()) # Set shape hints. # Try to merge what we know from the input x with what we know from the # parameters of this distribution. if tensorshape_util.rank( x.shape) is not None and tensorshape_util.rank( self.batch_shape) is not None: tensorshape_util.set_shape( log_prob, tf.broadcast_static_shape(x.shape[:-2], self.batch_shape)) return log_prob
def _sample_n(self, n, seed): batch_shape = self.batch_shape_tensor() event_shape = self.event_shape_tensor() batch_ndims = tf.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = tf.concat([[n], batch_shape, event_shape], 0) stream = SeedStream(seed, salt="Wishart") # Complexity: O(nbk**2) x = tf.random.normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=stream()) # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) expanded_df = self.df * tf.ones( self.scale_operator.batch_shape_tensor(), dtype=dtype_util.base_dtype(self.df.dtype)) g = tf.random.gamma(shape=[n], alpha=self._multi_gamma_sequence( 0.5 * expanded_df, self.dimension), beta=0.5, dtype=self.dtype, seed=stream()) # Complexity: O(nbk**2) x = tf.linalg.band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = tf.linalg.set_diag(x, tf.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk**2) perm = tf.concat([tf.range(1, ndims), [0]], 0) x = tf.transpose(a=x, perm=perm) shape = tf.concat( [batch_shape, [event_shape[0]], [event_shape[1] * n]], 0) x = tf.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so # this step has complexity O(nbk^3). x = self.scale_operator.matmul(x) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = tf.concat([batch_shape, event_shape, [n]], 0) x = tf.reshape(x, shape) perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0) x = tf.transpose(a=x, perm=perm) if not self.input_output_cholesky: # Complexity: O(nbk**3) x = tf.matmul(x, x, adjoint_b=True) return x
def draw_sample(num_samples, num_classes, logits, num_trials, dtype, seed): """Sample a multinomial. The batch shape is given by broadcasting num_trials with remove_last_dimension(logits). Args: num_samples: Python int or singleton integer Tensor: number of multinomial samples to draw. num_classes: Python int or singleton integer Tensor: number of classes. logits: Floating Tensor with last dimension k, of (unnormalized) logit probabilities per class. num_trials: Tensor of number of categorical trials each multinomial consists of. num_trials[..., tf.newaxis] must broadcast with logits. dtype: dtype at which to emit samples. seed: Random seed. Returns: samples: Tensor of given dtype and shape [n] + batch_shape + [k]. """ with tf.name_scope('draw_sample'): # broadcast the num_trials and logits to same shape num_trials = tf.ones_like(logits[..., 0], dtype=num_trials.dtype) * num_trials logits = tf.ones_like(num_trials[..., tf.newaxis], dtype=logits.dtype) * logits # flatten the total_count and logits # flat_logits has shape [B1B2...Bm, num_classes] flat_logits = tf.reshape(logits, [-1, num_classes]) flat_num_trials = num_samples * tf.reshape(num_trials, [-1]) # [B1B2...Bm] # Computes each logits and num_trials situation by map_fn. # Using just one batch tf.random.categorical call doesn't work because that # requires num_trials to be the same across all members of the batch of # logits. This restriction makes sense for tf.random.categorical because # for it, num_trials is part of the returned shape. However, the # multinomial sampler does not need that restriction, because it sums out # exactly that dimension. # One possibility would be to draw a batch categorical whose sample count is # max(num_trials) and mask out the excess ones. However, if the elements of # num_trials vary widely, this can be wasteful of memory. # TODO(b/123763054, b/112152209): Revisit the possibility of writing this # with a batch categorical followed by batch unsorted_segment_sum, once both # of those work and are memory-efficient enough. def _sample_one_batch_member(args): logits, num_cat_samples = args[0], args[1] # [K], [] # x has shape [1, num_cat_samples = num_samples * num_trials] x = tf.random.categorical(logits[tf.newaxis, ...], num_cat_samples, seed=seed) x = tf.reshape(x, shape=[num_samples, -1]) # [num_samples, num_trials] x = tf.one_hot( x, depth=num_classes) # [num_samples, num_trials, num_classes] x = tf.reduce_sum(x, axis=-2) # [num_samples, num_classes] return tf.cast(x, dtype=dtype) if seed is not None: # Force parallel_iterations to 1 to ensure reproducibility # b/139210489 x = tf.map_fn( _sample_one_batch_member, [flat_logits, flat_num_trials], dtype=dtype, # [B1B2...Bm, num_samples, num_classes] parallel_iterations=1) else: # Invoke default parallel_iterations behavior x = tf.map_fn(_sample_one_batch_member, [flat_logits, flat_num_trials], dtype=dtype) # [B1B2...Bm, num_samples, num_classes] # reshape the results to proper shape x = tf.transpose(a=x, perm=[1, 0, 2]) final_shape = tf.concat( [[num_samples], tf.shape(num_trials), [num_classes]], axis=0) x = tf.reshape(x, final_shape) return x