Пример #1
0
 def _transformed_beta(self,
                       low=None,
                       peak=None,
                       high=None,
                       temperature=None):
     low = tf.convert_to_tensor(self.low) if low is None else low
     peak = tf.convert_to_tensor(self.peak) if peak is None else peak
     high = tf.convert_to_tensor(self.high) if high is None else high
     temperature = (tf.convert_to_tensor(self.temperature)
                    if temperature is None else temperature)
     scale = high - low
     concentration1 = (1. + temperature * (peak - low) / scale)
     concentration0 = (1. + temperature * (high - peak) / scale)
     return transformed_distribution.TransformedDistribution(
         distribution=beta.Beta(concentration1=concentration1,
                                concentration0=concentration0,
                                allow_nan_stats=self.allow_nan_stats),
         bijector=chain_bijector.Chain([
             shift_bijector.Shift(shift=low),
             # Broadcasting scale on affine bijector to match batch dimension.
             # This prevents dimension mismatch for operations like cdf.
             # Note that `concentration1` incorporates the broadcast of all four
             # parameters.
             scale_bijector.Scale(
                 scale=tf.broadcast_to(scale, ps.shape(concentration1)))
         ]))
Пример #2
0
  def _make_mixture_dist(self, component_logits, locs, scales):
    """Builds a mixture of quantized logistic distributions.

    Args:
      component_logits: 4D `Tensor` of logits for the Categorical distribution
        over Quantized Logistic mixture components. Dimensions are `[batch_size,
        height, width, num_logistic_mix]`.
      locs: 4D `Tensor` of location parameters for the Quantized Logistic
        mixture components. Dimensions are `[batch_size, height, width,
        num_logistic_mix, num_channels]`.
      scales: 4D `Tensor` of location parameters for the Quantized Logistic
        mixture components. Dimensions are `[batch_size, height, width,
        num_logistic_mix, num_channels]`.

    Returns:
      dist: A quantized logistic mixture `tfp.distribution` over the input data.
    """
    mixture_distribution = categorical.Categorical(logits=component_logits)

    # Convert distribution parameters for pixel values in
    # `[self._low, self._high]` for use with `QuantizedDistribution`
    locs = self._low + 0.5 * (self._high - self._low) * (locs + 1.)
    scales *= 0.5 * (self._high - self._low)
    logistic_dist = quantized_distribution.QuantizedDistribution(
        distribution=transformed_distribution.TransformedDistribution(
            distribution=logistic.Logistic(loc=locs, scale=scales),
            bijector=shift.Shift(shift=tf.cast(-0.5, self.dtype))),
        low=self._low, high=self._high)

    dist = mixture_same_family.MixtureSameFamily(
        mixture_distribution=mixture_distribution,
        components_distribution=independent.Independent(
            logistic_dist, reinterpreted_batch_ndims=1))
    return independent.Independent(dist, reinterpreted_batch_ndims=2)
Пример #3
0
def _build_posterior_for_one_parameter(param, batch_shape, seed):
    """Built a transformed-normal variational dist over a parameter's support."""

    # Build a trainable Normal distribution.
    initial_loc = sample_uniform_initial_state(param,
                                               init_sample_shape=batch_shape,
                                               return_constrained=False,
                                               seed=seed)
    loc = tf.Variable(initial_value=initial_loc, name=param.name + '_loc')
    scale = tfp_util.TransformedVariable(
        tf.fill(tf.shape(initial_loc),
                value=tf.constant(0.02, initial_loc.dtype),
                name=param.name + '_scale'), softplus_lib.Softplus())
    posterior_dist = normal_lib.Normal(loc=loc, scale=scale)

    # Ensure the `event_shape` of the variational distribution matches the
    # parameter.
    if (param.prior.event_shape.ndims is None
            or param.prior.event_shape.ndims > 0):
        posterior_dist = independent_lib.Independent(
            posterior_dist,
            reinterpreted_batch_ndims=param.prior.event_shape.ndims)

    # Transform to constrained parameter space.
    posterior_dist = transformed_distribution_lib.TransformedDistribution(
        posterior_dist, param.bijector, name='{}_posterior'.format(param.name))
    return posterior_dist
Пример #4
0
 def _transformed_logistic(self):
     logistic_scale = tf.math.reciprocal(self._temperature)
     logits_parameter = self._logits_parameter_no_checks()
     logistic_loc = logits_parameter * logistic_scale
     return transformed_distribution.TransformedDistribution(
         distribution=logistic.Logistic(
             logistic_loc,
             logistic_scale,
             allow_nan_stats=self.allow_nan_stats),
         bijector=sigmoid_bijector.Sigmoid())
Пример #5
0
def joint_prior_on_parameters_and_state(parameter_prior,
                                        parameterized_initial_state_prior_fn,
                                        parameter_constraining_bijector,
                                        prior_is_constrained=True):
    """Constructs a joint dist. from p(parameters) and p(state | parameters)."""
    if prior_is_constrained:
        parameter_prior = transformed_distribution.TransformedDistribution(
            parameter_prior,
            invert.Invert(parameter_constraining_bijector),
            name='unconstrained_parameter_prior')

    return joint_distribution_named.JointDistributionNamed(
        ParametersAndState(
            unconstrained_parameters=parameter_prior,
            state=lambda unconstrained_parameters: (  # pylint: disable=g-long-lambda
                parameterized_initial_state_prior_fn(
                    parameter_constraining_bijector.forward(
                        unconstrained_parameters)))))
Пример #6
0
    def posterior_generator():

      prior_gen = prior._model_coroutine()  # pylint: disable=protected-access
      dist = next(prior_gen)

      i = 0
      try:
        while True:
          original_dist = dist.distribution if isinstance(dist, Root) else dist

          if isinstance(original_dist, joint_distribution.JointDistribution):
            # TODO(kateslin): Build inner JD surrogate in
            # _make_asvi_trainable_variables to avoid rebuilding variables.
            raise TypeError(
                'Argument `prior` cannot be a nested `JointDistribution`.')

          else:

            original_dist = _as_trainable_family(original_dist)

            try:
              actual_dist = original_dist.distribution
            except AttributeError:
              actual_dist = original_dist

            dist_params = actual_dist.parameters
            temp_params_dict = {}

            for param, value in dist_params.items():
              if param in (_NON_STATISTICAL_PARAMS +
                           _NON_TRAINABLE_PARAMS) or value is None:
                temp_params_dict[param] = value
              else:
                prior_weight = param_dicts[i][param].prior_weight
                mean_field_parameter = param_dicts[i][
                    param].mean_field_parameter
                if mean_field:
                  temp_params_dict[param] = mean_field_parameter
                else:
                  temp_params_dict[param] = prior_weight * value + (
                      1. - prior_weight) * mean_field_parameter

            if isinstance(original_dist, sample.Sample):
              inner_dist = type(actual_dist)(**temp_params_dict)

              surrogate_dist = independent.Independent(
                  inner_dist,
                  reinterpreted_batch_ndims=ps.rank_from_shape(
                      original_dist.sample_shape))
            else:
              surrogate_dist = type(actual_dist)(**temp_params_dict)

            if isinstance(original_dist,
                          transformed_distribution.TransformedDistribution):
              surrogate_dist = transformed_distribution.TransformedDistribution(
                  surrogate_dist, bijector=original_dist.bijector)

            if isinstance(original_dist, independent.Independent):
              surrogate_dist = independent.Independent(
                  surrogate_dist,
                  reinterpreted_batch_ndims=original_dist
                  .reinterpreted_batch_ndims)

            if isinstance(dist, Root):
              value_out = yield Root(surrogate_dist)
            else:
              value_out = yield surrogate_dist

          dist = prior_gen.send(value_out)
          i += 1
      except StopIteration:
        pass
