Exemplo n.º 1
0
def get_fldj_theoretical(bijector,
                         x,
                         event_ndims,
                         input_to_unconstrained=None,
                         output_to_unconstrained=None):
    """Numerically approximate the forward log det Jacobian of a bijector.

  We compute the Jacobian of the chain
  output_to_unconst_vec(bijector(inverse(input_to_unconst_vec))) so that
  we're working with a full rank matrix.  We then adjust the resulting Jacobian
  for the unconstraining bijectors.

  Bijectors that constrain / unconstrain their inputs/outputs may not be
  testable with this method, since the composition above may reduce the test
  to something trivial.  However, bijectors that map within constrained spaces
  should be fine.

  Args:
    bijector: the bijector whose Jacobian we wish to approximate
    x: the value for which we want to approximate the Jacobian.  x must have
      a batch dimension for compatibility with tape.batch_jacobian.
    event_ndims: number of dimensions in an event
    input_to_unconstrained: bijector that maps the input to the above bijector
      to an unconstrained 1-D vector.  If the inputs are already unconstrained
      vectors, use None.
    output_to_unconstrained: bijector that maps the output of the above bijector
      to an unconstrained 1-D vector.  If the outputs are unconstrained
      vectors, use None.

  Returns:
    A numerical approximation to the log det Jacobian of bijector.forward
    evaluated at x.
  """
    if input_to_unconstrained is None:
        input_to_unconstrained = identity_bijector.Identity()
    if output_to_unconstrained is None:
        output_to_unconstrained = identity_bijector.Identity()

    x = tf.convert_to_tensor(value=x)
    x_unconstrained = 1 * input_to_unconstrained.forward(x)

    with tf.GradientTape(persistent=True) as tape:
        tape.watch(x_unconstrained)
        f_x = bijector.forward(input_to_unconstrained.inverse(x_unconstrained))
        f_x_unconstrained = output_to_unconstrained.forward(f_x)
    jacobian = tape.batch_jacobian(f_x_unconstrained,
                                   x_unconstrained,
                                   experimental_use_pfor=False)

    return (tf.linalg.slogdet(jacobian).log_abs_determinant +
            input_to_unconstrained.forward_log_det_jacobian(
                x, event_ndims=event_ndims) -
            output_to_unconstrained.forward_log_det_jacobian(
                f_x, event_ndims=event_ndims))
Exemplo n.º 2
0
  def __init__(self,
               input_shape,
               blockwise_splits,
               coupling_bijector_fn=None):
    """Creates the exit bijector.

    Args:
      input_shape: A list specifying the input shape to the exit bijector.
        Used in constructing the network.
      blockwise_splits: A list of integers specifying the number of channels
        exiting the model, as well as those being left in the model, and those
        bypassing the exit bijector altogether.
      coupling_bijector_fn: A function which takes the argument `input_shape`
        and returns a callable neural network (e.g. a keras Sequential). The
        network should either return a tensor with the same event shape as
        `input_shape` (this will employ additive coupling), a tensor with the
        same height and width as `input_shape` but twice the number of channels
        (this will employ affine coupling), or a bijector which takes in a
        tensor with event shape `input_shape`, and returns a tensor with shape
        `input_shape`.
    """

    nleave, ngrab, npass = blockwise_splits

    new_input_shape = input_shape[:-1]+(nleave,)
    target_output_shape = input_shape[:-1]+(ngrab,)

    # if nleave or ngrab == 0, then just use an identity for everything.
    if nleave == 0 or ngrab == 0:
      exit_layer = None
      exit_bijector_fn = None

      self.exit_layer = exit_layer
      shift_distribution = identity.Identity()

    else:
      exit_layer = coupling_bijector_fn(new_input_shape,
                                        output_chan=ngrab)
      exit_bijector_fn = self.make_bijector_fn(
          exit_layer,
          target_shape=target_output_shape,
          scale_fn=tf.exp)
      self.exit_layer = exit_layer  # For variable tracking.
      shift_distribution = real_nvp.RealNVP(
          num_masked=nleave,
          bijector_fn=exit_bijector_fn)

    super(ExitBijector, self).__init__(
        [shift_distribution, identity.Identity()], [nleave + ngrab, npass])
Exemplo n.º 3
0
        def sample_and_trace_fn(dist, value, **_):
            bij = self._bijector_fn(dist)
            if bij is None:
                bij = identity_bijector.Identity()

            # If the RV is not yet constrained, transform it.
            value = value if constrained else bij.forward(value)
            return jd_lib.ValueWithTrace(value=value, traced=bij)
Exemplo n.º 4
0
    def _conditioned_bijectors(self, samples, constrained=False):
        if samples is None:
            return self.bijectors

        bijectors = []
        gen = self._jd._model_coroutine()
        cond = None
        for rv in self._jd._model_flatten(samples):
            d = gen.send(cond)
            dist = d.distribution if type(d).__name__ == 'Root' else d
            bij = self._bijector_fn(dist)

            if bij is None:
                bij = identity_bijector.Identity()
            bijectors.append(bij)

            # If the RV is not yet constrained, transform it.
            cond = rv if constrained else bij.forward(rv)
        return bijectors
