def normal_generator(shape):
  shape = ps.convert_to_shape_tensor(shape, dtype=np.int32)
  loc = yield trainable_state_util.Parameter(
      init_fn=functools.partial(samplers.normal, shape=shape),
      name='loc')
  bij = tfb.Softplus()
  scale = yield trainable_state_util.Parameter(
      init_fn=lambda seed: bij.forward(samplers.normal(shape, seed=seed)),
      constraining_bijector=bij,
      name='scale')
  return tfd.Normal(loc=loc, scale=scale, validate_args=True)
def _trainable_linear_operator_full_matrix(shape,
                                           scale_initializer=1e-2,
                                           dtype=None,
                                           name=None):
    """Build a trainable `LinearOperatorFullMatrix` instance.

  Args:
    shape: Shape of the `LinearOperator`, equal to `[b0, ..., bn, h, w]`, where
      `b0...bn` are batch dimensions `h` and `w` are the height and width of the
      matrix represented by the `LinearOperator`.
    scale_initializer: Variables are initialized with samples from
      `Normal(0, scale_initializer)`.
    dtype: `tf.dtype` of the `LinearOperator`.
    name: str, name for `tf.name_scope`.
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.
  """
    with tf.name_scope(name or 'trainable_linear_operator_full_matrix'):
        if dtype is None:
            dtype = dtype_util.common_dtype([scale_initializer],
                                            dtype_hint=tf.float32)
        scale_initializer = tf.convert_to_tensor(scale_initializer, dtype)
        scale_matrix = yield trainable_state_util.Parameter(
            init_fn=functools.partial(samplers.normal,
                                      mean=0.,
                                      stddev=scale_initializer,
                                      shape=shape,
                                      dtype=dtype),
            name='scale_matrix')
        return tf.linalg.LinearOperatorFullMatrix(matrix=scale_matrix)
Example #3
0
def _asvi_surrogate_for_markov_chain(dist, build_nested_surrogate):
    """Builds a structured surrogate posterior for a Markov chain."""
    surrogate_prior = yield from build_nested_surrogate(
        dist.initial_state_prior)

    (transition_all_steps_init_fn,
     _) = trainable_state_util.as_stateless_builder(
         lambda: build_nested_surrogate(  # pylint: disable=g-long-lambda
             dist.transition_fn(
                 tf.range(dist.num_steps - 1),
                 dist.initial_state_prior.sample(
                     dist.num_steps - 1, seed=samplers.zeros_seed()))))()
    transition_params = yield trainable_state_util.Parameter(
        transition_all_steps_init_fn, name='markov_chain_transition_params')

    build_transition_one_step = trainable_state_util.as_stateless_builder(
        lambda step, state: build_nested_surrogate(  # pylint: disable=g-long-lambda
            dist.transition_fn(step, state)))

    def surrogate_transition_fn(step, state):
        _, one_step_apply_fn = build_transition_one_step(step, state)
        return one_step_apply_fn(
            tf.nest.map_structure(
                # Gather parameters for this specific step of the chain.
                lambda v: tf.gather(v, step, axis=0),
                transition_params))

    return markov_chain.MarkovChain(initial_state_prior=surrogate_prior,
                                    transition_fn=surrogate_transition_fn,
                                    num_steps=dist.num_steps,
                                    validate_args=dist.validate_args)
def yields_structured_parameter():
  dict_loc_scale = yield trainable_state_util.Parameter(
      init_fn=lambda: {'scale': tf.ones([2]), 'loc': tf.zeros([2])},
      name='dict_loc_scale',
      constraining_bijector=tfb.JointMap(
          {'scale': tfb.Softplus(), 'loc': tfb.Identity()}))
  return tfd.Normal(**dict_loc_scale)
def seed_generator():
  # Seed must be passed as kwarg.
  a = yield trainable_state_util.Parameter(
      functools.partial(samplers.normal, shape=[5]))
  # Seed must be passed positionally.
  b = yield trainable_state_util.Parameter(
      lambda my_seed: samplers.normal([], seed=my_seed))
  # Seed not accepted.
  c = yield trainable_state_util.Parameter(lambda: tf.zeros([3]))
  # Bare value in place of callable.
  d = yield trainable_state_util.Parameter(tf.ones([1, 1]))
  # Distribution sample method.
  e = yield trainable_state_util.Parameter(tfd.LogNormal([-1., 1.], 1.).sample)
  return tfd.JointDistributionSequential(
      [tfd.Deterministic(a), tfd.Deterministic(b), tfd.Deterministic(c),
       tfd.Deterministic(d), tfd.Deterministic(e)])