Пример #7
0
def build_factored_surrogate_posterior(
    event_shape=None,
    bijector=None,
    constraining_bijectors=None,
    initial_unconstrained_loc=_sample_uniform_initial_loc,
    initial_unconstrained_scale=1e-2,
    trainable_distribution_fn=_build_trainable_normal_dist,
    seed=None,
    validate_args=False,
    name=None):
  """Builds a joint variational posterior that factors over model variables.

  By default, this method creates an independent trainable Normal distribution
  for each variable, transformed using a bijector (if provided) to
  match the support of that variable. This makes extremely strong
  assumptions about the posterior: that it is approximately normal (or
  transformed normal), and that all model variables are independent.

  Args:
    event_shape: `Tensor` shape, or nested structure of `Tensor` shapes,
      specifying the event shape(s) of the posterior variables.
    bijector: Optional `tfb.Bijector` instance, or nested structure of such
      instances, defining support(s) of the posterior variables. The structure
      must match that of `event_shape` and may contain `None` values. A
      posterior variable will be modeled as
      `tfd.TransformedDistribution(underlying_dist, bijector)` if a
      corresponding constraining bijector is specified, otherwise it is modeled
      as supported on the unconstrained real line.
    constraining_bijectors: Deprecated alias for `bijector`.
    initial_unconstrained_loc: Optional Python `callable` with signature
      `tensor = initial_unconstrained_loc(shape, seed)` used to sample
      real-valued initializations for the unconstrained representation of each
      variable. May alternately be a nested structure of
      `Tensor`s, giving specific initial locations for each variable; these
      must have structure matching `event_shape` and shapes determined by the
      inverse image of `event_shape` under `bijector`, which may optionally be
      prefixed with a common batch shape.
      Default value: `functools.partial(tf.random.uniform,
        minval=-2., maxval=2., dtype=tf.float32)`.
    initial_unconstrained_scale: Optional scalar float `Tensor` initial
      scale for the unconstrained distributions, or a nested structure of
      `Tensor` initial scales for each variable.
      Default value: `1e-2`.
    trainable_distribution_fn: Optional Python `callable` with signature
      `trainable_dist = trainable_distribution_fn(initial_loc, initial_scale,
      event_ndims, validate_args)`. This is called for each model variable to
      build the corresponding factor in the surrogate posterior. It is expected
      that the distribution returned is supported on unconstrained real values.
      Default value: `functools.partial(
        tfp.experimental.vi.build_trainable_location_scale_distribution,
        distribution_fn=tfd.Normal)`, i.e., a trainable Normal distribution.
    seed: Python integer to seed the random number generator. This is used
      only when `initial_loc` is not specified.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_factored_surrogate_posterior').

  Returns:
    surrogate_posterior: A `tfd.Distribution` instance whose samples have
      shape and structure matching that of `event_shape` or `initial_loc`.

  ### Examples

  Consider a Gamma model with unknown parameters, expressed as a joint
  Distribution:

  ```python
  Root = tfd.JointDistributionCoroutine.Root
  def model_fn():
    concentration = yield Root(tfd.Exponential(1.))
    rate = yield Root(tfd.Exponential(1.))
    y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate),
                         sample_shape=4)
  model = tfd.JointDistributionCoroutine(model_fn)
  ```

  Let's use variational inference to approximate the posterior over the
  data-generating parameters for some observed `y`. We'll build a
  surrogate posterior distribution by specifying the shapes of the latent
  `rate` and `concentration` parameters, and that both are constrained to
  be positive.

  ```python
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=model.event_shape_tensor()[:-1],  # Omit the observed `y`.
    bijector=[tfb.Softplus(),   # Rate is positive.
              tfb.Softplus()])  # Concentration is positive.
  ```

  This creates a trainable joint distribution, defined by variables in
  `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior`
  to fit this distribution by minimizing a divergence to the true posterior.

  ```python
  y = [0.2, 0.5, 0.3, 0.7]
  losses = tfp.vi.fit_surrogate_posterior(
    lambda rate, concentration: model.log_prob([rate, concentration, y]),
    surrogate_posterior=surrogate_posterior,
    num_steps=100,
    optimizer=tf.optimizers.Adam(0.1),
    sample_size=10)

  # After optimization, samples from the surrogate will approximate
  # samples from the true posterior.
  samples = surrogate_posterior.sample(100)
  posterior_mean = [tf.reduce_mean(x) for x in samples]     # mean ~= [1.1, 2.1]
  posterior_std = [tf.math.reduce_std(x) for x in samples]  # std  ~= [0.3, 0.8]
  ```

  If we wanted to initialize the optimization at a specific location, we can
  specify one when we build the surrogate posterior. This function requires the
  initial location to be specified in *unconstrained* space; we do this by
  inverting the constraining bijectors (note this section also demonstrates the
  creation of a dict-structured model).

  ```python
  initial_loc = {'concentration': 0.4, 'rate': 0.2}
  bijector={'concentration': tfb.Softplus(),   # Rate is positive.
            'rate': tfb.Softplus()}   # Concentration is positive.
  initial_unconstrained_loc = tf.nest.map_fn(
    lambda b, x: b.inverse(x) if b is not None else x, bijector, initial_loc)
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=tf.nest.map_fn(tf.shape, initial_loc),
    bijector=bijector,
    initial_unconstrained_loc=initial_unconstrained_state,
    initial_unconstrained_scale=1e-4)
  ```

  """

  with tf.name_scope(name or 'build_factored_surrogate_posterior'):
    bijector = deprecation.deprecated_argument_lookup(
        'bijector', bijector, 'constraining_bijectors', constraining_bijectors)

    seed = tfp_util.SeedStream(seed, salt='build_factored_surrogate_posterior')

    # Convert event shapes to Tensors.
    shallow_structure = _get_event_shape_shallow_structure(event_shape)
    event_shape = nest.map_structure_up_to(
        shallow_structure, lambda s: tf.convert_to_tensor(s, dtype=tf.int32),
        event_shape)

    if nest.is_nested(bijector):
      bijector = nest.map_structure(
          lambda b: identity.Identity() if b is None else b,
          bijector)

      # Support mismatched nested structures for backwards compatibility (e.g.
      # non-nested `event_shape` and a single-element list of `bijector`s).
      bijector = nest.pack_sequence_as(event_shape, nest.flatten(bijector))

      event_space_bijector = joint_map.JointMap(
          bijector, validate_args=validate_args)
    else:
      event_space_bijector = bijector

    if event_space_bijector is None:
      unconstrained_event_shape = event_shape
    else:
      unconstrained_event_shape = (
          event_space_bijector.inverse_event_shape_tensor(event_shape))

    # Construct initial locations for the internal unconstrained dists.
    if callable(initial_unconstrained_loc):  # Sample random initialization.
      initial_unconstrained_loc = nest.map_structure(
          lambda s: initial_unconstrained_loc(shape=s, seed=seed()),
          unconstrained_event_shape)

    if not nest.is_nested(initial_unconstrained_scale):
      initial_unconstrained_scale = nest.map_structure(
          lambda _: initial_unconstrained_scale,
          unconstrained_event_shape)

    # Extract the rank of each event, so that we build distributions with the
    # correct event shapes.
    unconstrained_event_ndims = nest.map_structure(
        ps.rank_from_shape,
        unconstrained_event_shape)

    # Build the component surrogate posteriors.
    unconstrained_distributions = nest.map_structure_up_to(
        unconstrained_event_shape,
        lambda loc, scale, ndims: trainable_distribution_fn(  # pylint: disable=g-long-lambda
            loc, scale, ndims, validate_args=validate_args),
        initial_unconstrained_loc,
        initial_unconstrained_scale,
        unconstrained_event_ndims)

    base_distribution = (
        joint_distribution_util.independent_joint_distribution_from_structure(
            unconstrained_distributions, validate_args=validate_args))
    if event_space_bijector is None:
      return base_distribution
    return transformed_distribution.TransformedDistribution(
        base_distribution, event_space_bijector)
