def _default_event_space_bijector(self): low = tfp_util.DeferredTensor(self.low, lambda x: x) scale = tfp_util.DeferredTensor(self.high, lambda x: x - self.low) return chain_bijector.Chain([ shift_bijector.Shift(shift=low, validate_args=self.validate_args), scale_bijector.Scale(scale=scale, validate_args=self.validate_args), sigmoid_bijector.Sigmoid(validate_args=self.validate_args) ], validate_args=self.validate_args)
def constrain_params(params_unconstrained, bijector_name): """Constrains a parameters dictionary to a bijector's parameter space.""" # Constrain them to legal values params_constrained = constraint_for(bijector_name)(params_unconstrained) # Sometimes the "bijector constraint" fn may replace c2t-tracking # DeferredTensor params with Tensor params (e.g. fix_triangular). In such # cases, we preserve the c2t-tracking DeferredTensors by wrapping them but # ignoring the value. We similarly reinstate raw tf.Variables, so they # appear in the bijector's `variables` list and can be initialized. for k in params_constrained: if (k in params_unconstrained and isinstance(params_unconstrained[k], (tfp_util.DeferredTensor, tf.Variable)) and params_unconstrained[k] is not params_constrained[k]): def constrained_value(v, val=params_constrained[k]): # pylint: disable=cell-var-from-loop # While the gradient to v will be 0, we only care about the c2t # counts. return v * 0 + val params_constrained[k] = tfp_util.DeferredTensor( params_unconstrained[k], constrained_value) hp.note('Forming bijector {} with constrained parameters {}'.format( bijector_name, params_constrained)) return params_constrained
def _build_posterior_for_one_parameter(param, batch_shape, seed): """Built a transformed-normal variational dist over a parameter's support.""" # Build a trainable Normal distribution. initial_loc = sample_uniform_initial_state(param, init_sample_shape=batch_shape, return_constrained=False, seed=seed) loc = tf.Variable(initial_value=initial_loc, name=param.name + '_loc') scale = tfp_util.DeferredTensor( tf.nn.softplus, tf.Variable(initial_value=tf.fill(tf.shape(initial_loc), value=tf.constant( -4, initial_loc.dtype)), name=param.name + '_scale')) posterior_dist = tfd.Normal(loc=loc, scale=scale) # Ensure the `event_shape` of the variational distribution matches the # parameter. if (param.prior.event_shape.ndims is None or param.prior.event_shape.ndims > 0): posterior_dist = tfd.Independent( posterior_dist, reinterpreted_batch_ndims=param.prior.event_shape.ndims) # Transform to constrained parameter space. posterior_dist = tfd.TransformedDistribution(posterior_dist, param.bijector, name='{}_posterior'.format( param.name)) return posterior_dist
def __init__(self, constant, feature_ndims=1, validate_args=False, name='Constant'): """Construct a constant kernel instance. Args: constant: Positive floating point `Tensor` (or convertible) that is used for all kernel entries. feature_ndims: Python `int` number of rightmost dims to include in kernel computation. validate_args: If `True`, parameters are checked for validity despite possibly degrading runtime performance name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with tf.name_scope(name): self._constant = tensor_util.convert_nonref_to_tensor( constant, name='constant') from tensorflow_probability.python import util as tfp_util # pylint:disable=g-import-not-at-top super(Constant, self).__init__(bias_variance=tfp_util.DeferredTensor( self._constant, tf.math.sqrt), slope_variance=0.0, shift=None, feature_ndims=feature_ndims, validate_args=validate_args, parameters=parameters, name=name)
def _fn(dtype, shape, name, trainable, add_variable_fn): """Creates `loc`, `scale` and weightnorm parameters.""" loc = add_variable_fn( name=name + "_loc", shape=shape, initializer=loc_initializer, regularizer=loc_regularizer, constraint=loc_constraint, dtype=dtype, trainable=trainable, ) if weightnorm: g = add_variable_fn( name=name + "_wn", shape=shape, initializer=tf.constant_initializer(1.4142), constraint=loc_constraint, regularizer=loc_regularizer, dtype=dtype, trainable=trainable, ) loc_wn = tfp_util.DeferredTensor( loc, lambda x: (tf.multiply(nn_impl.l2_normalize(x), g))) # loc = tfp_util.DeferredTensor(loc, lambda x: (nn_impl.l2_normalize(x))) if is_singular: if weightnorm: return loc_wn, None else: return loc, None untransformed_scale = add_variable_fn( name=name + "_untransformed_scale", shape=shape, initializer=untransformed_scale_initializer, regularizer=untransformed_scale_regularizer, constraint=untransformed_scale_constraint, dtype=dtype, trainable=trainable, ) scale = tfp_util.DeferredTensor( untransformed_scale, lambda x: (np.finfo(dtype.as_numpy_dtype).eps + tf.nn.softplus(x)), ) if weightnorm: return loc_wn, scale else: return loc, scale
def build_trainable_location_scale_distribution(initial_loc, initial_scale, event_ndims, distribution_fn=tfd.Normal, validate_args=False, name=None): """Builds a variational distribution from a location-scale family. Args: initial_loc: Float `Tensor` initial location. initial_scale: Float `Tensor` initial scale. event_ndims: Integer `Tensor` number of event dimensions in `initial_loc`. distribution_fn: Optional constructor for a `tfd.Distribution` instance in a location-scale family. This should have signature `dist = distribution_fn(loc, scale, validate_args)`. Default value: `tfd.Normal`. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_trainable_location_scale_distribution'). Returns: posterior_dist: A `tfd.Distribution` instance. """ with tf.name_scope(name or 'build_trainable_location_scale_distribution'): dtype = dtype_util.common_dtype([initial_loc, initial_scale], dtype_hint=tf.float32) initial_loc = tf.convert_to_tensor(initial_loc, dtype=dtype) initial_scale = tf.convert_to_tensor(initial_scale, dtype=dtype) loc = tf.Variable(initial_value=initial_loc, name='loc') scale = tfp_util.DeferredTensor( tf.nn.softplus, tf.Variable(initial_value=tf.broadcast_to( tfp_math.softplus_inverse(initial_scale), shape=prefer_static.shape(initial_loc)), name='inverse_softplus_scale')) posterior_dist = distribution_fn(loc=loc, scale=scale, validate_args=validate_args) # Ensure the distribution has the desired number of event dimensions. static_event_ndims = tf.get_static_value(event_ndims) if static_event_ndims is None or static_event_ndims > 0: posterior_dist = tfd.Independent( posterior_dist, reinterpreted_batch_ndims=event_ndims, validate_args=validate_args) return posterior_dist
def _fn(dtype, shape, name, trainable, add_variable_fn): """Creates `loc`, `scale` parameters.""" loc = add_variable_fn(name=name + '_loc', shape=shape, initializer=loc_initializer, regularizer=loc_regularizer, constraint=loc_constraint, dtype=dtype, trainable=trainable) if is_singular: return loc, None untransformed_scale = add_variable_fn( name=name + '_untransformed_scale', shape=shape, initializer=untransformed_scale_initializer, regularizer=untransformed_scale_regularizer, constraint=untransformed_scale_constraint, dtype=dtype, trainable=trainable) scale = tfp_util.DeferredTensor( untransformed_scale, lambda x: (np.finfo(dtype.as_numpy_dtype).eps + tf.nn.softplus(x))) return loc, scale
def __init__(self, kernel, index_points=None, observation_index_points=None, observations=None, observation_noise_variance=0., predictive_noise_variance=None, mean_fn=None, jitter=1e-6, validate_args=False, allow_nan_stats=False, name='GaussianProcessRegressionModel'): """Construct a GaussianProcessRegressionModel instance. Args: kernel: `PositiveSemidefiniteKernel`-like instance representing the GP's covariance function. index_points: `float` `Tensor` representing finite collection, or batch of collections, of points in the index set over which the GP is defined. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e` is the number (size) of index points in each batch. Ultimately this distribution corresponds to an `e`-dimensional multivariate normal. The batch shape must be broadcastable with `kernel.batch_shape` and any batch dims yielded by `mean_fn`. observation_index_points: `float` `Tensor` representing finite collection, or batch of collections, of points in the index set for which some data has been observed. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims`, and `e` is the number (size) of index points in each batch. `[b1, ..., bB, e]` must be broadcastable with the shape of `observations`, and `[b1, ..., bB]` must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc). The default value is `None`, which corresponds to the empty set of observations, and simply results in the prior predictive model (a GP with noise of variance `predictive_noise_variance`). observations: `float` `Tensor` representing collection, or batch of collections, of observations corresponding to `observation_index_points`. Shape has the form `[b1, ..., bB, e]`, which must be brodcastable with the batch and example shapes of `observation_index_points`. The batch shape `[b1, ..., bB]` must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). The default value is `None`, which corresponds to the empty set of observations, and simply results in the prior predictive model (a GP with noise of variance `predictive_noise_variance`). observation_noise_variance: `float` `Tensor` representing the variance of the noise in the Normal likelihood distribution of the model. May be batched, in which case the batch shape must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). Default value: `0.` predictive_noise_variance: `float` `Tensor` representing the variance in the posterior predictive model. If `None`, we simply re-use `observation_noise_variance` for the posterior predictive noise. If set explicitly, however, we use this value. This allows us, for example, to omit predictive noise variance (by setting this to zero) to obtain noiseless posterior predictions of function values, conditioned on noisy observations. mean_fn: Python `callable` that acts on `index_points` to produce a collection, or batch of collections, of mean values at `index_points`. Takes a `Tensor` of shape `[b1, ..., bB, f1, ..., fF]` and returns a `Tensor` whose shape is broadcastable with `[b1, ..., bB]`. Default value: `None` implies the constant zero function. jitter: `float` scalar `Tensor` added to the diagonal of the covariance matrix to ensure positive definiteness of the covariance matrix. Default value: `1e-6`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value `NaN` to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `False`. name: Python `str` name prefixed to Ops created by this class. Default value: 'GaussianProcessRegressionModel'. Raises: ValueError: if either - only one of `observations` and `observation_index_points` is given, or - `mean_fn` is not `None` and not callable. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([ index_points, observation_index_points, observations, observation_noise_variance, predictive_noise_variance, jitter ], tf.float32) index_points = tensor_util.convert_nonref_to_tensor( index_points, dtype=dtype, name='index_points') observation_index_points = tensor_util.convert_nonref_to_tensor( observation_index_points, dtype=dtype, name='observation_index_points') observations = tensor_util.convert_nonref_to_tensor( observations, dtype=dtype, name='observations') observation_noise_variance = tensor_util.convert_nonref_to_tensor( observation_noise_variance, dtype=dtype, name='observation_noise_variance') predictive_noise_variance = tensor_util.convert_nonref_to_tensor( predictive_noise_variance, dtype=dtype, name='observation_noise_variance') if predictive_noise_variance is None: predictive_noise_variance = observation_noise_variance jitter = tensor_util.convert_nonref_to_tensor( jitter, dtype=dtype, name='jitter') if (observation_index_points is None) != (observations is None): raise ValueError( '`observations` and `observation_index_points` must both be given ' 'or None. Got {} and {}, respectively.'.format( observations, observation_index_points)) # Default to a constant zero function, borrowing the dtype from # index_points to ensure consistency. if mean_fn is None: mean_fn = lambda x: tf.zeros([1], dtype=dtype) else: if not callable(mean_fn): raise ValueError('`mean_fn` must be a Python callable') self._name = name self._observation_index_points = observation_index_points self._observations = observations self._observation_noise_variance = observation_noise_variance self._predictive_noise_variance = predictive_noise_variance self._jitter = jitter self._validate_args = validate_args with tf.name_scope('init'): conditional_kernel = tfpk.SchurComplement( base_kernel=kernel, fixed_inputs=observation_index_points, diag_shift=tfp_util.DeferredTensor( observation_noise_variance, lambda x: jitter + x)) # Special logic for mean_fn only; SchurComplement already handles the # case of empty observations (ie, falls back to base_kernel). if _is_empty_observation_data( feature_ndims=kernel.feature_ndims, observation_index_points=observation_index_points, observations=observations): conditional_mean_fn = mean_fn else: _validate_observation_data( kernel=kernel, observation_index_points=observation_index_points, observations=observations) def conditional_mean_fn(x): """Conditional mean.""" observations = tf.convert_to_tensor(self._observations) observation_index_points = tf.convert_to_tensor( self._observation_index_points) k_x_obs_linop = tf.linalg.LinearOperatorFullMatrix( kernel.matrix(x, observation_index_points)) chol_linop = tf.linalg.LinearOperatorLowerTriangular( conditional_kernel.divisor_matrix_cholesky( fixed_inputs=observation_index_points)) diff = observations - mean_fn(observation_index_points) return mean_fn(x) + k_x_obs_linop.matvec( chol_linop.solvevec(chol_linop.solvevec(diff), adjoint=True)) super(GaussianProcessRegressionModel, self).__init__( kernel=conditional_kernel, mean_fn=conditional_mean_fn, index_points=index_points, jitter=jitter, # What the GP super class calls "observation noise variance" we call # here the "predictive noise variance". We use the observation noise # variance for the fit/solve process above, and predictive for # downstream computations like sampling. observation_noise_variance=predictive_noise_variance, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name='RelaxedBernoulli'): """Construct RelaxedBernoulli distributions. Args: temperature: A `Tensor`, representing the temperature of a set of RelaxedBernoulli distributions. The temperature values should be positive. logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent RelaxedBernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([logits, probs, temperature], tf.float32) self._temperature = tensor_util.convert_nonref_to_tensor( temperature, name='temperature', dtype=dtype) self._probs = tensor_util.convert_nonref_to_tensor(probs, name='probs', dtype=dtype) self._logits = tensor_util.convert_nonref_to_tensor(logits, name='logits', dtype=dtype) if logits is None: logits_parameter = tfp_util.DeferredTensor( lambda x: tf.math.log(x) - tf.math.log1p(-x), self._probs) else: logits_parameter = self._logits shape = tf.broadcast_static_shape(logits_parameter.shape, self._temperature.shape) logistic_scale = tfp_util.DeferredTensor(tf.math.reciprocal, self._temperature) logistic_loc = tfp_util.DeferredTensor( lambda x: x * logistic_scale, logits_parameter, shape=shape) self._transformed_logistic = ( transformed_distribution.TransformedDistribution( distribution=logistic.Logistic( logistic_loc, logistic_scale, allow_nan_stats=allow_nan_stats, name=name + '/Logistic'), bijector=sigmoid_bijector.Sigmoid())) super(RelaxedBernoulli, self).__init__(dtype=dtype, reparameterization_type=reparameterization. FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def base_distributions(draw, dist_name=None, batch_shape=None, event_dim=None, enable_vars=False, eligibility_filter=lambda name: True): """Strategy for drawing arbitrary base Distributions. This does not draw compound distributions like `Independent`, `MixtureSameFamily`, or `TransformedDistribution`; only base Distributions that do not accept other Distributions as arguments. Args: draw: Hypothesis strategy sampler supplied by `@hps.composite`. dist_name: Optional Python `str`. If given, the produced distributions will all have this type. batch_shape: An optional `TensorShape`. The batch shape of the resulting Distribution. Hypothesis will pick a batch shape if omitted. event_dim: Optional Python int giving the size of each of the distribution's parameters' event dimensions. This is shared across all parameters, permitting square event matrices, compatible location and scale Tensors, etc. If omitted, Hypothesis will choose one. enable_vars: TODO(bjp): Make this `True` all the time and put variable initialization in slicing_test. If `False`, the returned parameters are all `tf.Tensor`s and not {`tf.Variable`, `tfp.util.DeferredTensor` `tfp.util.TransformedVariable`}. eligibility_filter: Optional Python callable. Blacklists some Distribution class names so they will not be drawn at the top level. Returns: dists: A strategy for drawing Distributions with the specified `batch_shape` (or an arbitrary one if omitted). """ if dist_name is None: names = [ k for k in INSTANTIABLE_BASE_DISTS.keys() if eligibility_filter(k) ] dist_name = draw(hps.sampled_from(sorted(names))) if dist_name == 'Empirical': variants = [ k for k in INSTANTIABLE_BASE_DISTS.keys() if eligibility_filter(k) and 'Empirical' in k ] dist_name = draw(hps.sampled_from(sorted(variants))) if batch_shape is None: batch_shape = draw(tfp_hps.shapes()) params_kwargs = draw( broadcasting_params(dist_name, batch_shape, event_dim=event_dim, enable_vars=enable_vars)) params_constrained = constraint_for(dist_name)(params_kwargs) # Sometimes the "distribution constraint" fn may replace c2t-tracking # DeferredTensor params with Tensor params (e.g. fix_triangular). In such # cases, we preserve the c2t-tracking DeferredTensors by wrapping them but # ignoring the value. for k in params_constrained: if (k in params_kwargs and isinstance(params_kwargs[k], tfp_util.DeferredTensor) and params_kwargs[k] is not params_constrained[k]): def constrained_value(v, val=params_constrained[k]): # While the gradient to v will be 0, we only care about the c2t counts. return v * 0 + val params_constrained[k] = tfp_util.DeferredTensor( params_kwargs[k], constrained_value) hp.note('Forming dist {} with constrained parameters {}'.format( dist_name, params_constrained)) assert_shapes_unchanged(params_kwargs, params_constrained) params_constrained['validate_args'] = True dist_cls = INSTANTIABLE_BASE_DISTS[dist_name].cls result_dist = dist_cls(**params_constrained) if batch_shape != result_dist.batch_shape: msg = ('Distributions strategy generated a bad batch shape ' 'for {}, should have been {}.').format(result_dist, batch_shape) raise AssertionError(msg) return result_dist
def base_distributions(draw, dist_name=None, batch_shape=None, event_dim=None, enable_vars=False, eligibility_filter=lambda name: True, validate_args=True): """Strategy for drawing arbitrary base Distributions. This does not draw compound distributions like `Independent`, `MixtureSameFamily`, or `TransformedDistribution`; only base Distributions that do not accept other Distributions as arguments. Args: draw: Hypothesis strategy sampler supplied by `@hps.composite`. dist_name: Optional Python `str`. If given, the produced distributions will all have this type. batch_shape: An optional `TensorShape`. The batch shape of the resulting Distribution. Hypothesis will pick a batch shape if omitted. event_dim: Optional Python int giving the size of each of the distribution's parameters' event dimensions. This is shared across all parameters, permitting square event matrices, compatible location and scale Tensors, etc. If omitted, Hypothesis will choose one. enable_vars: TODO(bjp): Make this `True` all the time and put variable initialization in slicing_test. If `False`, the returned parameters are all `tf.Tensor`s and not {`tf.Variable`, `tfp.util.DeferredTensor` `tfp.util.TransformedVariable`}. eligibility_filter: Optional Python callable. Blocks some Distribution class names so they will not be drawn at the top level. validate_args: Python `bool`; whether to enable runtime assertions. Returns: dists: A strategy for drawing Distributions with the specified `batch_shape` (or an arbitrary one if omitted). """ if dist_name is None: names = [k for k in INSTANTIABLE_BASE_DISTS if eligibility_filter(k)] dist_name = draw(hps.sampled_from(sorted(names))) if dist_name == 'Empirical': variants = [ k for k in INSTANTIABLE_BASE_DISTS if eligibility_filter(k) and 'Empirical' in k ] dist_name = draw(hps.sampled_from(sorted(variants))) if dist_name == 'SphericalUniform': return draw( spherical_uniforms(batch_shape=batch_shape, event_dim=event_dim, validate_args=validate_args)) if batch_shape is None: batch_shape = draw(tfp_hps.shapes()) # Draw raw parameters params_kwargs = draw( broadcasting_params(dist_name, batch_shape, event_dim=event_dim, enable_vars=enable_vars)) hp.note('Forming dist {} with raw parameters {}'.format( dist_name, params_kwargs)) # Constrain them to legal values params_constrained = constraint_for(dist_name)(params_kwargs) # Sometimes the "distribution constraint" fn may replace c2t-tracking # DeferredTensor params with Tensor params (e.g. fix_triangular). In such # cases, we preserve the c2t-tracking DeferredTensors by wrapping them but # ignoring the value. We similarly reinstate raw tf.Variables, so they # appear in the distribution's `variables` list and can be initialized. for k in params_constrained: # In JAX_MODE, tfp_util.DeferredTensor is a function, not a class, so we # disable this check entirely. if (not JAX_MODE and k in params_kwargs and isinstance(params_kwargs[k], (tfp_util.DeferredTensor, tf.Variable)) and params_kwargs[k] is not params_constrained[k]): def constrained_value(v, val=params_constrained[k]): # pylint: disable=cell-var-from-loop # While the gradient to v will be 0, we only care about the c2t counts. return v * 0 + val params_constrained[k] = tfp_util.DeferredTensor( params_kwargs[k], constrained_value) hp.note('Forming dist {} with constrained parameters {}'.format( dist_name, params_constrained)) assert_shapes_unchanged(params_kwargs, params_constrained) params_constrained['validate_args'] = validate_args if dist_name in ['Wishart', 'WishartTriL']: # With the default `input_output_cholesky = False`, Wishart occasionally # produces samples for which the Cholesky decompositions fail, causing # an error in testDistribution when `log_prob` is called on a sample. params_constrained['input_output_cholesky'] = True # Actually construct the distribution dist_cls = INSTANTIABLE_BASE_DISTS[dist_name].cls result_dist = dist_cls(**params_constrained) # Check that the batch shape came out as expected if batch_shape != result_dist.batch_shape: msg = ('Distributions strategy generated a bad batch shape ' 'for {}, should have been {}.').format(result_dist, batch_shape) raise AssertionError(msg) return result_dist
def __init__(self, low=None, high=None, hinge_softness=None, validate_args=False, name='soft_clip'): """Instantiates the SoftClip bijector. Args: low: Optional float `Tensor` lower bound. If `None`, the lower-bound constraint is omitted. Default value: `None`. high: Optional float `Tensor` upper bound. If `None`, the upper-bound constraint is omitted. Default value: `None`. hinge_softness: Optional nonzero float `Tensor`. Controls the softness of the constraint at the boundaries; values outside of the constraint set are mapped into intervals of width approximately `log(2) * hinge_softness` on the interior of each boundary. High softness reserves more space for values outside of the constraint set, leading to greater distortion of inputs *within* the constraint set, but improved numerical stability near the boundaries. Default value: `None` (`1.0`). validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. """ parameters = dict(locals()) with tf.name_scope(name): dtype = dtype_util.common_dtype([low, high, hinge_softness], dtype_hint=tf.float32) low = tensor_util.convert_nonref_to_tensor(low, name='low', dtype=dtype) high = tensor_util.convert_nonref_to_tensor(high, name='high', dtype=dtype) hinge_softness = tensor_util.convert_nonref_to_tensor( hinge_softness, name='hinge_softness', dtype=dtype) softplus_bijector = softplus.Softplus( hinge_softness=hinge_softness) negate = tf.convert_to_tensor(-1., dtype=dtype) components = [] if low is not None and high is not None: # Support reference tensors (eg Variables) for `high` and `low` by # deferring all computation on them until needed. width = tfp_util.DeferredTensor( pretransformed_input=high, transform_fn=lambda high: high - low) negated_shrinkage_factor = tfp_util.DeferredTensor( pretransformed_input=width, transform_fn=lambda w: tf.cast( # pylint: disable=g-long-lambda negate * w / softplus_bijector.forward(w), dtype=dtype)) # Implement the soft constraint from 'Mathematical Details' above: # softclip(x) := -softplus(width - softplus(x - low)) * # (width) / (softplus(width)) + high components = [ shift.Shift(high), scale.Scale(negated_shrinkage_factor), softplus_bijector, shift.Shift(width), scale.Scale(negate), softplus_bijector, shift.Shift(tfp_util.DeferredTensor(low, lambda x: -x)) ] elif low is not None: # Implement a soft lower bound: # softlower(x) := softplus(x - low) + low components = [ shift.Shift(low), softplus_bijector, shift.Shift(tfp_util.DeferredTensor(low, lambda x: -x)) ] elif high is not None: # Implement a soft upper bound: # softupper(x) := -softplus(high - x) + high components = [ shift.Shift(high), scale.Scale(negate), softplus_bijector, scale.Scale(negate), shift.Shift(high) ] self._low = low self._high = high self._hinge_softness = hinge_softness self._chain = chain.Chain(components, validate_args=validate_args) super(SoftClip, self).__init__(forward_min_event_ndims=0, dtype=dtype, validate_args=validate_args, parameters=parameters, is_constant_jacobian=not components, name=name)
def __init__(self, df, kernel, index_points=None, observation_index_points=None, observations=None, observation_noise_variance=0., predictive_noise_variance=None, mean_fn=None, cholesky_fn=None, marginal_fn=None, validate_args=False, allow_nan_stats=False, name='StudentTProcessRegressionModel', _conditional_kernel=None, _conditional_mean_fn=None): """Construct a StudentTProcessRegressionModel instance. Args: df: Positive Floating-point `Tensor` representing the degrees of freedom. Must be greather than 2. kernel: `PositiveSemidefiniteKernel`-like instance representing the StP's covariance function. index_points: `float` `Tensor` representing finite collection, or batch of collections, of points in the index set over which the STP is defined. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims` and `e` is the number (size) of index points in each batch. Ultimately this distribution corresponds to an `e`-dimensional multivariate normal. The batch shape must be broadcastable with `kernel.batch_shape`. observation_index_points: `float` `Tensor` representing finite collection, or batch of collections, of points in the index set for which some data has been observed. Shape has the form `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature dimensions and must equal `kernel.feature_ndims`, and `e` is the number (size) of index points in each batch. `[b1, ..., bB, e]` must be broadcastable with the shape of `observations`, and `[b1, ..., bB]` must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc). observations: `float` `Tensor` representing collection, or batch of collections, of observations corresponding to `observation_index_points`. Shape has the form `[b1, ..., bB, e]`, which must be brodcastable with the batch and example shapes of `observation_index_points`. The batch shape `[b1, ..., bB]` must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). observation_noise_variance: `float` `Tensor` representing the variance of the noise in the Normal likelihood distribution of the model. May be batched, in which case the batch shape must be broadcastable with the shapes of all other batched parameters (`kernel.batch_shape`, `index_points`, etc.). Default value: `0.` predictive_noise_variance: `float` `Tensor` representing the variance in the posterior predictive model. If `None`, we simply re-use `observation_noise_variance` for the posterior predictive noise. If set explicitly, however, we use this value. This allows us, for example, to omit predictive noise variance (by setting this to zero) to obtain noiseless posterior predictions of function values, conditioned on noisy observations. mean_fn: Python `callable` that acts on `index_points` to produce a collection, or batch of collections, of mean values at `index_points`. Takes a `Tensor` of shape `[b1, ..., bB, f1, ..., fF]` and returns a `Tensor` whose shape is broadcastable with `[b1, ..., bB]`. Default value: `None` implies the constant zero function. cholesky_fn: Callable which takes a single (batch) matrix argument and returns a Cholesky-like lower triangular factor. Default value: `None`, in which case `make_cholesky_with_jitter_fn`. marginal_fn: A Python callable that takes a location, covariance matrix, optional `validate_args`, `allow_nan_stats` and `name` arguments, and returns a multivariate Student-T subclass of `tfd.Distribution`. Default value: `None`, in which case a Cholesky-factorizing function is is created using `make_cholesky_with_jitter_fn`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. Default value: `False`. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value `NaN` to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. Default value: `False`. name: Python `str` name prefixed to Ops created by this class. Default value: 'StudentTProcessRegressionModel'. _conditional_kernel: Internal parameter -- do not use. _conditional_mean_fn: Internal parameter -- do not use. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([ df, kernel, index_points, observation_noise_variance, observations ], tf.float32) df = tensor_util.convert_nonref_to_tensor(df, dtype=dtype, name='df') index_points = tensor_util.convert_nonref_to_tensor( index_points, dtype=dtype, name='index_points') observation_index_points = tensor_util.convert_nonref_to_tensor( observation_index_points, dtype=dtype, name='observation_index_points') observations = tensor_util.convert_nonref_to_tensor( observations, dtype=dtype, name='observations') observation_noise_variance = tensor_util.convert_nonref_to_tensor( observation_noise_variance, dtype=dtype, name='observation_noise_variance') predictive_noise_variance = tensor_util.convert_nonref_to_tensor( predictive_noise_variance, dtype=dtype, name='predictive_noise_variance') if predictive_noise_variance is None: predictive_noise_variance = observation_noise_variance if (observation_index_points is None) != (observations is None): raise ValueError( '`observations` and `observation_index_points` must both be given ' 'or None. Got {} and {}, respectively.'.format( observations, observation_index_points)) # Default to a constant zero function, borrowing the dtype from # index_points to ensure consistency. if mean_fn is None: mean_fn = lambda x: tf.zeros([1], dtype=dtype) else: if not callable(mean_fn): raise ValueError('`mean_fn` must be a Python callable') if cholesky_fn is None: cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn() self._observation_index_points = observation_index_points self._observations = observations self._observation_noise_variance = observation_noise_variance self._predictive_noise_variance = predictive_noise_variance with tf.name_scope('init'): if _conditional_kernel is None: _conditional_kernel = DampedSchurComplement( df=df, schur_complement=tfpk.SchurComplement( base_kernel=kernel, fixed_inputs=self._observation_index_points, diag_shift=observation_noise_variance), fixed_inputs_observations=self._observations, validate_args=validate_args) # Special logic for mean_fn only; SchurComplement already handles the # case of empty observations (ie, falls back to base_kernel). if _is_empty_observation_data( feature_ndims=kernel.feature_ndims, observation_index_points=observation_index_points, observations=observations): if _conditional_mean_fn is None: _conditional_mean_fn = mean_fn else: _validate_observation_data( kernel=kernel, observation_index_points=observation_index_points, observations=observations) n = tf.cast(ps.shape(observations)[-1], dtype=dtype) df = tfp_util.DeferredTensor(df, lambda x: x + n) if _conditional_mean_fn is None: def conditional_mean_fn(x): """Conditional mean.""" observations = tf.convert_to_tensor( self._observations) observation_index_points = tf.convert_to_tensor( self._observation_index_points) k_x_obs_linop = tf.linalg.LinearOperatorFullMatrix( kernel.matrix(x, observation_index_points)) chol_linop = tf.linalg.LinearOperatorLowerTriangular( _conditional_kernel.divisor_matrix_cholesky( fixed_inputs=observation_index_points)) diff = observations - mean_fn( observation_index_points) return mean_fn(x) + k_x_obs_linop.matvec( chol_linop.solvevec(chol_linop.solvevec(diff), adjoint=True)) _conditional_mean_fn = conditional_mean_fn super(StudentTProcessRegressionModel, self).__init__( df=df, kernel=_conditional_kernel, mean_fn=_conditional_mean_fn, cholesky_fn=cholesky_fn, index_points=index_points, observation_noise_variance=predictive_noise_variance, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters