def _transformed_beta(self, low=None, peak=None, high=None, temperature=None): low = tf.convert_to_tensor(self.low) if low is None else low peak = tf.convert_to_tensor(self.peak) if peak is None else peak high = tf.convert_to_tensor(self.high) if high is None else high temperature = (tf.convert_to_tensor(self.temperature) if temperature is None else temperature) scale = high - low concentration1 = (1. + temperature * (peak - low) / scale) concentration0 = (1. + temperature * (high - peak) / scale) return transformed_distribution.TransformedDistribution( distribution=beta.Beta(concentration1=concentration1, concentration0=concentration0, allow_nan_stats=self.allow_nan_stats), bijector=chain_bijector.Chain([ shift_bijector.Shift(shift=low), # Broadcasting scale on affine bijector to match batch dimension. # This prevents dimension mismatch for operations like cdf. # Note that `concentration1` incorporates the broadcast of all four # parameters. scale_bijector.Scale( scale=tf.broadcast_to(scale, ps.shape(concentration1))) ]))
def _make_mixture_dist(self, component_logits, locs, scales): """Builds a mixture of quantized logistic distributions. Args: component_logits: 4D `Tensor` of logits for the Categorical distribution over Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix]`. locs: 4D `Tensor` of location parameters for the Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix, num_channels]`. scales: 4D `Tensor` of location parameters for the Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix, num_channels]`. Returns: dist: A quantized logistic mixture `tfp.distribution` over the input data. """ mixture_distribution = categorical.Categorical(logits=component_logits) # Convert distribution parameters for pixel values in # `[self._low, self._high]` for use with `QuantizedDistribution` locs = self._low + 0.5 * (self._high - self._low) * (locs + 1.) scales *= 0.5 * (self._high - self._low) logistic_dist = quantized_distribution.QuantizedDistribution( distribution=transformed_distribution.TransformedDistribution( distribution=logistic.Logistic(loc=locs, scale=scales), bijector=shift.Shift(shift=tf.cast(-0.5, self.dtype))), low=self._low, high=self._high) dist = mixture_same_family.MixtureSameFamily( mixture_distribution=mixture_distribution, components_distribution=independent.Independent( logistic_dist, reinterpreted_batch_ndims=1)) return independent.Independent(dist, reinterpreted_batch_ndims=2)
def _build_posterior_for_one_parameter(param, batch_shape, seed): """Built a transformed-normal variational dist over a parameter's support.""" # Build a trainable Normal distribution. initial_loc = sample_uniform_initial_state(param, init_sample_shape=batch_shape, return_constrained=False, seed=seed) loc = tf.Variable(initial_value=initial_loc, name=param.name + '_loc') scale = tfp_util.TransformedVariable( tf.fill(tf.shape(initial_loc), value=tf.constant(0.02, initial_loc.dtype), name=param.name + '_scale'), softplus_lib.Softplus()) posterior_dist = normal_lib.Normal(loc=loc, scale=scale) # Ensure the `event_shape` of the variational distribution matches the # parameter. if (param.prior.event_shape.ndims is None or param.prior.event_shape.ndims > 0): posterior_dist = independent_lib.Independent( posterior_dist, reinterpreted_batch_ndims=param.prior.event_shape.ndims) # Transform to constrained parameter space. posterior_dist = transformed_distribution_lib.TransformedDistribution( posterior_dist, param.bijector, name='{}_posterior'.format(param.name)) return posterior_dist
def _transformed_logistic(self): logistic_scale = tf.math.reciprocal(self._temperature) logits_parameter = self._logits_parameter_no_checks() logistic_loc = logits_parameter * logistic_scale return transformed_distribution.TransformedDistribution( distribution=logistic.Logistic( logistic_loc, logistic_scale, allow_nan_stats=self.allow_nan_stats), bijector=sigmoid_bijector.Sigmoid())
def joint_prior_on_parameters_and_state(parameter_prior, parameterized_initial_state_prior_fn, parameter_constraining_bijector, prior_is_constrained=True): """Constructs a joint dist. from p(parameters) and p(state | parameters).""" if prior_is_constrained: parameter_prior = transformed_distribution.TransformedDistribution( parameter_prior, invert.Invert(parameter_constraining_bijector), name='unconstrained_parameter_prior') return joint_distribution_named.JointDistributionNamed( ParametersAndState( unconstrained_parameters=parameter_prior, state=lambda unconstrained_parameters: ( # pylint: disable=g-long-lambda parameterized_initial_state_prior_fn( parameter_constraining_bijector.forward( unconstrained_parameters)))))
def posterior_generator(): prior_gen = prior._model_coroutine() # pylint: disable=protected-access dist = next(prior_gen) i = 0 try: while True: original_dist = dist.distribution if isinstance(dist, Root) else dist if isinstance(original_dist, joint_distribution.JointDistribution): # TODO(kateslin): Build inner JD surrogate in # _make_asvi_trainable_variables to avoid rebuilding variables. raise TypeError( 'Argument `prior` cannot be a nested `JointDistribution`.') else: original_dist = _as_trainable_family(original_dist) try: actual_dist = original_dist.distribution except AttributeError: actual_dist = original_dist dist_params = actual_dist.parameters temp_params_dict = {} for param, value in dist_params.items(): if param in (_NON_STATISTICAL_PARAMS + _NON_TRAINABLE_PARAMS) or value is None: temp_params_dict[param] = value else: prior_weight = param_dicts[i][param].prior_weight mean_field_parameter = param_dicts[i][ param].mean_field_parameter if mean_field: temp_params_dict[param] = mean_field_parameter else: temp_params_dict[param] = prior_weight * value + ( 1. - prior_weight) * mean_field_parameter if isinstance(original_dist, sample.Sample): inner_dist = type(actual_dist)(**temp_params_dict) surrogate_dist = independent.Independent( inner_dist, reinterpreted_batch_ndims=ps.rank_from_shape( original_dist.sample_shape)) else: surrogate_dist = type(actual_dist)(**temp_params_dict) if isinstance(original_dist, transformed_distribution.TransformedDistribution): surrogate_dist = transformed_distribution.TransformedDistribution( surrogate_dist, bijector=original_dist.bijector) if isinstance(original_dist, independent.Independent): surrogate_dist = independent.Independent( surrogate_dist, reinterpreted_batch_ndims=original_dist .reinterpreted_batch_ndims) if isinstance(dist, Root): value_out = yield Root(surrogate_dist) else: value_out = yield surrogate_dist dist = prior_gen.send(value_out) i += 1 except StopIteration: pass
def build_factored_surrogate_posterior( event_shape=None, bijector=None, constraining_bijectors=None, initial_unconstrained_loc=_sample_uniform_initial_loc, initial_unconstrained_scale=1e-2, trainable_distribution_fn=_build_trainable_normal_dist, seed=None, validate_args=False, name=None): """Builds a joint variational posterior that factors over model variables. By default, this method creates an independent trainable Normal distribution for each variable, transformed using a bijector (if provided) to match the support of that variable. This makes extremely strong assumptions about the posterior: that it is approximately normal (or transformed normal), and that all model variables are independent. Args: event_shape: `Tensor` shape, or nested structure of `Tensor` shapes, specifying the event shape(s) of the posterior variables. bijector: Optional `tfb.Bijector` instance, or nested structure of such instances, defining support(s) of the posterior variables. The structure must match that of `event_shape` and may contain `None` values. A posterior variable will be modeled as `tfd.TransformedDistribution(underlying_dist, bijector)` if a corresponding constraining bijector is specified, otherwise it is modeled as supported on the unconstrained real line. constraining_bijectors: Deprecated alias for `bijector`. initial_unconstrained_loc: Optional Python `callable` with signature `tensor = initial_unconstrained_loc(shape, seed)` used to sample real-valued initializations for the unconstrained representation of each variable. May alternately be a nested structure of `Tensor`s, giving specific initial locations for each variable; these must have structure matching `event_shape` and shapes determined by the inverse image of `event_shape` under `bijector`, which may optionally be prefixed with a common batch shape. Default value: `functools.partial(tf.random.uniform, minval=-2., maxval=2., dtype=tf.float32)`. initial_unconstrained_scale: Optional scalar float `Tensor` initial scale for the unconstrained distributions, or a nested structure of `Tensor` initial scales for each variable. Default value: `1e-2`. trainable_distribution_fn: Optional Python `callable` with signature `trainable_dist = trainable_distribution_fn(initial_loc, initial_scale, event_ndims, validate_args)`. This is called for each model variable to build the corresponding factor in the surrogate posterior. It is expected that the distribution returned is supported on unconstrained real values. Default value: `functools.partial( tfp.experimental.vi.build_trainable_location_scale_distribution, distribution_fn=tfd.Normal)`, i.e., a trainable Normal distribution. seed: Python integer to seed the random number generator. This is used only when `initial_loc` is not specified. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_factored_surrogate_posterior'). Returns: surrogate_posterior: A `tfd.Distribution` instance whose samples have shape and structure matching that of `event_shape` or `initial_loc`. ### Examples Consider a Gamma model with unknown parameters, expressed as a joint Distribution: ```python Root = tfd.JointDistributionCoroutine.Root def model_fn(): concentration = yield Root(tfd.Exponential(1.)) rate = yield Root(tfd.Exponential(1.)) y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate), sample_shape=4) model = tfd.JointDistributionCoroutine(model_fn) ``` Let's use variational inference to approximate the posterior over the data-generating parameters for some observed `y`. We'll build a surrogate posterior distribution by specifying the shapes of the latent `rate` and `concentration` parameters, and that both are constrained to be positive. ```python surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=model.event_shape_tensor()[:-1], # Omit the observed `y`. bijector=[tfb.Softplus(), # Rate is positive. tfb.Softplus()]) # Concentration is positive. ``` This creates a trainable joint distribution, defined by variables in `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior` to fit this distribution by minimizing a divergence to the true posterior. ```python y = [0.2, 0.5, 0.3, 0.7] losses = tfp.vi.fit_surrogate_posterior( lambda rate, concentration: model.log_prob([rate, concentration, y]), surrogate_posterior=surrogate_posterior, num_steps=100, optimizer=tf.optimizers.Adam(0.1), sample_size=10) # After optimization, samples from the surrogate will approximate # samples from the true posterior. samples = surrogate_posterior.sample(100) posterior_mean = [tf.reduce_mean(x) for x in samples] # mean ~= [1.1, 2.1] posterior_std = [tf.math.reduce_std(x) for x in samples] # std ~= [0.3, 0.8] ``` If we wanted to initialize the optimization at a specific location, we can specify one when we build the surrogate posterior. This function requires the initial location to be specified in *unconstrained* space; we do this by inverting the constraining bijectors (note this section also demonstrates the creation of a dict-structured model). ```python initial_loc = {'concentration': 0.4, 'rate': 0.2} bijector={'concentration': tfb.Softplus(), # Rate is positive. 'rate': tfb.Softplus()} # Concentration is positive. initial_unconstrained_loc = tf.nest.map_fn( lambda b, x: b.inverse(x) if b is not None else x, bijector, initial_loc) surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=tf.nest.map_fn(tf.shape, initial_loc), bijector=bijector, initial_unconstrained_loc=initial_unconstrained_state, initial_unconstrained_scale=1e-4) ``` """ with tf.name_scope(name or 'build_factored_surrogate_posterior'): bijector = deprecation.deprecated_argument_lookup( 'bijector', bijector, 'constraining_bijectors', constraining_bijectors) seed = tfp_util.SeedStream(seed, salt='build_factored_surrogate_posterior') # Convert event shapes to Tensors. shallow_structure = _get_event_shape_shallow_structure(event_shape) event_shape = nest.map_structure_up_to( shallow_structure, lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape) if nest.is_nested(bijector): bijector = nest.map_structure( lambda b: identity.Identity() if b is None else b, bijector) # Support mismatched nested structures for backwards compatibility (e.g. # non-nested `event_shape` and a single-element list of `bijector`s). bijector = nest.pack_sequence_as(event_shape, nest.flatten(bijector)) event_space_bijector = joint_map.JointMap( bijector, validate_args=validate_args) else: event_space_bijector = bijector if event_space_bijector is None: unconstrained_event_shape = event_shape else: unconstrained_event_shape = ( event_space_bijector.inverse_event_shape_tensor(event_shape)) # Construct initial locations for the internal unconstrained dists. if callable(initial_unconstrained_loc): # Sample random initialization. initial_unconstrained_loc = nest.map_structure( lambda s: initial_unconstrained_loc(shape=s, seed=seed()), unconstrained_event_shape) if not nest.is_nested(initial_unconstrained_scale): initial_unconstrained_scale = nest.map_structure( lambda _: initial_unconstrained_scale, unconstrained_event_shape) # Extract the rank of each event, so that we build distributions with the # correct event shapes. unconstrained_event_ndims = nest.map_structure( ps.rank_from_shape, unconstrained_event_shape) # Build the component surrogate posteriors. unconstrained_distributions = nest.map_structure_up_to( unconstrained_event_shape, lambda loc, scale, ndims: trainable_distribution_fn( # pylint: disable=g-long-lambda loc, scale, ndims, validate_args=validate_args), initial_unconstrained_loc, initial_unconstrained_scale, unconstrained_event_ndims) base_distribution = ( joint_distribution_util.independent_joint_distribution_from_structure( unconstrained_distributions, validate_args=validate_args)) if event_space_bijector is None: return base_distribution return transformed_distribution.TransformedDistribution( base_distribution, event_space_bijector)
def build_split_flow_surrogate_posterior(event_shape, trainable_bijector, constraining_bijector=None, base_distribution=normal.Normal, batch_shape=(), dtype=tf.float32, validate_args=False, name=None): """Builds a joint variational posterior by splitting a normalizing flow. Args: event_shape: (Nested) event shape of the surrogate posterior. trainable_bijector: A trainable `tfb.Bijector` instance that operates on `Tensor`s (not structures), e.g. `tfb.MaskedAutoregressiveFlow` or `tfb.RealNVP`. This bijector transforms the base distribution before it is split. constraining_bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector` instances, that maps (nested) values in R^n to the support of the posterior. (This can be the `experimental_default_event_space_bijector` of the distribution over the prior latent variables.) Default value: `None` (i.e., the posterior is over R^n). base_distribution: A `tfd.Distribution` subclass parameterized by `loc` and `scale`. The base distribution for the transformed surrogate has `loc=0.` and `scale=1.`. Default value: `tfd.Normal`. batch_shape: The `batch_shape` of the output distribution. Default value: `()`. dtype: The `dtype` of the surrogate posterior. Default value: `tf.float32`. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_split_flow_surrogate_posterior'). Returns: surrogate_distribution: Trainable `tfd.TransformedDistribution` with event shape equal to `event_shape`. ### Examples ```python # Train a normalizing flow on the Eight Schools model [1]. treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.] treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.] model = tfd.JointDistributionNamed({ 'avg_effect': tfd.Normal(loc=0., scale=10., name='avg_effect'), 'log_stddev': tfd.Normal(loc=5., scale=1., name='log_stddev'), 'school_effects': lambda log_stddev, avg_effect: ( tfd.Independent( tfd.Normal( loc=avg_effect[..., None] * tf.ones(8), scale=tf.exp(log_stddev[..., None]) * tf.ones(8), name='school_effects'), reinterpreted_batch_ndims=1)), 'treatment_effects': lambda school_effects: tfd.Independent( tfd.Normal(loc=school_effects, scale=treatment_stddevs), reinterpreted_batch_ndims=1) }) # Pin the observed values in the model. target_model = model.experimental_pin(treatment_effects=treatment_effects) # Create a Masked Autoregressive Flow bijector. net = tfb.AutoregressiveNetwork(2, hidden_units=[16, 16], dtype=tf.float32) maf = tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=net) # Build and fit the surrogate posterior. surrogate_posterior = ( tfp.experimental.vi.build_split_flow_surrogate_posterior( event_shape=target_model.event_shape_tensor(), trainable_bijector=maf, constraining_bijector=( target_model.experimental_default_event_space_bijector()))) losses = tfp.vi.fit_surrogate_posterior( target_model.unnormalized_log_prob, surrogate_posterior, num_steps=100, optimizer=tf.optimizers.Adam(0.1), sample_size=10) ``` #### References [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and Donald Rubin. Bayesian Data Analysis, Third Edition. Chapman and Hall/CRC, 2013. """ with tf.name_scope(name or 'build_split_flow_surrogate_posterior'): shallow_structure = _get_event_shape_shallow_structure(event_shape) event_shape = nest.map_structure_up_to(shallow_structure, ps.convert_to_shape_tensor, event_shape) if nest.is_nested(constraining_bijector): constraining_bijector = joint_map.JointMap( nest.map_structure( lambda b: identity.Identity() if b is None else b, constraining_bijector), validate_args=validate_args) if constraining_bijector is None: unconstrained_event_shape = event_shape else: unconstrained_event_shape = ( constraining_bijector.inverse_event_shape_tensor(event_shape)) flat_base_event_shape = nest.flatten(unconstrained_event_shape) flat_base_event_size = nest.map_structure(tf.reduce_prod, flat_base_event_shape) event_size = tf.reduce_sum(flat_base_event_size) base_distribution = sample.Sample( base_distribution(tf.zeros(batch_shape, dtype=dtype), scale=1.), [event_size]) # After transforming base distribution samples with `trainable_bijector`, # split them into vector-valued components. split_bijector = split.Split(flat_base_event_size, validate_args=validate_args) # Reshape the vectors to the correct posterior event shape. event_reshape = joint_map.JointMap(nest.map_structure( reshape.Reshape, unconstrained_event_shape), validate_args=validate_args) # Restructure the flat list of components to the correct posterior # structure. event_unflatten = restructure.Restructure( nest.pack_sequence_as(unconstrained_event_shape, range(len(flat_base_event_shape)))) bijectors = [] if constraining_bijector is None else [ constraining_bijector ] bijectors.extend([ event_reshape, event_unflatten, split_bijector, trainable_bijector ]) bijector = chain.Chain(bijectors, validate_args=validate_args) return transformed_distribution.TransformedDistribution( base_distribution, bijector=bijector, validate_args=validate_args)
def _factored_surrogate_posterior( # pylint: disable=dangerous-default-value event_shape=None, bijector=None, batch_shape=(), base_distribution_cls=normal.Normal, initial_parameters={'scale': 1e-2}, dtype=tf.float32, validate_args=False, name=None): """Builds a joint variational posterior that factors over model variables. By default, this method creates an independent trainable Normal distribution for each variable, transformed using a bijector (if provided) to match the support of that variable. This makes extremely strong assumptions about the posterior: that it is approximately normal (or transformed normal), and that all model variables are independent. Args: event_shape: `Tensor` shape, or nested structure of `Tensor` shapes, specifying the event shape(s) of the posterior variables. bijector: Optional `tfb.Bijector` instance, or nested structure of such instances, defining support(s) of the posterior variables. The structure must match that of `event_shape` and may contain `None` values. A posterior variable will be modeled as `tfd.TransformedDistribution(underlying_dist, bijector)` if a corresponding constraining bijector is specified, otherwise it is modeled as supported on the unconstrained real line. batch_shape: The `batch_shape` of the output distribution. Default value: `()`. base_distribution_cls: Subclass of `tfd.Distribution` that is instantiated and optionally transformed by the bijector to define the component distributions. May optionally be a structure of such subclasses matching `event_shape`. Default value: `tfd.Normal`. initial_parameters: Optional `str : Tensor` dictionary specifying initial values for some or all of the base distribution's trainable parameters, or a Python `callable` with signature `value = parameter_init_fn(parameter_name, shape, dtype, seed, constraining_bijector)`, passed to `tfp.experimental.util.make_trainable`. May optionally be a structure matching `event_shape` of such dictionaries and/or callables. Dictionary entries that do not correspond to parameter names are ignored. Default value: `{'scale': 1e-2}` (ignored when `base_distribution` does not have a `scale` parameter). dtype: Optional float `dtype` for trainable parameters. May optionally be a structure of such `dtype`s matching `event_shape`. Default value: `tf.float32`. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_factored_surrogate_posterior'). Yields: *parameters: sequence of `trainable_state_util.Parameter` namedtuples. These are intended to be consumed by `trainable_state_util.as_stateful_builder` and `trainable_state_util.as_stateless_builder` to define stateful and stateless variants respectively. ### Examples Consider a Gamma model with unknown parameters, expressed as a joint Distribution: ```python Root = tfd.JointDistributionCoroutine.Root def model_fn(): concentration = yield Root(tfd.Exponential(1.)) rate = yield Root(tfd.Exponential(1.)) y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate), sample_shape=4) model = tfd.JointDistributionCoroutine(model_fn) ``` Let's use variational inference to approximate the posterior over the data-generating parameters for some observed `y`. We'll build a surrogate posterior distribution by specifying the shapes of the latent `rate` and `concentration` parameters, and that both are constrained to be positive. ```python surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=model.event_shape_tensor()[:-1], # Omit the observed `y`. bijector=[tfb.Softplus(), # Rate is positive. tfb.Softplus()]) # Concentration is positive. ``` This creates a trainable joint distribution, defined by variables in `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior` to fit this distribution by minimizing a divergence to the true posterior. ```python y = [0.2, 0.5, 0.3, 0.7] losses = tfp.vi.fit_surrogate_posterior( lambda rate, concentration: model.log_prob([rate, concentration, y]), surrogate_posterior=surrogate_posterior, num_steps=100, optimizer=tf.optimizers.Adam(0.1), sample_size=10) # After optimization, samples from the surrogate will approximate # samples from the true posterior. samples = surrogate_posterior.sample(100) posterior_mean = [tf.reduce_mean(x) for x in samples] # mean ~= [1.1, 2.1] posterior_std = [tf.math.reduce_std(x) for x in samples] # std ~= [0.3, 0.8] ``` If we wanted to initialize the optimization at a specific location, we can specify initial parameters when we build the surrogate posterior. Note that these parameterize the distribution(s) over unconstrained values, so we need to transform our desired constrained locations using the inverse of the constraining bijector(s). ```python surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=tf.nest.map_fn(tf.shape, initial_loc), bijector={'concentration': tfb.Softplus(), # Rate is positive. 'rate': tfb.Softplus()} # Concentration is positive. initial_parameters={ 'concentration': {'loc': tfb.Softplus().inverse(0.4), 'scale': 1e-2}, 'rate': {'loc': tfb.Softplus().inverse(0.2), 'scale': 1e-2}}) ``` """ with tf.name_scope(name or 'build_factored_surrogate_posterior'): # Convert event shapes to Tensors. shallow_structure = _get_event_shape_shallow_structure(event_shape) event_shape = nest.map_structure_up_to( shallow_structure, lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape) if nest.is_nested(bijector): event_space_bijector = joint_map.JointMap( nest.map_structure( lambda b: identity.Identity() if b is None else b, nest_util.coerce_structure(event_shape, bijector)), validate_args=validate_args) else: event_space_bijector = bijector if event_space_bijector is None: unconstrained_event_shape = event_shape else: unconstrained_event_shape = ( event_space_bijector.inverse_event_shape_tensor(event_shape)) unconstrained_batch_and_event_shape = tf.nest.map_structure( lambda s: ps.concat([batch_shape, s], axis=0), unconstrained_event_shape) base_distribution_cls = nest_util.broadcast_structure( event_shape, base_distribution_cls) try: # Check that we have initial parameters for each event part. nest.assert_shallow_structure(event_shape, initial_parameters) except (ValueError, TypeError): # If not, broadcast the parameters to match the event structure. # We do this manually rather than using `nest_util.broadcast_structure` # because the initial parameters can themselves be structures (dicts). initial_parameters = nest.map_structure( lambda x: initial_parameters, event_shape) unconstrained_trainable_distributions = yield from ( nest_util.map_structure_coroutine( trainable._make_trainable, # pylint: disable=protected-access cls=base_distribution_cls, initial_parameters=initial_parameters, batch_and_event_shape=unconstrained_batch_and_event_shape, parameter_dtype=nest_util.broadcast_structure( event_shape, dtype), _up_to=event_shape)) unconstrained_trainable_distribution = ( joint_distribution_util. independent_joint_distribution_from_structure( unconstrained_trainable_distributions, batch_ndims=ps.rank_from_shape(batch_shape), validate_args=validate_args)) if event_space_bijector is None: return unconstrained_trainable_distribution return transformed_distribution.TransformedDistribution( unconstrained_trainable_distribution, event_space_bijector)
def _affine_surrogate_posterior_from_base_distribution( base_distribution, operators='diag', bijector=None, initial_unconstrained_loc_fn=_sample_uniform_initial_loc, validate_args=False, name=None): """Builds a variational posterior by linearly transforming base distributions. This function builds a surrogate posterior by applying a trainable transformation to a base distribution (typically a `tfd.JointDistribution`) or nested structure of base distributions, and constraining the samples with `bijector`. Note that the distributions must have event shapes corresponding to the *pretransformed* surrogate posterior -- that is, if `bijector` contains a shape-changing bijector, then the corresponding base distribution event shape is the inverse event shape of the bijector applied to the desired surrogate posterior shape. The surrogate posterior is constucted as follows: 1. Flatten the base distribution event shapes to vectors, and pack the base distributions into a `tfd.JointDistribution`. 2. Apply a trainable blockwise LinearOperator bijector to the joint base distribution. 3. Apply the constraining bijectors and return the resulting trainable `tfd.TransformedDistribution` instance. Args: base_distribution: `tfd.Distribution` instance (typically a `tfd.JointDistribution`), or a nested structure of `tfd.Distribution` instances. operators: Either a string or a list/tuple containing `LinearOperator` subclasses, `LinearOperator` instances, or callables returning `LinearOperator` instances. Supported string values are "diag" (to create a mean-field surrogate posterior) and "tril" (to create a full-covariance surrogate posterior). A list/tuple may be passed to induce other posterior covariance structures. If the list is flat, a `tf.linalg.LinearOperatorBlockDiag` instance will be created and applied to the base distribution. Otherwise the list must be singly-nested and have a first element of length 1, second element of length 2, etc.; the elements of the outer list are interpreted as rows of a lower-triangular block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular` instance is created. For complete documentation and examples, see `tfp.experimental.vi.util.build_trainable_linear_operator_block`, which receives the `operators` arg if it is list-like. Default value: `"diag"`. bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector` instances, that maps (nested) values in R^n to the support of the posterior. (This can be the `experimental_default_event_space_bijector` of the distribution over the prior latent variables.) Default value: `None` (i.e., the posterior is over R^n). initial_unconstrained_loc_fn: Optional Python `callable` with signature `initial_loc = initial_unconstrained_loc_fn(shape, dtype, seed)` used to sample real-valued initializations for the unconstrained location of each variable. Default value: `functools.partial(tf.random.stateless_uniform, minval=-2., maxval=2., dtype=tf.float32)`. validate_args: Python `bool`. Whether to validate input with asserts. This imposes a runtime cost. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_affine_surrogate_posterior_from_base_distribution'). Yields: *parameters: sequence of `trainable_state_util.Parameter` namedtuples. These are intended to be consumed by `trainable_state_util.as_stateful_builder` and `trainable_state_util.as_stateless_builder` to define stateful and stateless variants respectively. Raises: NotImplementedError: Base distributions with mixed dtypes are not supported. #### Examples ```python tfd = tfp.distributions tfb = tfp.bijectors # Fit a multivariate Normal surrogate posterior on the Eight Schools model # [1]. treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.] treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.] def model_fn(): avg_effect = yield tfd.Normal(loc=0., scale=10., name='avg_effect') log_stddev = yield tfd.Normal(loc=5., scale=1., name='log_stddev') school_effects = yield tfd.Sample( tfd.Normal(loc=avg_effect, scale=tf.exp(log_stddev)), sample_shape=[8], name='school_effects') treatment_effects = yield tfd.Independent( tfd.Normal(loc=school_effects, scale=treatment_stddevs), reinterpreted_batch_ndims=1, name='treatment_effects') model = tfd.JointDistributionCoroutineAutoBatched(model_fn) # Pin the observed values in the model. target_model = model.experimental_pin(treatment_effects=treatment_effects) # Define a lower triangular structure of `LinearOperator` subclasses that # models full covariance among latent variables except for the 8 dimensions # of `school_effect`, which are modeled as independent (using # `LinearOperatorDiag`). operators = [ [tf.linalg.LinearOperatorLowerTriangular], [tf.linalg.LinearOperatorFullMatrix, LinearOperatorLowerTriangular], [tf.linalg.LinearOperatorFullMatrix, LinearOperatorFullMatrix, tf.linalg.LinearOperatorDiag]] # Constrain the posterior values to the support of the prior. bijector = target_model.experimental_default_event_space_bijector() # Build a full-covariance surrogate posterior. surrogate_posterior = ( tfp.experimental.vi.build_affine_surrogate_posterior_from_base_distribution( base_distribution=base_distribution, operators=operators, bijector=bijector)) # Fit the model. losses = tfp.vi.fit_surrogate_posterior( target_model.unnormalized_log_prob, surrogate_posterior, num_steps=100, optimizer=tf.optimizers.Adam(0.1), sample_size=10) ``` #### References [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and Donald Rubin. Bayesian Data Analysis, Third Edition. Chapman and Hall/CRC, 2013. """ with tf.name_scope(name or 'affine_surrogate_posterior_from_base_distribution'): if nest.is_nested(base_distribution): base_distribution = (joint_distribution_util. independent_joint_distribution_from_structure( base_distribution, validate_args=validate_args)) if nest.is_nested(bijector): bijector = joint_map.JointMap(nest.map_structure( lambda b: identity.Identity() if b is None else b, bijector), validate_args=validate_args) batch_shape = base_distribution.batch_shape_tensor() if tf.nest.is_nested( batch_shape): # Base is a classic JointDistribution. batch_shape = functools.reduce(ps.broadcast_shape, tf.nest.flatten(batch_shape)) event_shape = base_distribution.event_shape_tensor() flat_event_size = nest.flatten( nest.map_structure(ps.reduce_prod, event_shape)) base_dtypes = set([ dtype_util.base_dtype(d) for d in nest.flatten(base_distribution.dtype) ]) if len(base_dtypes) > 1: raise NotImplementedError( 'Base distributions with mixed dtype are not supported. Saw ' 'components of dtype {}'.format(base_dtypes)) base_dtype = list(base_dtypes)[0] num_components = len(flat_event_size) if operators == 'diag': operators = [tf.linalg.LinearOperatorDiag] * num_components elif operators == 'tril': operators = [[tf.linalg.LinearOperatorFullMatrix] * i + [tf.linalg.LinearOperatorLowerTriangular] for i in range(num_components)] elif isinstance(operators, str): raise ValueError( 'Unrecognized operator type {}. Valid operators are "diag", "tril", ' 'or a structure that can be passed to ' '`tfp.experimental.vi.util.build_trainable_linear_operator_block` as ' 'the `operators` arg.'.format(operators)) if nest.is_nested(operators): operators = yield from trainable_linear_operators._trainable_linear_operator_block( # pylint: disable=protected-access operators, block_dims=flat_event_size, dtype=base_dtype, batch_shape=batch_shape) linop_bijector = ( scale_matvec_linear_operator.ScaleMatvecLinearOperatorBlock( scale=operators, validate_args=validate_args)) def generate_shift_bijector(s): x = yield trainable_state_util.Parameter( functools.partial(initial_unconstrained_loc_fn, ps.concat([batch_shape, [s]], axis=0), dtype=base_dtype)) return shift.Shift(x) loc_bijectors = yield from nest_util.map_structure_coroutine( generate_shift_bijector, flat_event_size) loc_bijector = joint_map.JointMap(loc_bijectors, validate_args=validate_args) unflatten_and_reshape = chain.Chain([ joint_map.JointMap(nest.map_structure(reshape.Reshape, event_shape), validate_args=validate_args), restructure.Restructure( nest.pack_sequence_as(event_shape, range(num_components))) ], validate_args=validate_args) bijectors = [] if bijector is None else [bijector] bijectors.extend([ unflatten_and_reshape, loc_bijector, # Allow the mean of the standard dist to shift from 0. linop_bijector ]) # Apply LinOp to scale the standard dist. bijector = chain.Chain(bijectors, validate_args=validate_args) flat_base_distribution = invert.Invert(unflatten_and_reshape)( base_distribution) return transformed_distribution.TransformedDistribution( flat_base_distribution, bijector=bijector, validate_args=validate_args)
def quadrature_scheme_lognormal_quantiles(loc, scale, quadrature_size, validate_args=False, name=None): """Use LogNormal quantiles to form quadrature on positive-reals. Args: loc: `float`-like (batch of) scalar `Tensor`; the location parameter of the LogNormal prior. scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of the LogNormal prior. quadrature_size: Python `int` scalar representing the number of quadrature points. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Returns: grid: (Batch of) length-`quadrature_size` vectors representing the `log_rate` parameters of a `Poisson`. probs: (Batch of) length-`quadrature_size` vectors representing the weight associate with each `grid` value. """ with tf.name_scope(name, "quadrature_scheme_lognormal_quantiles", [loc, scale]): # Create a LogNormal distribution. dist = transformed_distribution.TransformedDistribution( distribution=normal.Normal(loc=loc, scale=scale), bijector=exp_bijector.Exp(), validate_args=validate_args) batch_ndims = dist.batch_shape.ndims if batch_ndims is None: batch_ndims = tf.shape(dist.batch_shape_tensor())[0] def _compute_quantiles(): """Helper to build quantiles.""" # Omit {0, 1} since they might lead to Inf/NaN. zero = tf.zeros([], dtype=dist.dtype) edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1] # Expand edges so its broadcast across batch dims. edges = tf.reshape( edges, shape=tf.concat( [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0)) quantiles = dist.quantile(edges) # Cyclically permute left by one. perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0) quantiles = tf.transpose(quantiles, perm) return quantiles quantiles = _compute_quantiles() # Compute grid as quantile midpoints. grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2. # Set shape hints. grid.set_shape(dist.batch_shape.concatenate([quadrature_size])) # By construction probs is constant, i.e., `1 / quadrature_size`. This is # important, because non-constant probs leads to non-reparameterizable # samples. probs = tf.fill(dims=[quadrature_size], value=1. / tf.cast(quadrature_size, dist.dtype)) return grid, probs
def __call__(self, value, name=None, **kwargs): """Applies or composes the `Bijector`, depending on input type. This is a convenience function which applies the `Bijector` instance in three different ways, depending on the input: 1. If the input is a `tfd.Distribution` instance, return `tfd.TransformedDistribution(distribution=input, bijector=self)`. 2. If the input is a `tfb.Bijector` instance, return `tfb.Chain([self, input])`. 3. Otherwise, return `self.forward(input)` Args: value: A `tfd.Distribution`, `tfb.Bijector`, or a `Tensor`. name: Python `str` name given to ops created by this function. **kwargs: Additional keyword arguments passed into the created `tfd.TransformedDistribution`, `tfb.Bijector`, or `self.forward`. Returns: composition: A `tfd.TransformedDistribution` if the input was a `tfd.Distribution`, a `tfb.Chain` if the input was a `tfb.Bijector`, or a `Tensor` computed by `self.forward`. #### Examples ```python sigmoid = tfb.Reciprocal()( tfb.AffineScalar(shift=1.)( tfb.Exp()( tfb.AffineScalar(scale=-1.)))) # ==> `tfb.Chain([ # tfb.Reciprocal(), # tfb.AffineScalar(shift=1.), # tfb.Exp(), # tfb.AffineScalar(scale=-1.), # ])` # ie, `tfb.Sigmoid()` log_normal = tfb.Exp()(tfd.Normal(0, 1)) # ==> `tfd.TransformedDistribution(tfd.Normal(0, 1), tfb.Exp())` tfb.Exp()([-1., 0., 1.]) # ==> tf.exp([-1., 0., 1.]) ``` """ # To avoid circular dependencies and keep the implementation local to the # `Bijector` class, we violate PEP8 guidelines and import here rather than # at the top of the file. from tensorflow_probability.python.bijectors import chain # pylint: disable=g-import-not-at-top from tensorflow_probability.python.distributions import distribution # pylint: disable=g-import-not-at-top from tensorflow_probability.python.distributions import transformed_distribution # pylint: disable=g-import-not-at-top if isinstance(value, transformed_distribution.TransformedDistribution): new_kwargs = value.parameters new_kwargs.update(kwargs) new_kwargs["name"] = name or new_kwargs.get("name", None) new_kwargs["bijector"] = self(value.bijector) return transformed_distribution.TransformedDistribution( **new_kwargs) if isinstance(value, distribution.Distribution): return transformed_distribution.TransformedDistribution( distribution=value, bijector=self, name=name, **kwargs) if isinstance(value, chain.Chain): new_kwargs = kwargs.copy() new_kwargs["bijectors"] = [self] + ([] if value.bijectors is None else list(value.bijectors)) if "validate_args" not in new_kwargs: new_kwargs["validate_args"] = value.validate_args new_kwargs["name"] = name or value.name return chain.Chain(**new_kwargs) if isinstance(value, Bijector): return chain.Chain([self, value], name=name, **kwargs) return self._call_forward(value, name=name or "forward", **kwargs)
def _transform(self, distribution): return transformed_distribution_lib.TransformedDistribution( bijector=masked_autoregressive_lib.MaskedAutoregressiveFlow( lambda x: tf.unstack(self._made(x), axis=-1)), distribution=distribution)
def _asvi_surrogate_for_distribution(dist, base_distribution_surrogate_fn, sample_shape=None, variables=None, seed=None): """Recursively creates ASVI surrogates, and creates new variables if needed. Args: dist: a `tfd.Distribution` instance. base_distribution_surrogate_fn: Callable to build a surrogate posterior for a 'base' (non-meta and non-joint) distribution, with signature `surrogate_posterior, variables = base_distribution_fn( dist, sample_shape=None, variables=None, seed=None)`. sample_shape: Optional `Tensor` shape of samples drawn from `dist` by `tfd.Sample` wrappers. If not `None`, the surrogate's event will include independent sample dimensions, i.e., it will have event shape `concat([sample_shape, dist.event_shape], axis=0)`. Default value: `None`. variables: Optional nested structure of `tf.Variable`s returned from a previous call to `_asvi_surrogate_for_distribution`. If `None`, new variables will be created; otherwise, constructs a surrogate posterior backed by the passed-in variables. Default value: `None`. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: surrogate_posterior: Instance of `tfd.Distribution` representing a trainable surrogate posterior distribution, with the same structure and `name` as `dist`. variables: Nested structure of `tf.Variable` trainable parameters for the surrogate posterior. If `dist` is a base distribution, this is a `dict` of `ASVIParameters` instances. If `dist` is a joint distribution, this is a `dist.dtype` structure of such `dict`s. """ # Pass args to any nested surrogates. build_nested_surrogate = functools.partial( _asvi_surrogate_for_distribution, base_distribution_surrogate_fn=base_distribution_surrogate_fn, sample_shape=sample_shape, seed=seed) # Apply any substitutions, while attempting to preserve the original name. dist = _set_name(_as_substituted_distribution(dist), name=_get_name(dist)) # Handle wrapper ("meta") distributions. if isinstance(dist, markov_chain.MarkovChain): return _asvi_surrogate_for_markov_chain( dist=dist, variables=variables, base_distribution_surrogate_fn=base_distribution_surrogate_fn, sample_shape=sample_shape, seed=seed) if isinstance(dist, sample.Sample): dist_sample_shape = distribution_util.expand_to_vector( dist.sample_shape) nested_surrogate, variables = build_nested_surrogate( # pylint: disable=redundant-keyword-arg dist=dist.distribution, variables=variables, sample_shape=(dist_sample_shape if sample_shape is None else ps.concat([sample_shape, dist_sample_shape], axis=0))) surrogate_posterior = independent.Independent( nested_surrogate, reinterpreted_batch_ndims=ps.rank_from_shape(dist_sample_shape), name=_get_name(dist)) # Treat distributions that subclass TransformedDistribution with their own # parameters (e.g., Gumbel, Weibull, MultivariateNormal*, etc) as their # own type of base distribution, rather than as explicit TDs. elif type(dist) == transformed_distribution.TransformedDistribution: # pylint: disable=unidiomatic-typecheck nested_surrogate, variables = build_nested_surrogate( dist.distribution, variables=variables) surrogate_posterior = transformed_distribution.TransformedDistribution( nested_surrogate, bijector=dist.bijector, name=_get_name(dist)) elif isinstance(dist, independent.Independent): nested_surrogate, variables = build_nested_surrogate( dist.distribution, variables=variables) surrogate_posterior = independent.Independent( nested_surrogate, reinterpreted_batch_ndims=dist.reinterpreted_batch_ndims, name=_get_name(dist)) elif hasattr(dist, '_model_coroutine'): surrogate_posterior, variables = _asvi_surrogate_for_joint_distribution( dist, base_distribution_surrogate_fn=base_distribution_surrogate_fn, variables=variables, seed=seed) elif (hasattr(dist, 'distribution') and # Transformed dists not handled above are treated as base distributions. not isinstance(dist, transformed_distribution.TransformedDistribution)): raise ValueError('Meta-distribution `{}` is not yet supported by this ' 'implementation of ASVI. Contact ' '`[email protected]` if you need this ' 'functionality.'.format(type(dist))) else: surrogate_posterior, variables = base_distribution_surrogate_fn( dist=dist, sample_shape=sample_shape, variables=variables, seed=seed) return surrogate_posterior, variables
def __call__(self, value, name=None, **kwargs): """Applies or composes the `Bijector`, depending on input type. This is a convenience function which applies the `Bijector` instance in three different ways, depending on the input: 1. If the input is a `tfd.Distribution` instance, return `tfd.TransformedDistribution(distribution=input, bijector=self)`. 2. If the input is a `tfb.Bijector` instance, return `tfb.Chain([self, input])`. 3. Otherwise, return `self.forward(input)` Args: value: A `tfd.Distribution`, `tfb.Bijector`, or a `Tensor`. name: Python `str` name given to ops created by this function. **kwargs: Additional keyword arguments passed into the created `tfd.TransformedDistribution`, `tfb.Bijector`, or `self.forward`. Returns: composition: A `tfd.TransformedDistribution` if the input was a `tfd.Distribution`, a `tfb.Chain` if the input was a `tfb.Bijector`, or a `Tensor` computed by `self.forward`. #### Examples ```python sigmoid = tfb.Reciprocal()( tfb.AffineScalar(shift=1.)( tfb.Exp()( tfb.AffineScalar(scale=-1.)))) # ==> `tfb.Chain([ # tfb.Reciprocal(), # tfb.AffineScalar(shift=1.), # tfb.Exp(), # tfb.AffineScalar(scale=-1.), # ])` # ie, `tfb.Sigmoid()` log_normal = tfb.Exp()(tfd.Normal(0, 1)) # ==> `tfd.TransformedDistribution(tfd.Normal(0, 1), tfb.Exp())` tfb.Exp()([-1., 0., 1.]) # ==> tf.exp([-1., 0., 1.]) ``` """ # To avoid circular dependencies and keep the implementation local to the # `Bijector` class, we violate PEP8 guidelines and import here rather than # at the top of the file. from tensorflow_probability.python.bijectors import chain # pylint: disable=g-import-not-at-top from tensorflow_probability.python.distributions import distribution # pylint: disable=g-import-not-at-top from tensorflow_probability.python.distributions import transformed_distribution # pylint: disable=g-import-not-at-top # TODO(b/128841942): Handle Conditional distributions and bijectors. if type(value) is transformed_distribution.TransformedDistribution: # pylint: disable=unidiomatic-typecheck # We cannot accept subclasses with different constructors here, because # subclass constructors may accept constructor arguments TD doesn't know # how to handle. e.g. `TypeError: __init__() got an unexpected keyword # argument 'allow_nan_stats'` when doing # `tfb.Identity()(tfd.Chi(df=1., allow_nan_stats=True))`. new_kwargs = value.parameters new_kwargs.update(kwargs) new_kwargs['name'] = name or new_kwargs.get('name', None) new_kwargs['bijector'] = self(value.bijector) return transformed_distribution.TransformedDistribution(**new_kwargs) if isinstance(value, distribution.Distribution): return transformed_distribution.TransformedDistribution( distribution=value, bijector=self, name=name, **kwargs) if isinstance(value, chain.Chain): new_kwargs = kwargs.copy() new_kwargs['bijectors'] = [self] + ([] if value.bijectors is None else list(value.bijectors)) if 'validate_args' not in new_kwargs: new_kwargs['validate_args'] = value.validate_args new_kwargs['name'] = name or value.name return chain.Chain(**new_kwargs) if isinstance(value, Bijector): return chain.Chain([self, value], name=name, **kwargs) return self.forward(value, name=name or 'forward', **kwargs)
def __init__(self, temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name='RelaxedBernoulli'): """Construct RelaxedBernoulli distributions. Args: temperature: A `Tensor`, representing the temperature of a set of RelaxedBernoulli distributions. The temperature values should be positive. logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent RelaxedBernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([logits, probs, temperature], tf.float32) self._temperature = tensor_util.convert_nonref_to_tensor( temperature, name='temperature', dtype=dtype) self._probs = tensor_util.convert_nonref_to_tensor(probs, name='probs', dtype=dtype) self._logits = tensor_util.convert_nonref_to_tensor(logits, name='logits', dtype=dtype) if logits is None: logits_parameter = tfp_util.DeferredTensor( lambda x: tf.math.log(x) - tf.math.log1p(-x), self._probs) else: logits_parameter = self._logits shape = tf.broadcast_static_shape(logits_parameter.shape, self._temperature.shape) logistic_scale = tfp_util.DeferredTensor(tf.math.reciprocal, self._temperature) logistic_loc = tfp_util.DeferredTensor( lambda x: x * logistic_scale, logits_parameter, shape=shape) self._transformed_logistic = ( transformed_distribution.TransformedDistribution( distribution=logistic.Logistic( logistic_loc, logistic_scale, allow_nan_stats=allow_nan_stats, name=name + '/Logistic'), bijector=sigmoid_bijector.Sigmoid())) super(RelaxedBernoulli, self).__init__(dtype=dtype, reparameterization_type=reparameterization. FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def _asvi_surrogate_for_transformed_distribution(dist, build_nested_surrogate): """Builds the surrogate for a `tfd.TransformedDistribution`.""" nested_surrogate = yield from build_nested_surrogate(dist.distribution) return transformed_distribution.TransformedDistribution( nested_surrogate, bijector=dist.bijector)