Пример #8
0
def build_split_flow_surrogate_posterior(event_shape,
                                         trainable_bijector,
                                         constraining_bijector=None,
                                         base_distribution=normal.Normal,
                                         batch_shape=(),
                                         dtype=tf.float32,
                                         validate_args=False,
                                         name=None):
    """Builds a joint variational posterior by splitting a normalizing flow.

  Args:
    event_shape: (Nested) event shape of the surrogate posterior.
    trainable_bijector: A trainable `tfb.Bijector` instance that operates on
      `Tensor`s (not structures), e.g. `tfb.MaskedAutoregressiveFlow` or
      `tfb.RealNVP`. This bijector transforms the base distribution before it is
      split.
    constraining_bijector: `tfb.Bijector` instance, or nested structure of
      `tfb.Bijector` instances, that maps (nested) values in R^n to the support
      of the posterior. (This can be the
      `experimental_default_event_space_bijector` of the distribution over the
      prior latent variables.)
      Default value: `None` (i.e., the posterior is over R^n).
    base_distribution: A `tfd.Distribution` subclass parameterized by `loc` and
      `scale`. The base distribution for the transformed surrogate has `loc=0.`
      and `scale=1.`.
      Default value: `tfd.Normal`.
    batch_shape: The `batch_shape` of the output distribution.
      Default value: `()`.
    dtype: The `dtype` of the surrogate posterior.
      Default value: `tf.float32`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_split_flow_surrogate_posterior').

  Returns:
    surrogate_distribution: Trainable `tfd.TransformedDistribution` with event
      shape equal to `event_shape`.

  ### Examples
  ```python

  # Train a normalizing flow on the Eight Schools model [1].

  treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
  treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]
  model = tfd.JointDistributionNamed({
      'avg_effect':
          tfd.Normal(loc=0., scale=10., name='avg_effect'),
      'log_stddev':
          tfd.Normal(loc=5., scale=1., name='log_stddev'),
      'school_effects':
          lambda log_stddev, avg_effect: (
              tfd.Independent(
                  tfd.Normal(
                      loc=avg_effect[..., None] * tf.ones(8),
                      scale=tf.exp(log_stddev[..., None]) * tf.ones(8),
                      name='school_effects'),
                  reinterpreted_batch_ndims=1)),
      'treatment_effects': lambda school_effects: tfd.Independent(
          tfd.Normal(loc=school_effects, scale=treatment_stddevs),
          reinterpreted_batch_ndims=1)
  })

  # Pin the observed values in the model.
  target_model = model.experimental_pin(treatment_effects=treatment_effects)

  # Create a Masked Autoregressive Flow bijector.
  net = tfb.AutoregressiveNetwork(2, hidden_units=[16, 16], dtype=tf.float32)
  maf = tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=net)

  # Build and fit the surrogate posterior.
  surrogate_posterior = (
      tfp.experimental.vi.build_split_flow_surrogate_posterior(
          event_shape=target_model.event_shape_tensor(),
          trainable_bijector=maf,
          constraining_bijector=(
              target_model.experimental_default_event_space_bijector())))

  losses = tfp.vi.fit_surrogate_posterior(
      target_model.unnormalized_log_prob,
      surrogate_posterior,
      num_steps=100,
      optimizer=tf.optimizers.Adam(0.1),
      sample_size=10)
  ```

  #### References

  [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and
      Donald Rubin. Bayesian Data Analysis, Third Edition.
      Chapman and Hall/CRC, 2013.

  """
    with tf.name_scope(name or 'build_split_flow_surrogate_posterior'):

        shallow_structure = _get_event_shape_shallow_structure(event_shape)
        event_shape = nest.map_structure_up_to(shallow_structure,
                                               ps.convert_to_shape_tensor,
                                               event_shape)

        if nest.is_nested(constraining_bijector):
            constraining_bijector = joint_map.JointMap(
                nest.map_structure(
                    lambda b: identity.Identity()
                    if b is None else b, constraining_bijector),
                validate_args=validate_args)

        if constraining_bijector is None:
            unconstrained_event_shape = event_shape
        else:
            unconstrained_event_shape = (
                constraining_bijector.inverse_event_shape_tensor(event_shape))

        flat_base_event_shape = nest.flatten(unconstrained_event_shape)
        flat_base_event_size = nest.map_structure(tf.reduce_prod,
                                                  flat_base_event_shape)
        event_size = tf.reduce_sum(flat_base_event_size)

        base_distribution = sample.Sample(
            base_distribution(tf.zeros(batch_shape, dtype=dtype), scale=1.),
            [event_size])

        # After transforming base distribution samples with `trainable_bijector`,
        # split them into vector-valued components.
        split_bijector = split.Split(flat_base_event_size,
                                     validate_args=validate_args)

        # Reshape the vectors to the correct posterior event shape.
        event_reshape = joint_map.JointMap(nest.map_structure(
            reshape.Reshape, unconstrained_event_shape),
                                           validate_args=validate_args)

        # Restructure the flat list of components to the correct posterior
        # structure.
        event_unflatten = restructure.Restructure(
            nest.pack_sequence_as(unconstrained_event_shape,
                                  range(len(flat_base_event_shape))))

        bijectors = [] if constraining_bijector is None else [
            constraining_bijector
        ]
        bijectors.extend([
            event_reshape, event_unflatten, split_bijector, trainable_bijector
        ])
        bijector = chain.Chain(bijectors, validate_args=validate_args)

        return transformed_distribution.TransformedDistribution(
            base_distribution, bijector=bijector, validate_args=validate_args)
