def _default_event_space_bijector(self): # TODO(b/145620027) Finalize choice of bijector. return sigmoid_bijector.Sigmoid(low=tf.constant(-np.pi, dtype=self.dtype), high=tf.constant(np.pi, dtype=self.dtype), validate_args=self.validate_args)
def _negative_concentration_bijector(self): # Constructed dynamically so that `loc + scale / concentration` is # tape-safe. high = self.loc + tf.math.abs(self.scale / self.concentration) return sigmoid_bijector.Sigmoid(low=self.loc, high=high, validate_args=self.validate_args)
def __init__(self, temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="RelaxedBernoulli"): """Construct RelaxedBernoulli distributions. Args: temperature: An 0-D `Tensor`, representing the temperature of a set of RelaxedBernoulli distributions. The temperature should be positive. logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent RelaxedBernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ parameters = dict(locals()) with tf.compat.v1.name_scope( name, values=[logits, probs, temperature]) as name: dtype = dtype_util.common_dtype([logits, probs, temperature], tf.float32) self._temperature = tf.convert_to_tensor( value=temperature, name="temperature", dtype=dtype) if validate_args: with tf.control_dependencies( [tf.compat.v1.assert_positive(temperature)]): self._temperature = tf.identity(self._temperature) self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, validate_args=validate_args, dtype=dtype) super(RelaxedBernoulli, self).__init__( distribution=logistic.Logistic( self._logits / self._temperature, 1. / self._temperature, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name + "/Logistic"), bijector=sigmoid_bijector.Sigmoid(validate_args=validate_args), validate_args=validate_args, name=name) self._parameters = parameters
def _default_event_space_bijector(self): # TODO(b/145620027) Finalize choice of bijector. return chain_bijector.Chain([ shift_bijector.Shift(shift=-np.pi, validate_args=self.validate_args), scale_bijector.Scale(scale=2. * np.pi, validate_args=self.validate_args), sigmoid_bijector.Sigmoid(validate_args=self.validate_args) ], validate_args=self.validate_args)
def _default_event_space_bijector(self): if tensor_util.is_ref(self.low) or tensor_util.is_ref(self.high): scale = DeferredTensor(self.high, lambda x: x - self.low) else: scale = self.high - self.low return chain_bijector.Chain([ shift_bijector.Shift(shift=self.low, validate_args=self.validate_args), scale_bijector.Scale(scale=scale, validate_args=self.validate_args), sigmoid_bijector.Sigmoid(validate_args=self.validate_args) ], validate_args=self.validate_args)
def _transformed_logistic(self): logistic_scale = tf.math.reciprocal(self._temperature) logits_parameter = self._logits_parameter_no_checks() logistic_loc = logits_parameter * logistic_scale return transformed_distribution.TransformedDistribution( distribution=logistic.Logistic( logistic_loc, logistic_scale, allow_nan_stats=self.allow_nan_stats), bijector=sigmoid_bijector.Sigmoid())
def _default_event_space_bijector(self): low = tfp_util.DeferredTensor(self.low, lambda x: x) scale = tfp_util.DeferredTensor(self.high, lambda x: x - self.low) return chain_bijector.Chain([ shift_bijector.Shift(shift=low, validate_args=self.validate_args), scale_bijector.Scale(scale=scale, validate_args=self.validate_args), sigmoid_bijector.Sigmoid(validate_args=self.validate_args) ], validate_args=self.validate_args)
def _negative_concentration_bijector(self): # Constructed dynamically so that `scale * reciprocal(concentration)` is # tape-safe. return chain_bijector.Chain([ shift_bijector.Shift(shift=self.loc, validate_args=self.validate_args), # TODO(b/146568897): Resolve numerical issues by implementing a new # bijector instead of multiplying `scale` by `(1. - 1e-6)`. scale_bijector.Scale( scale=-(self.scale * tf.math.reciprocal(self.concentration) * (1. - 1e-6)), validate_args=self.validate_args), sigmoid_bijector.Sigmoid(validate_args=self.validate_args) ], validate_args=self.validate_args)
def _default_event_space_bijector(self): # TODO(b/146568897): Resolve numerical issues by implementing a new bijector # instead of multiplying `scale` by `(1. - 1e-6)`. if tensor_util.is_ref(self.low) or tensor_util.is_ref(self.high): scale = DeferredTensor( self.high, lambda x: (x - self.low) * (1. - 1e-6), shape=tf.broadcast_static_shape(self.high.shape, self.low.shape)) else: scale = (self.high - self.low) * (1. - 1e-6) return chain_bijector.Chain([ shift_bijector.Shift(shift=self.low, validate_args=self.validate_args), scale_bijector.Scale(scale=scale, validate_args=self.validate_args), sigmoid_bijector.Sigmoid(validate_args=self.validate_args) ], validate_args=self.validate_args)
def __init__(self, loc, scale, validate_args=False, allow_nan_stats=True, name='LogitNormal'): """Construct a logit-normal distribution. The LogititNormal distribution models positive-valued random variables whose logit (i.e., sigmoid_inverse, i.e., `log(p) - log1p(-p)`) is normally distributed with mean `loc` and standard deviation `scale`. It is constructed as the sigmoid transformation, (i.e., `1 / (1 + exp(-x))`) of a Normal distribution. Args: loc: Floating-point `Tensor`; the mean of the underlying Normal distribution(s). Must broadcast with `scale`. scale: Floating-point `Tensor`; the stddev of the underlying Normal distribution(s). Must broadcast with `loc`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. """ parameters = dict(locals()) with tf.name_scope(name) as name: super(LogitNormal, self).__init__(distribution=normal.Normal(loc=loc, scale=scale), bijector=sigmoid_bijector.Sigmoid(), validate_args=validate_args, parameters=parameters, name=name)
def _default_event_space_bijector(self): return sigmoid_bijector.Sigmoid(low=self.low, high=self.high, validate_args=self.validate_args)
def _default_event_space_bijector(self): return sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
def _asvi_convex_update_for_base_distribution(dist, mean_field=False, initial_prior_weight=0.5, sample_shape=None): """Creates a trainable surrogate for a (non-meta, non-joint) distribution.""" posterior_batch_shape = dist.batch_shape_tensor() if sample_shape is not None: posterior_batch_shape = ps.concat([ posterior_batch_shape, distribution_util.expand_to_vector(sample_shape) ], axis=0) temp_params_dict = {'name': _get_name(dist)} all_parameter_properties = dist.parameter_properties(dtype=dist.dtype) for param, prior_value in dist.parameters.items(): if (param in (_NON_STATISTICAL_PARAMS + _NON_TRAINABLE_PARAMS) or prior_value is None): temp_params_dict[param] = prior_value continue param_properties = all_parameter_properties[param] try: bijector = param_properties.default_constraining_bijector_fn() except NotImplementedError: bijector = identity.Identity() param_shape = ps.concat([ posterior_batch_shape, ps.shape(prior_value)[ps.rank(prior_value) - param_properties.event_ndims:] ], axis=0) # Initialize the mean-field parameter as a (constrained) standard # normal sample. # pylint: disable=cell-var-from-loop # Safe because the state utils guarantee to either call `init_fn` # immediately upon yielding, or not at all. mean_field_parameter = yield trainable_state_util.Parameter( init_fn=lambda seed: ( # pylint: disable=g-long-lambda bijector.forward( samplers.normal(shape=bijector.inverse_event_shape( param_shape), seed=seed))), name='mean_field_parameter_{}_{}'.format(_get_name(dist), param), constraining_bijector=bijector) if mean_field: temp_params_dict[param] = mean_field_parameter else: prior_weight = yield trainable_state_util.Parameter( init_fn=lambda: tf.fill( # pylint: disable=g-long-lambda dims=param_shape, value=tf.cast(initial_prior_weight, tf.convert_to_tensor(prior_value).dtype)), name='prior_weight_{}_{}'.format(_get_name(dist), param), constraining_bijector=sigmoid.Sigmoid()) temp_params_dict[param] = prior_weight * prior_value + ( (1. - prior_weight) * mean_field_parameter) # pylint: enable=cell-var-from-loop return type(dist)(**temp_params_dict)
def __init__(self, temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name='RelaxedBernoulli'): """Construct RelaxedBernoulli distributions. Args: temperature: A `Tensor`, representing the temperature of a set of RelaxedBernoulli distributions. The temperature values should be positive. logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent RelaxedBernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([logits, probs, temperature], tf.float32) self._temperature = tensor_util.convert_nonref_to_tensor( temperature, name='temperature', dtype=dtype) self._probs = tensor_util.convert_nonref_to_tensor(probs, name='probs', dtype=dtype) self._logits = tensor_util.convert_nonref_to_tensor(logits, name='logits', dtype=dtype) if logits is None: logits_parameter = tfp_util.DeferredTensor( lambda x: tf.math.log(x) - tf.math.log1p(-x), self._probs) else: logits_parameter = self._logits shape = tf.broadcast_static_shape(logits_parameter.shape, self._temperature.shape) logistic_scale = tfp_util.DeferredTensor(tf.math.reciprocal, self._temperature) logistic_loc = tfp_util.DeferredTensor( lambda x: x * logistic_scale, logits_parameter, shape=shape) self._transformed_logistic = ( transformed_distribution.TransformedDistribution( distribution=logistic.Logistic( logistic_loc, logistic_scale, allow_nan_stats=allow_nan_stats, name=name + '/Logistic'), bijector=sigmoid_bijector.Sigmoid())) super(RelaxedBernoulli, self).__init__(dtype=dtype, reparameterization_type=reparameterization. FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def __init__(self, loc, scale, num_probit_terms_approx=2, validate_args=False, allow_nan_stats=True, name='LogitNormal'): """Construct a logit-normal distribution. The LogitNormal distribution models random variables between 0 and 1 whose logit (i.e., sigmoid_inverse, i.e., `log(p) - log1p(-p)`) is normally distributed with mean `loc` and standard deviation `scale`. It is constructed as the sigmoid transformation, (i.e., `1 / (1 + exp(-x))`) of a Normal distribution. Args: loc: Floating-point `Tensor`; the mean of the underlying Normal distribution(s). Must broadcast with `scale`. scale: Floating-point `Tensor`; the stddev of the underlying Normal distribution(s). Must broadcast with `loc`. num_probit_terms_approx: The `k` used in the approximation, `sigmoid(x) approx= sum_i^k p[k,i] Normal(0, c[k, i]).cdf(x)` where `sum_i^k p[k,i]=1` and `p[k,i],c[k,i] > 0` [(Monahan and Stefanski, 1989)][1] and used in `mean_*_approx` functions [(Owen, 1980)][2]. Must be a python scalar integer between `1` and `8` (inclusive). Using `num_probit_terms_approx=2` should result in `mean_approx` error not exceeding `10**-4`. Default value: `2`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. #### References [1]: Monahan, John H., and Leonard A. Stefanski. Normal scale mixture approximations to the logistic distribution with applications. North Carolina State University. Dept. of Statistics, 1989. http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.154.5032 [2]: Owen, Donald Bruce. "A table of normal integrals: A table." Communications in Statistics-Simulation and Computation 9.4 (1980): 389-419. https://www.tandfonline.com/doi/abs/10.1080/03610918008812164 """ parameters = dict(locals()) num_probit_terms_approx = int(num_probit_terms_approx) if num_probit_terms_approx < 1 or num_probit_terms_approx > 8: raise ValueError( 'Argument `num_probit_terms_approx` must be an integer between ' '`1` and `8` (inclusive).') self._num_probit_terms_approx = num_probit_terms_approx with tf.name_scope(name) as name: super(LogitNormal, self).__init__(distribution=normal_lib.Normal(loc=loc, scale=scale), bijector=sigmoid_bijector.Sigmoid(), validate_args=validate_args, parameters=parameters, name=name)
def _default_event_space_bijector(self): # TODO(b/145620027) Finalize choice of bijector. return sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
def _asvi_convex_update_for_base_distribution(dist, mean_field, initial_prior_weight, sample_shape=None, variables=None, seed=None): """Creates a trainable surrogate for a (non-meta, non-joint) distribution.""" if variables is None: variables = {} posterior_batch_shape = dist.batch_shape_tensor() if sample_shape is not None: posterior_batch_shape = ps.concat([ posterior_batch_shape, distribution_util.expand_to_vector(sample_shape) ], axis=0) # Create variables backing each parameter, if needed. all_parameter_properties = dist.parameter_properties(dtype=dist.dtype) for param, prior_value in dist.parameters.items(): if (param in variables or param in (_NON_STATISTICAL_PARAMS + _NON_TRAINABLE_PARAMS) or prior_value is None): continue param_properties = all_parameter_properties[param] try: bijector = param_properties.default_constraining_bijector_fn() except NotImplementedError: bijector = identity.Identity() param_shape = ps.concat([ posterior_batch_shape, ps.shape(prior_value)[ps.rank(prior_value) - param_properties.event_ndims:] ], axis=0) prior_weight = ( None if mean_field # pylint: disable=g-long-ternary else tfp_util.TransformedVariable(initial_value=tf.fill( dims=param_shape, value=tf.cast(initial_prior_weight, tf.convert_to_tensor(prior_value).dtype)), bijector=sigmoid.Sigmoid(), name='prior_weight/{}/{}'.format( _get_name(dist), param))) # Initialize the mean-field parameter as a (constrained) standard # normal sample. seed, param_seed = samplers.split_seed(seed) variables[param] = ASVIParameters( prior_weight=prior_weight, mean_field_parameter=tfp_util.TransformedVariable( initial_value=bijector.forward( samplers.normal( shape=bijector.inverse_event_shape(param_shape), seed=param_seed)), bijector=bijector, name='mean_field_parameter/{}/{}'.format( _get_name(dist), param))) temp_params_dict = {'name': _get_name(dist)} for param, prior_value in dist.parameters.items(): if param in (_NON_STATISTICAL_PARAMS + _NON_TRAINABLE_PARAMS) or prior_value is None: temp_params_dict[param] = prior_value else: if mean_field: temp_params_dict[param] = variables[param].mean_field_parameter else: temp_params_dict[param] = ( variables[param].prior_weight * prior_value + ((1. - variables[param].prior_weight) * variables[param].mean_field_parameter)) return type(dist)(**temp_params_dict), variables
def _make_asvi_trainable_variables(prior, mean_field=False, initial_prior_weight=0.5): """Generates parameter dictionaries given a prior distribution and list.""" with tf.name_scope('make_asvi_trainable_variables'): param_dicts = [] prior_dists = prior._get_single_sample_distributions() # pylint: disable=protected-access for dist in prior_dists: original_dist = dist.distribution if isinstance(dist, Root) else dist substituted_dist = _as_trainable_family(original_dist) # Grab the base distribution if it exists try: actual_dist = substituted_dist.distribution except AttributeError: actual_dist = substituted_dist new_params_dict = {} # Build trainable ASVI representation for each distribution's parameters. parameter_properties = actual_dist.parameter_properties( dtype=actual_dist.dtype) if isinstance(original_dist, sample.Sample): posterior_batch_shape = ps.concat([ actual_dist.batch_shape_tensor(), distribution_util.expand_to_vector(original_dist.sample_shape) ], axis=0) else: posterior_batch_shape = actual_dist.batch_shape_tensor() for param, value in actual_dist.parameters.items(): if param in (_NON_STATISTICAL_PARAMS + _NON_TRAINABLE_PARAMS) or value is None: continue actual_event_shape = parameter_properties[param].shape_fn( actual_dist.event_shape_tensor()) try: bijector = parameter_properties[ param].default_constraining_bijector_fn() except NotImplementedError: bijector = identity.Identity() if mean_field: prior_weight = None else: unconstrained_ones = tf.ones( shape=ps.concat([ posterior_batch_shape, bijector.inverse_event_shape_tensor( actual_event_shape) ], axis=0), dtype=tf.convert_to_tensor(value).dtype) prior_weight = tfp_util.TransformedVariable( initial_prior_weight * unconstrained_ones, bijector=sigmoid.Sigmoid(), name='prior_weight/{}/{}'.format(dist.name, param)) # If the prior distribution was a tfd.Sample wrapping a base # distribution, we want to give every single sample in the prior its # own lambda and alpha value (rather than having a single lambda and # alpha). if isinstance(original_dist, sample.Sample): value = tf.reshape( value, ps.concat([ actual_dist.batch_shape_tensor(), ps.ones(ps.rank_from_shape(original_dist.sample_shape)), actual_event_shape ], axis=0)) value = tf.broadcast_to( value, ps.concat([posterior_batch_shape, actual_event_shape], axis=0)) new_params_dict[param] = ASVIParameters( prior_weight=prior_weight, mean_field_parameter=tfp_util.TransformedVariable( value, bijector=bijector, name='mean_field_parameter/{}/{}'.format(dist.name, param))) param_dicts.append(new_params_dict) return param_dicts