def _inverse(self, y): map_values = tf.convert_to_tensor(self.map_values) flat_y = tf.reshape(y, shape=[-1]) # Search for the indices of map_values that are closest to flat_y. # Since map_values is strictly increasing, the closest is either the # first one that is strictly greater than flat_y, or the one before it. upper_candidates = tf.minimum( tf.size(map_values) - 1, tf.searchsorted(map_values, values=flat_y, side='right')) lower_candidates = tf.maximum(0, upper_candidates - 1) candidates = tf.stack([lower_candidates, upper_candidates], axis=-1) lower_cand_diff = tf.abs(flat_y - self._forward(lower_candidates)) upper_cand_diff = tf.abs(flat_y - self._forward(upper_candidates)) if self.validate_args: with tf.control_dependencies([ assert_util.assert_near(tf.minimum(lower_cand_diff, upper_cand_diff), 0, message='inverse value not found') ]): candidates = tf.identity(candidates) candidate_selector = tf.stack([ tf.range(tf.size(flat_y), dtype=tf.int32), tf.argmin([lower_cand_diff, upper_cand_diff], output_type=tf.int32) ], axis=-1) return tf.reshape(tf.gather_nd(candidates, candidate_selector), shape=y.shape)
def _cdf(self, k): # TODO(b/135263541): Improve numerical precision of categorical.cdf. probs = self.probs_parameter() num_categories = self._num_categories(probs) k, probs = _broadcast_cat_event_and_params( k, probs, base_dtype=dtype_util.base_dtype(self.dtype)) # Since the lowest number in the support is 0, any k < 0 should be zero in # the output. should_be_zero = k < 0 # Will use k as an index in the gather below, so clip it to {0,...,K-1}. k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1) batch_shape = tf.shape(k) # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so # to handle the case where the batch shape is dynamic, flatten the batch # dims (so we know batch_dims=1). k_flat_batch = tf.reshape(k, [-1]) probs_flat_batch = tf.reshape( probs, tf.concat(([-1], [num_categories]), axis=0)) cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1), k_flat_batch[..., tf.newaxis], batch_dims=1) cdf = tf.reshape(cdf_flat, shape=batch_shape) zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype)) return tf.where(should_be_zero, zero, cdf)
def _mode(self, samples=None): # Samples count can vary by batch member. Use map_fn to compute mode for # each batch separately. def _get_mode(samples): # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count return tf.argmax(count) if samples is None: samples = tf.convert_to_tensor(self._samples) num_samples = self._compute_num_samples(samples) # Flatten samples for each batch. if self._event_ndims == 0: flattened_samples = tf.reshape(samples, [-1, num_samples]) mode_shape = self._batch_shape_tensor(samples) else: event_size = tf.reduce_prod(self._event_shape_tensor(samples)) mode_shape = tf.concat([ self._batch_shape_tensor(samples), self._event_shape_tensor(samples) ], axis=0) flattened_samples = tf.reshape(samples, [-1, num_samples, event_size]) indices = tf.map_fn(_get_mode, flattened_samples, dtype=tf.int64) full_indices = tf.stack( [tf.range(tf.shape(indices)[0]), tf.cast(indices, tf.int32)], axis=1) mode = tf.gather_nd(flattened_samples, full_indices) return tf.reshape(mode, mode_shape)
def _call_reshape_input_output(self, fn, x, extra_kwargs=None): """Calls `fn`, appropriately reshaping its input `x` and output.""" # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs` # because it is possible the user provided extra kwargs would itself # have `fn` and/or `x` as a key. with tf.control_dependencies(self._runtime_assertions + self._validate_sample_arg(x)): sample_shape, static_sample_shape = self._sample_shape(x) old_shape = tf.concat([ sample_shape, self.distribution.batch_shape_tensor(), self.event_shape_tensor(), ], axis=0) x_reshape = tf.reshape(x, old_shape) result = fn(x_reshape, ** extra_kwargs) if extra_kwargs else fn(x_reshape) new_shape = tf.concat([ sample_shape, self._batch_shape_unexpanded, ], axis=0) result = tf.reshape(result, new_shape) if (tensorshape_util.rank(static_sample_shape) is not None and tensorshape_util.rank(self.batch_shape) is not None): new_shape = tensorshape_util.concatenate( static_sample_shape, self.batch_shape) tensorshape_util.set_shape(result, new_shape) return result
def _sparse_tensor_dense_matmul(sp_a, b, **kwargs): """Returns (batched) matmul of a SparseTensor with a Tensor. Args: sp_a: `SparseTensor` representing a (batch of) matrices. b: `Tensor` representing a (batch of) matrices, with the same batch shape of `sp_a`. The shape must be compatible with the shape of `sp_a` and kwargs. **kwargs: Keyword arguments to `tf.sparse_tensor_dense_matmul`. Returns: product: A dense (batch of) matrix-shaped Tensor of the same batch shape and dtype as `sp_a` and `b`. If `sp_a` or `b` is adjointed through `kwargs` then the shape is adjusted accordingly. """ batch_shape = _get_shape(sp_a)[:-2] # Reshape the SparseTensor into a rank 3 SparseTensors, with the # batch shape flattened to a single dimension. If the batch rank is 0, then # we add a batch dimension of rank 1. sp_a = tf.sparse.reshape(sp_a, tf.concat([[-1], _get_shape(sp_a)[-2:]], axis=0)) # Reshape b to stack the batch dimension along the rows. b = tf.reshape(b, tf.concat([[-1], _get_shape(b)[-1:]], axis=0)) # Convert the SparseTensor to a matrix in block diagonal form with blocks of # matrices [M, N]. This allow us to use tf.sparse_tensor_dense_matmul which # only accepts rank 2 (Sparse)Tensors. out = tf.sparse.sparse_dense_matmul(_sparse_block_diag(sp_a), b, **kwargs) # Finally retrieve the original batch shape from the resulting rank 2 Tensor. # Note that we avoid inferring the final shape from `sp_a` or `b` because we # might have transposed one or both of them. return tf.reshape( out, tf.concat([batch_shape, [-1], _get_shape(out)[-1:]], axis=0))
def _log_prob(self, x): logits = self._logits_parameter_no_checks() event_size = self._event_size(logits) x = tf.cast(x, logits.dtype) x = self._maybe_assert_valid_sample(x, dtype=logits.dtype) # broadcast logits or x if need be. if (not tensorshape_util.is_fully_defined(x.shape) or not tensorshape_util.is_fully_defined(logits.shape) or x.shape != logits.shape): broadcast_shape = tf.broadcast_dynamic_shape( tf.shape(logits), tf.shape(x)) logits = tf.broadcast_to(logits, broadcast_shape) x = tf.broadcast_to(x, broadcast_shape) logits_shape = tf.shape(tf.reduce_sum(logits, axis=-1)) logits_2d = tf.reshape(logits, [-1, event_size]) x_2d = tf.reshape(x, [-1, event_size]) ret = -tf.nn.softmax_cross_entropy_with_logits( labels=tf.stop_gradient(x_2d), logits=logits_2d) # Reshape back to user-supplied batch and sample dims prior to 2D reshape. ret = tf.reshape(ret, logits_shape) return ret
def _variance(self): with tf.control_dependencies(self._runtime_assertions): probs = self._marginal_hidden_probs() # probs :: num_steps batch_shape num_states means = self._observation_distribution.mean() # means :: observation_batch_shape[:-1] num_states # observation_event_shape means_shape = tf.concat([ self.batch_shape_tensor(), [self._num_states], self._observation_distribution.event_shape_tensor() ], axis=0) means = tf.broadcast_to(means, means_shape) # means :: batch_shape num_states observation_event_shape observation_event_shape = ( self._observation_distribution.event_shape_tensor()) batch_size = tf.reduce_prod(self.batch_shape_tensor()) flat_probs_shape = [self._num_steps, batch_size, self._num_states] flat_means_shape = [ batch_size, 1, self._num_states, tf.reduce_prod(observation_event_shape) ] flat_probs = tf.reshape(probs, flat_probs_shape) # flat_probs :: num_steps batch_size num_states flat_means = tf.reshape(means, flat_means_shape) # flat_means :: batch_size 1 num_states observation_event_size flat_mean = tf.einsum("ijk,jmkl->jiml", flat_probs, flat_means) # flat_mean :: batch_size num_steps 1 observation_event_size variances = self._observation_distribution.variance() variances = tf.broadcast_to(variances, means_shape) # variances :: batch_shape num_states observation_event_shape flat_variances = tf.reshape(variances, flat_means_shape) # flat_variances :: batch_size 1 num_states observation_event_size # For a mixture of n distributions with mixture probabilities # p[i], and where the individual distributions have means and # variances given by mean[i] and var[i], the variance of # the mixture is given by: # # var = sum i=1..n p[i] * ((mean[i] - mean)**2 + var[i]**2) flat_variance = tf.einsum("ijk,jikl->jil", flat_probs, (flat_means - flat_mean)**2 + flat_variances) # flat_variance :: batch_size num_steps observation_event_size unflat_mean_shape = tf.concat([ self.batch_shape_tensor(), [self._num_steps], observation_event_shape ], axis=0) # returns :: batch_shape num_steps observation_event_shape return tf.reshape(flat_variance, unflat_mean_shape)
def _sample_n(self, n, seed=None): logits = self._logits_parameter_no_checks() logits_2d = tf.reshape(logits, [-1, self._num_categories(logits)]) sample_dtype = tf.int64 if dtype_util.size( self.dtype) > 4 else tf.int32 draws = tf.random.categorical(logits_2d, n, dtype=sample_dtype, seed=seed) draws = tf.cast(draws, self.dtype) return tf.reshape(tf.transpose(draws), shape=tf.concat( [[n], self._batch_shape_tensor(logits)], axis=0))
def _cdf(self, x): x = tf.convert_to_tensor(x, name='x') flat_x = tf.reshape(x, shape=[-1]) upper_bound = tf.searchsorted(self.outcomes, values=flat_x, side='right') values_at_ub = tf.gather( self.outcomes, indices=tf.minimum(upper_bound, dist_util.prefer_static_shape(self.outcomes)[-1] - 1)) should_use_upper_bound = self._is_equal_or_close(flat_x, values_at_ub) indices = tf.where(should_use_upper_bound, upper_bound, upper_bound - 1) return self._categorical.cdf( tf.reshape(indices, shape=dist_util.prefer_static_shape(x)))
def _sample_n(self, n, seed=None): logits = self._logits_parameter_no_checks() sample_shape = prefer_static.concat( [[n], prefer_static.shape(logits)], 0) event_size = self._event_size(logits) if tensorshape_util.rank(logits.shape) == 2: logits_2d = logits else: logits_2d = tf.reshape(logits, [-1, event_size]) samples = tf.random.categorical(logits_2d, n, seed=seed) samples = tf.transpose(a=samples) samples = tf.one_hot(samples, event_size, dtype=self.dtype) ret = tf.reshape(samples, sample_shape) return ret
def _log_prob(self, x, **kwargs): batch_ndims = prefer_static.rank_from_shape( self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_ndims = prefer_static.rank_from_shape(self.sample_shape) event_ndims = prefer_static.rank_from_shape( self.distribution.event_shape_tensor, self.distribution.event_shape) ndims = prefer_static.rank(x) # (1) Expand x's dims. d = ndims - batch_ndims - extra_sample_ndims - event_ndims x = tf.reshape(x, shape=tf.pad( tf.shape(x), paddings=[[prefer_static.maximum(0, -d), 0]], constant_values=1)) sample_ndims = prefer_static.maximum(0, d) # (2) Transpose x's dims. sample_dims = prefer_static.range(0, sample_ndims) batch_dims = prefer_static.range(sample_ndims, sample_ndims + batch_ndims) extra_sample_dims = prefer_static.range( sample_ndims + batch_ndims, sample_ndims + batch_ndims + extra_sample_ndims) event_dims = prefer_static.range( sample_ndims + batch_ndims + extra_sample_ndims, ndims) perm = prefer_static.concat( [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0) x = tf.transpose(a=x, perm=perm) # (3) Compute x's log_prob. lp = self.distribution.log_prob(x, **kwargs) # (4) Make the final reduction in x. axis = prefer_static.range(sample_ndims, sample_ndims + extra_sample_ndims) return tf.reduce_sum(lp, axis=axis)
def _inverse(self, y): output_shape, output_tensorshape = _replace_event_shape_in_shape_tensor( tf.shape(y), self._event_shape_out, self._event_shape_in, self.validate_args) x = tf.reshape(y, output_shape) tensorshape_util.set_shape(x, output_tensorshape) return x
def _forward(self, x): output_shape, output_tensorshape = _replace_event_shape_in_shape_tensor( tf.shape(x), self._event_shape_in, self._event_shape_out, self.validate_args) y = tf.reshape(x, output_shape) tensorshape_util.set_shape(y, output_tensorshape) return y
def _std_var_helper(self, statistic, statistic_name, statistic_ndims, df_factor_fn): """Helper to compute stddev, covariance and variance.""" df = tf.reshape( self.df, tf.concat([ tf.shape(self.df), tf.ones([statistic_ndims], dtype=tf.int32) ], -1)) # We need to put the tf.where inside the outer tf1.where to ensure we never # hit a NaN in the gradient. denom = tf.where(df > 2., df - 2., tf.ones_like(df)) statistic = statistic * df_factor_fn(df / denom) # When 1 < df <= 2, stddev/variance are infinite. result_where_defined = tf.where( df > 2., statistic, dtype_util.as_numpy_dtype(self.dtype)(np.inf)) if self.allow_nan_stats: return tf.where(df > 1., result_where_defined, dtype_util.as_numpy_dtype(self.dtype)(np.nan)) else: with tf.control_dependencies([ assert_util.assert_less( tf.cast(1., self.dtype), df, message='{} not defined for components of df <= 1.'. format(statistic_name.capitalize())), ]): return tf.identity(result_where_defined)
def _call_sample_n(self, sample_shape, seed, name, **kwargs): # We override `_call_sample_n` rather than `_sample_n` so we can ensure that # the result of `self.bijector.forward` is not modified (and thus caching # works). with self._name_and_control_scope(name): sample_shape = tf.convert_to_tensor(sample_shape, dtype=tf.int32, name="sample_shape") sample_shape, n = self._expand_sample_shape_to_vector( sample_shape, "sample_shape") distribution_kwargs, bijector_kwargs = self._kwargs_split_fn( kwargs) # First, generate samples. We will possibly generate extra samples in the # event that we need to reinterpret the samples as part of the # event_shape. x = self._sample_n(n, seed, **distribution_kwargs) # Next, we reshape `x` into its final form. We do this prior to the call # to the bijector to ensure that the bijector caching works. batch_event_shape = tf.shape(x)[1:] final_shape = tf.concat([sample_shape, batch_event_shape], 0) x = tf.reshape(x, final_shape) # Finally, we apply the bijector's forward transformation. For caching to # work, it is imperative that this is the last modification to the # returned result. y = self.bijector.forward(x, **bijector_kwargs) y = self._set_sample_static_shape(y, sample_shape) return y
def _mean(self, **kwargs): if not self.bijector.is_constant_jacobian: raise NotImplementedError("mean is not implemented for non-affine " "bijectors") distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) x = self.distribution.mean(**distribution_kwargs) if self._is_maybe_batch_override or self._is_maybe_event_override: # A batch (respectively event) shape override is only allowed if the batch # (event) shape of the base distribution is [], so concatenating all the # shapes does the right thing. new_shape = prefer_static.concat([ prefer_static.ones_like(self._override_batch_shape), self.distribution.batch_shape_tensor(), prefer_static.ones_like(self._override_event_shape), self.distribution.event_shape_tensor(), ], 0) x = tf.reshape(x, new_shape) new_shape = prefer_static.concat( [self.batch_shape_tensor(), self.event_shape_tensor()], 0) x = tf.broadcast_to(x, new_shape) y = self.bijector.forward(x, **bijector_kwargs) sample_shape = tf.convert_to_tensor([], dtype=tf.int32, name="sample_shape") y = self._set_sample_static_shape(y, sample_shape) return y
def _make_columnar(self, x): """Ensures non-scalar input has at least one column. Example: If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. If `x = 1` then the output is unchanged. Args: x: `Tensor`. Returns: columnar_x: `Tensor` with at least two dimensions. """ if tensorshape_util.rank(x.shape) is not None: if tensorshape_util.rank(x.shape) == 1: x = x[tf.newaxis, :] return x shape = tf.shape(x) maybe_expanded_shape = tf.concat([ shape[:-1], distribution_util.pick_vector( tf.equal(tf.rank(x), 1), [1], np.array([], dtype=np.int32)), shape[-1:], ], 0) return tf.reshape(x, maybe_expanded_shape)
def _call_and_reshape_output(self, fn, event_shape_list=None, static_event_shape_list=None, extra_kwargs=None): """Calls `fn` and appropriately reshapes its output.""" # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs` # because it is possible the user provided extra kwargs would itself # have `fn`, `event_shape_list`, `static_event_shape_list` and/or # `extra_kwargs` as keys. with tf.control_dependencies(self._runtime_assertions): if event_shape_list is None: event_shape_list = [self._event_shape_tensor()] if static_event_shape_list is None: static_event_shape_list = [self.event_shape] new_shape = tf.concat([self._batch_shape_unexpanded] + event_shape_list, axis=0) result = tf.reshape( fn(**extra_kwargs) if extra_kwargs else fn(), new_shape) if (tensorshape_util.rank(self.batch_shape) is not None and tensorshape_util.rank(self.event_shape) is not None): event_shape = tf.TensorShape([]) for rss in static_event_shape_list: event_shape = tensorshape_util.concatenate( event_shape, rss) static_shape = tensorshape_util.concatenate( self.batch_shape, event_shape) tensorshape_util.set_shape(result, static_shape) return result
def _slice_single_param(param, param_event_ndims, slices, dist_batch_shape): """Slices a single parameter of a distribution. Args: param: A `Tensor`, the original parameter to slice. param_event_ndims: `int` event parameterization rank for this parameter. slices: A `tuple` of normalized slices. dist_batch_shape: The distribution's batch shape `Tensor`. Returns: new_param: A `Tensor`, batch-sliced according to slices. """ # Extend param shape with ones on the left to match dist_batch_shape. param_shape = tf.shape(input=param) insert_ones = tf.ones( [tf.size(input=dist_batch_shape) + param_event_ndims - tf.rank(param)], dtype=param_shape.dtype) new_param_shape = tf.concat([insert_ones, param_shape], axis=0) full_batch_param = tf.reshape(param, new_param_shape) param_slices = [] # We separately track the batch axis from the parameter axis because we want # them to align for positive indexing, and be offset by param_event_ndims for # negative indexing. param_dim_idx = 0 batch_dim_idx = 0 for slc in slices: if slc is tf.newaxis: param_slices.append(slc) continue if slc is Ellipsis: if batch_dim_idx < 0: raise ValueError('Found multiple `...` in slices {}'.format(slices)) param_slices.append(slc) # Switch over to negative indexing for the broadcast check. num_remaining_non_newaxis_slices = sum( [s is not tf.newaxis for s in slices[slices.index(Ellipsis) + 1:]]) batch_dim_idx = -num_remaining_non_newaxis_slices param_dim_idx = batch_dim_idx - param_event_ndims continue # Find the batch dimension sizes for both parameter and distribution. param_dim_size = new_param_shape[param_dim_idx] batch_dim_size = dist_batch_shape[batch_dim_idx] is_broadcast = batch_dim_size > param_dim_size # Slices are denoted by start:stop:step. if isinstance(slc, slice): start, stop, step = slc.start, slc.stop, slc.step if start is not None: start = tf.where(is_broadcast, 0, start) if stop is not None: stop = tf.where(is_broadcast, 1, stop) if step is not None: step = tf.where(is_broadcast, 1, step) param_slices.append(slice(start, stop, step)) else: # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2] param_slices.append(tf.where(is_broadcast, 0, slc)) param_dim_idx += 1 batch_dim_idx += 1 param_slices.extend([ALL_SLICE] * param_event_ndims) return full_batch_param.__getitem__(param_slices)
def _reshape_part(part, dtype, event_shape): part = tf.cast(part, dtype) static_rank = tf.get_static_value( ps.rank_from_shape(event_shape)) if static_rank == 1: return part new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32))
def _extract_log_probs(num_states, dist): """Tabulate log probabilities from a batch of distributions.""" states = tf.reshape( tf.range(num_states), tf.concat([[num_states], tf.ones_like(dist.batch_shape_tensor())], axis=0)) return distribution_util.move_dimension(dist.log_prob(states), 0, -1)
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. distributions = self.poisson_and_mixture_distributions() dist, mixture_dist = distributions batch_size = tensorshape_util.num_elements(self.batch_shape) if batch_size is None: batch_size = tf.reduce_prod( self._batch_shape_tensor(distributions=distributions)) # We need to 'sample extra' from the mixture distribution if it doesn't # already specify a probs vector for each batch coordinate. # We only support this kind of reduced broadcasting, i.e., there is exactly # one probs vector for all batch dims or one for each. stream = SeedStream(seed, salt='PoissonLogNormalQuadratureCompound') ids = mixture_dist.sample(sample_shape=concat_vectors( [n], distribution_util.pick_vector(mixture_dist.is_scalar_batch(), [batch_size], np.int32([]))), seed=stream()) # We need to flatten batch dims in case mixture_dist has its own # batch dims. ids = tf.reshape(ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `quadrature_size` for `batch_size` number of times. offset = tf.range(start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids = ids + offset rate = tf.gather(tf.reshape(dist.rate, shape=[-1]), ids) rate = tf.reshape( rate, shape=concat_vectors( [n], self._batch_shape_tensor(distributions=distributions))) return tf.random.poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)
def _stddev(self): with tf.control_dependencies(self._assertions): distribution_means = [d.mean() for d in self.components] distribution_devs = [d.stddev() for d in self.components] cat_probs = self._cat_probs(log_probs=False) stacked_means = tf.stack(distribution_means, axis=-1) stacked_devs = tf.stack(distribution_devs, axis=-1) cat_probs = [self._expand_to_event_rank(c_p) for c_p in cat_probs] broadcasted_cat_probs = ( tf.stack(cat_probs, axis=-1) * tf.ones_like(stacked_means)) batched_dev = distribution_util.mixture_stddev( tf.reshape(broadcasted_cat_probs, [-1, len(self.components)]), tf.reshape(stacked_means, [-1, len(self.components)]), tf.reshape(stacked_devs, [-1, len(self.components)])) # I.e. re-shape to list(batch_shape) + list(event_shape). return tf.reshape(batched_dev, tf.shape(broadcasted_cat_probs)[:-1])
def _sample_n(self, n, seed=None, **kwargs): with tf.control_dependencies(self._runtime_assertions): x = self.distribution.sample(sample_shape=n, seed=seed, **kwargs) new_shape = tf.concat([ [n], self._batch_shape_unexpanded, self.event_shape_tensor(), ], axis=0) return tf.reshape(x, new_shape)
def _log_prob(self, x): x = tf.convert_to_tensor(x, name='x') right_indices = tf.minimum( tf.size(self.outcomes) - 1, tf.reshape( tf.searchsorted( self.outcomes, values=tf.reshape(x, shape=[-1]), side='right'), dist_util.prefer_static_shape(x))) use_right_indices = self._is_equal_or_close( x, tf.gather(self.outcomes, indices=right_indices)) left_indices = tf.maximum(0, right_indices - 1) use_left_indices = self._is_equal_or_close( x, tf.gather(self.outcomes, indices=left_indices)) log_probs = self._categorical.log_prob( tf.where(use_left_indices, left_indices, right_indices)) return tf.where( tf.logical_not(use_left_indices | use_right_indices), dtype_util.as_numpy_dtype(log_probs.dtype)(-np.inf), log_probs)
def _pad_sample_dims(self, x): with tf.name_scope("pad_sample_dims"): ndims = tensorshape_util.rank(x.shape) if tensorshape_util.rank( x.shape) is not None else tf.rank(x) shape = tf.shape(x) d = ndims - self._event_ndims x = tf.reshape(x, shape=tf.concat([shape[:d], [1], shape[d:]], axis=0)) return x
def _expand_base_distribution_mean(self): """Ensures `self.distribution.mean()` has `[batch, event]` shape.""" single_draw_shape = concat_vectors(self.batch_shape_tensor(), self.event_shape_tensor()) m = tf.reshape( self.distribution.mean(), # A scalar. shape=tf.ones_like(single_draw_shape, dtype=tf.int32)) m = tf.tile(m, multiples=single_draw_shape) tensorshape_util.set_shape( m, tensorshape_util.concatenate(self.batch_shape, self.event_shape)) return m
def _sample_one_batch_member(args): logits, num_cat_samples = args[0], args[1] # [K], [] # x has shape [1, num_cat_samples = num_samples * num_trials] x = tf.random.categorical(logits[tf.newaxis, ...], num_cat_samples, seed=seed) x = tf.reshape(x, shape=[num_samples, -1]) # [num_samples, num_trials] x = tf.one_hot( x, depth=num_classes) # [num_samples, num_trials, num_classes] x = tf.reduce_sum(x, axis=-2) # [num_samples, num_classes] return tf.cast(x, dtype=dtype)
def _entropy(self): samples = tf.convert_to_tensor(self.samples) num_samples = self._compute_num_samples(samples) entropy_shape = self._batch_shape_tensor(samples) # Flatten samples for each batch. if self._event_ndims == 0: samples = tf.reshape(samples, [-1, num_samples]) else: event_size = tf.reduce_prod(self.event_shape_tensor()) samples = tf.reshape(samples, [-1, num_samples, event_size]) # Use map_fn to compute entropy for each batch separately. def _get_entropy(samples): # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count prob = tf.cast(count / num_samples, dtype=self.dtype) entropy = tf.reduce_sum(-prob * tf.math.log(prob)) return entropy entropy = tf.map_fn(_get_entropy, samples, dtype=self.dtype) return tf.reshape(entropy, entropy_shape)
def _mean(self): with tf.control_dependencies(self._runtime_assertions): probs = self._marginal_hidden_probs() # probs :: num_steps batch_shape num_states means = self._observation_distribution.mean() # means :: observation_batch_shape[:-1] num_states # observation_event_shape means_shape = tf.concat([ self.batch_shape_tensor(), [self._num_states], self._observation_distribution.event_shape_tensor() ], axis=0) means = tf.broadcast_to(means, means_shape) # means :: batch_shape num_states observation_event_shape observation_event_shape = ( self._observation_distribution.event_shape_tensor()) batch_size = tf.reduce_prod(self.batch_shape_tensor()) flat_probs_shape = [self._num_steps, batch_size, self._num_states] flat_means_shape = [ batch_size, self._num_states, tf.reduce_prod(observation_event_shape) ] flat_probs = tf.reshape(probs, flat_probs_shape) # flat_probs :: num_steps batch_size num_states flat_means = tf.reshape(means, flat_means_shape) # flat_means :: batch_size num_states observation_event_size flat_mean = tf.einsum("ijk,jkl->jil", flat_probs, flat_means) # flat_mean :: batch_size num_steps observation_event_size unflat_mean_shape = tf.concat([ self.batch_shape_tensor(), [self._num_steps], observation_event_shape ], axis=0) # returns :: batch_shape num_steps observation_event_shape return tf.reshape(flat_mean, unflat_mean_shape)