Пример #9
0
def _factored_surrogate_posterior(  # pylint: disable=dangerous-default-value
        event_shape=None,
        bijector=None,
        batch_shape=(),
        base_distribution_cls=normal.Normal,
        initial_parameters={'scale': 1e-2},
        dtype=tf.float32,
        validate_args=False,
        name=None):
    """Builds a joint variational posterior that factors over model variables.

  By default, this method creates an independent trainable Normal distribution
  for each variable, transformed using a bijector (if provided) to
  match the support of that variable. This makes extremely strong
  assumptions about the posterior: that it is approximately normal (or
  transformed normal), and that all model variables are independent.

  Args:
    event_shape: `Tensor` shape, or nested structure of `Tensor` shapes,
      specifying the event shape(s) of the posterior variables.
    bijector: Optional `tfb.Bijector` instance, or nested structure of such
      instances, defining support(s) of the posterior variables. The structure
      must match that of `event_shape` and may contain `None` values. A
      posterior variable will be modeled as
      `tfd.TransformedDistribution(underlying_dist, bijector)` if a
      corresponding constraining bijector is specified, otherwise it is modeled
      as supported on the unconstrained real line.
    batch_shape: The `batch_shape` of the output distribution.
      Default value: `()`.
    base_distribution_cls: Subclass of `tfd.Distribution` that is instantiated
      and optionally transformed by the bijector to define the component
      distributions. May optionally be a structure of such subclasses
      matching `event_shape`.
      Default value: `tfd.Normal`.
    initial_parameters: Optional `str : Tensor` dictionary specifying initial
      values for some or all of the base distribution's trainable parameters,
      or a Python `callable` with signature
      `value = parameter_init_fn(parameter_name, shape, dtype, seed,
      constraining_bijector)`, passed to `tfp.experimental.util.make_trainable`.
      May optionally be a structure matching `event_shape` of such dictionaries
      and/or callables. Dictionary entries that do not correspond to parameter
      names are ignored.
      Default value: `{'scale': 1e-2}` (ignored when `base_distribution` does
        not have a `scale` parameter).
    dtype: Optional float `dtype` for trainable parameters. May
      optionally be a structure of such `dtype`s matching `event_shape`.
      Default value: `tf.float32`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_factored_surrogate_posterior').
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.

  ### Examples

  Consider a Gamma model with unknown parameters, expressed as a joint
  Distribution:

  ```python
  Root = tfd.JointDistributionCoroutine.Root
  def model_fn():
    concentration = yield Root(tfd.Exponential(1.))
    rate = yield Root(tfd.Exponential(1.))
    y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate),
                         sample_shape=4)
  model = tfd.JointDistributionCoroutine(model_fn)
  ```

  Let's use variational inference to approximate the posterior over the
  data-generating parameters for some observed `y`. We'll build a
  surrogate posterior distribution by specifying the shapes of the latent
  `rate` and `concentration` parameters, and that both are constrained to
  be positive.

  ```python
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=model.event_shape_tensor()[:-1],  # Omit the observed `y`.
    bijector=[tfb.Softplus(),   # Rate is positive.
              tfb.Softplus()])  # Concentration is positive.
  ```

  This creates a trainable joint distribution, defined by variables in
  `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior`
  to fit this distribution by minimizing a divergence to the true posterior.

  ```python
  y = [0.2, 0.5, 0.3, 0.7]
  losses = tfp.vi.fit_surrogate_posterior(
    lambda rate, concentration: model.log_prob([rate, concentration, y]),
    surrogate_posterior=surrogate_posterior,
    num_steps=100,
    optimizer=tf.optimizers.Adam(0.1),
    sample_size=10)

  # After optimization, samples from the surrogate will approximate
  # samples from the true posterior.
  samples = surrogate_posterior.sample(100)
  posterior_mean = [tf.reduce_mean(x) for x in samples]     # mean ~= [1.1, 2.1]
  posterior_std = [tf.math.reduce_std(x) for x in samples]  # std  ~= [0.3, 0.8]
  ```

  If we wanted to initialize the optimization at a specific location, we can
  specify initial parameters when we build the surrogate posterior. Note that
  these parameterize the distribution(s) over unconstrained values,
  so we need to transform our desired constrained locations using the inverse
  of the constraining bijector(s).

  ```python
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=tf.nest.map_fn(tf.shape, initial_loc),
    bijector={'concentration': tfb.Softplus(),   # Rate is positive.
              'rate': tfb.Softplus()}   # Concentration is positive.
    initial_parameters={
      'concentration': {'loc': tfb.Softplus().inverse(0.4), 'scale': 1e-2},
      'rate': {'loc': tfb.Softplus().inverse(0.2), 'scale': 1e-2}})
  ```

  """
    with tf.name_scope(name or 'build_factored_surrogate_posterior'):
        # Convert event shapes to Tensors.
        shallow_structure = _get_event_shape_shallow_structure(event_shape)
        event_shape = nest.map_structure_up_to(
            shallow_structure,
            lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape)

        if nest.is_nested(bijector):
            event_space_bijector = joint_map.JointMap(
                nest.map_structure(
                    lambda b: identity.Identity() if b is None else b,
                    nest_util.coerce_structure(event_shape, bijector)),
                validate_args=validate_args)
        else:
            event_space_bijector = bijector

        if event_space_bijector is None:
            unconstrained_event_shape = event_shape
        else:
            unconstrained_event_shape = (
                event_space_bijector.inverse_event_shape_tensor(event_shape))
        unconstrained_batch_and_event_shape = tf.nest.map_structure(
            lambda s: ps.concat([batch_shape, s], axis=0),
            unconstrained_event_shape)

        base_distribution_cls = nest_util.broadcast_structure(
            event_shape, base_distribution_cls)
        try:
            # Check that we have initial parameters for each event part.
            nest.assert_shallow_structure(event_shape, initial_parameters)
        except (ValueError, TypeError):
            # If not, broadcast the parameters to match the event structure.
            # We do this manually rather than using `nest_util.broadcast_structure`
            # because the initial parameters can themselves be structures (dicts).
            initial_parameters = nest.map_structure(
                lambda x: initial_parameters, event_shape)

        unconstrained_trainable_distributions = yield from (
            nest_util.map_structure_coroutine(
                trainable._make_trainable,  # pylint: disable=protected-access
                cls=base_distribution_cls,
                initial_parameters=initial_parameters,
                batch_and_event_shape=unconstrained_batch_and_event_shape,
                parameter_dtype=nest_util.broadcast_structure(
                    event_shape, dtype),
                _up_to=event_shape))
        unconstrained_trainable_distribution = (
            joint_distribution_util.
            independent_joint_distribution_from_structure(
                unconstrained_trainable_distributions,
                batch_ndims=ps.rank_from_shape(batch_shape),
                validate_args=validate_args))
        if event_space_bijector is None:
            return unconstrained_trainable_distribution
        return transformed_distribution.TransformedDistribution(
            unconstrained_trainable_distribution, event_space_bijector)