Exemplo n.º 5
0
    def __init__(self, input_shape, num_steps, coupling_bijector_fn,
                 use_actnorm, seedstream):
        parameters = dict(locals())
        rnvp_block = [identity.Identity()]
        this_nchan = input_shape[-1]

        for j in range(num_steps):  # pylint: disable=unused-variable

            this_layer_input_shape = input_shape[:-1] + (input_shape[-1] //
                                                         2, )
            this_layer = coupling_bijector_fn(this_layer_input_shape)
            bijector_fn = self.make_bijector_fn(this_layer)

            # For each step in the block, we do (optional) actnorm, followed
            # by an invertible 1x1 convolution, then affine coupling.
            this_rnvp = invert.Invert(
                real_nvp.RealNVP(this_nchan // 2, bijector_fn=bijector_fn))

            # Append the layer to the realNVP bijector for variable tracking.
            this_rnvp.coupling_bijector_layer = this_layer
            rnvp_block.append(this_rnvp)

            rnvp_block.append(
                invert.Invert(
                    OneByOneConv(this_nchan,
                                 seed=seedstream(),
                                 dtype=dtype_util.common_dtype(
                                     this_rnvp.variables,
                                     dtype_hint=tf.float32))))

            if use_actnorm:
                rnvp_block.append(
                    ActivationNormalization(this_nchan,
                                            dtype=dtype_util.common_dtype(
                                                this_rnvp.variables,
                                                dtype_hint=tf.float32)))

        # Note that we reverse the list since Chain applies bijectors in reverse
        # order.
        super(GlowBlock, self).__init__(chain.Chain(rnvp_block[::-1]),
                                        parameters=parameters,
                                        name='glow_block')
Exemplo n.º 6
0
    def _evaluate_bijector(self, bijector_fn, values):
        gen = self._jd._model_coroutine()
        outputs = []
        d = next(gen)
        index = 0
        try:
            while True:
                dist = d.distribution if type(d).__name__ == 'Root' else d
                bijector = dist._experimental_default_event_space_bijector()

                # For discrete distributions, the default event space bijector is None.
                # For a joint distribution's discrete components, we want the behavior
                # of the Identity bijector.
                bijector = (identity_bijector.Identity()
                            if bijector is None else bijector)

                out, y = bijector_fn(bijector, values[index])
                outputs.append(out)
                d = gen.send(y)
                index += 1
        except StopIteration:
            pass
        return outputs
Exemplo n.º 7
0
def build_factored_surrogate_posterior(
    event_shape=None,
    bijector=None,
    constraining_bijectors=None,
    initial_unconstrained_loc=_sample_uniform_initial_loc,
    initial_unconstrained_scale=1e-2,
    trainable_distribution_fn=_build_trainable_normal_dist,
    seed=None,
    validate_args=False,
    name=None):
  """Builds a joint variational posterior that factors over model variables.

  By default, this method creates an independent trainable Normal distribution
  for each variable, transformed using a bijector (if provided) to
  match the support of that variable. This makes extremely strong
  assumptions about the posterior: that it is approximately normal (or
  transformed normal), and that all model variables are independent.

  Args:
    event_shape: `Tensor` shape, or nested structure of `Tensor` shapes,
      specifying the event shape(s) of the posterior variables.
    bijector: Optional `tfb.Bijector` instance, or nested structure of such
      instances, defining support(s) of the posterior variables. The structure
      must match that of `event_shape` and may contain `None` values. A
      posterior variable will be modeled as
      `tfd.TransformedDistribution(underlying_dist, bijector)` if a
      corresponding constraining bijector is specified, otherwise it is modeled
      as supported on the unconstrained real line.
    constraining_bijectors: Deprecated alias for `bijector`.
    initial_unconstrained_loc: Optional Python `callable` with signature
      `tensor = initial_unconstrained_loc(shape, seed)` used to sample
      real-valued initializations for the unconstrained representation of each
      variable. May alternately be a nested structure of
      `Tensor`s, giving specific initial locations for each variable; these
      must have structure matching `event_shape` and shapes determined by the
      inverse image of `event_shape` under `bijector`, which may optionally be
      prefixed with a common batch shape.
      Default value: `functools.partial(tf.random.uniform,
        minval=-2., maxval=2., dtype=tf.float32)`.
    initial_unconstrained_scale: Optional scalar float `Tensor` initial
      scale for the unconstrained distributions, or a nested structure of
      `Tensor` initial scales for each variable.
      Default value: `1e-2`.
    trainable_distribution_fn: Optional Python `callable` with signature
      `trainable_dist = trainable_distribution_fn(initial_loc, initial_scale,
      event_ndims, validate_args)`. This is called for each model variable to
      build the corresponding factor in the surrogate posterior. It is expected
      that the distribution returned is supported on unconstrained real values.
      Default value: `functools.partial(
        tfp.experimental.vi.build_trainable_location_scale_distribution,
        distribution_fn=tfd.Normal)`, i.e., a trainable Normal distribution.
    seed: Python integer to seed the random number generator. This is used
      only when `initial_loc` is not specified.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_factored_surrogate_posterior').

  Returns:
    surrogate_posterior: A `tfd.Distribution` instance whose samples have
      shape and structure matching that of `event_shape` or `initial_loc`.

  ### Examples

  Consider a Gamma model with unknown parameters, expressed as a joint
  Distribution:

  ```python
  Root = tfd.JointDistributionCoroutine.Root
  def model_fn():
    concentration = yield Root(tfd.Exponential(1.))
    rate = yield Root(tfd.Exponential(1.))
    y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate),
                         sample_shape=4)
  model = tfd.JointDistributionCoroutine(model_fn)
  ```

  Let's use variational inference to approximate the posterior over the
  data-generating parameters for some observed `y`. We'll build a
  surrogate posterior distribution by specifying the shapes of the latent
  `rate` and `concentration` parameters, and that both are constrained to
  be positive.

  ```python
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=model.event_shape_tensor()[:-1],  # Omit the observed `y`.
    bijector=[tfb.Softplus(),   # Rate is positive.
              tfb.Softplus()])  # Concentration is positive.
  ```

  This creates a trainable joint distribution, defined by variables in
  `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior`
  to fit this distribution by minimizing a divergence to the true posterior.

  ```python
  y = [0.2, 0.5, 0.3, 0.7]
  losses = tfp.vi.fit_surrogate_posterior(
    lambda rate, concentration: model.log_prob([rate, concentration, y]),
    surrogate_posterior=surrogate_posterior,
    num_steps=100,
    optimizer=tf.optimizers.Adam(0.1),
    sample_size=10)

  # After optimization, samples from the surrogate will approximate
  # samples from the true posterior.
  samples = surrogate_posterior.sample(100)
  posterior_mean = [tf.reduce_mean(x) for x in samples]     # mean ~= [1.1, 2.1]
  posterior_std = [tf.math.reduce_std(x) for x in samples]  # std  ~= [0.3, 0.8]
  ```

  If we wanted to initialize the optimization at a specific location, we can
  specify one when we build the surrogate posterior. This function requires the
  initial location to be specified in *unconstrained* space; we do this by
  inverting the constraining bijectors (note this section also demonstrates the
  creation of a dict-structured model).

  ```python
  initial_loc = {'concentration': 0.4, 'rate': 0.2}
  bijector={'concentration': tfb.Softplus(),   # Rate is positive.
            'rate': tfb.Softplus()}   # Concentration is positive.
  initial_unconstrained_loc = tf.nest.map_fn(
    lambda b, x: b.inverse(x) if b is not None else x, bijector, initial_loc)
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=tf.nest.map_fn(tf.shape, initial_loc),
    bijector=bijector,
    initial_unconstrained_loc=initial_unconstrained_state,
    initial_unconstrained_scale=1e-4)
  ```

  """

  with tf.name_scope(name or 'build_factored_surrogate_posterior'):
    bijector = deprecation.deprecated_argument_lookup(
        'bijector', bijector, 'constraining_bijectors', constraining_bijectors)

    seed = tfp_util.SeedStream(seed, salt='build_factored_surrogate_posterior')

    # Convert event shapes to Tensors.
    shallow_structure = _get_event_shape_shallow_structure(event_shape)
    event_shape = nest.map_structure_up_to(
        shallow_structure, lambda s: tf.convert_to_tensor(s, dtype=tf.int32),
        event_shape)

    if nest.is_nested(bijector):
      bijector = nest.map_structure(
          lambda b: identity.Identity() if b is None else b,
          bijector)

      # Support mismatched nested structures for backwards compatibility (e.g.
      # non-nested `event_shape` and a single-element list of `bijector`s).
      bijector = nest.pack_sequence_as(event_shape, nest.flatten(bijector))

      event_space_bijector = joint_map.JointMap(
          bijector, validate_args=validate_args)
    else:
      event_space_bijector = bijector

    if event_space_bijector is None:
      unconstrained_event_shape = event_shape
    else:
      unconstrained_event_shape = (
          event_space_bijector.inverse_event_shape_tensor(event_shape))

    # Construct initial locations for the internal unconstrained dists.
    if callable(initial_unconstrained_loc):  # Sample random initialization.
      initial_unconstrained_loc = nest.map_structure(
          lambda s: initial_unconstrained_loc(shape=s, seed=seed()),
          unconstrained_event_shape)

    if not nest.is_nested(initial_unconstrained_scale):
      initial_unconstrained_scale = nest.map_structure(
          lambda _: initial_unconstrained_scale,
          unconstrained_event_shape)

    # Extract the rank of each event, so that we build distributions with the
    # correct event shapes.
    unconstrained_event_ndims = nest.map_structure(
        ps.rank_from_shape,
        unconstrained_event_shape)

    # Build the component surrogate posteriors.
    unconstrained_distributions = nest.map_structure_up_to(
        unconstrained_event_shape,
        lambda loc, scale, ndims: trainable_distribution_fn(  # pylint: disable=g-long-lambda
            loc, scale, ndims, validate_args=validate_args),
        initial_unconstrained_loc,
        initial_unconstrained_scale,
        unconstrained_event_ndims)

    base_distribution = (
        joint_distribution_util.independent_joint_distribution_from_structure(
            unconstrained_distributions, validate_args=validate_args))
    if event_space_bijector is None:
      return base_distribution
    return transformed_distribution.TransformedDistribution(
        base_distribution, event_space_bijector)
Exemplo n.º 8
0
def build_split_flow_surrogate_posterior(event_shape,
                                         trainable_bijector,
                                         constraining_bijector=None,
                                         base_distribution=normal.Normal,
                                         batch_shape=(),
                                         dtype=tf.float32,
                                         validate_args=False,
                                         name=None):
    """Builds a joint variational posterior by splitting a normalizing flow.

  Args:
    event_shape: (Nested) event shape of the surrogate posterior.
    trainable_bijector: A trainable `tfb.Bijector` instance that operates on
      `Tensor`s (not structures), e.g. `tfb.MaskedAutoregressiveFlow` or
      `tfb.RealNVP`. This bijector transforms the base distribution before it is
      split.
    constraining_bijector: `tfb.Bijector` instance, or nested structure of
      `tfb.Bijector` instances, that maps (nested) values in R^n to the support
      of the posterior. (This can be the
      `experimental_default_event_space_bijector` of the distribution over the
      prior latent variables.)
      Default value: `None` (i.e., the posterior is over R^n).
    base_distribution: A `tfd.Distribution` subclass parameterized by `loc` and
      `scale`. The base distribution for the transformed surrogate has `loc=0.`
      and `scale=1.`.
      Default value: `tfd.Normal`.
    batch_shape: The `batch_shape` of the output distribution.
      Default value: `()`.
    dtype: The `dtype` of the surrogate posterior.
      Default value: `tf.float32`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_split_flow_surrogate_posterior').

  Returns:
    surrogate_distribution: Trainable `tfd.TransformedDistribution` with event
      shape equal to `event_shape`.

  ### Examples
  ```python

  # Train a normalizing flow on the Eight Schools model [1].

  treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
  treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]
  model = tfd.JointDistributionNamed({
      'avg_effect':
          tfd.Normal(loc=0., scale=10., name='avg_effect'),
      'log_stddev':
          tfd.Normal(loc=5., scale=1., name='log_stddev'),
      'school_effects':
          lambda log_stddev, avg_effect: (
              tfd.Independent(
                  tfd.Normal(
                      loc=avg_effect[..., None] * tf.ones(8),
                      scale=tf.exp(log_stddev[..., None]) * tf.ones(8),
                      name='school_effects'),
                  reinterpreted_batch_ndims=1)),
      'treatment_effects': lambda school_effects: tfd.Independent(
          tfd.Normal(loc=school_effects, scale=treatment_stddevs),
          reinterpreted_batch_ndims=1)
  })

  # Pin the observed values in the model.
  target_model = model.experimental_pin(treatment_effects=treatment_effects)

  # Create a Masked Autoregressive Flow bijector.
  net = tfb.AutoregressiveNetwork(2, hidden_units=[16, 16], dtype=tf.float32)
  maf = tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=net)

  # Build and fit the surrogate posterior.
  surrogate_posterior = (
      tfp.experimental.vi.build_split_flow_surrogate_posterior(
          event_shape=target_model.event_shape_tensor(),
          trainable_bijector=maf,
          constraining_bijector=(
              target_model.experimental_default_event_space_bijector())))

  losses = tfp.vi.fit_surrogate_posterior(
      target_model.unnormalized_log_prob,
      surrogate_posterior,
      num_steps=100,
      optimizer=tf.optimizers.Adam(0.1),
      sample_size=10)
  ```

  #### References

  [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and
      Donald Rubin. Bayesian Data Analysis, Third Edition.
      Chapman and Hall/CRC, 2013.

  """
    with tf.name_scope(name or 'build_split_flow_surrogate_posterior'):

        shallow_structure = _get_event_shape_shallow_structure(event_shape)
        event_shape = nest.map_structure_up_to(shallow_structure,
                                               ps.convert_to_shape_tensor,
                                               event_shape)

        if nest.is_nested(constraining_bijector):
            constraining_bijector = joint_map.JointMap(
                nest.map_structure(
                    lambda b: identity.Identity()
                    if b is None else b, constraining_bijector),
                validate_args=validate_args)

        if constraining_bijector is None:
            unconstrained_event_shape = event_shape
        else:
            unconstrained_event_shape = (
                constraining_bijector.inverse_event_shape_tensor(event_shape))

        flat_base_event_shape = nest.flatten(unconstrained_event_shape)
        flat_base_event_size = nest.map_structure(tf.reduce_prod,
                                                  flat_base_event_shape)
        event_size = tf.reduce_sum(flat_base_event_size)

        base_distribution = sample.Sample(
            base_distribution(tf.zeros(batch_shape, dtype=dtype), scale=1.),
            [event_size])

        # After transforming base distribution samples with `trainable_bijector`,
        # split them into vector-valued components.
        split_bijector = split.Split(flat_base_event_size,
                                     validate_args=validate_args)

        # Reshape the vectors to the correct posterior event shape.
        event_reshape = joint_map.JointMap(nest.map_structure(
            reshape.Reshape, unconstrained_event_shape),
                                           validate_args=validate_args)

        # Restructure the flat list of components to the correct posterior
        # structure.
        event_unflatten = restructure.Restructure(
            nest.pack_sequence_as(unconstrained_event_shape,
                                  range(len(flat_base_event_shape))))

        bijectors = [] if constraining_bijector is None else [
            constraining_bijector
        ]
        bijectors.extend([
            event_reshape, event_unflatten, split_bijector, trainable_bijector
        ])
        bijector = chain.Chain(bijectors, validate_args=validate_args)

        return transformed_distribution.TransformedDistribution(
            base_distribution, bijector=bijector, validate_args=validate_args)
Exemplo n.º 9
0
def _factored_surrogate_posterior(  # pylint: disable=dangerous-default-value
        event_shape=None,
        bijector=None,
        batch_shape=(),
        base_distribution_cls=normal.Normal,
        initial_parameters={'scale': 1e-2},
        dtype=tf.float32,
        validate_args=False,
        name=None):
    """Builds a joint variational posterior that factors over model variables.

  By default, this method creates an independent trainable Normal distribution
  for each variable, transformed using a bijector (if provided) to
  match the support of that variable. This makes extremely strong
  assumptions about the posterior: that it is approximately normal (or
  transformed normal), and that all model variables are independent.

  Args:
    event_shape: `Tensor` shape, or nested structure of `Tensor` shapes,
      specifying the event shape(s) of the posterior variables.
    bijector: Optional `tfb.Bijector` instance, or nested structure of such
      instances, defining support(s) of the posterior variables. The structure
      must match that of `event_shape` and may contain `None` values. A
      posterior variable will be modeled as
      `tfd.TransformedDistribution(underlying_dist, bijector)` if a
      corresponding constraining bijector is specified, otherwise it is modeled
      as supported on the unconstrained real line.
    batch_shape: The `batch_shape` of the output distribution.
      Default value: `()`.
    base_distribution_cls: Subclass of `tfd.Distribution` that is instantiated
      and optionally transformed by the bijector to define the component
      distributions. May optionally be a structure of such subclasses
      matching `event_shape`.
      Default value: `tfd.Normal`.
    initial_parameters: Optional `str : Tensor` dictionary specifying initial
      values for some or all of the base distribution's trainable parameters,
      or a Python `callable` with signature
      `value = parameter_init_fn(parameter_name, shape, dtype, seed,
      constraining_bijector)`, passed to `tfp.experimental.util.make_trainable`.
      May optionally be a structure matching `event_shape` of such dictionaries
      and/or callables. Dictionary entries that do not correspond to parameter
      names are ignored.
      Default value: `{'scale': 1e-2}` (ignored when `base_distribution` does
        not have a `scale` parameter).
    dtype: Optional float `dtype` for trainable parameters. May
      optionally be a structure of such `dtype`s matching `event_shape`.
      Default value: `tf.float32`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_factored_surrogate_posterior').
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.

  ### Examples

  Consider a Gamma model with unknown parameters, expressed as a joint
  Distribution:

  ```python
  Root = tfd.JointDistributionCoroutine.Root
  def model_fn():
    concentration = yield Root(tfd.Exponential(1.))
    rate = yield Root(tfd.Exponential(1.))
    y = yield tfd.Sample(tfd.Gamma(concentration=concentration, rate=rate),
                         sample_shape=4)
  model = tfd.JointDistributionCoroutine(model_fn)
  ```

  Let's use variational inference to approximate the posterior over the
  data-generating parameters for some observed `y`. We'll build a
  surrogate posterior distribution by specifying the shapes of the latent
  `rate` and `concentration` parameters, and that both are constrained to
  be positive.

  ```python
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=model.event_shape_tensor()[:-1],  # Omit the observed `y`.
    bijector=[tfb.Softplus(),   # Rate is positive.
              tfb.Softplus()])  # Concentration is positive.
  ```

  This creates a trainable joint distribution, defined by variables in
  `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior`
  to fit this distribution by minimizing a divergence to the true posterior.

  ```python
  y = [0.2, 0.5, 0.3, 0.7]
  losses = tfp.vi.fit_surrogate_posterior(
    lambda rate, concentration: model.log_prob([rate, concentration, y]),
    surrogate_posterior=surrogate_posterior,
    num_steps=100,
    optimizer=tf.optimizers.Adam(0.1),
    sample_size=10)

  # After optimization, samples from the surrogate will approximate
  # samples from the true posterior.
  samples = surrogate_posterior.sample(100)
  posterior_mean = [tf.reduce_mean(x) for x in samples]     # mean ~= [1.1, 2.1]
  posterior_std = [tf.math.reduce_std(x) for x in samples]  # std  ~= [0.3, 0.8]
  ```

  If we wanted to initialize the optimization at a specific location, we can
  specify initial parameters when we build the surrogate posterior. Note that
  these parameterize the distribution(s) over unconstrained values,
  so we need to transform our desired constrained locations using the inverse
  of the constraining bijector(s).

  ```python
  surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior(
    event_shape=tf.nest.map_fn(tf.shape, initial_loc),
    bijector={'concentration': tfb.Softplus(),   # Rate is positive.
              'rate': tfb.Softplus()}   # Concentration is positive.
    initial_parameters={
      'concentration': {'loc': tfb.Softplus().inverse(0.4), 'scale': 1e-2},
      'rate': {'loc': tfb.Softplus().inverse(0.2), 'scale': 1e-2}})
  ```

  """
    with tf.name_scope(name or 'build_factored_surrogate_posterior'):
        # Convert event shapes to Tensors.
        shallow_structure = _get_event_shape_shallow_structure(event_shape)
        event_shape = nest.map_structure_up_to(
            shallow_structure,
            lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape)

        if nest.is_nested(bijector):
            event_space_bijector = joint_map.JointMap(
                nest.map_structure(
                    lambda b: identity.Identity() if b is None else b,
                    nest_util.coerce_structure(event_shape, bijector)),
                validate_args=validate_args)
        else:
            event_space_bijector = bijector

        if event_space_bijector is None:
            unconstrained_event_shape = event_shape
        else:
            unconstrained_event_shape = (
                event_space_bijector.inverse_event_shape_tensor(event_shape))
        unconstrained_batch_and_event_shape = tf.nest.map_structure(
            lambda s: ps.concat([batch_shape, s], axis=0),
            unconstrained_event_shape)

        base_distribution_cls = nest_util.broadcast_structure(
            event_shape, base_distribution_cls)
        try:
            # Check that we have initial parameters for each event part.
            nest.assert_shallow_structure(event_shape, initial_parameters)
        except (ValueError, TypeError):
            # If not, broadcast the parameters to match the event structure.
            # We do this manually rather than using `nest_util.broadcast_structure`
            # because the initial parameters can themselves be structures (dicts).
            initial_parameters = nest.map_structure(
                lambda x: initial_parameters, event_shape)

        unconstrained_trainable_distributions = yield from (
            nest_util.map_structure_coroutine(
                trainable._make_trainable,  # pylint: disable=protected-access
                cls=base_distribution_cls,
                initial_parameters=initial_parameters,
                batch_and_event_shape=unconstrained_batch_and_event_shape,
                parameter_dtype=nest_util.broadcast_structure(
                    event_shape, dtype),
                _up_to=event_shape))
        unconstrained_trainable_distribution = (
            joint_distribution_util.
            independent_joint_distribution_from_structure(
                unconstrained_trainable_distributions,
                batch_ndims=ps.rank_from_shape(batch_shape),
                validate_args=validate_args))
        if event_space_bijector is None:
            return unconstrained_trainable_distribution
        return transformed_distribution.TransformedDistribution(
            unconstrained_trainable_distribution, event_space_bijector)
Exemplo n.º 10
0
def _affine_surrogate_posterior_from_base_distribution(
        base_distribution,
        operators='diag',
        bijector=None,
        initial_unconstrained_loc_fn=_sample_uniform_initial_loc,
        validate_args=False,
        name=None):
    """Builds a variational posterior by linearly transforming base distributions.

  This function builds a surrogate posterior by applying a trainable
  transformation to a base distribution (typically a `tfd.JointDistribution`) or
  nested structure of base distributions, and constraining the samples with
  `bijector`. Note that the distributions must have event shapes corresponding
  to the *pretransformed* surrogate posterior -- that is, if `bijector` contains
  a shape-changing bijector, then the corresponding base distribution event
  shape is the inverse event shape of the bijector applied to the desired
  surrogate posterior shape. The surrogate posterior is constucted as follows:

  1. Flatten the base distribution event shapes to vectors, and pack the base
     distributions into a `tfd.JointDistribution`.
  2. Apply a trainable blockwise LinearOperator bijector to the joint base
     distribution.
  3. Apply the constraining bijectors and return the resulting trainable
     `tfd.TransformedDistribution` instance.

  Args:
    base_distribution: `tfd.Distribution` instance (typically a
      `tfd.JointDistribution`), or a nested structure of `tfd.Distribution`
      instances.
    operators: Either a string or a list/tuple containing `LinearOperator`
      subclasses, `LinearOperator` instances, or callables returning
      `LinearOperator` instances. Supported string values are "diag" (to create
      a mean-field surrogate posterior) and "tril" (to create a full-covariance
      surrogate posterior). A list/tuple may be passed to induce other
      posterior covariance structures. If the list is flat, a
      `tf.linalg.LinearOperatorBlockDiag` instance will be created and applied
      to the base distribution. Otherwise the list must be singly-nested and
      have a first element of length 1, second element of length 2, etc.; the
      elements of the outer list are interpreted as rows of a lower-triangular
      block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular`
      instance is created. For complete documentation and examples, see
      `tfp.experimental.vi.util.build_trainable_linear_operator_block`, which
      receives the `operators` arg if it is list-like.
      Default value: `"diag"`.
    bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector`
      instances, that maps (nested) values in R^n to the support of the
      posterior. (This can be the `experimental_default_event_space_bijector` of
      the distribution over the prior latent variables.)
      Default value: `None` (i.e., the posterior is over R^n).
    initial_unconstrained_loc_fn: Optional Python `callable` with signature
      `initial_loc = initial_unconstrained_loc_fn(shape, dtype, seed)` used to
      sample real-valued initializations for the unconstrained location of
      each variable.
      Default value: `functools.partial(tf.random.stateless_uniform,
        minval=-2., maxval=2., dtype=tf.float32)`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e.,
      'build_affine_surrogate_posterior_from_base_distribution').
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.
  Raises:
    NotImplementedError: Base distributions with mixed dtypes are not supported.

  #### Examples
  ```python
  tfd = tfp.distributions
  tfb = tfp.bijectors

  # Fit a multivariate Normal surrogate posterior on the Eight Schools model
  # [1].

  treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
  treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]

  def model_fn():
    avg_effect = yield tfd.Normal(loc=0., scale=10., name='avg_effect')
    log_stddev = yield tfd.Normal(loc=5., scale=1., name='log_stddev')
    school_effects = yield tfd.Sample(
        tfd.Normal(loc=avg_effect, scale=tf.exp(log_stddev)),
        sample_shape=[8],
        name='school_effects')
    treatment_effects = yield tfd.Independent(
        tfd.Normal(loc=school_effects, scale=treatment_stddevs),
        reinterpreted_batch_ndims=1,
        name='treatment_effects')
  model = tfd.JointDistributionCoroutineAutoBatched(model_fn)

  # Pin the observed values in the model.
  target_model = model.experimental_pin(treatment_effects=treatment_effects)

  # Define a lower triangular structure of `LinearOperator` subclasses that
  # models full covariance among latent variables except for the 8 dimensions
  # of `school_effect`, which are modeled as independent (using
  # `LinearOperatorDiag`).
  operators = [
    [tf.linalg.LinearOperatorLowerTriangular],
    [tf.linalg.LinearOperatorFullMatrix, LinearOperatorLowerTriangular],
    [tf.linalg.LinearOperatorFullMatrix, LinearOperatorFullMatrix,
     tf.linalg.LinearOperatorDiag]]


  # Constrain the posterior values to the support of the prior.
  bijector = target_model.experimental_default_event_space_bijector()

  # Build a full-covariance surrogate posterior.
  surrogate_posterior = (
    tfp.experimental.vi.build_affine_surrogate_posterior_from_base_distribution(
        base_distribution=base_distribution,
        operators=operators,
        bijector=bijector))

  # Fit the model.
  losses = tfp.vi.fit_surrogate_posterior(
      target_model.unnormalized_log_prob,
      surrogate_posterior,
      num_steps=100,
      optimizer=tf.optimizers.Adam(0.1),
      sample_size=10)
  ```

  #### References

  [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and
      Donald Rubin. Bayesian Data Analysis, Third Edition.
      Chapman and Hall/CRC, 2013.

  """
    with tf.name_scope(name
                       or 'affine_surrogate_posterior_from_base_distribution'):

        if nest.is_nested(base_distribution):
            base_distribution = (joint_distribution_util.
                                 independent_joint_distribution_from_structure(
                                     base_distribution,
                                     validate_args=validate_args))

        if nest.is_nested(bijector):
            bijector = joint_map.JointMap(nest.map_structure(
                lambda b: identity.Identity() if b is None else b, bijector),
                                          validate_args=validate_args)

        batch_shape = base_distribution.batch_shape_tensor()
        if tf.nest.is_nested(
                batch_shape):  # Base is a classic JointDistribution.
            batch_shape = functools.reduce(ps.broadcast_shape,
                                           tf.nest.flatten(batch_shape))
        event_shape = base_distribution.event_shape_tensor()
        flat_event_size = nest.flatten(
            nest.map_structure(ps.reduce_prod, event_shape))

        base_dtypes = set([
            dtype_util.base_dtype(d)
            for d in nest.flatten(base_distribution.dtype)
        ])
        if len(base_dtypes) > 1:
            raise NotImplementedError(
                'Base distributions with mixed dtype are not supported. Saw '
                'components of dtype {}'.format(base_dtypes))
        base_dtype = list(base_dtypes)[0]

        num_components = len(flat_event_size)
        if operators == 'diag':
            operators = [tf.linalg.LinearOperatorDiag] * num_components
        elif operators == 'tril':
            operators = [[tf.linalg.LinearOperatorFullMatrix] * i +
                         [tf.linalg.LinearOperatorLowerTriangular]
                         for i in range(num_components)]
        elif isinstance(operators, str):
            raise ValueError(
                'Unrecognized operator type {}. Valid operators are "diag", "tril", '
                'or a structure that can be passed to '
                '`tfp.experimental.vi.util.build_trainable_linear_operator_block` as '
                'the `operators` arg.'.format(operators))

        if nest.is_nested(operators):
            operators = yield from trainable_linear_operators._trainable_linear_operator_block(  # pylint: disable=protected-access
                operators,
                block_dims=flat_event_size,
                dtype=base_dtype,
                batch_shape=batch_shape)

        linop_bijector = (
            scale_matvec_linear_operator.ScaleMatvecLinearOperatorBlock(
                scale=operators, validate_args=validate_args))

        def generate_shift_bijector(s):
            x = yield trainable_state_util.Parameter(
                functools.partial(initial_unconstrained_loc_fn,
                                  ps.concat([batch_shape, [s]], axis=0),
                                  dtype=base_dtype))
            return shift.Shift(x)

        loc_bijectors = yield from nest_util.map_structure_coroutine(
            generate_shift_bijector, flat_event_size)
        loc_bijector = joint_map.JointMap(loc_bijectors,
                                          validate_args=validate_args)

        unflatten_and_reshape = chain.Chain([
            joint_map.JointMap(nest.map_structure(reshape.Reshape,
                                                  event_shape),
                               validate_args=validate_args),
            restructure.Restructure(
                nest.pack_sequence_as(event_shape, range(num_components)))
        ],
                                            validate_args=validate_args)

        bijectors = [] if bijector is None else [bijector]
        bijectors.extend([
            unflatten_and_reshape,
            loc_bijector,  # Allow the mean of the standard dist to shift from 0.
            linop_bijector
        ])  # Apply LinOp to scale the standard dist.
        bijector = chain.Chain(bijectors, validate_args=validate_args)

        flat_base_distribution = invert.Invert(unflatten_and_reshape)(
            base_distribution)

        return transformed_distribution.TransformedDistribution(
            flat_base_distribution,
            bijector=bijector,
            validate_args=validate_args)
Exemplo n.º 11
0
def _affine_surrogate_posterior(event_shape,
                                operators='diag',
                                bijector=None,
                                base_distribution=normal.Normal,
                                dtype=tf.float32,
                                batch_shape=(),
                                validate_args=False,
                                name=None):
    """Builds a joint variational posterior with a given `event_shape`.

  This function builds a surrogate posterior by applying a trainable
  transformation to a standard base distribution and constraining the samples
  with `bijector`. The surrogate posterior has event shape equal to
  the input `event_shape`.

  This function is a convenience wrapper around
  `build_affine_surrogate_posterior_from_base_distribution` that allows the
  user to pass in the desired posterior `event_shape` instead of
  pre-constructed base distributions (at the expense of full control over the
  base distribution types and parameterizations).

  Args:
    event_shape: (Nested) event shape of the posterior.
    operators: Either a string or a list/tuple containing `LinearOperator`
      subclasses, `LinearOperator` instances, or callables returning
      `LinearOperator` instances. Supported string values are "diag" (to create
      a mean-field surrogate posterior) and "tril" (to create a full-covariance
      surrogate posterior). A list/tuple may be passed to induce other
      posterior covariance structures. If the list is flat, a
      `tf.linalg.LinearOperatorBlockDiag` instance will be created and applied
      to the base distribution. Otherwise the list must be singly-nested and
      have a first element of length 1, second element of length 2, etc.; the
      elements of the outer list are interpreted as rows of a lower-triangular
      block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular`
      instance is created. For complete documentation and examples, see
      `tfp.experimental.vi.util.build_trainable_linear_operator_block`, which
      receives the `operators` arg if it is list-like.
      Default value: `"diag"`.
    bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector`
      instances, that maps (nested) values in R^n to the support of the
      posterior. (This can be the `experimental_default_event_space_bijector` of
      the distribution over the prior latent variables.)
      Default value: `None` (i.e., the posterior is over R^n).
    base_distribution: A `tfd.Distribution` subclass parameterized by `loc` and
      `scale`. The base distribution of the transformed surrogate has `loc=0.`
      and `scale=1.`.
      Default value: `tfd.Normal`.
    dtype: The `dtype` of the surrogate posterior.
      Default value: `tf.float32`.
    batch_shape: Batch shape (Python tuple, list, or int) of the surrogate
      posterior, to enable parallel optimization from multiple initializations.
      Default value: `()`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e., 'build_affine_surrogate_posterior').
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.

  #### Examples

  ```python
  tfd = tfp.distributions
  tfb = tfp.bijectors

  # Define a joint probabilistic model.
  Root = tfd.JointDistributionCoroutine.Root
  def model_fn():
    concentration = yield Root(tfd.Exponential(1.))
    rate = yield Root(tfd.Exponential(1.))
    y = yield tfd.Sample(
        tfd.Gamma(concentration=concentration, rate=rate),
        sample_shape=4)
  model = tfd.JointDistributionCoroutine(model_fn)

  # Assume the `y` are observed, such that the posterior is a joint distribution
  # over `concentration` and `rate`. The posterior event shape is then equal to
  # the first two components of the model's event shape.
  posterior_event_shape = model.event_shape_tensor()[:-1]

  # Constrain the posterior values to be positive using the `Exp` bijector.
  bijector = [tfb.Exp(), tfb.Exp()]

  # Build a full-covariance surrogate posterior.
  surrogate_posterior = (
    tfp.experimental.vi.build_affine_surrogate_posterior(
        event_shape=posterior_event_shape,
        operators='tril',
        bijector=bijector))

  # For an example defining `'operators'` as a list to express an alternative
  # covariance structure, see
  # `build_affine_surrogate_posterior_from_base_distribution`.

  # Fit the model.
  y = [0.2, 0.5, 0.3, 0.7]
  target_model = model.experimental_pin(y=y)
  losses = tfp.vi.fit_surrogate_posterior(
      target_model.unnormalized_log_prob,
      surrogate_posterior,
      num_steps=100,
      optimizer=tf.optimizers.Adam(0.1),
      sample_size=10)
  ```
  """
    with tf.name_scope(name or 'build_affine_surrogate_posterior'):

        event_shape = nest.map_structure_up_to(
            _get_event_shape_shallow_structure(event_shape),
            lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape)

        if nest.is_nested(bijector):
            bijector = joint_map.JointMap(nest.map_structure(
                lambda b: identity.Identity() if b is None else b, bijector),
                                          validate_args=validate_args)

        if bijector is None:
            unconstrained_event_shape = event_shape
        else:
            unconstrained_event_shape = (
                bijector.inverse_event_shape_tensor(event_shape))

        standard_base_distribution = nest.map_structure(
            lambda s: base_distribution(loc=tf.zeros([], dtype=dtype),
                                        scale=1.), unconstrained_event_shape)
        standard_base_distribution = nest.map_structure(
            lambda d, s: (  # pylint: disable=g-long-lambda
                sample.Sample(d, sample_shape=s, validate_args=validate_args)
                if distribution_util.shape_may_be_nontrivial(s) else d),
            standard_base_distribution,
            unconstrained_event_shape)
        if distribution_util.shape_may_be_nontrivial(batch_shape):
            standard_base_distribution = nest.map_structure(
                lambda d: batch_broadcast.BatchBroadcast(  # pylint: disable=g-long-lambda
                    d,
                    to_shape=batch_shape,
                    validate_args=validate_args),
                standard_base_distribution)

        surrogate_posterior = yield from _affine_surrogate_posterior_from_base_distribution(
            standard_base_distribution,
            operators=operators,
            bijector=bijector,
            validate_args=validate_args)
        return surrogate_posterior
Exemplo n.º 12
0
def _asvi_convex_update_for_base_distribution(dist,
                                              mean_field,
                                              initial_prior_weight,
                                              sample_shape=None,
                                              variables=None,
                                              seed=None):
    """Creates a trainable surrogate for a (non-meta, non-joint) distribution."""
    if variables is None:
        variables = {}

    posterior_batch_shape = dist.batch_shape_tensor()
    if sample_shape is not None:
        posterior_batch_shape = ps.concat([
            posterior_batch_shape,
            distribution_util.expand_to_vector(sample_shape)
        ],
                                          axis=0)

    # Create variables backing each parameter, if needed.
    all_parameter_properties = dist.parameter_properties(dtype=dist.dtype)
    for param, prior_value in dist.parameters.items():
        if (param in variables
                or param in (_NON_STATISTICAL_PARAMS + _NON_TRAINABLE_PARAMS)
                or prior_value is None):
            continue

        param_properties = all_parameter_properties[param]
        try:
            bijector = param_properties.default_constraining_bijector_fn()
        except NotImplementedError:
            bijector = identity.Identity()

        param_shape = ps.concat([
            posterior_batch_shape,
            ps.shape(prior_value)[ps.rank(prior_value) -
                                  param_properties.event_ndims:]
        ],
                                axis=0)

        prior_weight = (
            None if mean_field  # pylint: disable=g-long-ternary
            else tfp_util.TransformedVariable(initial_value=tf.fill(
                dims=param_shape,
                value=tf.cast(initial_prior_weight,
                              tf.convert_to_tensor(prior_value).dtype)),
                                              bijector=sigmoid.Sigmoid(),
                                              name='prior_weight/{}/{}'.format(
                                                  _get_name(dist), param)))

        # Initialize the mean-field parameter as a (constrained) standard
        # normal sample.
        seed, param_seed = samplers.split_seed(seed)
        variables[param] = ASVIParameters(
            prior_weight=prior_weight,
            mean_field_parameter=tfp_util.TransformedVariable(
                initial_value=bijector.forward(
                    samplers.normal(
                        shape=bijector.inverse_event_shape(param_shape),
                        seed=param_seed)),
                bijector=bijector,
                name='mean_field_parameter/{}/{}'.format(
                    _get_name(dist), param)))

    temp_params_dict = {'name': _get_name(dist)}
    for param, prior_value in dist.parameters.items():
        if param in (_NON_STATISTICAL_PARAMS +
                     _NON_TRAINABLE_PARAMS) or prior_value is None:
            temp_params_dict[param] = prior_value
        else:
            if mean_field:
                temp_params_dict[param] = variables[param].mean_field_parameter
            else:
                temp_params_dict[param] = (
                    variables[param].prior_weight * prior_value +
                    ((1. - variables[param].prior_weight) *
                     variables[param].mean_field_parameter))
    return type(dist)(**temp_params_dict), variables
Exemplo n.º 13
0
 def _default_event_space_bijector(self):
     return identity_bijector.Identity(validate_args=self.validate_args)
Exemplo n.º 14
0
  def __init__(self,
               output_shape=(32, 32, 3),
               num_glow_blocks=3,
               num_steps_per_block=32,
               coupling_bijector_fn=None,
               exit_bijector_fn=None,
               grab_after_block=None,
               use_actnorm=True,
               seed=None,
               validate_args=False,
               name='glow'):
    """Creates the Glow bijector.

    Args:
      output_shape: A list of integers, specifying the event shape of the
        output, of the bijectors forward pass (the image).  Specified as
        [H, W, C].
        Default Value: (32, 32, 3)
      num_glow_blocks: An integer, specifying how many downsampling levels to
        include in the model. This must divide equally into both H and W,
        otherwise the bijector would not be invertible.
        Default Value: 3
      num_steps_per_block: An integer specifying how many Affine Coupling and
        1x1 convolution layers to include at each level of the spatial
        hierarchy.
        Default Value: 32 (i.e. the value used in the original glow paper).
      coupling_bijector_fn: A function which takes the argument `input_shape`
        and returns a callable neural network (e.g. a keras.Sequential). The
        network should either return a tensor with the same event shape as
        `input_shape` (this will employ additive coupling), a tensor with the
        same height and width as `input_shape` but twice the number of channels
        (this will employ affine coupling), or a bijector which takes in a
        tensor with event shape `input_shape`, and returns a tensor with shape
        `input_shape`.
      exit_bijector_fn: Similar to coupling_bijector_fn, exit_bijector_fn is
        a function which takes the argument `input_shape` and `output_chan`
        and returns a callable neural network. The neural network it returns
        should take a tensor of shape `input_shape` as the input, and return
        one of three options: A tensor with `output_chan` channels, a tensor
        with `2 * output_chan` channels, or a bijector. Additional details can
        be found in the documentation for ExitBijector.
      grab_after_block: A tuple of floats, specifying what fraction of the
        remaining channels to remove following each glow block. Glow will take
        the integer floor of this number multiplied by the remaining number of
        channels. The default is half at each spatial hierarchy.
        Default value: None (this will take out half of the channels after each
          block.
      use_actnorm: A bool deciding whether or not to use actnorm. Data-dependent
        initialization is used to initialize this layer.
        Default value: `False`
      seed: A seed to control randomness in the 1x1 convolution initialization.
        Default value: `None` (i.e., non-reproducible sampling).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
        Default value: `False`
      name: Python `str`, name given to ops managed by this object.
        Default value: `'glow'`.
    """
    # Make sure that the input shape is fully defined.
    if not tensorshape_util.is_fully_defined(output_shape):
      raise ValueError('Shape must be fully defined.')
    if tensorshape_util.rank(output_shape) != 3:
      raise ValueError('Shape ndims must be 3 for images.  Your shape is'
                       '{}'.format(tensorshape_util.rank(output_shape)))

    num_glow_blocks_ = tf.get_static_value(num_glow_blocks)
    if (num_glow_blocks_ is None or
        int(num_glow_blocks_) != num_glow_blocks_ or
        num_glow_blocks_ < 1):
      raise ValueError('Argument `num_glow_blocks` must be a statically known'
                       'positive `int` (saw: {}).'.format(num_glow_blocks))
    num_glow_blocks = int(num_glow_blocks_)

    output_shape = tensorshape_util.as_list(output_shape)
    h, w, c = output_shape
    n = num_glow_blocks
    nsteps = num_steps_per_block

    # Default Glow: Half of the channels are split off after each block,
    # and after the final block, no channels are split off.
    if grab_after_block is None:
      grab_after_block = tuple([0.5] * (n - 1) + [0.])

    # Thing we know must be true: h and w are evenly divisible by 2, n times.
    # Otherwise, the squeeze bijector will not work.
    if w % 2**n != 0:
      raise ValueError('Width must be divisible by 2 at least n times.'
                       'Saw: {} % {} != 0'.format(w, 2**n))
    if h % 2**n != 0:
      raise ValueError('Height should be divisible by 2 at least n times.')
    if h // 2**n < 1:
      raise ValueError('num_glow_blocks ({0}) is too large. The image height '
                       '({1}) must be divisible by 2 no more than {2} '
                       'times.'.format(num_glow_blocks, h,
                                       int(np.log(h) / np.log(2.))))
    if w // 2**n < 1:
      raise ValueError('num_glow_blocks ({0}) is too large. The image width '
                       '({1}) must be divisible by 2 no more than {2} '
                       'times.'.format(num_glow_blocks, w,
                                       int(np.log(h) / np.log(2.))))

    # Other things we want to be true:
    # - The number of times we take must be equal to the number of glow blocks.
    if len(grab_after_block) != num_glow_blocks:
      raise ValueError('Length of grab_after_block ({0}) must match the number'
                       'of blocks ({1}).'.format(len(grab_after_block),
                                                 num_glow_blocks))

    self._blockwise_splits = self._get_blockwise_splits(output_shape,
                                                        grab_after_block[::-1])

    # Now check on the values of blockwise splits
    if any([bs[0] < 1 for bs in self._blockwise_splits]):
      first_offender = [bs[0] for bs in self._blockwise_splits].index(True)
      raise ValueError('At at least one exit, you are taking out all of your '
                       'channels, and therefore have no inputs to later blocks.'
                       ' Try setting grab_after_block to a lower value at index'
                       '{}.'.format(first_offender))

    if any(np.isclose(gab, 0) for gab in grab_after_block):
      # Special case: if specifically exiting no channels, then the exit is
      # just an identity bijector.
      pass
    elif any([bs[1] < 1 for bs in self._blockwise_splits]):
      first_offender = [bs[1] for bs in self._blockwise_splits].index(True)
      raise ValueError('At least one of your layers has < 1 output channels. '
                       'This means you set grab_at_block too small. '
                       'Try setting grab_after_block to a larger value at index'
                       '{}.'.format(first_offender))

    # Lets start to build our bijector. We assume that the distribution is 1
    # dimensional. First, lets reshape it to an image.
    glow_chain = [
        reshape.Reshape(
            event_shape_out=[h // 2**n, w // 2**n, c * 4**n],
            event_shape_in=[h * w * c])
    ]

    seedstream = SeedStream(seed=seed, salt='random_beta')

    for i in range(n):

      # This is the shape of the current tensor
      current_shape = (h // 2**n * 2**i, w // 2**n * 2**i, c * 4**(i + 1))

      # This is the shape of the input to both the glow block and exit bijector.
      this_nchan = sum(self._blockwise_splits[i][0:2])
      this_input_shape = (h // 2**n * 2**i, w // 2**n * 2**i, this_nchan)

      glow_chain.append(invert.Invert(ExitBijector(current_shape,
                                                   self._blockwise_splits[i],
                                                   exit_bijector_fn)))

      glow_block = GlowBlock(input_shape=this_input_shape,
                             num_steps=nsteps,
                             coupling_bijector_fn=coupling_bijector_fn,
                             use_actnorm=use_actnorm,
                             seedstream=seedstream)

      if self._blockwise_splits[i][2] == 0:
        # All channels are passed to the RealNVP
        glow_chain.append(glow_block)
      else:
        # Some channels are passed around the block.
        # This is done with the Blockwise bijector.
        glow_chain.append(
            blockwise.Blockwise(
                [glow_block, identity.Identity()],
                [sum(self._blockwise_splits[i][0:2]),
                 self._blockwise_splits[i][2]]))

      # Finally, lets expand the channels into spatial features.
      glow_chain.append(
          Expand(input_shape=[
              h // 2**n * 2**i,
              w // 2**n * 2**i,
              c * 4**n // 4**i,
          ]))

    glow_chain = glow_chain[::-1]
    # To finish off, we initialize the bijector with the chain we've built
    # This way, the rest of the model attributes are taken care of for us.
    super(Glow, self).__init__(
        bijectors=glow_chain, validate_args=validate_args, name=name)
Exemplo n.º 15
0
 def _default_event_space_bijector(self):
     return identity_bijector.Identity()
Exemplo n.º 16
0
def _default_constraining_bijector_fn():
    from tensorflow_probability.python.bijectors import identity as identity_bijector  # pylint:disable=g-import-not-at-top
    return identity_bijector.Identity()
Exemplo n.º 17
0
def _make_asvi_trainable_variables(prior,
                                   mean_field=False,
                                   initial_prior_weight=0.5):
  """Generates parameter dictionaries given a prior distribution and list."""
  with tf.name_scope('make_asvi_trainable_variables'):
    param_dicts = []
    prior_dists = prior._get_single_sample_distributions()  # pylint: disable=protected-access
    for dist in prior_dists:
      original_dist = dist.distribution if isinstance(dist, Root) else dist

      substituted_dist = _as_trainable_family(original_dist)

      # Grab the base distribution if it exists
      try:
        actual_dist = substituted_dist.distribution
      except AttributeError:
        actual_dist = substituted_dist

      new_params_dict = {}

      #  Build trainable ASVI representation for each distribution's parameters.
      parameter_properties = actual_dist.parameter_properties(
          dtype=actual_dist.dtype)

      if isinstance(original_dist, sample.Sample):
        posterior_batch_shape = ps.concat([
            actual_dist.batch_shape_tensor(),
            distribution_util.expand_to_vector(original_dist.sample_shape)
        ], axis=0)
      else:
        posterior_batch_shape = actual_dist.batch_shape_tensor()

      for param, value in actual_dist.parameters.items():

        if param in (_NON_STATISTICAL_PARAMS +
                     _NON_TRAINABLE_PARAMS) or value is None:
          continue

        actual_event_shape = parameter_properties[param].shape_fn(
            actual_dist.event_shape_tensor())
        try:
          bijector = parameter_properties[
              param].default_constraining_bijector_fn()
        except NotImplementedError:
          bijector = identity.Identity()

        if mean_field:
          prior_weight = None
        else:
          unconstrained_ones = tf.ones(
              shape=ps.concat([
                  posterior_batch_shape,
                  bijector.inverse_event_shape_tensor(
                      actual_event_shape)
              ], axis=0),
              dtype=tf.convert_to_tensor(value).dtype)

          prior_weight = tfp_util.TransformedVariable(
              initial_prior_weight * unconstrained_ones,
              bijector=sigmoid.Sigmoid(),
              name='prior_weight/{}/{}'.format(dist.name, param))

        # If the prior distribution was a tfd.Sample wrapping a base
        # distribution, we want to give every single sample in the prior its
        # own lambda and alpha value (rather than having a single lambda and
        # alpha).
        if isinstance(original_dist, sample.Sample):
          value = tf.reshape(
              value,
              ps.concat([
                  actual_dist.batch_shape_tensor(),
                  ps.ones(ps.rank_from_shape(original_dist.sample_shape)),
                  actual_event_shape
              ],
                        axis=0))
          value = tf.broadcast_to(
              value,
              ps.concat([posterior_batch_shape, actual_event_shape], axis=0))
        new_params_dict[param] = ASVIParameters(
            prior_weight=prior_weight,
            mean_field_parameter=tfp_util.TransformedVariable(
                value,
                bijector=bijector,
                name='mean_field_parameter/{}/{}'.format(dist.name, param)))

      param_dicts.append(new_params_dict)
  return param_dicts
Exemplo n.º 18
0
 def _default_event_space_bijector(self):
     # TODO(b/145620027) Finalize choice of bijector. Consider switching to
     # Chain([Softplus(), Log()]) to lighten the doubly-exponential right tail.
     return identity_bijector.Identity(validate_args=self.validate_args)
Exemplo n.º 19
0
def _asvi_convex_update_for_base_distribution(dist,
                                              mean_field=False,
                                              initial_prior_weight=0.5,
                                              sample_shape=None):
    """Creates a trainable surrogate for a (non-meta, non-joint) distribution."""
    posterior_batch_shape = dist.batch_shape_tensor()
    if sample_shape is not None:
        posterior_batch_shape = ps.concat([
            posterior_batch_shape,
            distribution_util.expand_to_vector(sample_shape)
        ],
                                          axis=0)

    temp_params_dict = {'name': _get_name(dist)}
    all_parameter_properties = dist.parameter_properties(dtype=dist.dtype)
    for param, prior_value in dist.parameters.items():
        if (param in (_NON_STATISTICAL_PARAMS + _NON_TRAINABLE_PARAMS)
                or prior_value is None):
            temp_params_dict[param] = prior_value
            continue

        param_properties = all_parameter_properties[param]
        try:
            bijector = param_properties.default_constraining_bijector_fn()
        except NotImplementedError:
            bijector = identity.Identity()

        param_shape = ps.concat([
            posterior_batch_shape,
            ps.shape(prior_value)[ps.rank(prior_value) -
                                  param_properties.event_ndims:]
        ],
                                axis=0)

        # Initialize the mean-field parameter as a (constrained) standard
        # normal sample.
        # pylint: disable=cell-var-from-loop
        # Safe because the state utils guarantee to either call `init_fn`
        # immediately upon yielding, or not at all.
        mean_field_parameter = yield trainable_state_util.Parameter(
            init_fn=lambda seed: (  # pylint: disable=g-long-lambda
                bijector.forward(
                    samplers.normal(shape=bijector.inverse_event_shape(
                        param_shape),
                                    seed=seed))),
            name='mean_field_parameter_{}_{}'.format(_get_name(dist), param),
            constraining_bijector=bijector)
        if mean_field:
            temp_params_dict[param] = mean_field_parameter
        else:
            prior_weight = yield trainable_state_util.Parameter(
                init_fn=lambda: tf.fill(  # pylint: disable=g-long-lambda
                    dims=param_shape,
                    value=tf.cast(initial_prior_weight,
                                  tf.convert_to_tensor(prior_value).dtype)),
                name='prior_weight_{}_{}'.format(_get_name(dist), param),
                constraining_bijector=sigmoid.Sigmoid())
            temp_params_dict[param] = prior_weight * prior_value + (
                (1. - prior_weight) * mean_field_parameter)
    # pylint: enable=cell-var-from-loop

    return type(dist)(**temp_params_dict)
Exemplo n.º 20
0
 def _default_event_space_bijector(self):
     # TODO(b/145620027) Finalize choice of bijector (consider one that
     # transforms away the heavy tails).
     return identity_bijector.Identity(validate_args=self.validate_args)
Exemplo n.º 21
0
 def _default_event_space_bijector(self):
     # TODO(b/145620027) Finalize choice of bijector.
     return identity_bijector.Identity(validate_args=self.validate_args)