def maybe_assert_bernoulli_param_correctness( is_init, validate_args, probs, logits): """Return assertions for `Bernoulli`-type distributions.""" if is_init: x, name = (probs, 'probs') if logits is None else (logits, 'logits') if not dtype_util.is_floating(x.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) if not validate_args: return [] assertions = [] if probs is not None: if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.constant(1., probs.dtype) assertions += [ assert_util.assert_non_negative( probs, message='probs has components less than 0.'), assert_util.assert_less_equal( probs, one, message='probs has components greater than 1.') ] return assertions
def __init__(self, rate=None, log_rate=None, interpolate_nondiscrete=True, validate_args=False, allow_nan_stats=True, name='Poisson'): """Initialize a batch of Poisson distributions. Args: rate: Floating point tensor, the rate parameter. `rate` must be positive. Must specify exactly one of `rate` and `log_rate`. log_rate: Floating point tensor, the log of the rate parameter. Must specify exactly one of `rate` and `log_rate`. interpolate_nondiscrete: Python `bool`. When `False`, `log_prob` returns `-inf` (and `prob` returns `0`) for non-integer inputs. When `True`, `log_prob` evaluates the continuous function `k * log_rate - lgamma(k+1) - rate`, which matches the Poisson pmf at integer arguments `k` (note that this function is not itself a normalized probability log-density). Default value: `True`. validate_args: Python `bool`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `True`. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if none or both of `rate`, `log_rate` are specified. TypeError: if `rate` is not a float-type. TypeError: if `log_rate` is not a float-type. """ parameters = dict(locals()) if (rate is None) == (log_rate is None): raise ValueError('Must specify exactly one of `rate` and `log_rate`.') with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([rate, log_rate], dtype_hint=tf.float32) if not dtype_util.is_floating(dtype): raise TypeError('[log_]rate.dtype ({}) is a not a float-type.'.format( dtype_util.name(dtype))) self._rate = tensor_util.convert_nonref_to_tensor( rate, name='rate', dtype=dtype) self._log_rate = tensor_util.convert_nonref_to_tensor( log_rate, name='log_rate', dtype=dtype) self._interpolate_nondiscrete = interpolate_nondiscrete super(Poisson, self).__init__( dtype=dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def maybe_assert_categorical_param_correctness(is_init, validate_args, probs, logits): """Return assertions for `Categorical`-type distributions.""" assertions = [] # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: x, name = (probs, 'probs') if logits is None else (logits, 'logits') if not dtype_util.is_floating(x.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) ndims = tensorshape_util.rank(x.shape) if ndims is not None: if ndims < 1: raise ValueError(msg) elif validate_args: x = tf.convert_to_tensor(x) probs = x if logits is None else None # Retain tensor conversion. logits = x if probs is None else None assertions.append( assert_util.assert_rank_at_least(x, 1, message=msg)) if not validate_args: assert not assertions # Should never happen. return [] if logits is not None: if is_init != tensor_util.is_ref(logits): logits = tf.convert_to_tensor(logits) assertions.extend( distribution_util.assert_categorical_event_shape(logits)) if probs is not None: if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), np.array(1, dtype=dtype_util.as_numpy_dtype(probs.dtype)), message='Argument `probs` must sum to 1.') ]) assertions.extend( distribution_util.assert_categorical_event_shape(probs)) return assertions
def _prob(self, event): samples = tf.convert_to_tensor(self._samples) num_samples = self._compute_num_samples(samples) event = tf.convert_to_tensor(event, name='event', dtype=self.dtype) event, samples = _broadcast_event_and_samples(event, samples, event_ndims=self._event_ndims) prob = tf.reduce_sum( tf.cast( tf.reduce_all( tf.equal(samples, event), axis=tf.range(-self._event_ndims, 0)), dtype=tf.int32), axis=-1) / num_samples if dtype_util.is_floating(self.dtype): prob = tf.cast(prob, self.dtype) return prob
def _maybe_validate_matrix(a, validate_args): """Checks that input is a `float` matrix.""" assertions = [] if not dtype_util.is_floating(a.dtype): raise TypeError('Input `a` must have `float`-like `dtype` ' '(saw {}).'.format(dtype_util.name(a.dtype))) if tensorshape_util.rank(a.shape) is not None: if tensorshape_util.rank(a.shape) < 2: raise ValueError('Input `a` must have at least 2 dimensions ' '(saw: {}).'.format(tensorshape_util.rank( a.shape))) elif validate_args: assertions.append( assert_util.assert_rank_at_least( a, rank=2, message='Input `a` must have at least 2 dimensions.')) return assertions
def _parameter_control_dependencies(self, is_init): """Checks the validity of the concentration parameter.""" assertions = [] # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(self.concentration.dtype): raise TypeError('Argument `concentration` must be float type.') msg = 'Argument `concentration` must have rank at least 1.' ndims = tensorshape_util.rank(self.concentration.shape) if ndims is not None: if ndims < 1: raise ValueError(msg) elif self.validate_args: assertions.append(assert_util.assert_rank_at_least( self.concentration, 1, message=msg)) msg = 'Argument `concentration` must have `event_size` at least 2.' event_size = tf.compat.dimension_value(self.concentration.shape[-1]) if event_size is not None: if event_size < 2: raise ValueError(msg) elif self.validate_args: assertions.append(assert_util.assert_less( 1, tf.shape(self.concentration)[-1], message=msg)) if not self.validate_args: assert not assertions # Should never happen. return [] if is_init != tensor_util.is_ref(self.concentration): assertions.append(assert_util.assert_positive( self.concentration, message='Argument `concentration` must be positive.')) return assertions
def maybe_assert_negative_binomial_param_correctness(is_init, validate_args, total_count, probs, logits): """Return assertions for `NegativeBinomial`-type distributions.""" if is_init: x, name = (probs, 'probs') if logits is None else (logits, 'logits') if not dtype_util.is_floating(x.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) if not validate_args: return [] assertions = [] if is_init != tensor_util.is_ref(total_count): total_count = tf.convert_to_tensor(total_count) assertions.extend([ assert_util.assert_non_negative( total_count, message='`total_count` has components less than 0.'), distribution_util.assert_integer_form( total_count, message='`total_count` has fractional components.') ]) if probs is not None: if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.constant(1., probs.dtype) assertions.extend([ assert_util.assert_non_negative( probs, message='`probs` has components less than 0.'), assert_util.assert_less_equal( probs, one, message='`probs` has components greater than 1.') ]) return assertions
def _broadcast_cat_event_and_params(event, params, base_dtype): """Broadcasts the event or distribution parameters.""" if dtype_util.is_integer(event.dtype): pass elif dtype_util.is_floating(event.dtype): # When `validate_args=True` we've already ensured int/float casting # is closed. event = tf.cast(event, dtype=tf.int32) else: raise TypeError('`value` should have integer `dtype` or ' '`self.dtype` ({})'.format(base_dtype)) shape_known_statically = ( tensorshape_util.rank(params.shape) is not None and tensorshape_util.is_fully_defined(params.shape[:-1]) and tensorshape_util.is_fully_defined(event.shape)) if not shape_known_statically or params.shape[:-1] != event.shape: params = params * tf.ones_like(event[..., tf.newaxis], dtype=params.dtype) params_shape = tf.shape(params)[:-1] event = event * tf.ones(params_shape, dtype=event.dtype) if tensorshape_util.rank(params.shape) is not None: tensorshape_util.set_shape(event, params.shape[:-1]) return event, params
def _parameter_control_dependencies(self, is_init): assertions = [] logits = self._logits probs = self._probs param, name = (probs, 'probs') if logits is None else (logits, 'logits') # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: if not dtype_util.is_floating(param.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) shape_static = tensorshape_util.dims(param.shape) if shape_static is not None: if len(shape_static) < 1: raise ValueError(msg) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_rank_at_least(param, 1, message=msg)) msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name) msg2 = 'Argument `{}` must have final dimension <= {}.'.format( name, tf.int32.max) event_size = shape_static[-1] if shape_static is not None else None if event_size is not None: if event_size < 1: raise ValueError(msg1) if event_size > tf.int32.max: raise ValueError(msg2) elif self.validate_args: param = tf.convert_to_tensor(param) assertions.append( assert_util.assert_greater_equal(tf.shape(param)[-1:], 1, message=msg1)) # NOTE: For now, we leave out a runtime assertion that # `tf.shape(param)[-1] <= tf.int32.max`. An earlier `tf.shape` call # will fail before we get to this point. if not self.validate_args: assert not assertions # Should never happen. return [] if is_init != tensor_util.is_ref(self.temperature): assertions.append(assert_util.assert_positive(self.temperature)) if probs is not None: probs = param # reuse tensor conversion from above if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) one = tf.ones([], dtype=probs.dtype) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_less_equal(probs, one), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), one, message='Argument `probs` must sum to 1.'), ]) return assertions
def __init__(self, loc=None, scale=None, validate_args=False, allow_nan_stats=True, name='MultivariateNormalLinearOperator'): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape `[B1, ..., Bb, k, k]`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ parameters = dict(locals()) if scale is None: raise ValueError('Missing required `scale` parameter.') if not dtype_util.is_floating(scale.dtype): raise TypeError( '`scale` parameter must have floating-point dtype.') with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32) # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = tensor_util.convert_nonref_to_tensor(loc, dtype=dtype, name='loc') batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale) super(MultivariateNormalLinearOperator, self).__init__( distribution=normal.Normal(loc=tf.zeros([], dtype=dtype), scale=tf.ones([], dtype=dtype)), bijector=affine_linear_operator_bijector.AffineLinearOperator( shift=loc, scale=scale, validate_args=validate_args), batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, df, scale_operator, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name=None): """Construct Wishart distributions. Args: df: `float` or `double` tensor, the degrees of freedom of the distribution(s). `df` must be greater than or equal to `k`. scale_operator: `float` or `double` instance of `LinearOperator`. input_output_cholesky: Python `bool`. If `True`, functions whose input or output have the semantics of samples assume inputs are in Cholesky form and return outputs in Cholesky form. In particular, if this flag is `True`, input to `log_prob` is presumed of Cholesky form and output from `sample`, `mean`, and `mode` are of Cholesky form. Setting this argument to `True` is purely a computational optimization and does not change the underlying distribution; for instance, `mean` returns the Cholesky of the mean, not the mean of Cholesky factors. The `variance` and `stddev` methods are unaffected by this flag. Default value: `False` (i.e., input/output does not have Cholesky semantics). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if scale is not floating-type TypeError: if scale.dtype != df.dtype ValueError: if df < k, where scale operator event shape is `(k, k)` """ parameters = dict(locals()) self._input_output_cholesky = input_output_cholesky with tf.name_scope(name) as name: with tf.name_scope("init"): if not dtype_util.is_floating(scale_operator.dtype): raise TypeError( "scale_operator.dtype=%s is not a floating-point type" % scale_operator.dtype) if not scale_operator.is_square: print(scale_operator.to_dense().eval()) raise ValueError("scale_operator must be square.") self._scale_operator = scale_operator self._df = tf.convert_to_tensor(df, dtype=scale_operator.dtype, name="df") dtype_util.assert_same_float_dtype( [self._df, self._scale_operator]) if tf.compat.dimension_value( self._scale_operator.shape[-1]) is None: self._dimension = tf.cast( self._scale_operator.domain_dimension_tensor(), dtype=self._scale_operator.dtype, name="dimension") else: self._dimension = tf.convert_to_tensor( tf.compat.dimension_value( self._scale_operator.shape[-1]), dtype=self._scale_operator.dtype, name="dimension") df_val = tf.get_static_value(self._df) dim_val = tf.get_static_value(self._dimension) if df_val is not None and dim_val is not None: df_val = np.asarray(df_val) if not df_val.shape: df_val = [df_val] if np.any(df_val < dim_val): raise ValueError( "Degrees of freedom (df = %s) cannot be less than " "dimension of scale matrix (scale.dimension = %s)" % (df_val, dim_val)) elif validate_args: assertions = assert_util.assert_less_equal( self._dimension, self._df, message=("Degrees of freedom (df = %s) cannot be " "less than dimension of scale matrix " "(scale.dimension = %s)" % (self._dimension, self._df))) self._df = distribution_util.with_dependencies( [assertions], self._df) super(_WishartLinearOperator, self).__init__( dtype=self._scale_operator.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, parameters=parameters, name=name)
def __init__(self, outcomes, logits=None, probs=None, rtol=None, atol=None, validate_args=False, allow_nan_stats=True, name='FiniteDiscrete'): """Construct a finite discrete contribution. Args: outcomes: A 1-D floating or integer `Tensor`, representing a list of possible outcomes in strictly ascending order. logits: A floating N-D `Tensor`, `N >= 1`, representing the log probabilities of a set of FiniteDiscrete distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of logits for each discrete value. Only one of `logits` or `probs` should be passed in. probs: A floating N-D `Tensor`, `N >= 1`, representing the probabilities of a set of FiniteDiscrete distributions. The first `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of probabilities for each discrete value. Only one of `logits` or `probs` should be passed in. rtol: `Tensor` with same `dtype` as `outcomes`. The relative tolerance for floating number comparison. Only effective when `outcomes` is a floating `Tensor`. Default is `10 * eps`. atol: `Tensor` with same `dtype` as `outcomes`. The absolute tolerance for floating number comparison. Only effective when `outcomes` is a floating `Tensor`. Default is `10 * eps`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value '`NaN`' to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with tf.name_scope(name) as name: outcomes_dtype = dtype_util.common_dtype([outcomes], dtype_hint=tf.float32) self._outcomes = tensor_util.convert_nonref_to_tensor( outcomes, dtype_hint=outcomes_dtype, name='outcomes') if dtype_util.is_floating(self._outcomes.dtype): eps = np.finfo(dtype_util.as_numpy_dtype(outcomes_dtype)).eps self._rtol = 10 * eps if rtol is None else rtol self._atol = 10 * eps if atol is None else atol else: self._rtol = None self._atol = None self._categorical = categorical.Categorical( logits=logits, probs=probs, dtype=tf.int32, validate_args=validate_args, allow_nan_stats=allow_nan_stats) super(FiniteDiscrete, self).__init__( dtype=self._outcomes.dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def __init__(self, power, dtype=tf.int32, interpolate_nondiscrete=True, sample_maximum_iterations=100, validate_args=False, allow_nan_stats=False, name='Zipf'): """Initialize a batch of Zipf distributions. Args: power: `Float` like `Tensor` representing the power parameter. Must be strictly greater than `1`. dtype: The `dtype` of `Tensor` returned by `sample`. Default value: `tf.int32`. interpolate_nondiscrete: Python `bool`. When `False`, `log_prob` returns `-inf` (and `prob` returns `0`) for non-integer inputs. When `True`, `log_prob` evaluates the continuous function `-power log(k) - log(zeta(power))` , which matches the Zipf pmf at integer arguments `k` (note that this function is not itself a normalized probability log-density). Default value: `True`. sample_maximum_iterations: Maximum number of iterations of allowable iterations in `sample`. When `validate_args=True`, samples which fail to reach convergence (subject to this cap) are masked out with `self.dtype.min` or `nan` depending on `self.dtype.is_integer`. Default value: `100`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `False`. name: Python `str` name prefixed to Ops created by this class. Default value: `'Zipf'`. Raises: TypeError: if `power` is not `float` like. """ parameters = dict(locals()) with tf.name_scope(name) as name: self._power = tensor_util.convert_nonref_to_tensor( power, name='power', dtype=dtype_util.common_dtype([power], dtype_hint=tf.float32)) if (not dtype_util.is_floating(self._power.dtype) or dtype_util.base_equal(self._power.dtype, tf.float16)): raise TypeError( 'power.dtype ({}) is not a supported `float` type.'.format( dtype_util.name(self._power.dtype))) self._interpolate_nondiscrete = interpolate_nondiscrete self._sample_maximum_iterations = sample_maximum_iterations super(Zipf, self).__init__( dtype=dtype, reparameterization_type=reparameterization.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def _sample_n(self, num_samples, seed=None, name=None): """Returns a Tensor of samples from an LKJ distribution. Args: num_samples: Python `int`. The number of samples to draw. seed: Python integer seed for RNG name: Python `str` name prefixed to Ops created by this function. Returns: samples: A Tensor of correlation matrices with shape `[n, B, D, D]`, where `B` is the shape of the `concentration` parameter, and `D` is the `dimension`. Raises: ValueError: If `dimension` is negative. """ if self.dimension < 0: raise ValueError( 'Cannot sample negative-dimension correlation matrices.') # Notation below: B is the batch shape, i.e., tf.shape(concentration) seed = SeedStream(seed, 'sample_lkj') with tf.name_scope('sample_lkj' or name): concentration = tf.convert_to_tensor(self.concentration) if not dtype_util.is_floating(concentration.dtype): raise TypeError( 'The concentration argument should have floating type, not ' '{}'.format(dtype_util.name(concentration.dtype))) concentration = _replicate(num_samples, concentration) concentration_shape = tf.shape(concentration) if self.dimension <= 1: # For any dimension <= 1, there is only one possible correlation matrix. shape = tf.concat( [concentration_shape, [self.dimension, self.dimension]], axis=0) return tf.ones(shape=shape, dtype=concentration.dtype) beta_conc = concentration + (self.dimension - 2.) / 2. beta_dist = beta.Beta(concentration1=beta_conc, concentration0=beta_conc) # Note that the sampler below deviates from [1], by doing the sampling in # cholesky space. This does not change the fundamental logic of the # sampler, but does speed up the sampling. # This is the correlation coefficient between the first two dimensions. # This is also `r` in reference [1]. corr12 = 2. * beta_dist.sample(seed=seed()) - 1. # Below we construct the Cholesky of the initial 2x2 correlation matrix, # which is of the form: # [[1, 0], [r, sqrt(1 - r**2)]], where r is the correlation between the # first two dimensions. # This is the top-left corner of the cholesky of the final sample. first_row = tf.concat([ tf.ones_like(corr12)[..., tf.newaxis], tf.zeros_like(corr12)[..., tf.newaxis] ], axis=-1) second_row = tf.concat([ corr12[..., tf.newaxis], tf.sqrt(1 - corr12**2)[..., tf.newaxis] ], axis=-1) chol_result = tf.concat([ first_row[..., tf.newaxis, :], second_row[..., tf.newaxis, :] ], axis=-2) for n in range(2, self.dimension): # Loop invariant: on entry, result has shape B + [n, n] beta_conc = beta_conc - 0.5 # norm is y in reference [1]. norm = beta.Beta(concentration1=n / 2., concentration0=beta_conc).sample(seed=seed()) # distance shape: B + [1] for broadcast distance = tf.sqrt(norm)[..., tf.newaxis] # direction is u in reference [1]. # direction shape: B + [n] direction = _uniform_unit_norm(n, concentration_shape, concentration.dtype, seed) # raw_correlation is w in reference [1]. raw_correlation = distance * direction # shape: B + [n] # This is the next row in the cholesky of the result, # which differs from the construction in reference [1]. # In the reference, the new row `z` = chol_result @ raw_correlation^T # = C @ raw_correlation^T (where as short hand we use C = chol_result). # We prove that the below equation is the right row to add to the # cholesky, by showing equality with reference [1]. # Let S be the sample constructed so far, and let `z` be as in # reference [1]. Then at this iteration, the new sample S' will be # [[S z^T] # [z 1]] # In our case we have the cholesky decomposition factor C, so # we want our new row x (same size as z) to satisfy: # [[S z^T] [[C 0] [[C^T x^T] [[CC^T Cx^T] # [z 1]] = [x k]] [0 k]] = [xC^t xx^T + k**2]] # Since C @ raw_correlation^T = z = C @ x^T, and C is invertible, # we have that x = raw_correlation. Also 1 = xx^T + k**2, so k # = sqrt(1 - xx^T) = sqrt(1 - |raw_correlation|**2) = sqrt(1 - # distance**2). new_row = tf.concat( [raw_correlation, tf.sqrt(1. - norm[..., tf.newaxis])], axis=-1) # Finally add this new row, by growing the cholesky of the result. chol_result = tf.concat([ chol_result, tf.zeros_like(chol_result[..., 0][..., tf.newaxis]) ], axis=-1) chol_result = tf.concat( [chol_result, new_row[..., tf.newaxis, :]], axis=-2) if self.input_output_cholesky: return chol_result result = tf.matmul(chol_result, chol_result, transpose_b=True) # The diagonal for a correlation matrix should always be ones. Due to # numerical instability the matmul might not achieve that, so manually set # these to ones. result = tf.linalg.set_diag( result, tf.ones(shape=tf.shape(result)[:-1], dtype=result.dtype)) # This sampling algorithm can produce near-PSD matrices on which standard # algorithms such as `tf.cholesky` or `tf.linalg.self_adjoint_eigvals` # fail. Specifically, as documented in b/116828694, around 2% of trials # of 900,000 5x5 matrices (distributed according to 9 different # concentration parameter values) contained at least one matrix on which # the Cholesky decomposition failed. return result