Пример #10
0
def _affine_surrogate_posterior_from_base_distribution(
        base_distribution,
        operators='diag',
        bijector=None,
        initial_unconstrained_loc_fn=_sample_uniform_initial_loc,
        validate_args=False,
        name=None):
    """Builds a variational posterior by linearly transforming base distributions.

  This function builds a surrogate posterior by applying a trainable
  transformation to a base distribution (typically a `tfd.JointDistribution`) or
  nested structure of base distributions, and constraining the samples with
  `bijector`. Note that the distributions must have event shapes corresponding
  to the *pretransformed* surrogate posterior -- that is, if `bijector` contains
  a shape-changing bijector, then the corresponding base distribution event
  shape is the inverse event shape of the bijector applied to the desired
  surrogate posterior shape. The surrogate posterior is constucted as follows:

  1. Flatten the base distribution event shapes to vectors, and pack the base
     distributions into a `tfd.JointDistribution`.
  2. Apply a trainable blockwise LinearOperator bijector to the joint base
     distribution.
  3. Apply the constraining bijectors and return the resulting trainable
     `tfd.TransformedDistribution` instance.

  Args:
    base_distribution: `tfd.Distribution` instance (typically a
      `tfd.JointDistribution`), or a nested structure of `tfd.Distribution`
      instances.
    operators: Either a string or a list/tuple containing `LinearOperator`
      subclasses, `LinearOperator` instances, or callables returning
      `LinearOperator` instances. Supported string values are "diag" (to create
      a mean-field surrogate posterior) and "tril" (to create a full-covariance
      surrogate posterior). A list/tuple may be passed to induce other
      posterior covariance structures. If the list is flat, a
      `tf.linalg.LinearOperatorBlockDiag` instance will be created and applied
      to the base distribution. Otherwise the list must be singly-nested and
      have a first element of length 1, second element of length 2, etc.; the
      elements of the outer list are interpreted as rows of a lower-triangular
      block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular`
      instance is created. For complete documentation and examples, see
      `tfp.experimental.vi.util.build_trainable_linear_operator_block`, which
      receives the `operators` arg if it is list-like.
      Default value: `"diag"`.
    bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector`
      instances, that maps (nested) values in R^n to the support of the
      posterior. (This can be the `experimental_default_event_space_bijector` of
      the distribution over the prior latent variables.)
      Default value: `None` (i.e., the posterior is over R^n).
    initial_unconstrained_loc_fn: Optional Python `callable` with signature
      `initial_loc = initial_unconstrained_loc_fn(shape, dtype, seed)` used to
      sample real-valued initializations for the unconstrained location of
      each variable.
      Default value: `functools.partial(tf.random.stateless_uniform,
        minval=-2., maxval=2., dtype=tf.float32)`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e.,
      'build_affine_surrogate_posterior_from_base_distribution').
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.
  Raises:
    NotImplementedError: Base distributions with mixed dtypes are not supported.

  #### Examples
  ```python
  tfd = tfp.distributions
  tfb = tfp.bijectors

  # Fit a multivariate Normal surrogate posterior on the Eight Schools model
  # [1].

  treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
  treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]

  def model_fn():
    avg_effect = yield tfd.Normal(loc=0., scale=10., name='avg_effect')
    log_stddev = yield tfd.Normal(loc=5., scale=1., name='log_stddev')
    school_effects = yield tfd.Sample(
        tfd.Normal(loc=avg_effect, scale=tf.exp(log_stddev)),
        sample_shape=[8],
        name='school_effects')
    treatment_effects = yield tfd.Independent(
        tfd.Normal(loc=school_effects, scale=treatment_stddevs),
        reinterpreted_batch_ndims=1,
        name='treatment_effects')
  model = tfd.JointDistributionCoroutineAutoBatched(model_fn)

  # Pin the observed values in the model.
  target_model = model.experimental_pin(treatment_effects=treatment_effects)

  # Define a lower triangular structure of `LinearOperator` subclasses that
  # models full covariance among latent variables except for the 8 dimensions
  # of `school_effect`, which are modeled as independent (using
  # `LinearOperatorDiag`).
  operators = [
    [tf.linalg.LinearOperatorLowerTriangular],
    [tf.linalg.LinearOperatorFullMatrix, LinearOperatorLowerTriangular],
    [tf.linalg.LinearOperatorFullMatrix, LinearOperatorFullMatrix,
     tf.linalg.LinearOperatorDiag]]


  # Constrain the posterior values to the support of the prior.
  bijector = target_model.experimental_default_event_space_bijector()

  # Build a full-covariance surrogate posterior.
  surrogate_posterior = (
    tfp.experimental.vi.build_affine_surrogate_posterior_from_base_distribution(
        base_distribution=base_distribution,
        operators=operators,
        bijector=bijector))

  # Fit the model.
  losses = tfp.vi.fit_surrogate_posterior(
      target_model.unnormalized_log_prob,
      surrogate_posterior,
      num_steps=100,
      optimizer=tf.optimizers.Adam(0.1),
      sample_size=10)
  ```

  #### References

  [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and
      Donald Rubin. Bayesian Data Analysis, Third Edition.
      Chapman and Hall/CRC, 2013.

  """
    with tf.name_scope(name
                       or 'affine_surrogate_posterior_from_base_distribution'):

        if nest.is_nested(base_distribution):
            base_distribution = (joint_distribution_util.
                                 independent_joint_distribution_from_structure(
                                     base_distribution,
                                     validate_args=validate_args))

        if nest.is_nested(bijector):
            bijector = joint_map.JointMap(nest.map_structure(
                lambda b: identity.Identity() if b is None else b, bijector),
                                          validate_args=validate_args)

        batch_shape = base_distribution.batch_shape_tensor()
        if tf.nest.is_nested(
                batch_shape):  # Base is a classic JointDistribution.
            batch_shape = functools.reduce(ps.broadcast_shape,
                                           tf.nest.flatten(batch_shape))
        event_shape = base_distribution.event_shape_tensor()
        flat_event_size = nest.flatten(
            nest.map_structure(ps.reduce_prod, event_shape))

        base_dtypes = set([
            dtype_util.base_dtype(d)
            for d in nest.flatten(base_distribution.dtype)
        ])
        if len(base_dtypes) > 1:
            raise NotImplementedError(
                'Base distributions with mixed dtype are not supported. Saw '
                'components of dtype {}'.format(base_dtypes))
        base_dtype = list(base_dtypes)[0]

        num_components = len(flat_event_size)
        if operators == 'diag':
            operators = [tf.linalg.LinearOperatorDiag] * num_components
        elif operators == 'tril':
            operators = [[tf.linalg.LinearOperatorFullMatrix] * i +
                         [tf.linalg.LinearOperatorLowerTriangular]
                         for i in range(num_components)]
        elif isinstance(operators, str):
            raise ValueError(
                'Unrecognized operator type {}. Valid operators are "diag", "tril", '
                'or a structure that can be passed to '
                '`tfp.experimental.vi.util.build_trainable_linear_operator_block` as '
                'the `operators` arg.'.format(operators))

        if nest.is_nested(operators):
            operators = yield from trainable_linear_operators._trainable_linear_operator_block(  # pylint: disable=protected-access
                operators,
                block_dims=flat_event_size,
                dtype=base_dtype,
                batch_shape=batch_shape)

        linop_bijector = (
            scale_matvec_linear_operator.ScaleMatvecLinearOperatorBlock(
                scale=operators, validate_args=validate_args))

        def generate_shift_bijector(s):
            x = yield trainable_state_util.Parameter(
                functools.partial(initial_unconstrained_loc_fn,
                                  ps.concat([batch_shape, [s]], axis=0),
                                  dtype=base_dtype))
            return shift.Shift(x)

        loc_bijectors = yield from nest_util.map_structure_coroutine(
            generate_shift_bijector, flat_event_size)
        loc_bijector = joint_map.JointMap(loc_bijectors,
                                          validate_args=validate_args)

        unflatten_and_reshape = chain.Chain([
            joint_map.JointMap(nest.map_structure(reshape.Reshape,
                                                  event_shape),
                               validate_args=validate_args),
            restructure.Restructure(
                nest.pack_sequence_as(event_shape, range(num_components)))
        ],
                                            validate_args=validate_args)

        bijectors = [] if bijector is None else [bijector]
        bijectors.extend([
            unflatten_and_reshape,
            loc_bijector,  # Allow the mean of the standard dist to shift from 0.
            linop_bijector
        ])  # Apply LinOp to scale the standard dist.
        bijector = chain.Chain(bijectors, validate_args=validate_args)

        flat_base_distribution = invert.Invert(unflatten_and_reshape)(
            base_distribution)

        return transformed_distribution.TransformedDistribution(
            flat_base_distribution,
            bijector=bijector,
            validate_args=validate_args)
