def _sample_3d(self, n, seed=None): """Specialized inversion sampler for 3D.""" seed = SeedStream(seed, salt='von_mises_fisher_3d') u_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0) z = tf.random.uniform(u_shape, seed=seed(), dtype=self.dtype) # TODO(bjp): Higher-order odd dim analytic CDFs are available in [1], could # be bisected for bounded sampling runtime (i.e. not rejection sampling). # [1]: Inversion sampler via: https://ieeexplore.ieee.org/document/7347705/ # The inversion is: u = 1 + log(z + (1-z)*exp(-2*kappa)) / kappa # We must protect against both kappa and z being zero. safe_conc = tf.where(self.concentration > 0, self.concentration, tf.ones_like(self.concentration)) safe_z = tf.where(z > 0, z, tf.ones_like(z)) safe_u = 1 + tf.reduce_logsumexp( [tf.math.log(safe_z), tf.math.log1p(-safe_z) - 2 * safe_conc], axis=0) / safe_conc # Limit of the above expression as kappa->0 is 2*z-1 u = tf.where(self.concentration > tf.zeros_like(safe_u), safe_u, 2 * z - 1) # Limit of the expression as z->0 is -1. u = tf.where(tf.equal(z, 0), -tf.ones_like(u), u) if not self._allow_nan_stats: u = tf.debugging.check_numerics(u, 'u in _sample_3d') return u[..., tf.newaxis]
def _survival_function(self, y): low = self._low high = self._high # Recall the promise: # survival_function(y) := P[Y > y] # = 0, if y >= high, # = 1, if y < low, # = P[X > y], otherwise. # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in # between. j = tf.math.ceil(y) # P[X > j], used when low < X < high. result_so_far = self.distribution.survival_function(j) # Re-define values at the cutoffs. if low is not None: result_so_far = tf.where(j < low, tf.ones_like(result_so_far), result_so_far) if high is not None: result_so_far = tf.where(j >= high, tf.zeros_like(result_so_far), result_so_far) return result_so_far
def _variance(self): df = tf.convert_to_tensor(self.df) scale = tf.convert_to_tensor(self.scale) # We need to put the tf.where inside the outer tf.where to ensure we never # hit a NaN in the gradient. denom = tf.where(df > 2., df - 2., tf.ones_like(df)) # Abs(scale) superfluous. var = (tf.ones(self._batch_shape_tensor(df=df, scale=scale), dtype=self.dtype) * tf.square(scale) * df / denom) # When 1 < df <= 2, variance is infinite. result_where_defined = tf.where( df > 2., var, dtype_util.as_numpy_dtype(self.dtype)(np.inf)) if self.allow_nan_stats: return tf.where( df > 1., result_where_defined, dtype_util.as_numpy_dtype(self.dtype)(np.nan)) else: return distribution_util.with_dependencies([ assert_util.assert_less( tf.ones([], dtype=self.dtype), df, message='variance not defined for components of df <= 1'), ], result_where_defined)
def assert_finite(x, data=None, summarize=None, message=None, name=None): """Assert all elements of `x` are finite. Args: x: Numeric `Tensor`. data: The tensors to print out if the condition is False. Defaults to error message and first few entries of `x`. summarize: Print this many entries of each tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_finite". Returns: Op raising `InvalidArgumentError` unless `x` has specified rank or lower. If static checks determine `x` has correct rank, a `no_op` is returned. Raises: ValueError: If static checks determine `x` has wrong rank. """ with tf.name_scope(name or 'assert_finite'): x_ = tf.get_static_value(x) if x_ is not None: if ~np.all(np.isfinite(x_)): raise ValueError(message) return x assertion = tf1.assert_equal(tf.math.is_finite(x), tf.ones_like(x, tf.bool), data=data, summarize=summarize, message=message) with tf.control_dependencies([assertion]): return tf.identity(x)
def _covariance(self): # Derivation: https://sachinruk.github.io/blog/von-Mises-Fisher/ event_dim = tf.compat.dimension_value(self.event_shape[0]) if event_dim is None: raise ValueError( 'event shape must be statically known for _bessel_ive') # TODO(bjp): Enable this; numerically unstable. if event_dim > 2: raise ValueError( 'vMF covariance is numerically unstable for dim>2') concentration = self.concentration[..., tf.newaxis] safe_conc = tf.where(concentration > 0, concentration, tf.ones_like(concentration)) h = (_bessel_ive(event_dim / 2, safe_conc) / _bessel_ive(event_dim / 2 - 1, safe_conc)) intermediate = ( tf.matmul(self.mean_direction[..., :, tf.newaxis], self.mean_direction[..., tf.newaxis, :]) * (1 - event_dim * h / safe_conc - h**2)[..., tf.newaxis]) cov = tf.linalg.set_diag( intermediate, tf.linalg.diag_part(intermediate) + (h / safe_conc)) return tf.where( concentration[..., tf.newaxis] > tf.zeros_like(cov), cov, tf.linalg.eye(event_dim, batch_shape=self.batch_shape_tensor()) / event_dim)
def _cdf(self, y): low = self._low high = self._high # Recall the promise: # cdf(y) := P[Y <= y] # = 1, if y >= high, # = 0, if y < low, # = P[X <= y], otherwise. # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in # between. j = tf.floor(y) # P[X <= j], used when low < X < high. result_so_far = self.distribution.cdf(j) # Re-define values at the cutoffs. if low is not None: result_so_far = tf.where(j < low, tf.zeros_like(result_so_far), result_so_far) if high is not None: result_so_far = tf.where(j >= high, tf.ones_like(result_so_far), result_so_far) return result_so_far
def _extract_log_probs(num_states, dist): """Tabulate log probabilities from a batch of distributions.""" states = tf.reshape( tf.range(num_states), tf.concat([[num_states], tf.ones_like(dist.batch_shape_tensor())], axis=0)) return distribution_util.move_dimension(dist.log_prob(states), 0, -1)
def _prob(self, x): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) return tf.where( tf.math.is_nan(x), x, tf.where( # This > is only sound for continuous uniform (x < low) | (x > high), tf.zeros_like(x), tf.ones_like(x) / self._range(low=low, high=high)))
def _expand_base_distribution_mean(self): """Ensures `self.distribution.mean()` has `[batch, event]` shape.""" single_draw_shape = concat_vectors(self.batch_shape_tensor(), self.event_shape_tensor()) m = tf.reshape( self.distribution.mean(), # A scalar. shape=tf.ones_like(single_draw_shape, dtype=tf.int32)) m = tf.tile(m, multiples=single_draw_shape) tensorshape_util.set_shape( m, tensorshape_util.concatenate(self.batch_shape, self.event_shape)) return m
def _log_prob(self, x): temperature = tf.convert_to_tensor(self.temperature) logits = self._logits_parameter_no_checks() x = self._assert_valid_sample(x) # broadcast logits or x if need be. if (not tensorshape_util.is_fully_defined(x.shape) or not tensorshape_util.is_fully_defined(logits.shape) or x.shape != logits.shape): logits = tf.ones_like(x, dtype=logits.dtype) * logits x = tf.ones_like(logits, dtype=x.dtype) * x # compute the normalization constant k = tf.cast(self._event_size(logits), x.dtype) log_norm_const = (tf.math.lgamma(k) + (k - 1.) * tf.math.log(temperature)) # compute the unnormalized density log_softmax = tf.math.log_softmax(logits - x * temperature[..., tf.newaxis]) log_unnorm_prob = tf.reduce_sum(log_softmax, axis=[-1], keepdims=False) # combine unnormalized density with normalization constant return log_norm_const + log_unnorm_prob
def _bessel_ive(v, z, cache=None): """Computes I_v(z)*exp(-abs(z)) using a recurrence relation, where z > 0.""" # TODO(b/67497980): Switch to a more numerically faithful implementation. z = tf.convert_to_tensor(z) wrap = lambda result: tf.debugging.check_numerics(result, 'besseli{}'. format(v)) if float(v) >= 2: raise ValueError( 'Evaluating bessel_i by recurrence becomes imprecise for large v') cache = cache or {} safe_z = tf.where(z > 0, z, tf.ones_like(z)) if v in cache: return wrap(cache[v]) if v == 0: cache[v] = tf.math.bessel_i0e(z) elif v == 1: cache[v] = tf.math.bessel_i1e(z) elif v == 0.5: # sinh(x)*exp(-abs(x)), sinh(x) = (e^x - e^{-x}) / 2 sinhe = lambda x: (tf.exp(x - tf.abs(x)) - tf.exp(-x - tf.abs(x))) / 2 cache[v] = ( np.sqrt(2 / np.pi) * sinhe(z) * tf.where(z > 0, tf.math.rsqrt(safe_z), tf.ones_like(safe_z))) elif v == -0.5: # cosh(x)*exp(-abs(x)), cosh(x) = (e^x + e^{-x}) / 2 coshe = lambda x: (tf.exp(x - tf.abs(x)) + tf.exp(-x - tf.abs(x))) / 2 cache[v] = ( np.sqrt(2 / np.pi) * coshe(z) * tf.where(z > 0, tf.math.rsqrt(safe_z), tf.ones_like(safe_z))) if v <= 1: return wrap(cache[v]) # Recurrence relation: cache[v] = (_bessel_ive(v - 2, z, cache) - (2 * (v - 1)) * _bessel_ive(v - 1, z, cache) / z) return wrap(cache[v])
def _mean(self): # Derivation: https://sachinruk.github.io/blog/von-Mises-Fisher/ event_dim = tf.compat.dimension_value(self.event_shape[0]) if event_dim is None: raise ValueError( 'event shape must be statically known for _bessel_ive') safe_conc = tf.where(self.concentration > 0, self.concentration, tf.ones_like(self.concentration)) safe_mean = self.mean_direction * ( _bessel_ive(event_dim / 2, safe_conc) / _bessel_ive(event_dim / 2 - 1, safe_conc))[..., tf.newaxis] return tf.where( self.concentration[..., tf.newaxis] > tf.zeros_like(safe_mean), safe_mean, tf.zeros_like(safe_mean))
def _log_normalization(self): """Computes the log-normalizer of the distribution.""" event_dim = tf.compat.dimension_value(self.event_shape[0]) if event_dim is None: raise ValueError('vMF _log_normalizer currently only supports ' 'statically known event shape') safe_conc = tf.where(self.concentration > 0, self.concentration, tf.ones_like(self.concentration)) safe_lognorm = ((event_dim / 2 - 1) * tf.math.log(safe_conc) - (event_dim / 2) * np.log(2 * np.pi) - tf.math.log( _bessel_ive(event_dim / 2 - 1, safe_conc)) - tf.abs(safe_conc)) log_nsphere_surface_area = ( np.log(2.) + (event_dim / 2) * np.log(np.pi) - tf.math.lgamma(tf.cast(event_dim / 2, self.dtype))) return tf.where(self.concentration > 0, -safe_lognorm, log_nsphere_surface_area)
def _broadcast_event_and_samples(event, samples, event_ndims): """Broadcasts the event or samples.""" # This is the shape of self.samples, without the samples axis, i.e. the shape # of the result of a call to dist.sample(). This way we can broadcast it with # event to get a properly-sized event, then add the singleton dim back at # -event_ndims - 1. samples_shape = tf.concat( [ tf.shape(samples)[:-event_ndims - 1], tf.shape(samples)[tf.rank(samples) - event_ndims:] ], axis=0) event = event * tf.ones(samples_shape, dtype=event.dtype) event = tf.expand_dims(event, axis=-event_ndims - 1) samples = samples * tf.ones_like(event, dtype=samples.dtype) return event, samples
def _bdtr(k, n, p): """The binomial cumulative distribution function. Args: k: floating point `Tensor`. n: floating point `Tensor`. p: floating point `Tensor`. Returns: `sum_{j=0}^k p^j (1 - p)^(n - j)`. """ # Trick for getting safe backprop/gradients into n, k when # betainc(a = 0, ..) = nan # Write: # where(unsafe, safe_output, betainc(where(unsafe, safe_input, input))) ones = tf.ones_like(n - k) k_eq_n = tf.equal(k, n) safe_dn = tf.where(k_eq_n, ones, n - k) dk = tf.math.betainc(a=safe_dn, b=k + 1, x=1 - p) return tf.where(k_eq_n, ones, dk)
def _stddev(self): with tf.control_dependencies(self._assertions): distribution_means = [d.mean() for d in self.components] distribution_devs = [d.stddev() for d in self.components] cat_probs = self._cat_probs(log_probs=False) stacked_means = tf.stack(distribution_means, axis=-1) stacked_devs = tf.stack(distribution_devs, axis=-1) cat_probs = [self._expand_to_event_rank(c_p) for c_p in cat_probs] broadcasted_cat_probs = (tf.stack(cat_probs, axis=-1) * tf.ones_like(stacked_means)) batched_dev = distribution_util.mixture_stddev( tf.reshape(broadcasted_cat_probs, [-1, len(self.components)]), tf.reshape(stacked_means, [-1, len(self.components)]), tf.reshape(stacked_devs, [-1, len(self.components)])) # I.e. re-shape to list(batch_shape) + list(event_shape). return tf.reshape(batched_dev, tf.shape(broadcasted_cat_probs)[:-1])
def _cdf(self, x): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) peak = tf.convert_to_tensor(self.peak) interval_length = high - low # Due to the PDF being not smooth at the peak, we have to treat each side # somewhat differently. The PDF is two line segments, and thus we get # quadratics here for the CDF. result_inside_interval = tf.where( (x >= low) & (x <= peak), # (x - low) ** 2 / ((high - low) * (peak - low)) tf.math.squared_difference(x, low) / (interval_length * (peak - low)), # 1 - (high - x) ** 2 / ((high - low) * (high - peak)) 1. - tf.math.squared_difference(high, x) / (interval_length * (high - peak))) # We now add that the left tail is 0 and the right tail is 1. result_if_not_big = tf.where(x < low, tf.zeros_like(x), result_inside_interval) return tf.where(x >= high, tf.ones_like(x), result_if_not_big)
def _broadcast_cat_event_and_params(event, params, base_dtype): """Broadcasts the event or distribution parameters.""" if dtype_util.is_integer(event.dtype): pass elif dtype_util.is_floating(event.dtype): # When `validate_args=True` we've already ensured int/float casting # is closed. event = tf.cast(event, dtype=tf.int32) else: raise TypeError('`value` should have integer `dtype` or ' '`self.dtype` ({})'.format(base_dtype)) shape_known_statically = ( tensorshape_util.rank(params.shape) is not None and tensorshape_util.is_fully_defined(params.shape[:-1]) and tensorshape_util.is_fully_defined(event.shape)) if not shape_known_statically or params.shape[:-1] != event.shape: params = params * tf.ones_like(event[..., tf.newaxis], dtype=params.dtype) params_shape = tf.shape(params)[:-1] event = event * tf.ones(params_shape, dtype=event.dtype) if tensorshape_util.rank(params.shape) is not None: tensorshape_util.set_shape(event, params.shape[:-1]) return event, params
def _sample_n(self, n, seed=None): seed = SeedStream(seed, 'dirichlet_multinomial') concentration = tf.convert_to_tensor(self._concentration) total_count = tf.convert_to_tensor(self._total_count) n_draws = tf.cast(total_count, dtype=tf.int32) k = self._event_shape_tensor(concentration)[0] alpha = tf.math.multiply( tf.ones_like(total_count[..., tf.newaxis]), concentration, name='alpha') unnormalized_logits = tf.math.log( tf.random.gamma( shape=[n], alpha=alpha, dtype=self.dtype, seed=seed())) x = multinomial.draw_sample( 1, k, unnormalized_logits, n_draws, self.dtype, seed()) final_shape = tf.concat( [[n], self._batch_shape_tensor(concentration, total_count), [k]], 0) return tf.reshape(x, final_shape)
def _mode(self): return self.loc * tf.ones_like(self.scale)
def _entropy(self): h = np.log(2 * np.pi) + tf.math.log(self.scale) return h * tf.ones_like(self.loc)
def _entropy(self): log_normalization = 0.5 * np.log(2. * np.pi) + tf.math.log(self.scale) entropy = 0.5 + log_normalization return entropy * tf.ones_like(self.loc)
def _stddev(self): return self.scale * tf.ones_like(self.loc) * np.pi / np.sqrt(6)
def _stddev(self): return self.scale * tf.ones_like(self.loc)
def reduce_weighted_logsumexp(logx, w=None, axis=None, keep_dims=False, return_sign=False, name=None): """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`. If all weights `w` are known to be positive, it is more efficient to directly use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.log(w))` is more efficient than `du.reduce_weighted_logsumexp(logx, w)`. Reduces `input_tensor` along the dimensions given in `axis`. Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in `axis`. If `keep_dims` is true, the reduced dimensions are retained with length 1. If `axis` has no entries, all dimensions are reduced, and a tensor with a single element is returned. This function is more numerically stable than log(sum(w * exp(input))). It avoids overflows caused by taking the exp of large inputs and underflows caused by taking the log of small inputs. For example: ```python x = tf.constant([[0., 0, 0], [0, 0, 0]]) w = tf.constant([[-1., 1, 1], [1, 1, 1]]) du.reduce_weighted_logsumexp(x, w) # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4) du.reduce_weighted_logsumexp(x, w, axis=0) # ==> [log(-1+1), log(1+1), log(1+1)] du.reduce_weighted_logsumexp(x, w, axis=1) # ==> [log(-1+1+1), log(1+1+1)] du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True) # ==> [[log(-1+1+1)], [log(1+1+1)]] du.reduce_weighted_logsumexp(x, w, axis=[0, 1]) # ==> log(-1+5) ``` Args: logx: The tensor to reduce. Should have numeric type. w: The weight tensor. Should have numeric type identical to `logx`. axis: The dimensions to reduce. If `None` (the default), reduces all dimensions. Must be in the range `[-rank(input_tensor), rank(input_tensor))`. keep_dims: If true, retains reduced dimensions with length 1. return_sign: If `True`, returns the sign of the result. name: A name for the operation (optional). Returns: lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor. sign: (Optional) The sign of `sum(weight * exp(x))`. """ with tf.name_scope(name or 'reduce_weighted_logsumexp'): logx = tf.convert_to_tensor(logx, name='logx') if w is None: lswe = tf.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims) if return_sign: sgn = tf.ones_like(lswe) return lswe, sgn return lswe w = tf.convert_to_tensor(w, dtype=logx.dtype, name='w') log_absw_x = logx + tf.math.log(tf.abs(w)) max_log_absw_x = tf.reduce_max(log_absw_x, axis=axis, keepdims=True) # If the largest element is `-inf` or `inf` then we don't bother subtracting # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That # this is ok follows from the fact that we're actually free to subtract any # value we like, so long as we add it back after taking the `log(sum(...))`. max_log_absw_x = tf.where( tf.math.is_inf(max_log_absw_x), tf.zeros([], max_log_absw_x.dtype), max_log_absw_x) wx_over_max_absw_x = (tf.sign(w) * tf.exp(log_absw_x - max_log_absw_x)) sum_wx_over_max_absw_x = tf.reduce_sum( wx_over_max_absw_x, axis=axis, keepdims=keep_dims) if not keep_dims: max_log_absw_x = tf.squeeze(max_log_absw_x, axis) sgn = tf.sign(sum_wx_over_max_absw_x) lswe = max_log_absw_x + tf.math.log(sgn * sum_wx_over_max_absw_x) if return_sign: return lswe, sgn return lswe
def draw_sample(num_samples, num_classes, logits, num_trials, dtype, seed): """Sample a multinomial. The batch shape is given by broadcasting num_trials with remove_last_dimension(logits). Args: num_samples: Python int or singleton integer Tensor: number of multinomial samples to draw. num_classes: Python int or singleton integer Tensor: number of classes. logits: Floating Tensor with last dimension k, of (unnormalized) logit probabilities per class. num_trials: Tensor of number of categorical trials each multinomial consists of. num_trials[..., tf.newaxis] must broadcast with logits. dtype: dtype at which to emit samples. seed: Random seed. Returns: samples: Tensor of given dtype and shape [n] + batch_shape + [k]. """ with tf.name_scope('draw_sample'): # broadcast the num_trials and logits to same shape num_trials = tf.ones_like(logits[..., 0], dtype=num_trials.dtype) * num_trials logits = tf.ones_like(num_trials[..., tf.newaxis], dtype=logits.dtype) * logits # flatten the total_count and logits # flat_logits has shape [B1B2...Bm, num_classes] flat_logits = tf.reshape(logits, [-1, num_classes]) flat_num_trials = num_samples * tf.reshape(num_trials, [-1]) # [B1B2...Bm] # Computes each logits and num_trials situation by map_fn. # Using just one batch tf.random.categorical call doesn't work because that # requires num_trials to be the same across all members of the batch of # logits. This restriction makes sense for tf.random.categorical because # for it, num_trials is part of the returned shape. However, the # multinomial sampler does not need that restriction, because it sums out # exactly that dimension. # One possibility would be to draw a batch categorical whose sample count is # max(num_trials) and mask out the excess ones. However, if the elements of # num_trials vary widely, this can be wasteful of memory. # TODO(b/123763054, b/112152209): Revisit the possibility of writing this # with a batch categorical followed by batch unsorted_segment_sum, once both # of those work and are memory-efficient enough. def _sample_one_batch_member(args): logits, num_cat_samples = args[0], args[1] # [K], [] # x has shape [1, num_cat_samples = num_samples * num_trials] x = tf.random.categorical(logits[tf.newaxis, ...], num_cat_samples, seed=seed) x = tf.reshape(x, shape=[num_samples, -1]) # [num_samples, num_trials] x = tf.one_hot( x, depth=num_classes) # [num_samples, num_trials, num_classes] x = tf.reduce_sum(x, axis=-2) # [num_samples, num_classes] return tf.cast(x, dtype=dtype) if seed is not None: # Force parallel_iterations to 1 to ensure reproducibility # b/139210489 x = tf.map_fn( _sample_one_batch_member, [flat_logits, flat_num_trials], dtype=dtype, # [B1B2...Bm, num_samples, num_classes] parallel_iterations=1) else: # Invoke default parallel_iterations behavior x = tf.map_fn(_sample_one_batch_member, [flat_logits, flat_num_trials], dtype=dtype) # [B1B2...Bm, num_samples, num_classes] # reshape the results to proper shape x = tf.transpose(a=x, perm=[1, 0, 2]) final_shape = tf.concat( [[num_samples], tf.shape(num_trials), [num_classes]], axis=0) x = tf.reshape(x, final_shape) return x
def _entropy(self): # Use broadcasting rules to calculate the full broadcast sigma. scale = self.scale * tf.ones_like(self.loc) return 1. + tf.math.log(scale) + np.euler_gamma
def _mode(self): return tf.ones_like(self.power, dtype=self.dtype)
def _sample_n(self, n, seed=None): with tf.control_dependencies(self._runtime_assertions): strm = SeedStream(seed, salt="HiddenMarkovModel") num_states = self._num_states batch_shape = self.batch_shape_tensor() batch_size = tf.reduce_prod(batch_shape) # The batch sizes of the underlying initial distributions and # transition distributions might not match the batch size of # the HMM distribution. # As a result we need to ask for more samples from the # underlying distributions and then reshape the results into # the correct batch size for the HMM. init_repeat = ( tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod( self._initial_distribution.batch_shape_tensor())) init_state = self._initial_distribution.sample(n * init_repeat, seed=strm()) init_state = tf.reshape(init_state, [n, batch_size]) # init_state :: n batch_size transition_repeat = ( tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod( self._transition_distribution.batch_shape_tensor()[:-1])) def generate_step(state, _): """Take a single step in Markov chain.""" gen = self._transition_distribution.sample(n * transition_repeat, seed=strm()) # gen :: (n * transition_repeat) transition_batch new_states = tf.reshape(gen, [n, batch_size, num_states]) # new_states :: n batch_size num_states old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32) # old_states :: n batch_size num_states return tf.reduce_sum(old_states_one_hot * new_states, axis=-1) def _scan_multiple_steps(): """Take multiple steps with tf.scan.""" dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) if seed is not None: # Force parallel_iterations to 1 to ensure reproducibility # b/139210489 hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state, parallel_iterations=1) else: # Invoke default parallel_iterations behavior hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state) # TODO(b/115618503): add/use prepend_initializer to tf.scan return tf.concat([[init_state], hidden_states], axis=0) hidden_states = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps, lambda: init_state[tf.newaxis, ...]) hidden_one_hot = tf.one_hot( hidden_states, num_states, dtype=self._observation_distribution.dtype) # hidden_one_hot :: num_steps n batch_size num_states # The observation distribution batch size might not match # the required batch size so as with the initial and # transition distributions we generate more samples and # reshape. observation_repeat = (batch_size // tf.reduce_prod( self._observation_distribution.batch_shape_tensor()[:-1])) possible_observations = self._observation_distribution.sample( [self._num_steps, observation_repeat * n], seed=strm()) inner_shape = self._observation_distribution.event_shape # possible_observations :: num_steps (observation_repeat * n) # observation_batch[:-1] num_states inner_shape possible_observations = tf.reshape( possible_observations, tf.concat([[self._num_steps, n], batch_shape, [num_states], inner_shape], axis=0)) # possible_observations :: steps n batch_size num_states inner_shape hidden_one_hot = tf.reshape( hidden_one_hot, tf.concat([[self._num_steps, n], batch_shape, [num_states], tf.ones_like(inner_shape)], axis=0)) # hidden_one_hot :: steps n batch_size num_states "inner_shape" observations = tf.reduce_sum(hidden_one_hot * possible_observations, axis=-1 - tf.size(inner_shape)) # observations :: steps n batch_size inner_shape observations = distribution_util.move_dimension( observations, 0, 1 + tf.size(batch_shape)) # returned :: n batch_shape steps inner_shape return observations
def _sample_n(self, num_samples, seed=None, name=None): """Returns a Tensor of samples from an LKJ distribution. Args: num_samples: Python `int`. The number of samples to draw. seed: Python integer seed for RNG name: Python `str` name prefixed to Ops created by this function. Returns: samples: A Tensor of correlation matrices with shape `[n, B, D, D]`, where `B` is the shape of the `concentration` parameter, and `D` is the `dimension`. Raises: ValueError: If `dimension` is negative. """ if self.dimension < 0: raise ValueError( 'Cannot sample negative-dimension correlation matrices.') # Notation below: B is the batch shape, i.e., tf.shape(concentration) seed = SeedStream(seed, 'sample_lkj') with tf.name_scope('sample_lkj' or name): concentration = tf.convert_to_tensor(self.concentration) if not dtype_util.is_floating(concentration.dtype): raise TypeError( 'The concentration argument should have floating type, not ' '{}'.format(dtype_util.name(concentration.dtype))) concentration = _replicate(num_samples, concentration) concentration_shape = tf.shape(concentration) if self.dimension <= 1: # For any dimension <= 1, there is only one possible correlation matrix. shape = tf.concat( [concentration_shape, [self.dimension, self.dimension]], axis=0) return tf.ones(shape=shape, dtype=concentration.dtype) beta_conc = concentration + (self.dimension - 2.) / 2. beta_dist = beta.Beta(concentration1=beta_conc, concentration0=beta_conc) # Note that the sampler below deviates from [1], by doing the sampling in # cholesky space. This does not change the fundamental logic of the # sampler, but does speed up the sampling. # This is the correlation coefficient between the first two dimensions. # This is also `r` in reference [1]. corr12 = 2. * beta_dist.sample(seed=seed()) - 1. # Below we construct the Cholesky of the initial 2x2 correlation matrix, # which is of the form: # [[1, 0], [r, sqrt(1 - r**2)]], where r is the correlation between the # first two dimensions. # This is the top-left corner of the cholesky of the final sample. first_row = tf.concat([ tf.ones_like(corr12)[..., tf.newaxis], tf.zeros_like(corr12)[..., tf.newaxis] ], axis=-1) second_row = tf.concat([ corr12[..., tf.newaxis], tf.sqrt(1 - corr12**2)[..., tf.newaxis] ], axis=-1) chol_result = tf.concat([ first_row[..., tf.newaxis, :], second_row[..., tf.newaxis, :] ], axis=-2) for n in range(2, self.dimension): # Loop invariant: on entry, result has shape B + [n, n] beta_conc = beta_conc - 0.5 # norm is y in reference [1]. norm = beta.Beta(concentration1=n / 2., concentration0=beta_conc).sample(seed=seed()) # distance shape: B + [1] for broadcast distance = tf.sqrt(norm)[..., tf.newaxis] # direction is u in reference [1]. # direction shape: B + [n] direction = _uniform_unit_norm(n, concentration_shape, concentration.dtype, seed) # raw_correlation is w in reference [1]. raw_correlation = distance * direction # shape: B + [n] # This is the next row in the cholesky of the result, # which differs from the construction in reference [1]. # In the reference, the new row `z` = chol_result @ raw_correlation^T # = C @ raw_correlation^T (where as short hand we use C = chol_result). # We prove that the below equation is the right row to add to the # cholesky, by showing equality with reference [1]. # Let S be the sample constructed so far, and let `z` be as in # reference [1]. Then at this iteration, the new sample S' will be # [[S z^T] # [z 1]] # In our case we have the cholesky decomposition factor C, so # we want our new row x (same size as z) to satisfy: # [[S z^T] [[C 0] [[C^T x^T] [[CC^T Cx^T] # [z 1]] = [x k]] [0 k]] = [xC^t xx^T + k**2]] # Since C @ raw_correlation^T = z = C @ x^T, and C is invertible, # we have that x = raw_correlation. Also 1 = xx^T + k**2, so k # = sqrt(1 - xx^T) = sqrt(1 - |raw_correlation|**2) = sqrt(1 - # distance**2). new_row = tf.concat( [raw_correlation, tf.sqrt(1. - norm[..., tf.newaxis])], axis=-1) # Finally add this new row, by growing the cholesky of the result. chol_result = tf.concat([ chol_result, tf.zeros_like(chol_result[..., 0][..., tf.newaxis]) ], axis=-1) chol_result = tf.concat( [chol_result, new_row[..., tf.newaxis, :]], axis=-2) if self.input_output_cholesky: return chol_result result = tf.matmul(chol_result, chol_result, transpose_b=True) # The diagonal for a correlation matrix should always be ones. Due to # numerical instability the matmul might not achieve that, so manually set # these to ones. result = tf.linalg.set_diag( result, tf.ones(shape=tf.shape(result)[:-1], dtype=result.dtype)) # This sampling algorithm can produce near-PSD matrices on which standard # algorithms such as `tf.cholesky` or `tf.linalg.self_adjoint_eigvals` # fail. Specifically, as documented in b/116828694, around 2% of trials # of 900,000 5x5 matrices (distributed according to 9 different # concentration parameter values) contained at least one matrix on which # the Cholesky decomposition failed. return result