def _trainable_linear_operator_tril(shape,
                                    scale_initializer=1e-2,
                                    diag_bijector=None,
                                    dtype=None,
                                    name=None):
    """Build a trainable `LinearOperatorLowerTriangular` instance.

  Args:
    shape: Shape of the `LinearOperator`, equal to `[b0, ..., bn, d]`, where
      `b0...bn` are batch dimensions and `d` is the length of the diagonal.
    scale_initializer: Variables are initialized with samples from
      `Normal(0, scale_initializer)`.
    diag_bijector: Bijector to apply to the diagonal of the operator.
    dtype: `tf.dtype` of the `LinearOperator`.
    name: str, name for `tf.name_scope`.
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.
  """
    with tf.name_scope(name or 'trainable_linear_operator_tril'):
        if dtype is None:
            dtype = dtype_util.common_dtype([scale_initializer],
                                            dtype_hint=tf.float32)

        scale_initializer = tf.convert_to_tensor(scale_initializer,
                                                 dtype=dtype)
        diag_bijector = diag_bijector or _DefaultScaleDiagonal()
        batch_shape, dim = ps.split(shape, num_or_size_splits=[-1, 1])

        scale_tril_bijector = fill_scale_tril.FillScaleTriL(
            diag_bijector, diag_shift=tf.zeros([], dtype=dtype))
        scale_tril = yield trainable_state_util.Parameter(
            init_fn=lambda seed: scale_tril_bijector(  # pylint: disable=g-long-lambda
                samplers.normal(mean=0.,
                                stddev=scale_initializer,
                                shape=ps.concat(
                                    [batch_shape, dim * (dim + 1) // 2],
                                    axis=0),
                                seed=seed,
                                dtype=dtype)),
            name='scale_tril',
            constraining_bijector=scale_tril_bijector)
        return tf.linalg.LinearOperatorLowerTriangular(tril=scale_tril,
                                                       is_non_singular=True)
 def convert_operator(path, op):
     if isinstance(op, tf.linalg.LinearOperator):
         return op
     if len(set(path)) == 1:  # for operators on the diagonal
         shape = ps.concat([batch_shape, [block_dims[path[0]]]], axis=0)
     else:
         shape = ps.concat(
             [batch_shape, [block_dims[path[0]], block_dims[path[1]]]],
             axis=0)
     if op in _OPERATOR_COROUTINES:
         operator = yield from _OPERATOR_COROUTINES[op](shape=shape,
                                                        dtype=dtype)
     else:  # Custom stateless constructor.
         init_fn, apply_fn = op(shape=shape, dtype=dtype)
         raw_params = yield trainable_state_util.Parameter(init_fn)
         operator = apply_fn(raw_params)
     return operator
Example #8
0
 def generate_shift_bijector(s):
     x = yield trainable_state_util.Parameter(
         functools.partial(initial_unconstrained_loc_fn,
                           ps.concat([batch_shape, [s]], axis=0),
                           dtype=base_dtype))
     return shift.Shift(x)
Example #9
0
def _asvi_convex_update_for_base_distribution(dist,
                                              mean_field=False,
                                              initial_prior_weight=0.5,
                                              sample_shape=None):
    """Creates a trainable surrogate for a (non-meta, non-joint) distribution."""
    posterior_batch_shape = dist.batch_shape_tensor()
    if sample_shape is not None:
        posterior_batch_shape = ps.concat([
            posterior_batch_shape,
            distribution_util.expand_to_vector(sample_shape)
        ],
                                          axis=0)

    temp_params_dict = {'name': _get_name(dist)}
    all_parameter_properties = dist.parameter_properties(dtype=dist.dtype)
    for param, prior_value in dist.parameters.items():
        if (param in (_NON_STATISTICAL_PARAMS + _NON_TRAINABLE_PARAMS)
                or prior_value is None):
            temp_params_dict[param] = prior_value
            continue

        param_properties = all_parameter_properties[param]
        try:
            bijector = param_properties.default_constraining_bijector_fn()
        except NotImplementedError:
            bijector = identity.Identity()

        param_shape = ps.concat([
            posterior_batch_shape,
            ps.shape(prior_value)[ps.rank(prior_value) -
                                  param_properties.event_ndims:]
        ],
                                axis=0)

        # Initialize the mean-field parameter as a (constrained) standard
        # normal sample.
        # pylint: disable=cell-var-from-loop
        # Safe because the state utils guarantee to either call `init_fn`
        # immediately upon yielding, or not at all.
        mean_field_parameter = yield trainable_state_util.Parameter(
            init_fn=lambda seed: (  # pylint: disable=g-long-lambda
                bijector.forward(
                    samplers.normal(shape=bijector.inverse_event_shape(
                        param_shape),
                                    seed=seed))),
            name='mean_field_parameter_{}_{}'.format(_get_name(dist), param),
            constraining_bijector=bijector)
        if mean_field:
            temp_params_dict[param] = mean_field_parameter
        else:
            prior_weight = yield trainable_state_util.Parameter(
                init_fn=lambda: tf.fill(  # pylint: disable=g-long-lambda
                    dims=param_shape,
                    value=tf.cast(initial_prior_weight,
                                  tf.convert_to_tensor(prior_value).dtype)),
                name='prior_weight_{}_{}'.format(_get_name(dist), param),
                constraining_bijector=sigmoid.Sigmoid())
            temp_params_dict[param] = prior_weight * prior_value + (
                (1. - prior_weight) * mean_field_parameter)
    # pylint: enable=cell-var-from-loop

    return type(dist)(**temp_params_dict)
Example #10
0
def _asvi_surrogate_for_distribution(dist,
                                     base_distribution_surrogate_fn,
                                     prior_substitution_rules,
                                     surrogate_rules,
                                     sample_shape=None):
    """Recursively creates ASVI surrogates, and creates new variables if needed.

  Args:
    dist: a `tfd.Distribution` instance.
    base_distribution_surrogate_fn: Callable to build a surrogate posterior
      for a 'base' (non-meta and non-joint) distribution, with signature
      `surrogate_posterior, variables = base_distribution_fn(
      dist, sample_shape=None, variables=None, seed=None)`.
    prior_substitution_rules: Iterable of substitution rules applied to the
      prior before constructing a surrogate. Each rule is a `(condition,
      substitution_fn)` tuple; these are checked in order and *all* applicable
      substitutions are made. The `condition` may be either a class or a
      callable returning a boolean (for example, `tfd.Normal` or, equivalently,
      `lambda dist: isinstance(dist, tfd.Normal)`). The `substitution_fn` should
      have signature `new_dist = substitution_fn(dist)`.
    surrogate_rules: Iterable of special-purpose rules to create surrogates
      for specific distribution types. Each rule is a `(condition,
      surrogate_fn)` tuple; these are checked in order and the first applicable
      `surrogate_fn` is used. The `condition` may be either a class or a
      callable returning a boolean (for example, `tfd.Normal` or, equivalently,
      `lambda dist: isinstance(dist, tfd.Normal)`). The `surrogate_fn` should
      have signature `surrogate_posterior, variables = surrogate_fn(dist,
      build_nested_surrogate_fn, sample_shape=None, variables=None, seed=None)`.
    sample_shape: Optional `Tensor` shape of samples drawn from `dist` by
      `tfd.Sample` wrappers. If not `None`, the surrogate's event will include
      independent sample dimensions, i.e., it will have event shape
      `concat([sample_shape, dist.event_shape], axis=0)`.
      Default value: `None`.
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.
  """
    dist_name = _get_name(dist)  # Attempt to preserve the original name.
    dist = _as_substituted_distribution(dist, prior_substitution_rules)
    # Apply the first surrogate rule that matches this distribution.
    surrogate_posterior = None
    for condition, surrogate_fn in surrogate_rules:
        if _satisfies_condition(dist, condition):
            # The surrogate fn may be a generator (internal interface) or a stateless
            # trainable builder returning `(init_fn, apply_fn)` (external interface).
            maybe_gen = surrogate_fn(
                dist,
                build_nested_surrogate=functools.partial(
                    _asvi_surrogate_for_distribution,
                    base_distribution_surrogate_fn=
                    base_distribution_surrogate_fn,
                    prior_substitution_rules=prior_substitution_rules,
                    surrogate_rules=surrogate_rules,
                    sample_shape=sample_shape),
                sample_shape=sample_shape)
            if inspect.isgenerator(maybe_gen):
                surrogate_posterior = yield from maybe_gen
            else:
                init_fn, apply_fn = maybe_gen
                params = yield trainable_state_util.Parameter(init_fn)
                surrogate_posterior = apply_fn(params)
            break
    if surrogate_posterior is None:
        if (hasattr(dist, 'distribution') and
                # Transformed dists not handled above are treated as base distributions.
                not isinstance(
                    dist, transformed_distribution.TransformedDistribution)):
            raise ValueError(
                'None of the provided substitution rules matched meta-distribution: '
                '`{}`.'.format(dist))
        else:
            surrogate_posterior = yield from base_distribution_surrogate_fn(
                dist=dist, sample_shape=sample_shape)
    return _set_name(surrogate_posterior, dist_name)
def yields_non_callable_init_fn():
  yield trainable_state_util.Parameter(0.)
Example #12
0
def _make_trainable(cls,
                    initial_parameters=None,
                    batch_and_event_shape=(),
                    parameter_dtype=tf.float32,
                    **init_kwargs):
    """Constructs a distribution or bijector instance with trainable parameters.

  This is a convenience method that instantiates a class with trainable
  parameters. Parameters are randomly initialized, and transformed to enforce
  any domain constraints. This method assumes that the class exposes a
  `parameter_properties` method annotating its trainable parameters, and that
  the caller provides any additional (non-trainable) arguments required by the
  class.

  Args:
    cls: Python class that implements `cls.parameter_properties()`, e.g., a TFP
      distribution (`tfd.Normal`) or bijector (`tfb.Scale`).
    initial_parameters: a dictionary containing initial values for some or
      all of the parameters to `cls`, OR a Python `callable` with signature
      `value = parameter_init_fn(parameter_name, shape, dtype, seed,
      constraining_bijector)`. If a dictionary is provided, any parameters not
      specified will be initialized to a random value in their domain.
      Default value: `None` (equivalent to `{}`; all parameters are
        initialized randomly).
    batch_and_event_shape: Optional int `Tensor` desired shape of samples
      (for distributions) or inputs (for bijectors), used to determine the shape
      of the trainable parameters.
      Default value: `()`.
    parameter_dtype: Optional float `dtype` for trainable variables.
    **init_kwargs: Additional keyword arguments passed to `cls.__init__()` to
      specify any non-trainable parameters. If a value is passed for
      an otherwise-trainable parameter---for example,
      `trainable(tfd.Normal, scale=1.)`---it will be taken as a fixed value and
      no variable will be constructed for that parameter.
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.

  #### Example

  Suppose we want to fit a normal distribution to observed data. We could
  of course just examine the empirical mean and standard deviation of the data:

  ```python
  samples = [4.57, 6.37, 5.93, 7.98, 2.03, 3.59, 8.55, 3.45, 5.06, 6.44]
  model = tfd.Normal(
    loc=tf.reduce_mean(samples),  # ==> 5.40
    scale=tf.math.reduce_std(sample))  # ==> 1.95
  ```

  and this would be a very sensible approach. But that's boring, so instead,
  let's do way more work to get the same result. We'll build a trainable normal
  distribution, and explicitly optimize to find the maximum-likelihood estimate
  for the parameters given our data:

  ${minimize_example_code}

  In this trivial case, doing the explicit optimization has few advantages over
  the first approach in which we simply matched the empirical moments of the
  data. However, trainable distributions are useful more generally. For example,
  they can enable maximum-likelihood estimation of distributions when a
  moment-matching estimator is not available, and they can also serve as
  surrogate posteriors in variational inference.

  """

    # Attempt to set a name scope using the name of the object we're about to
    # create, so that the variables we create are easy to identity.
    name_arg = _get_arg_value(arg_name='name',
                              f=cls.__init__,
                              kwargs=init_kwargs)
    with tf.name_scope(((name_arg + '_') if name_arg else '') +
                       'trainable_variables'):

        # Canonicalize initial parameter specification as `parameter_init_fn`.
        if initial_parameters is None:
            initial_parameters = {}
        parameter_init_fn = initial_parameters
        if not callable(parameter_init_fn):
            parameter_init_fn = _default_parameter_init_fn(initial_parameters)

        # Create a trainable variable for each parameter.
        for parameter_name, properties in cls.parameter_properties(
                dtype=parameter_dtype).items():
            if parameter_name in init_kwargs:  # Prefer user-provided values.
                continue
            if not (properties.is_tensor and properties.is_preferred):
                continue
            if properties.specifies_shape or (properties.event_ndims is None):
                continue

            parameter_shape = properties.shape_fn(batch_and_event_shape)
            constraining_bijector = properties.default_constraining_bijector_fn(
            )

            init_kwargs[parameter_name] = yield trainable_state_util.Parameter(
                init_fn=functools.partial(
                    parameter_init_fn,
                    parameter_name,
                    shape=parameter_shape,
                    dtype=parameter_dtype,
                    constraining_bijector=constraining_bijector),
                constraining_bijector=constraining_bijector,
                name=parameter_name)

    return cls(**init_kwargs)