Пример #11
0
def quadrature_scheme_lognormal_quantiles(loc,
                                          scale,
                                          quadrature_size,
                                          validate_args=False,
                                          name=None):
    """Use LogNormal quantiles to form quadrature on positive-reals.

  Args:
    loc: `float`-like (batch of) scalar `Tensor`; the location parameter of
      the LogNormal prior.
    scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
      the LogNormal prior.
    quadrature_size: Python `int` scalar representing the number of quadrature
      points.
    validate_args: Python `bool`, default `False`. When `True` distribution
      parameters are checked for validity despite possibly degrading runtime
      performance. When `False` invalid inputs may silently render incorrect
      outputs.
    name: Python `str` name prefixed to Ops created by this class.

  Returns:
    grid: (Batch of) length-`quadrature_size` vectors representing the
      `log_rate` parameters of a `Poisson`.
    probs: (Batch of) length-`quadrature_size` vectors representing the
      weight associate with each `grid` value.
  """
    with tf.name_scope(name, "quadrature_scheme_lognormal_quantiles",
                       [loc, scale]):
        # Create a LogNormal distribution.
        dist = transformed_distribution.TransformedDistribution(
            distribution=normal.Normal(loc=loc, scale=scale),
            bijector=exp_bijector.Exp(),
            validate_args=validate_args)
        batch_ndims = dist.batch_shape.ndims
        if batch_ndims is None:
            batch_ndims = tf.shape(dist.batch_shape_tensor())[0]

        def _compute_quantiles():
            """Helper to build quantiles."""
            # Omit {0, 1} since they might lead to Inf/NaN.
            zero = tf.zeros([], dtype=dist.dtype)
            edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1]
            # Expand edges so its broadcast across batch dims.
            edges = tf.reshape(
                edges,
                shape=tf.concat(
                    [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0))
            quantiles = dist.quantile(edges)
            # Cyclically permute left by one.
            perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0)
            quantiles = tf.transpose(quantiles, perm)
            return quantiles

        quantiles = _compute_quantiles()

        # Compute grid as quantile midpoints.
        grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2.
        # Set shape hints.
        grid.set_shape(dist.batch_shape.concatenate([quadrature_size]))

        # By construction probs is constant, i.e., `1 / quadrature_size`. This is
        # important, because non-constant probs leads to non-reparameterizable
        # samples.
        probs = tf.fill(dims=[quadrature_size],
                        value=1. / tf.cast(quadrature_size, dist.dtype))

        return grid, probs
