コード例 #1
0
def _make_asvi_trainable_variables(prior,
                                   mean_field=False,
                                   initial_prior_weight=0.5):
  """Generates parameter dictionaries given a prior distribution and list."""
  with tf.name_scope('make_asvi_trainable_variables'):
    param_dicts = []
    prior_dists = prior._get_single_sample_distributions()  # pylint: disable=protected-access
    for dist in prior_dists:
      original_dist = dist.distribution if isinstance(dist, Root) else dist

      substituted_dist = _as_trainable_family(original_dist)

      # Grab the base distribution if it exists
      try:
        actual_dist = substituted_dist.distribution
      except AttributeError:
        actual_dist = substituted_dist

      new_params_dict = {}

      #  Build trainable ASVI representation for each distribution's parameters.
      parameter_properties = actual_dist.parameter_properties(
          dtype=actual_dist.dtype)
      sample_shape = tf.concat(
          [dist.batch_shape_tensor(),
           dist.event_shape_tensor()], axis=0)
      for param, value in actual_dist.parameters.items():
        if param in (_NON_STATISTICAL_PARAMS +
                     _NON_TRAINABLE_PARAMS) or value is None:
          continue
        try:
          bijector = parameter_properties[
              param].default_constraining_bijector_fn()
        except NotImplementedError:
          bijector = tfb.Identity()
        unconstrained_ones = tf.ones(
            shape=bijector.inverse_event_shape_tensor(
                parameter_properties[param].shape_fn(
                    sample_shape=sample_shape)),
            dtype=actual_dist.dtype)

        if mean_field:
          new_params_dict[param] = ASVIParameters(
              prior_weight=None,
              mean_field_parameter=tfp_util.TransformedVariable(
                  value,
                  bijector=bijector,
                  name='mean_field_parameter/{}/{}'.format(dist.name, param)))
        else:
          new_params_dict[param] = ASVIParameters(
              prior_weight=tfp_util.TransformedVariable(
                  initial_prior_weight * unconstrained_ones,
                  bijector=tfb.Sigmoid(),
                  name='prior_weight/{}/{}'.format(dist.name, param)),
              mean_field_parameter=tfp_util.TransformedVariable(
                  value,
                  bijector=bijector,
                  name='mean_field_parameter/{}/{}'.format(dist.name, param)))
      param_dicts.append(new_params_dict)
  return param_dicts
コード例 #2
0
def _build_posterior_for_one_parameter(param, batch_shape, seed):
    """Built a transformed-normal variational dist over a parameter's support."""

    # Build a trainable Normal distribution.
    initial_loc = sample_uniform_initial_state(param,
                                               init_sample_shape=batch_shape,
                                               return_constrained=False,
                                               seed=seed)
    loc = tf.Variable(initial_value=initial_loc, name=param.name + '_loc')
    scale = tfp_util.TransformedVariable(
        tf.fill(tf.shape(initial_loc),
                value=tf.constant(0.02, initial_loc.dtype),
                name=param.name + '_scale'), softplus_lib.Softplus())
    posterior_dist = normal_lib.Normal(loc=loc, scale=scale)

    # Ensure the `event_shape` of the variational distribution matches the
    # parameter.
    if (param.prior.event_shape.ndims is None
            or param.prior.event_shape.ndims > 0):
        posterior_dist = independent_lib.Independent(
            posterior_dist,
            reinterpreted_batch_ndims=param.prior.event_shape.ndims)

    # Transform to constrained parameter space.
    posterior_dist = transformed_distribution_lib.TransformedDistribution(
        posterior_dist, param.bijector, name='{}_posterior'.format(param.name))
    return posterior_dist
コード例 #3
0
def build_trainable_location_scale_distribution(initial_loc,
                                                initial_scale,
                                                event_ndims,
                                                distribution_fn=normal.Normal,
                                                validate_args=False,
                                                name=None):
    """Builds a variational distribution from a location-scale family.

  Args:
    initial_loc: Float `Tensor` initial location.
    initial_scale: Float `Tensor` initial scale.
    event_ndims: Integer `Tensor` number of event dimensions in `initial_loc`.
    distribution_fn: Optional constructor for a `tfd.Distribution` instance
      in a location-scale family. This should have signature `dist =
      distribution_fn(loc, scale, validate_args)`.
      Default value: `tfd.Normal`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e.,
        'build_trainable_location_scale_distribution').

  Returns:
    posterior_dist: A `tfd.Distribution` instance.
  """
    with tf.name_scope(name or 'build_trainable_location_scale_distribution'):
        dtype = dtype_util.common_dtype([initial_loc, initial_scale],
                                        dtype_hint=tf.float32)
        initial_loc = initial_loc * tf.ones(tf.shape(initial_scale),
                                            dtype=dtype)
        initial_scale = initial_scale * tf.ones_like(initial_loc)

        loc = tf.Variable(initial_value=initial_loc, name='loc')
        scale = tfp_util.TransformedVariable(initial_scale,
                                             softplus.Softplus(),
                                             name='scale')
        posterior_dist = distribution_fn(loc=loc,
                                         scale=scale,
                                         validate_args=validate_args)

        # Ensure the distribution has the desired number of event dimensions.
        static_event_ndims = tf.get_static_value(event_ndims)
        if static_event_ndims is None or static_event_ndims > 0:
            posterior_dist = independent.Independent(
                posterior_dist,
                reinterpreted_batch_ndims=event_ndims,
                validate_args=validate_args)

    return posterior_dist
