def _default_event_space_bijector(self): return chain_bijector.Chain([ shift_bijector.Shift(shift=self.loc, validate_args=self.validate_args), exp_bijector.Exp(validate_args=self.validate_args) ], validate_args=self.validate_args)
def _transformed_beta(self, low=None, peak=None, high=None, temperature=None): low = tf.convert_to_tensor(self.low) if low is None else low peak = tf.convert_to_tensor(self.peak) if peak is None else peak high = tf.convert_to_tensor(self.high) if high is None else high temperature = (tf.convert_to_tensor(self.temperature) if temperature is None else temperature) scale = high - low concentration1 = (1. + temperature * (peak - low) / scale) concentration0 = (1. + temperature * (high - peak) / scale) return transformed_distribution.TransformedDistribution( distribution=beta.Beta(concentration1=concentration1, concentration0=concentration0, allow_nan_stats=self.allow_nan_stats), bijector=chain_bijector.Chain([ shift_bijector.Shift(shift=low), # Broadcasting scale on affine bijector to match batch dimension. # This prevents dimension mismatch for operations like cdf. # Note that `concentration1` incorporates the broadcast of all four # parameters. scale_bijector.Scale( scale=tf.broadcast_to(scale, ps.shape(concentration1))) ]))
def __init__(self, loc, scale, concentration, validate_args=False, name='generalized_pareto'): parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale, concentration], dtype_hint=tf.float32) self._loc = tensor_util.convert_nonref_to_tensor(loc) self._scale = tensor_util.convert_nonref_to_tensor(scale) self._concentration = tensor_util.convert_nonref_to_tensor( concentration) self._non_negative_concentration_bijector = chain_bijector.Chain( [ shift_bijector.Shift(shift=self._loc, validate_args=validate_args), softplus_bijector.Softplus(validate_args=validate_args) ], validate_args=validate_args) super(GeneralizedPareto, self).__init__(validate_args=validate_args, forward_min_event_ndims=0, dtype=dtype, parameters=parameters, name=name)
def _make_mixture_dist(self, component_logits, locs, scales): """Builds a mixture of quantized logistic distributions. Args: component_logits: 4D `Tensor` of logits for the Categorical distribution over Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix]`. locs: 4D `Tensor` of location parameters for the Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix, num_channels]`. scales: 4D `Tensor` of location parameters for the Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix, num_channels]`. Returns: dist: A quantized logistic mixture `tfp.distribution` over the input data. """ mixture_distribution = categorical.Categorical(logits=component_logits) # Convert distribution parameters for pixel values in # `[self._low, self._high]` for use with `QuantizedDistribution` locs = self._low + 0.5 * (self._high - self._low) * (locs + 1.) scales *= 0.5 * (self._high - self._low) logistic_dist = quantized_distribution.QuantizedDistribution( distribution=transformed_distribution.TransformedDistribution( distribution=logistic.Logistic(loc=locs, scale=scales), bijector=shift.Shift(shift=tf.cast(-0.5, self.dtype))), low=self._low, high=self._high) dist = mixture_same_family.MixtureSameFamily( mixture_distribution=mixture_distribution, components_distribution=independent.Independent( logistic_dist, reinterpreted_batch_ndims=1)) return independent.Independent(dist, reinterpreted_batch_ndims=2)
def _default_event_space_bijector(self): # TODO(b/145620027) Finalize choice of bijector. deferred_scale = DeferredTensor(self.scale, lambda x: x) return chain_bijector.Chain([ shift_bijector.Shift( shift=deferred_scale, validate_args=self.validate_args), softplus_bijector.Softplus(validate_args=self.validate_args) ], validate_args=self.validate_args)
def _bijector_fn(x0, input_depth, **condition_kwargs): shift, log_scale = shift_and_log_scale_fn(x0, input_depth, **condition_kwargs) bijectors = [] if shift is not None: bijectors.append(shift_lib.Shift(shift)) if log_scale is not None: bijectors.append(scale_lib.Scale(log_scale=log_scale)) return chain_lib.Chain(bijectors)
def _default_event_space_bijector(self): return chain_bijector.Chain([ shift_bijector.Shift(shift=self.loc, validate_args=self.validate_args), scale_matvec_linear_operator.ScaleMatvecLinearOperator( scale=self.scale, validate_args=self.validate_args), softplus_bijector.Softplus(validate_args=self.validate_args) ], validate_args=self.validate_args)
def bijector_fn(inputs, ignored_input): """Decorated function to get the RealNVP bijector.""" # Build this so we can handle a user passing a NN that returns a tensor # OR an NN that returns a bijector possible_output = layer(inputs) # We need to produce a bijector, but we do not know if the layer has done # so. We are setting this up to handle 2 possibilities: # 1) The layer outputs a bijector --> all is good # 2) The layer outputs a tensor --> we need to turn it into a bijector. if isinstance(possible_output, bijector.Bijector): output = possible_output elif isinstance(possible_output, tf.Tensor): input_shape = inputs.get_shape().as_list() output_shape = possible_output.get_shape().as_list() assert input_shape[:-1] == output_shape[:-1] c = input_shape[-1] # For layers which output a tensor, we have two possibilities: # 1) There are twice as many output channels as inputs --> the coupling # is affine, meaning there is a scale followed by a shift. # 2) There are an equal number of input and output channels --> the # coupling is additive, meaning there is just a shift if input_shape[-1] == output_shape[-1] // 2: this_scale = scale.Scale( scale_fn(possible_output[..., :c] + 2.)) this_shift = shift.Shift(possible_output[..., c:]) output = this_shift(this_scale) elif input_shape[-1] == output_shape[-1]: output = shift.Shift(possible_output[..., :c]) else: raise ValueError( 'Shape inconsistent with input. Expected shape' '{0} or {1} but tensor was shape {2}'.format( input_shape, tf.concat( [input_shape[:-1], [2 * input_shape[-1]]], 0), output_shape)) else: raise ValueError( 'Expected a bijector or a tensor, but instead got' '{}'.format(possible_output.__class__)) return output
def _default_event_space_bijector(self): # TODO(b/145620027) Finalize choice of bijector. return chain_bijector.Chain([ shift_bijector.Shift(shift=-np.pi, validate_args=self.validate_args), scale_bijector.Scale(scale=2. * np.pi, validate_args=self.validate_args), sigmoid_bijector.Sigmoid(validate_args=self.validate_args) ], validate_args=self.validate_args)
def _default_event_space_bijector(self): if tensor_util.is_ref(self.low) or tensor_util.is_ref(self.high): scale = DeferredTensor(self.high, lambda x: x - self.low) else: scale = self.high - self.low return chain_bijector.Chain([ shift_bijector.Shift(shift=self.low, validate_args=self.validate_args), scale_bijector.Scale(scale=scale, validate_args=self.validate_args), sigmoid_bijector.Sigmoid(validate_args=self.validate_args) ], validate_args=self.validate_args)
def _default_event_space_bijector(self): low = tfp_util.DeferredTensor(self.low, lambda x: x) scale = tfp_util.DeferredTensor(self.high, lambda x: x - self.low) return chain_bijector.Chain([ shift_bijector.Shift(shift=low, validate_args=self.validate_args), scale_bijector.Scale(scale=scale, validate_args=self.validate_args), sigmoid_bijector.Sigmoid(validate_args=self.validate_args) ], validate_args=self.validate_args)
def __init__(self, diag_bijector=None, diag_shift=1e-5, validate_args=False, name='fill_scale_tril'): """Instantiates the `FillScaleTriL` bijector. Args: diag_bijector: `Bijector` instance, used to transform the output diagonal to be positive. Must be an instance of `tf.__internal__.CompositeTensor` (including `tfb.AutoCompositeTensorBijector`). Default value: `None` (i.e., `tfb.Softplus()`). diag_shift: Float value broadcastable and added to all diagonal entries after applying the `diag_bijector`. Setting a positive value forces the output diagonal entries to be positive, but prevents inverting the transformation for matrices with diagonal entries less than this value. Default value: `1e-5`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` (i.e., arguments are not validated). name: Python `str` name given to ops managed by this object. Default value: `fill_scale_tril`. Raises: TypeError, if `diag_bijector` is not an instance of `tf.__internal__.CompositeTensor`. """ parameters = dict(locals()) with tf.name_scope(name) as name: if diag_bijector is None: diag_bijector = softplus.Softplus(validate_args=validate_args) if not isinstance(diag_bijector, tf.__internal__.CompositeTensor): raise TypeError('`diag_bijector` must be an instance of ' '`tf.__internal__.CompositeTensor`.') if diag_shift is not None: dtype = dtype_util.common_dtype([diag_bijector, diag_shift], tf.float32) diag_shift = tensor_util.convert_nonref_to_tensor( diag_shift, name='diag_shift', dtype=dtype) diag_bijector = chain.Chain( [shift.Shift(shift=diag_shift), diag_bijector]) super(FillScaleTriL, self).__init__([ transform_diagonal.TransformDiagonal( diag_bijector=diag_bijector), fill_triangular.FillTriangular() ], validate_args=validate_args, validate_event_size=False, parameters=parameters, name=name)
def __init__(self, shift, scale, tailweight, validate_args=False, name="lambertw_tail"): """Construct a location scale heavy-tail Lambert W bijector. The parameters `shift`, `scale`, and `tail` must be shaped in a way that supports broadcasting (e.g. `shift + scale + tail` is a valid operation). Args: shift: Floating point tensor; the shift for centering (uncentering) the input (output) random variable(s). scale: Floating point tensor; the scaling (unscaling) of the input (output) random variable(s). Must contain only positive values. tailweight: Floating point tensor; the tail behaviors of the output random variable(s). Must contain only non-negative values. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if `shift` and `scale` and `tail` have different `dtype`. """ with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([tailweight, shift, scale], tf.float32) self._tailweight = tensor_util.convert_nonref_to_tensor( tailweight, name="tailweight", dtype=dtype) self._shift = tensor_util.convert_nonref_to_tensor(shift, name="shift", dtype=dtype) self._scale = tensor_util.convert_nonref_to_tensor(scale, name="scale", dtype=dtype) dtype_util.assert_same_float_dtype( (self._tailweight, self._shift, self._scale)) self._shift_and_scale = chain.Chain( [tfb_shift.Shift(self._shift), tfb_scale.Scale(self._scale)]) # 'bijectors' argument in tfb.Chain super class are executed in reverse(!) # order. Hence the ordering in the list must be (3,2,1), not (1,2,3). super(LambertWTail, self).__init__(bijectors=[ self._shift_and_scale, _HeavyTailOnly(tailweight=self._tailweight), invert.Invert(self._shift_and_scale) ], validate_args=validate_args)
def _negative_concentration_bijector(self): # Constructed dynamically so that `scale * reciprocal(concentration)` is # tape-safe. return chain_bijector.Chain([ shift_bijector.Shift(shift=self.loc, validate_args=self.validate_args), # TODO(b/146568897): Resolve numerical issues by implementing a new # bijector instead of multiplying `scale` by `(1. - 1e-6)`. scale_bijector.Scale( scale=-(self.scale * tf.math.reciprocal(self.concentration) * (1. - 1e-6)), validate_args=self.validate_args), sigmoid_bijector.Sigmoid(validate_args=self.validate_args) ], validate_args=self.validate_args)
def _default_event_space_bijector(self): # TODO(b/146568897): Resolve numerical issues by implementing a new bijector # instead of multiplying `scale` by `(1. - 1e-6)`. if tensor_util.is_ref(self.low) or tensor_util.is_ref(self.high): scale = DeferredTensor( self.high, lambda x: (x - self.low) * (1. - 1e-6), shape=tf.broadcast_static_shape(self.high.shape, self.low.shape)) else: scale = (self.high - self.low) * (1. - 1e-6) return chain_bijector.Chain([ shift_bijector.Shift(shift=self.low, validate_args=self.validate_args), scale_bijector.Scale(scale=scale, validate_args=self.validate_args), sigmoid_bijector.Sigmoid(validate_args=self.validate_args) ], validate_args=self.validate_args)
def __init__(self, nchan, dtype=tf.float32, validate_args=False, name=None): parameters = dict(locals()) self._initialized = tf.Variable(False, trainable=False) self._m = tf.Variable(tf.zeros(nchan, dtype)) self._s = TransformedVariable(tf.ones(nchan, dtype), exp.Exp()) self._bijector = invert.Invert( chain.Chain([ scale.Scale(self._s), shift.Shift(self._m), ])) super(ActivationNormalization, self).__init__( validate_args=validate_args, forward_min_event_ndims=1, parameters=parameters, name=name or 'ActivationNormalization')
def _as_trainable_family(distribution): """Substitutes prior distributions with more easily trainable ones.""" with tf.name_scope('as_trainable_family'): if isinstance(distribution, half_normal.HalfNormal): return truncated_normal.TruncatedNormal( loc=0., scale=distribution.scale, low=0., high=distribution.scale * 10.) elif isinstance(distribution, uniform.Uniform): return shift.Shift(distribution.low)( scale_lib.Scale(distribution.high - distribution.low)(beta.Beta( concentration0=tf.ones( distribution.event_shape_tensor(), dtype=distribution.dtype), concentration1=1.))) else: return distribution
def build_affine_surrogate_posterior_from_base_distribution( base_distribution, operators='diag', bijector=None, initial_unconstrained_loc_fn=_sample_uniform_initial_loc, seed=None, validate_args=False, name=None): """Builds a variational posterior by linearly transforming base distributions. This function builds a surrogate posterior by applying a trainable transformation to a base distribution (typically a `tfd.JointDistribution`) or nested structure of base distributions, and constraining the samples with `bijector`. Note that the distributions must have event shapes corresponding to the *pretransformed* surrogate posterior -- that is, if `bijector` contains a shape-changing bijector, then the corresponding base distribution event shape is the inverse event shape of the bijector applied to the desired surrogate posterior shape. The surrogate posterior is constucted as follows: 1. Flatten the base distribution event shapes to vectors, and pack the base distributions into a `tfd.JointDistribution`. 2. Apply a trainable blockwise LinearOperator bijector to the joint base distribution. 3. Apply the constraining bijectors and return the resulting trainable `tfd.TransformedDistribution` instance. Args: base_distribution: `tfd.Distribution` instance (typically a `tfd.JointDistribution`), or a nested structure of `tfd.Distribution` instances. operators: Either a string or a list/tuple containing `LinearOperator` subclasses, `LinearOperator` instances, or callables returning `LinearOperator` instances. Supported string values are "diag" (to create a mean-field surrogate posterior) and "tril" (to create a full-covariance surrogate posterior). A list/tuple may be passed to induce other posterior covariance structures. If the list is flat, a `tf.linalg.LinearOperatorBlockDiag` instance will be created and applied to the base distribution. Otherwise the list must be singly-nested and have a first element of length 1, second element of length 2, etc.; the elements of the outer list are interpreted as rows of a lower-triangular block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular` instance is created. For complete documentation and examples, see `tfp.experimental.vi.util.build_trainable_linear_operator_block`, which receives the `operators` arg if it is list-like. Default value: `"diag"`. bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector` instances, that maps (nested) values in R^n to the support of the posterior. (This can be the `experimental_default_event_space_bijector` of the distribution over the prior latent variables.) Default value: `None` (i.e., the posterior is over R^n). initial_unconstrained_loc_fn: Optional Python `callable` with signature `initial_loc = initial_unconstrained_loc_fn(shape, dtype, seed)` used to sample real-valued initializations for the unconstrained location of each variable. Default value: `functools.partial(tf.random.stateless_uniform, minval=-2., maxval=2., dtype=tf.float32)`. seed: Python integer to seed the random number generator for initial values. Default value: `None`. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_affine_surrogate_posterior_from_base_distribution'). Returns: surrogate_distribution: Trainable `tfd.JointDistribution` instance. Raises: NotImplementedError: Base distributions with mixed dtypes are not supported. #### Examples ```python tfd = tfp.distributions tfb = tfp.bijectors # Fit a multivariate Normal surrogate posterior on the Eight Schools model # [1]. treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.] treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.] def model_fn(): avg_effect = yield tfd.Normal(loc=0., scale=10., name='avg_effect') log_stddev = yield tfd.Normal(loc=5., scale=1., name='log_stddev') school_effects = yield tfd.Sample( tfd.Normal(loc=avg_effect, scale=tf.exp(log_stddev)), sample_shape=[8], name='school_effects') treatment_effects = yield tfd.Independent( tfd.Normal(loc=school_effects, scale=treatment_stddevs), reinterpreted_batch_ndims=1, name='treatment_effects') model = tfd.JointDistributionCoroutineAutoBatched(model_fn) # Pin the observed values in the model. target_model = model.experimental_pin(treatment_effects=treatment_effects) # Define a lower triangular structure of `LinearOperator` subclasses that # models full covariance among latent variables except for the 8 dimensions # of `school_effect`, which are modeled as independent (using # `LinearOperatorDiag`). operators = [ [tf.linalg.LinearOperatorLowerTriangular], [tf.linalg.LinearOperatorFullMatrix, LinearOperatorLowerTriangular], [tf.linalg.LinearOperatorFullMatrix, LinearOperatorFullMatrix, tf.linalg.LinearOperatorDiag]] # Constrain the posterior values to the support of the prior. bijector = target_model.experimental_default_event_space_bijector() # Build a full-covariance surrogate posterior. surrogate_posterior = ( tfp.experimental.vi.build_affine_surrogate_posterior_from_base_distribution( base_distribution=base_distribution, operators=operators, bijector=bijector)) # Fit the model. losses = tfp.vi.fit_surrogate_posterior( target_model.unnormalized_log_prob, surrogate_posterior, num_steps=100, optimizer=tf.optimizers.Adam(0.1), sample_size=10) ``` #### References [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and Donald Rubin. Bayesian Data Analysis, Third Edition. Chapman and Hall/CRC, 2013. """ with tf.name_scope( name or 'build_affine_surrogate_posterior_from_base_distribution'): if nest.is_nested(base_distribution): base_distribution = (joint_distribution_util. independent_joint_distribution_from_structure( base_distribution, validate_args=validate_args)) if nest.is_nested(bijector): bijector = joint_map.JointMap(nest.map_structure( lambda b: identity.Identity() if b is None else b, bijector), validate_args=validate_args) batch_shape = base_distribution.batch_shape_tensor() if tf.nest.is_nested( batch_shape): # Base is a classic JointDistribution. batch_shape = functools.reduce(ps.broadcast_shape, tf.nest.flatten(batch_shape)) event_shape = base_distribution.event_shape_tensor() flat_event_size = nest.flatten( nest.map_structure(ps.reduce_prod, event_shape)) base_dtypes = set(nest.flatten(base_distribution.dtype)) if len(base_dtypes) > 1: raise NotImplementedError( 'Base distributions with mixed dtype are not supported. Saw ' 'components of dtype {}'.format(base_dtypes)) base_dtype = list(base_dtypes)[0] num_components = len(flat_event_size) if operators == 'diag': operators = [tf.linalg.LinearOperatorDiag] * num_components elif operators == 'tril': operators = [[tf.linalg.LinearOperatorFullMatrix] * i + [tf.linalg.LinearOperatorLowerTriangular] for i in range(num_components)] elif isinstance(operators, str): raise ValueError( 'Unrecognized operator type {}. Valid operators are "diag", "tril", ' 'or a structure that can be passed to ' '`tfp.experimental.vi.util.build_trainable_linear_operator_block` as ' 'the `operators` arg.'.format(operators)) if nest.is_nested(operators): seed, operators_seed = samplers.split_seed(seed) operators = (trainable_linear_operators. build_trainable_linear_operator_block( operators, block_dims=flat_event_size, dtype=base_dtype, batch_shape=batch_shape, seed=operators_seed)) linop_bijector = ( scale_matvec_linear_operator.ScaleMatvecLinearOperatorBlock( scale=operators, validate_args=validate_args)) loc_bijector = joint_map.JointMap( tf.nest.map_structure( lambda s, seed: shift.Shift( # pylint: disable=g-long-lambda tf.Variable( initial_unconstrained_loc_fn(ps.concat( [batch_shape, [s]], axis=0), dtype=base_dtype, seed=seed))), flat_event_size, samplers.split_seed(seed, n=len(flat_event_size))), validate_args=validate_args) unflatten_and_reshape = chain.Chain([ joint_map.JointMap(nest.map_structure(reshape.Reshape, event_shape), validate_args=validate_args), restructure.Restructure( nest.pack_sequence_as(event_shape, range(num_components))) ], validate_args=validate_args) bijectors = [] if bijector is None else [bijector] bijectors.extend([ unflatten_and_reshape, loc_bijector, # Allow the mean of the standard dist to shift from 0. linop_bijector ]) # Apply LinOp to scale the standard dist. bijector = chain.Chain(bijectors, validate_args=validate_args) flat_base_distribution = invert.Invert(unflatten_and_reshape)( base_distribution) return transformed_distribution.TransformedDistribution( flat_base_distribution, bijector=bijector, validate_args=validate_args)
def __init__(self, skewness, tailweight, loc, scale, validate_args=False, allow_nan_stats=True, name=None): """Construct Johnson's SU distributions. The distributions have shape parameteres `tailweight` and `skewness`, mean `loc`, and scale `scale`. The parameters `tailweight`, `skewness`, `loc`, and `scale` must be shaped in a way that supports broadcasting (e.g. `skewness + tailweight + loc + scale` is a valid operation). Args: skewness: Floating-point `Tensor`. Skewness of the distribution(s). tailweight: Floating-point `Tensor`. Tail weight of the distribution(s). `tailweight` must contain only positive values. loc: Floating-point `Tensor`. The mean(s) of the distribution(s). scale: Floating-point `Tensor`. The scaling factor(s) for the distribution(s). Note that `scale` is not technically the standard deviation of this distribution but has semantics more similar to standard deviation than variance. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value '`NaN`' to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if any of skewness, tailweight, loc and scale are different dtypes. """ parameters = dict(locals()) with tf.name_scope(name or 'JohnsonSU') as name: dtype = dtype_util.common_dtype([skewness, tailweight, loc, scale], tf.float32) self._skewness = tensor_util.convert_nonref_to_tensor( skewness, name='skewness', dtype=dtype) self._tailweight = tensor_util.convert_nonref_to_tensor( tailweight, name='tailweight', dtype=dtype) self._loc = tensor_util.convert_nonref_to_tensor(loc, name='loc', dtype=dtype) self._scale = tensor_util.convert_nonref_to_tensor(scale, name='scale', dtype=dtype) norm_shift = invert_bijector.Invert( shift_bijector.Shift(shift=self._skewness, validate_args=validate_args)) norm_scale = invert_bijector.Invert( scale_bijector.Scale(scale=self._tailweight, validate_args=validate_args)) sinh = sinh_bijector.Sinh(validate_args=validate_args) scale = scale_bijector.Scale(scale=self._scale, validate_args=validate_args) shift = shift_bijector.Shift(shift=self._loc, validate_args=validate_args) bijector = shift(scale(sinh(norm_scale(norm_shift)))) batch_rank = ps.reduce_max([ distribution_util.prefer_static_rank(x) for x in (self._skewness, self._tailweight, self._loc, self._scale) ]) super(JohnsonSU, self).__init__( # TODO(b/160730249): Make `loc` a scalar `0.` and remove overridden # `batch_shape` and `batch_shape_tensor` when # TransformedDistribution's bijector can modify its `batch_shape`. distribution=normal.Normal(loc=tf.zeros(ps.ones( batch_rank, tf.int32), dtype=dtype), scale=tf.ones([], dtype=dtype), validate_args=validate_args, allow_nan_stats=allow_nan_stats), bijector=bijector, validate_args=validate_args, parameters=parameters, name=name)
def __init__(self, loc=None, scale=None, validate_args=False, allow_nan_stats=True, name='MultivariateNormalLinearOperator'): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape `[B1, ..., Bb, k, k]`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ parameters = dict(locals()) if scale is None: raise ValueError('Missing required `scale` parameter.') if not dtype_util.is_floating(scale.dtype): raise TypeError( '`scale` parameter must have floating-point dtype.') with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32) # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = tensor_util.convert_nonref_to_tensor(loc, dtype=dtype, name='loc') batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale) self._loc = loc self._scale = scale bijector = scale_matvec_linear_operator.ScaleMatvecLinearOperator( scale, validate_args=validate_args) if loc is not None: bijector = shift_bijector.Shift( shift=loc, validate_args=validate_args)(bijector) super(MultivariateNormalLinearOperator, self).__init__( distribution=normal.Normal(loc=tf.zeros([], dtype=dtype), scale=tf.ones([], dtype=dtype)), bijector=bijector, batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, name=name) self._parameters = parameters
def generate_shift_bijector(s): x = yield trainable_state_util.Parameter( functools.partial(initial_unconstrained_loc_fn, ps.concat([batch_shape, [s]], axis=0), dtype=base_dtype)) return shift.Shift(x)
global ASVI_SURROGATE_SUBSTITUTIONS if inspect.isclass(condition): condition = lambda distribution, cls=condition: isinstance( # pylint: disable=g-long-lambda distribution, cls) ASVI_SURROGATE_SUBSTITUTIONS[condition] = substitution_fn # Default substitutions attempt to express distributions using the most # flexible available parameterization. # pylint: disable=g-long-lambda register_asvi_substitution_rule( half_normal.HalfNormal, lambda dist: truncated_normal.TruncatedNormal( loc=0., scale=dist.scale, low=0., high=dist.scale * 10.)) register_asvi_substitution_rule( uniform.Uniform, lambda dist: shift.Shift(dist.low) (scale_lib.Scale(dist.high - dist.low) (beta.Beta(concentration0=tf.ones_like(dist.mean()), concentration1=1.)))) register_asvi_substitution_rule( exponential.Exponential, lambda dist: gamma.Gamma(concentration=1., rate=dist.rate)) register_asvi_substitution_rule( chi2.Chi2, lambda dist: gamma.Gamma(concentration=0.5 * dist.df, rate=0.5)) # pylint: enable=g-long-lambda # TODO(kateslin): Add support for models with prior+likelihood written as # a single JointDistribution. def build_asvi_surrogate_posterior(prior, mean_field=False, initial_prior_weight=0.5,
def __init__(self, loc=None, scale=None, validate_args=False, allow_nan_stats=True, experimental_use_kahan_sum=False, name='MultivariateNormalLinearOperator'): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape `[B1, ..., Bb, k, k]`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. experimental_use_kahan_sum: Python `bool`. When `True`, we use Kahan summation to aggregate independent underlying log_prob values. For best results, Kahan summation should also be applied when computing the log-determinant of the `LinearOperator` representing the scale matrix. Kahan summation improves against the precision of a naive float32 sum. This can be noticeable in particular for large dimensions in float32. See CPU caveat on `tfp.math.reduce_kahan_sum`. name: The name to give Ops created by the initializer. Raises: ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ parameters = dict(locals()) self._experimental_use_kahan_sum = experimental_use_kahan_sum if scale is None: raise ValueError('Missing required `scale` parameter.') if not dtype_util.is_floating(scale.dtype): raise TypeError('`scale` parameter must have floating-point dtype.') with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32) # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = tensor_util.convert_nonref_to_tensor( loc, dtype=dtype, name='loc') batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale) self._loc = loc self._scale = scale bijector = scale_matvec_linear_operator.ScaleMatvecLinearOperator( scale, validate_args=validate_args) if loc is not None: bijector = shift_bijector.Shift( shift=loc, validate_args=validate_args)(bijector) super(MultivariateNormalLinearOperator, self).__init__( # TODO(b/137665504): Use batch-adding meta-distribution to set the batch # shape instead of tf.zeros. # We use `Sample` instead of `Independent` because `Independent` # requires concatenating `batch_shape` and `event_shape`, which loses # static `batch_shape` information when `event_shape` is not statically # known. distribution=sample.Sample( normal.Normal( loc=tf.zeros(batch_shape, dtype=dtype), scale=tf.ones([], dtype=dtype)), event_shape, experimental_use_kahan_sum=experimental_use_kahan_sum), bijector=bijector, validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, loc=None, precision_factor=None, precision=None, validate_args=False, allow_nan_stats=True, name='MultivariateNormalPrecisionFactorLinearOperator'): """Initialize distribution. Precision is the inverse of the covariance matrix, and `precision_factor @ precision_factor.T = precision`. The `batch_shape` of this distribution is the broadcast of `loc.shape[:-1]` and `precision_factor.batch_shape`. The `event_shape` of this distribution is determined by `loc.shape[-1:]`, OR `precision_factor.shape[-1:]`, which must match. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. precision_factor: Required nonsingular `tf.linalg.LinearOperator` instance with same `dtype` and shape compatible with `loc`. precision: Optional square `tf.linalg.LinearOperator` instance with same `dtype` and shape compatible with `loc` and `precision_factor`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. """ parameters = dict(locals()) with tf.name_scope(name) as name: if precision_factor is None: raise ValueError( 'Argument `precision_factor` must be provided. Found `None`') dtype = dtype_util.common_dtype([loc, precision_factor, precision], dtype_hint=tf.float32) loc = tensor_util.convert_nonref_to_tensor(loc, dtype=dtype, name='loc') self._loc = loc self._precision_factor = precision_factor self._precision = precision batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, precision_factor) # Proof of factors (used throughout code): # Let, # C = covariance, # P = inv(covariance) = precision # P = F @ F.T (so F is the `precision_factor`). # # Then, the log prob term is # x.T @ inv(C) @ x # = x.T @ P @ x # = x.T @ F @ F.T @ x # = || F.T @ x ||**2 # notice it involves F.T, which is why we set adjoint=True in various # places. # # Also, if w ~ Normal(0, I), then we can sample by setting # x = inv(F.T) @ w + loc, # since then # E[(x - loc) @ (x - loc).T] # = E[inv(F.T) @ w @ w.T @ inv(F)] # = inv(F.T) @ inv(F) # = inv(F @ F.T) # = inv(P) # = C. if precision is not None: precision.shape.assert_is_compatible_with(precision_factor.shape) bijector = invert.Invert( scale_matvec_linear_operator.ScaleMatvecLinearOperator( scale=precision_factor, validate_args=validate_args, adjoint=True) ) if loc is not None: shift = shift_bijector.Shift(shift=loc, validate_args=validate_args) bijector = shift(bijector) super(MultivariateNormalPrecisionFactorLinearOperator, self).__init__( distribution=mvn_diag.MultivariateNormalDiag( loc=tf.zeros( ps.concat([batch_shape, event_shape], axis=0), dtype=dtype)), bijector=bijector, validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, loc, scale, skewness=None, tailweight=None, distribution=None, validate_args=False, allow_nan_stats=True, name='SinhArcsinh'): """Construct SinhArcsinh distribution on `(-inf, inf)`. Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape (indexing batch dimensions). They must all have the same `dtype`. Args: loc: Floating-point `Tensor`. scale: `Tensor` of same `dtype` as `loc`. skewness: Skewness parameter. Default is `0.0` (no skew). tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) distribution: `tf.Distribution`-like instance. Distribution that is transformed to produce this distribution. Must have a batch shape to which the shapes of `loc`, `scale`, `skewness`, and `tailweight` all broadcast. Default is `tfd.Normal(batch_shape, 1.)`, where `batch_shape` is the broadcasted shape of the parameters. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a `SinhArcsinh` sample and `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight], tf.float32) self._loc = tensor_util.convert_nonref_to_tensor( loc, name='loc', dtype=dtype) self._scale = tensor_util.convert_nonref_to_tensor( scale, name='scale', dtype=dtype) tailweight = 1. if tailweight is None else tailweight has_default_skewness = skewness is None skewness = 0. if has_default_skewness else skewness self._tailweight = tensor_util.convert_nonref_to_tensor( tailweight, name='tailweight', dtype=dtype) self._skewness = tensor_util.convert_nonref_to_tensor( skewness, name='skewness', dtype=dtype) # Recall, with Z a random variable, # Y := loc + scale * F(Z), # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) * C # C := 2 / F_0(2) # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) if distribution is None: batch_shape = functools.reduce( ps.broadcast_shape, [ps.shape(x) for x in (self._skewness, self._tailweight, self._loc, self._scale)]) distribution = normal.Normal( loc=tf.zeros(batch_shape, dtype=dtype), scale=tf.ones([], dtype=dtype), allow_nan_stats=allow_nan_stats, validate_args=validate_args) # Make the SAS bijector, 'F'. f = sinh_arcsinh_bijector.SinhArcsinh( skewness=self._skewness, tailweight=self._tailweight, validate_args=validate_args) # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2)) affine = shift_bijector.Shift(shift=self._loc)( scale_bijector.Scale(scale=self._scale)) bijector = chain_bijector.Chain([affine, f]) super(SinhArcsinh, self).__init__( distribution=distribution, bijector=bijector, validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, low=None, high=None, hinge_softness=None, validate_args=False, name='soft_clip'): """Instantiates the SoftClip bijector. Args: low: Optional float `Tensor` lower bound. If `None`, the lower-bound constraint is omitted. Default value: `None`. high: Optional float `Tensor` upper bound. If `None`, the upper-bound constraint is omitted. Default value: `None`. hinge_softness: Optional nonzero float `Tensor`. Controls the softness of the constraint at the boundaries; values outside of the constraint set are mapped into intervals of width approximately `log(2) * hinge_softness` on the interior of each boundary. High softness reserves more space for values outside of the constraint set, leading to greater distortion of inputs *within* the constraint set, but improved numerical stability near the boundaries. Default value: `None` (`1.0`). validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. """ parameters = dict(locals()) with tf.name_scope(name): dtype = dtype_util.common_dtype([low, high, hinge_softness], dtype_hint=tf.float32) low = tensor_util.convert_nonref_to_tensor(low, name='low', dtype=dtype) high = tensor_util.convert_nonref_to_tensor(high, name='high', dtype=dtype) hinge_softness = tensor_util.convert_nonref_to_tensor( hinge_softness, name='hinge_softness', dtype=dtype) softplus_bijector = softplus.Softplus( hinge_softness=hinge_softness) negate = tf.convert_to_tensor(-1., dtype=dtype) components = [] if low is not None and high is not None: # Support reference tensors (eg Variables) for `high` and `low` by # deferring all computation on them until needed. width = tfp_util.DeferredTensor( pretransformed_input=high, transform_fn=lambda high: high - low) negated_shrinkage_factor = tfp_util.DeferredTensor( pretransformed_input=width, transform_fn=lambda w: tf.cast( # pylint: disable=g-long-lambda negate * w / softplus_bijector.forward(w), dtype=dtype)) # Implement the soft constraint from 'Mathematical Details' above: # softclip(x) := -softplus(width - softplus(x - low)) * # (width) / (softplus(width)) + high components = [ shift.Shift(high), scale.Scale(negated_shrinkage_factor), softplus_bijector, shift.Shift(width), scale.Scale(negate), softplus_bijector, shift.Shift(tfp_util.DeferredTensor(low, lambda x: -x)) ] elif low is not None: # Implement a soft lower bound: # softlower(x) := softplus(x - low) + low components = [ shift.Shift(low), softplus_bijector, shift.Shift(tfp_util.DeferredTensor(low, lambda x: -x)) ] elif high is not None: # Implement a soft upper bound: # softupper(x) := -softplus(high - x) + high components = [ shift.Shift(high), scale.Scale(negate), softplus_bijector, scale.Scale(negate), shift.Shift(high) ] self._low = low self._high = high self._hinge_softness = hinge_softness self._chain = chain.Chain(components, validate_args=validate_args) super(SoftClip, self).__init__(forward_min_event_ndims=0, dtype=dtype, validate_args=validate_args, parameters=parameters, is_constant_jacobian=not components, name=name)
def __init__(self, loc=None, scale=None, validate_args=False, allow_nan_stats=True, name='VectorExponentialLinearOperator'): """Construct Vector Exponential distribution supported on a subset of `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape `[B1, ..., Bb, k, k]`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ parameters = dict(locals()) if loc is None: loc = 0.0 # Implicit value for backwards compatibility. if scale is None: raise ValueError('Missing required `scale` parameter.') if not dtype_util.is_floating(scale.dtype): raise TypeError( '`scale` parameter must have floating-point dtype.') with tf.name_scope(name) as name: # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = loc if loc is None else tf.convert_to_tensor( loc, name='loc', dtype=scale.dtype) batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale) self._loc = loc self._scale = scale super(VectorExponentialLinearOperator, self).__init__( # TODO(b/137665504): Use batch-adding meta-distribution to set the # batch shape instead of tf.ones. # We use `Sample` instead of `Independent` because `Independent` # requires concatenating `batch_shape` and `event_shape`, which loses # static `batch_shape` information when `event_shape` is not # statically known. distribution=sample.Sample( exponential.Exponential(rate=tf.ones(batch_shape, dtype=scale.dtype), allow_nan_stats=allow_nan_stats), event_shape), bijector=shift_bijector.Shift(shift=loc)( scale_matvec_linear_operator.ScaleMatvecLinearOperator( scale=scale, validate_args=validate_args)), validate_args=validate_args, name=name) self._parameters = parameters