Пример #12
0
    def __call__(self, value, name=None, **kwargs):
        """Applies or composes the `Bijector`, depending on input type.

    This is a convenience function which applies the `Bijector` instance in
    three different ways, depending on the input:

    1. If the input is a `tfd.Distribution` instance, return
       `tfd.TransformedDistribution(distribution=input, bijector=self)`.
    2. If the input is a `tfb.Bijector` instance, return
       `tfb.Chain([self, input])`.
    3. Otherwise, return `self.forward(input)`

    Args:
      value: A `tfd.Distribution`, `tfb.Bijector`, or a `Tensor`.
      name: Python `str` name given to ops created by this function.
      **kwargs: Additional keyword arguments passed into the created
        `tfd.TransformedDistribution`, `tfb.Bijector`, or `self.forward`.

    Returns:
      composition: A `tfd.TransformedDistribution` if the input was a
        `tfd.Distribution`, a `tfb.Chain` if the input was a `tfb.Bijector`, or
        a `Tensor` computed by `self.forward`.

    #### Examples

    ```python
    sigmoid = tfb.Reciprocal()(
        tfb.AffineScalar(shift=1.)(
          tfb.Exp()(
            tfb.AffineScalar(scale=-1.))))
    # ==> `tfb.Chain([
    #         tfb.Reciprocal(),
    #         tfb.AffineScalar(shift=1.),
    #         tfb.Exp(),
    #         tfb.AffineScalar(scale=-1.),
    #      ])`  # ie, `tfb.Sigmoid()`

    log_normal = tfb.Exp()(tfd.Normal(0, 1))
    # ==> `tfd.TransformedDistribution(tfd.Normal(0, 1), tfb.Exp())`

    tfb.Exp()([-1., 0., 1.])
    # ==> tf.exp([-1., 0., 1.])
    ```

    """

        # To avoid circular dependencies and keep the implementation local to the
        # `Bijector` class, we violate PEP8 guidelines and import here rather than
        # at the top of the file.
        from tensorflow_probability.python.bijectors import chain  # pylint: disable=g-import-not-at-top
        from tensorflow_probability.python.distributions import distribution  # pylint: disable=g-import-not-at-top
        from tensorflow_probability.python.distributions import transformed_distribution  # pylint: disable=g-import-not-at-top

        if isinstance(value, transformed_distribution.TransformedDistribution):
            new_kwargs = value.parameters
            new_kwargs.update(kwargs)
            new_kwargs["name"] = name or new_kwargs.get("name", None)
            new_kwargs["bijector"] = self(value.bijector)
            return transformed_distribution.TransformedDistribution(
                **new_kwargs)

        if isinstance(value, distribution.Distribution):
            return transformed_distribution.TransformedDistribution(
                distribution=value, bijector=self, name=name, **kwargs)

        if isinstance(value, chain.Chain):
            new_kwargs = kwargs.copy()
            new_kwargs["bijectors"] = [self] + ([] if value.bijectors is None
                                                else list(value.bijectors))
            if "validate_args" not in new_kwargs:
                new_kwargs["validate_args"] = value.validate_args
            new_kwargs["name"] = name or value.name
            return chain.Chain(**new_kwargs)

        if isinstance(value, Bijector):
            return chain.Chain([self, value], name=name, **kwargs)

        return self._call_forward(value, name=name or "forward", **kwargs)
Пример #13
0
 def _transform(self, distribution):
     return transformed_distribution_lib.TransformedDistribution(
         bijector=masked_autoregressive_lib.MaskedAutoregressiveFlow(
             lambda x: tf.unstack(self._made(x), axis=-1)),
         distribution=distribution)
Пример #14
0
def _asvi_surrogate_for_distribution(dist,
                                     base_distribution_surrogate_fn,
                                     sample_shape=None,
                                     variables=None,
                                     seed=None):
    """Recursively creates ASVI surrogates, and creates new variables if needed.

  Args:
    dist: a `tfd.Distribution` instance.
    base_distribution_surrogate_fn: Callable to build a surrogate posterior
      for a 'base' (non-meta and non-joint) distribution, with signature
      `surrogate_posterior, variables = base_distribution_fn(
      dist, sample_shape=None, variables=None, seed=None)`.
    sample_shape: Optional `Tensor` shape of samples drawn from `dist` by
      `tfd.Sample` wrappers. If not `None`, the surrogate's event will include
      independent sample dimensions, i.e., it will have event shape
      `concat([sample_shape, dist.event_shape], axis=0)`.
      Default value: `None`.
    variables: Optional nested structure of `tf.Variable`s returned from a
      previous call to `_asvi_surrogate_for_distribution`. If `None`,
      new variables will be created; otherwise, constructs a surrogate posterior
      backed by the passed-in variables.
      Default value: `None`.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
  Returns:
    surrogate_posterior: Instance of `tfd.Distribution` representing a trainable
      surrogate posterior distribution, with the same structure and `name` as
      `dist`.
    variables: Nested structure of `tf.Variable` trainable parameters for the
      surrogate posterior. If `dist` is a base distribution, this is
      a `dict` of `ASVIParameters` instances. If `dist` is a joint
      distribution, this is a `dist.dtype` structure of such `dict`s.
  """
    # Pass args to any nested surrogates.
    build_nested_surrogate = functools.partial(
        _asvi_surrogate_for_distribution,
        base_distribution_surrogate_fn=base_distribution_surrogate_fn,
        sample_shape=sample_shape,
        seed=seed)

    # Apply any substitutions, while attempting to preserve the original name.
    dist = _set_name(_as_substituted_distribution(dist), name=_get_name(dist))

    # Handle wrapper ("meta") distributions.
    if isinstance(dist, markov_chain.MarkovChain):
        return _asvi_surrogate_for_markov_chain(
            dist=dist,
            variables=variables,
            base_distribution_surrogate_fn=base_distribution_surrogate_fn,
            sample_shape=sample_shape,
            seed=seed)
    if isinstance(dist, sample.Sample):
        dist_sample_shape = distribution_util.expand_to_vector(
            dist.sample_shape)
        nested_surrogate, variables = build_nested_surrogate(  # pylint: disable=redundant-keyword-arg
            dist=dist.distribution,
            variables=variables,
            sample_shape=(dist_sample_shape if sample_shape is None
                          else ps.concat([sample_shape, dist_sample_shape],
                                         axis=0)))
        surrogate_posterior = independent.Independent(
            nested_surrogate,
            reinterpreted_batch_ndims=ps.rank_from_shape(dist_sample_shape),
            name=_get_name(dist))
    # Treat distributions that subclass TransformedDistribution with their own
    # parameters (e.g., Gumbel, Weibull, MultivariateNormal*, etc) as their
    # own type of base distribution, rather than as explicit TDs.
    elif type(dist) == transformed_distribution.TransformedDistribution:  # pylint: disable=unidiomatic-typecheck
        nested_surrogate, variables = build_nested_surrogate(
            dist.distribution, variables=variables)
        surrogate_posterior = transformed_distribution.TransformedDistribution(
            nested_surrogate, bijector=dist.bijector, name=_get_name(dist))
    elif isinstance(dist, independent.Independent):
        nested_surrogate, variables = build_nested_surrogate(
            dist.distribution, variables=variables)
        surrogate_posterior = independent.Independent(
            nested_surrogate,
            reinterpreted_batch_ndims=dist.reinterpreted_batch_ndims,
            name=_get_name(dist))
    elif hasattr(dist, '_model_coroutine'):
        surrogate_posterior, variables = _asvi_surrogate_for_joint_distribution(
            dist,
            base_distribution_surrogate_fn=base_distribution_surrogate_fn,
            variables=variables,
            seed=seed)
    elif (hasattr(dist, 'distribution') and
          # Transformed dists not handled above are treated as base distributions.
          not isinstance(dist,
                         transformed_distribution.TransformedDistribution)):
        raise ValueError('Meta-distribution `{}` is not yet supported by this '
                         'implementation of ASVI. Contact '
                         '`[email protected]` if you need this '
                         'functionality.'.format(type(dist)))
    else:
        surrogate_posterior, variables = base_distribution_surrogate_fn(
            dist=dist,
            sample_shape=sample_shape,
            variables=variables,
            seed=seed)
    return surrogate_posterior, variables