コード例 #4
0
def build_trainable_linear_operator_tril(shape,
                                         scale_initializer=1e-2,
                                         diag_bijector=None,
                                         dtype=None,
                                         seed=None,
                                         name=None):
    """Build a trainable `LinearOperatorLowerTriangular` instance.

  Args:
    shape: Shape of the `LinearOperator`, equal to `[b0, ..., bn, d]`, where
      `b0...bn` are batch dimensions and `d` is the length of the diagonal.
    scale_initializer: Variables are initialized with samples from
      `Normal(0, scale_initializer)`.
    diag_bijector: Bijector to apply to the diagonal of the operator.
    dtype: `tf.dtype` of the `LinearOperator`.
    seed: Python integer to seed the random number generator.
    name: str, name for `tf.name_scope`.

  Returns:
    operator: Trainable instance of `tf.linalg.LinearOperatorLowerTriangular`.
  """
    with tf.name_scope(name or 'build_trainable_linear_operator_tril'):
        if dtype is None:
            dtype = dtype_util.common_dtype([scale_initializer],
                                            dtype_hint=tf.float32)

        scale_initializer = tf.convert_to_tensor(scale_initializer,
                                                 dtype=dtype)
        diag_bijector = diag_bijector or _DefaultScaleDiagonal()
        batch_shape, dim = ps.split(shape, num_or_size_splits=[-1, 1])

        scale_tril_bijector = fill_scale_tril.FillScaleTriL(
            diag_bijector, diag_shift=tf.zeros([], dtype=dtype))
        flat_initial_scale = samplers.normal(
            mean=0.,
            stddev=scale_initializer,
            shape=ps.concat([batch_shape, dim * (dim + 1) // 2], axis=0),
            seed=seed,
            dtype=dtype)
        return tf.linalg.LinearOperatorLowerTriangular(
            tril=tfp_util.TransformedVariable(
                scale_tril_bijector.forward(flat_initial_scale),
                bijector=scale_tril_bijector,
                name='tril'),
            is_non_singular=True)
コード例 #5
0
def build_trainable_linear_operator_diag(shape,
                                         scale_initializer=1e-2,
                                         diag_bijector=None,
                                         dtype=None,
                                         seed=None,
                                         name=None):
    """Build a trainable `LinearOperatorDiag` instance.

  Args:
    shape: Shape of the `LinearOperator`, equal to `[b0, ..., bn, d]`, where
      `b0...bn` are batch dimensions and `d` is the length of the diagonal.
    scale_initializer: Variables are initialized with samples from
      `Normal(0, scale_initializer)`.
    diag_bijector: Bijector to apply to the diagonal of the operator.
    dtype: `tf.dtype` of the `LinearOperator`.
    seed: Python integer to seed the random number generator.
    name: str, name for `tf.name_scope`.

  Returns:
    operator: Trainable instance of `tf.linalg.LinearOperatorDiag`.
  """
    with tf.name_scope(name or 'build_trainable_linear_operator_diag'):
        if dtype is None:
            dtype = dtype_util.common_dtype([scale_initializer],
                                            dtype_hint=tf.float32)
        scale_initializer = tf.convert_to_tensor(scale_initializer,
                                                 dtype=dtype)

        diag_bijector = diag_bijector or _DefaultScaleDiagonal()
        initial_scale_diag = samplers.normal(mean=0.,
                                             stddev=scale_initializer,
                                             shape=shape,
                                             dtype=dtype,
                                             seed=seed)
        return tf.linalg.LinearOperatorDiag(tfp_util.TransformedVariable(
            diag_bijector.forward(initial_scale_diag),
            bijector=diag_bijector,
            name='diag'),
                                            is_non_singular=True)
コード例 #6
0
def _make_asvi_trainable_variables(prior,
                                   mean_field=False,
                                   initial_prior_weight=0.5):
  """Generates parameter dictionaries given a prior distribution and list."""
  with tf.name_scope('make_asvi_trainable_variables'):
    param_dicts = []
    prior_dists = prior._get_single_sample_distributions()  # pylint: disable=protected-access
    for dist in prior_dists:
      original_dist = dist.distribution if isinstance(dist, Root) else dist

      substituted_dist = _as_trainable_family(original_dist)

      # Grab the base distribution if it exists
      try:
        actual_dist = substituted_dist.distribution
      except AttributeError:
        actual_dist = substituted_dist

      new_params_dict = {}

      #  Build trainable ASVI representation for each distribution's parameters.
      parameter_properties = actual_dist.parameter_properties(
          dtype=actual_dist.dtype)

      if isinstance(original_dist, sample.Sample):
        posterior_batch_shape = ps.concat([
            actual_dist.batch_shape_tensor(),
            distribution_util.expand_to_vector(original_dist.sample_shape)
        ], axis=0)
      else:
        posterior_batch_shape = actual_dist.batch_shape_tensor()

      for param, value in actual_dist.parameters.items():

        if param in (_NON_STATISTICAL_PARAMS +
                     _NON_TRAINABLE_PARAMS) or value is None:
          continue

        actual_event_shape = parameter_properties[param].shape_fn(
            actual_dist.event_shape_tensor())
        try:
          bijector = parameter_properties[
              param].default_constraining_bijector_fn()
        except NotImplementedError:
          bijector = identity.Identity()

        if mean_field:
          prior_weight = None
        else:
          unconstrained_ones = tf.ones(
              shape=ps.concat([
                  posterior_batch_shape,
                  bijector.inverse_event_shape_tensor(
                      actual_event_shape)
              ], axis=0),
              dtype=tf.convert_to_tensor(value).dtype)

          prior_weight = tfp_util.TransformedVariable(
              initial_prior_weight * unconstrained_ones,
              bijector=sigmoid.Sigmoid(),
              name='prior_weight/{}/{}'.format(dist.name, param))

        # If the prior distribution was a tfd.Sample wrapping a base
        # distribution, we want to give every single sample in the prior its
        # own lambda and alpha value (rather than having a single lambda and
        # alpha).
        if isinstance(original_dist, sample.Sample):
          value = tf.reshape(
              value,
              ps.concat([
                  actual_dist.batch_shape_tensor(),
                  ps.ones(ps.rank_from_shape(original_dist.sample_shape)),
                  actual_event_shape
              ],
                        axis=0))
          value = tf.broadcast_to(
              value,
              ps.concat([posterior_batch_shape, actual_event_shape], axis=0))
        new_params_dict[param] = ASVIParameters(
            prior_weight=prior_weight,
            mean_field_parameter=tfp_util.TransformedVariable(
                value,
                bijector=bijector,
                name='mean_field_parameter/{}/{}'.format(dist.name, param)))

      param_dicts.append(new_params_dict)
  return param_dicts
コード例 #7
0
def _asvi_convex_update_for_base_distribution(dist,
                                              mean_field,
                                              initial_prior_weight,
                                              sample_shape=None,
                                              variables=None,
                                              seed=None):
    """Creates a trainable surrogate for a (non-meta, non-joint) distribution."""
    if variables is None:
        variables = {}

    posterior_batch_shape = dist.batch_shape_tensor()
    if sample_shape is not None:
        posterior_batch_shape = ps.concat([
            posterior_batch_shape,
            distribution_util.expand_to_vector(sample_shape)
        ],
                                          axis=0)

    # Create variables backing each parameter, if needed.
    all_parameter_properties = dist.parameter_properties(dtype=dist.dtype)
    for param, prior_value in dist.parameters.items():
        if (param in variables
                or param in (_NON_STATISTICAL_PARAMS + _NON_TRAINABLE_PARAMS)
                or prior_value is None):
            continue

        param_properties = all_parameter_properties[param]
        try:
            bijector = param_properties.default_constraining_bijector_fn()
        except NotImplementedError:
            bijector = identity.Identity()

        param_shape = ps.concat([
            posterior_batch_shape,
            ps.shape(prior_value)[ps.rank(prior_value) -
                                  param_properties.event_ndims:]
        ],
                                axis=0)

        prior_weight = (
            None if mean_field  # pylint: disable=g-long-ternary
            else tfp_util.TransformedVariable(initial_value=tf.fill(
                dims=param_shape,
                value=tf.cast(initial_prior_weight,
                              tf.convert_to_tensor(prior_value).dtype)),
                                              bijector=sigmoid.Sigmoid(),
                                              name='prior_weight/{}/{}'.format(
                                                  _get_name(dist), param)))

        # Initialize the mean-field parameter as a (constrained) standard
        # normal sample.
        seed, param_seed = samplers.split_seed(seed)
        variables[param] = ASVIParameters(
            prior_weight=prior_weight,
            mean_field_parameter=tfp_util.TransformedVariable(
                initial_value=bijector.forward(
                    samplers.normal(
                        shape=bijector.inverse_event_shape(param_shape),
                        seed=param_seed)),
                bijector=bijector,
                name='mean_field_parameter/{}/{}'.format(
                    _get_name(dist), param)))

    temp_params_dict = {'name': _get_name(dist)}
    for param, prior_value in dist.parameters.items():
        if param in (_NON_STATISTICAL_PARAMS +
                     _NON_TRAINABLE_PARAMS) or prior_value is None:
            temp_params_dict[param] = prior_value
        else:
            if mean_field:
                temp_params_dict[param] = variables[param].mean_field_parameter
            else:
                temp_params_dict[param] = (
                    variables[param].prior_weight * prior_value +
                    ((1. - variables[param].prior_weight) *
                     variables[param].mean_field_parameter))
    return type(dist)(**temp_params_dict), variables
コード例 #8
0
def build_trainable_highway_flow(width,
                                 residual_fraction_initial_value=0.5,
                                 activation_fn=None,
                                 gate_first_n=None,
                                 seed=None,
                                 validate_args=False):
    """Builds a HighwayFlow parameterized by trainable variables.

  The variables are transformed to enforce the following parameter constraints:

  - `residual_fraction` is bounded between 0 and 1.
  - `upper_diagonal_weights_matrix` is a randomly initialized (lower) diagonal
     matrix with positive diagonal of size `width x width`.
  - `lower_diagonal_weights_matrix` is a randomly initialized lower diagonal
     matrix with ones on the diagonal of size `width x width`;
  - `bias` is a randomly initialized vector of size `width`.

  Args:
    width: Input dimension of the bijector.
    residual_fraction_initial_value: Initial value for gating parameter, must be
      between 0 and 1.
    activation_fn: Callable invertible activation function
      (e.g., `tf.nn.softplus`), or `None`.
    gate_first_n: Decides which part of the input should be gated (useful for
      example when using auxiliary variables).
    seed: Seed for random initialization of the weights.
    validate_args: Python `bool`. Whether to validate input with runtime
        assertions.
        Default value: `False`.

  Returns:
    trainable_highway_flow: The initialized bijector.
  """

    residual_fraction_initial_value = tf.convert_to_tensor(
        residual_fraction_initial_value,
        dtype_hint=tf.float32,
        name='residual_fraction_initial_value')
    dtype = residual_fraction_initial_value.dtype

    bias_seed, upper_seed, lower_seed = samplers.split_seed(seed, n=3)
    lower_bijector = tfb.Chain([
        tfb.TransformDiagonal(diag_bijector=tfb.Shift(1.)),
        tfb.Pad(paddings=[(1, 0), (0, 1)]),
        tfb.FillTriangular()
    ])
    unconstrained_lower_initial_values = samplers.normal(
        shape=lower_bijector.inverse_event_shape([width, width]),
        mean=0.,
        stddev=.01,
        seed=lower_seed)
    upper_bijector = tfb.FillScaleTriL(diag_bijector=tfb.Softplus(),
                                       diag_shift=None)
    unconstrained_upper_initial_values = samplers.normal(
        shape=upper_bijector.inverse_event_shape([width, width]),
        mean=0.,
        stddev=.01,
        seed=upper_seed)

    return HighwayFlow(residual_fraction=util.TransformedVariable(
        initial_value=residual_fraction_initial_value,
        bijector=tfb.Sigmoid(),
        dtype=dtype),
                       activation_fn=activation_fn,
                       bias=tf.Variable(samplers.normal((width, ),
                                                        mean=0.,
                                                        stddev=0.01,
                                                        seed=bias_seed),
                                        dtype=dtype),
                       upper_diagonal_weights_matrix=util.TransformedVariable(
                           initial_value=upper_bijector.forward(
                               unconstrained_upper_initial_values),
                           bijector=upper_bijector,
                           dtype=dtype),
                       lower_diagonal_weights_matrix=util.TransformedVariable(
                           initial_value=lower_bijector.forward(
                               unconstrained_lower_initial_values),
                           bijector=lower_bijector,
                           dtype=dtype),
                       gate_first_n=gate_first_n,
                       validate_args=validate_args)