Пример #15
0
  def __call__(self, value, name=None, **kwargs):
    """Applies or composes the `Bijector`, depending on input type.

    This is a convenience function which applies the `Bijector` instance in
    three different ways, depending on the input:

    1. If the input is a `tfd.Distribution` instance, return
       `tfd.TransformedDistribution(distribution=input, bijector=self)`.
    2. If the input is a `tfb.Bijector` instance, return
       `tfb.Chain([self, input])`.
    3. Otherwise, return `self.forward(input)`

    Args:
      value: A `tfd.Distribution`, `tfb.Bijector`, or a `Tensor`.
      name: Python `str` name given to ops created by this function.
      **kwargs: Additional keyword arguments passed into the created
        `tfd.TransformedDistribution`, `tfb.Bijector`, or `self.forward`.

    Returns:
      composition: A `tfd.TransformedDistribution` if the input was a
        `tfd.Distribution`, a `tfb.Chain` if the input was a `tfb.Bijector`, or
        a `Tensor` computed by `self.forward`.

    #### Examples

    ```python
    sigmoid = tfb.Reciprocal()(
        tfb.AffineScalar(shift=1.)(
          tfb.Exp()(
            tfb.AffineScalar(scale=-1.))))
    # ==> `tfb.Chain([
    #         tfb.Reciprocal(),
    #         tfb.AffineScalar(shift=1.),
    #         tfb.Exp(),
    #         tfb.AffineScalar(scale=-1.),
    #      ])`  # ie, `tfb.Sigmoid()`

    log_normal = tfb.Exp()(tfd.Normal(0, 1))
    # ==> `tfd.TransformedDistribution(tfd.Normal(0, 1), tfb.Exp())`

    tfb.Exp()([-1., 0., 1.])
    # ==> tf.exp([-1., 0., 1.])
    ```

    """

    # To avoid circular dependencies and keep the implementation local to the
    # `Bijector` class, we violate PEP8 guidelines and import here rather than
    # at the top of the file.
    from tensorflow_probability.python.bijectors import chain  # pylint: disable=g-import-not-at-top
    from tensorflow_probability.python.distributions import distribution  # pylint: disable=g-import-not-at-top
    from tensorflow_probability.python.distributions import transformed_distribution  # pylint: disable=g-import-not-at-top

    # TODO(b/128841942): Handle Conditional distributions and bijectors.
    if type(value) is transformed_distribution.TransformedDistribution:  # pylint: disable=unidiomatic-typecheck
      # We cannot accept subclasses with different constructors here, because
      # subclass constructors may accept constructor arguments TD doesn't know
      # how to handle. e.g. `TypeError: __init__() got an unexpected keyword
      # argument 'allow_nan_stats'` when doing
      # `tfb.Identity()(tfd.Chi(df=1., allow_nan_stats=True))`.
      new_kwargs = value.parameters
      new_kwargs.update(kwargs)
      new_kwargs['name'] = name or new_kwargs.get('name', None)
      new_kwargs['bijector'] = self(value.bijector)
      return transformed_distribution.TransformedDistribution(**new_kwargs)

    if isinstance(value, distribution.Distribution):
      return transformed_distribution.TransformedDistribution(
          distribution=value,
          bijector=self,
          name=name,
          **kwargs)

    if isinstance(value, chain.Chain):
      new_kwargs = kwargs.copy()
      new_kwargs['bijectors'] = [self] + ([] if value.bijectors is None
                                          else list(value.bijectors))
      if 'validate_args' not in new_kwargs:
        new_kwargs['validate_args'] = value.validate_args
      new_kwargs['name'] = name or value.name
      return chain.Chain(**new_kwargs)

    if isinstance(value, Bijector):
      return chain.Chain([self, value], name=name, **kwargs)

    return self.forward(value, name=name or 'forward', **kwargs)
Пример #16
0
    def __init__(self,
                 temperature,
                 logits=None,
                 probs=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='RelaxedBernoulli'):
        """Construct RelaxedBernoulli distributions.

    Args:
      temperature: A `Tensor`, representing the temperature of a set of
        RelaxedBernoulli distributions. The temperature values should be
        positive.
      logits: An N-D `Tensor` representing the log-odds
        of a positive event. Each entry in the `Tensor` parametrizes
        an independent RelaxedBernoulli distribution where the probability of an
        event is sigmoid(logits). Only one of `logits` or `probs` should be
        passed in.
      probs: An N-D `Tensor` representing the probability of a positive event.
        Each entry in the `Tensor` parameterizes an independent Bernoulli
        distribution. Only one of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: If both `probs` and `logits` are passed, or if neither.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([logits, probs, temperature],
                                            tf.float32)

            self._temperature = tensor_util.convert_nonref_to_tensor(
                temperature, name='temperature', dtype=dtype)
            self._probs = tensor_util.convert_nonref_to_tensor(probs,
                                                               name='probs',
                                                               dtype=dtype)
            self._logits = tensor_util.convert_nonref_to_tensor(logits,
                                                                name='logits',
                                                                dtype=dtype)

            if logits is None:
                logits_parameter = tfp_util.DeferredTensor(
                    lambda x: tf.math.log(x) - tf.math.log1p(-x), self._probs)
            else:
                logits_parameter = self._logits

            shape = tf.broadcast_static_shape(logits_parameter.shape,
                                              self._temperature.shape)

            logistic_scale = tfp_util.DeferredTensor(tf.math.reciprocal,
                                                     self._temperature)
            logistic_loc = tfp_util.DeferredTensor(
                lambda x: x * logistic_scale, logits_parameter, shape=shape)

            self._transformed_logistic = (
                transformed_distribution.TransformedDistribution(
                    distribution=logistic.Logistic(
                        logistic_loc,
                        logistic_scale,
                        allow_nan_stats=allow_nan_stats,
                        name=name + '/Logistic'),
                    bijector=sigmoid_bijector.Sigmoid()))

            super(RelaxedBernoulli,
                  self).__init__(dtype=dtype,
                                 reparameterization_type=reparameterization.
                                 FULLY_REPARAMETERIZED,
                                 validate_args=validate_args,
                                 allow_nan_stats=allow_nan_stats,
                                 parameters=parameters,
                                 name=name)
Пример #17
0
def _asvi_surrogate_for_transformed_distribution(dist, build_nested_surrogate):
    """Builds the surrogate for a `tfd.TransformedDistribution`."""
    nested_surrogate = yield from build_nested_surrogate(dist.distribution)
    return transformed_distribution.TransformedDistribution(
        nested_surrogate, bijector=dist.bijector)