Ejemplo n.º 1
0
 def _default_event_space_bijector(self):
     # TODO(b/145620027) Finalize choice of bijector.
     return sigmoid_bijector.Sigmoid(low=tf.constant(-np.pi,
                                                     dtype=self.dtype),
                                     high=tf.constant(np.pi,
                                                      dtype=self.dtype),
                                     validate_args=self.validate_args)
Ejemplo n.º 2
0
 def _negative_concentration_bijector(self):
     # Constructed dynamically so that `loc + scale / concentration` is
     # tape-safe.
     high = self.loc + tf.math.abs(self.scale / self.concentration)
     return sigmoid_bijector.Sigmoid(low=self.loc,
                                     high=high,
                                     validate_args=self.validate_args)
Ejemplo n.º 3
0
  def __init__(self,
               temperature,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="RelaxedBernoulli"):
    """Construct RelaxedBernoulli distributions.

    Args:
      temperature: An 0-D `Tensor`, representing the temperature
        of a set of RelaxedBernoulli distributions. The temperature should be
        positive.
      logits: An N-D `Tensor` representing the log-odds
        of a positive event. Each entry in the `Tensor` parametrizes
        an independent RelaxedBernoulli distribution where the probability of an
        event is sigmoid(logits). Only one of `logits` or `probs` should be
        passed in.
      probs: An N-D `Tensor` representing the probability of a positive event.
        Each entry in the `Tensor` parameterizes an independent Bernoulli
        distribution. Only one of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: If both `probs` and `logits` are passed, or if neither.
    """
    parameters = dict(locals())
    with tf.compat.v1.name_scope(
        name, values=[logits, probs, temperature]) as name:
      dtype = dtype_util.common_dtype([logits, probs, temperature], tf.float32)
      self._temperature = tf.convert_to_tensor(
          value=temperature, name="temperature", dtype=dtype)
      if validate_args:
        with tf.control_dependencies(
            [tf.compat.v1.assert_positive(temperature)]):
          self._temperature = tf.identity(self._temperature)
      self._logits, self._probs = distribution_util.get_logits_and_probs(
          logits=logits, probs=probs, validate_args=validate_args, dtype=dtype)
      super(RelaxedBernoulli, self).__init__(
          distribution=logistic.Logistic(
              self._logits / self._temperature,
              1. / self._temperature,
              validate_args=validate_args,
              allow_nan_stats=allow_nan_stats,
              name=name + "/Logistic"),
          bijector=sigmoid_bijector.Sigmoid(validate_args=validate_args),
          validate_args=validate_args,
          name=name)
    self._parameters = parameters
Ejemplo n.º 4
0
 def _default_event_space_bijector(self):
     # TODO(b/145620027) Finalize choice of bijector.
     return chain_bijector.Chain([
         shift_bijector.Shift(shift=-np.pi,
                              validate_args=self.validate_args),
         scale_bijector.Scale(scale=2. * np.pi,
                              validate_args=self.validate_args),
         sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
     ],
                                 validate_args=self.validate_args)
Ejemplo n.º 5
0
 def _default_event_space_bijector(self):
   if tensor_util.is_ref(self.low) or tensor_util.is_ref(self.high):
     scale = DeferredTensor(self.high, lambda x: x - self.low)
   else:
     scale = self.high - self.low
   return chain_bijector.Chain([
       shift_bijector.Shift(shift=self.low, validate_args=self.validate_args),
       scale_bijector.Scale(scale=scale, validate_args=self.validate_args),
       sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
   ], validate_args=self.validate_args)
Ejemplo n.º 6
0
 def _transformed_logistic(self):
     logistic_scale = tf.math.reciprocal(self._temperature)
     logits_parameter = self._logits_parameter_no_checks()
     logistic_loc = logits_parameter * logistic_scale
     return transformed_distribution.TransformedDistribution(
         distribution=logistic.Logistic(
             logistic_loc,
             logistic_scale,
             allow_nan_stats=self.allow_nan_stats),
         bijector=sigmoid_bijector.Sigmoid())
Ejemplo n.º 7
0
 def _default_event_space_bijector(self):
     low = tfp_util.DeferredTensor(self.low, lambda x: x)
     scale = tfp_util.DeferredTensor(self.high, lambda x: x - self.low)
     return chain_bijector.Chain([
         shift_bijector.Shift(shift=low, validate_args=self.validate_args),
         scale_bijector.Scale(scale=scale,
                              validate_args=self.validate_args),
         sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
     ],
                                 validate_args=self.validate_args)
Ejemplo n.º 8
0
 def _negative_concentration_bijector(self):
   # Constructed dynamically so that `scale * reciprocal(concentration)` is
   # tape-safe.
   return chain_bijector.Chain([
       shift_bijector.Shift(shift=self.loc, validate_args=self.validate_args),
       # TODO(b/146568897): Resolve numerical issues by implementing a new
       # bijector instead of multiplying `scale` by `(1. - 1e-6)`.
       scale_bijector.Scale(
           scale=-(self.scale *
                   tf.math.reciprocal(self.concentration) * (1. - 1e-6)),
           validate_args=self.validate_args),
       sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
   ], validate_args=self.validate_args)
Ejemplo n.º 9
0
 def _default_event_space_bijector(self):
   # TODO(b/146568897): Resolve numerical issues by implementing a new bijector
   # instead of multiplying `scale` by `(1. - 1e-6)`.
   if tensor_util.is_ref(self.low) or tensor_util.is_ref(self.high):
     scale = DeferredTensor(
         self.high,
         lambda x: (x - self.low) * (1. - 1e-6),
         shape=tf.broadcast_static_shape(self.high.shape, self.low.shape))
   else:
     scale = (self.high - self.low) * (1. - 1e-6)
   return chain_bijector.Chain([
       shift_bijector.Shift(shift=self.low, validate_args=self.validate_args),
       scale_bijector.Scale(scale=scale, validate_args=self.validate_args),
       sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
   ], validate_args=self.validate_args)
Ejemplo n.º 10
0
    def __init__(self,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='LogitNormal'):
        """Construct a logit-normal distribution.

    The LogititNormal distribution models positive-valued random variables whose
    logit (i.e., sigmoid_inverse, i.e., `log(p) - log1p(-p)`) is normally
    distributed with mean `loc` and standard deviation `scale`. It is
    constructed as the sigmoid transformation, (i.e., `1 / (1 + exp(-x))`) of a
    Normal distribution.

    Args:
      loc: Floating-point `Tensor`; the mean of the underlying
        Normal distribution(s). Must broadcast with `scale`.
      scale: Floating-point `Tensor`; the stddev of the underlying
        Normal distribution(s). Must broadcast with `loc`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            super(LogitNormal,
                  self).__init__(distribution=normal.Normal(loc=loc,
                                                            scale=scale),
                                 bijector=sigmoid_bijector.Sigmoid(),
                                 validate_args=validate_args,
                                 parameters=parameters,
                                 name=name)
Ejemplo n.º 11
0
 def _default_event_space_bijector(self):
     return sigmoid_bijector.Sigmoid(low=self.low,
                                     high=self.high,
                                     validate_args=self.validate_args)
Ejemplo n.º 12
0
 def _default_event_space_bijector(self):
   return sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
Ejemplo n.º 13
0
def _asvi_convex_update_for_base_distribution(dist,
                                              mean_field=False,
                                              initial_prior_weight=0.5,
                                              sample_shape=None):
    """Creates a trainable surrogate for a (non-meta, non-joint) distribution."""
    posterior_batch_shape = dist.batch_shape_tensor()
    if sample_shape is not None:
        posterior_batch_shape = ps.concat([
            posterior_batch_shape,
            distribution_util.expand_to_vector(sample_shape)
        ],
                                          axis=0)

    temp_params_dict = {'name': _get_name(dist)}
    all_parameter_properties = dist.parameter_properties(dtype=dist.dtype)
    for param, prior_value in dist.parameters.items():
        if (param in (_NON_STATISTICAL_PARAMS + _NON_TRAINABLE_PARAMS)
                or prior_value is None):
            temp_params_dict[param] = prior_value
            continue

        param_properties = all_parameter_properties[param]
        try:
            bijector = param_properties.default_constraining_bijector_fn()
        except NotImplementedError:
            bijector = identity.Identity()

        param_shape = ps.concat([
            posterior_batch_shape,
            ps.shape(prior_value)[ps.rank(prior_value) -
                                  param_properties.event_ndims:]
        ],
                                axis=0)

        # Initialize the mean-field parameter as a (constrained) standard
        # normal sample.
        # pylint: disable=cell-var-from-loop
        # Safe because the state utils guarantee to either call `init_fn`
        # immediately upon yielding, or not at all.
        mean_field_parameter = yield trainable_state_util.Parameter(
            init_fn=lambda seed: (  # pylint: disable=g-long-lambda
                bijector.forward(
                    samplers.normal(shape=bijector.inverse_event_shape(
                        param_shape),
                                    seed=seed))),
            name='mean_field_parameter_{}_{}'.format(_get_name(dist), param),
            constraining_bijector=bijector)
        if mean_field:
            temp_params_dict[param] = mean_field_parameter
        else:
            prior_weight = yield trainable_state_util.Parameter(
                init_fn=lambda: tf.fill(  # pylint: disable=g-long-lambda
                    dims=param_shape,
                    value=tf.cast(initial_prior_weight,
                                  tf.convert_to_tensor(prior_value).dtype)),
                name='prior_weight_{}_{}'.format(_get_name(dist), param),
                constraining_bijector=sigmoid.Sigmoid())
            temp_params_dict[param] = prior_weight * prior_value + (
                (1. - prior_weight) * mean_field_parameter)
    # pylint: enable=cell-var-from-loop

    return type(dist)(**temp_params_dict)
Ejemplo n.º 14
0
    def __init__(self,
                 temperature,
                 logits=None,
                 probs=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='RelaxedBernoulli'):
        """Construct RelaxedBernoulli distributions.

    Args:
      temperature: A `Tensor`, representing the temperature of a set of
        RelaxedBernoulli distributions. The temperature values should be
        positive.
      logits: An N-D `Tensor` representing the log-odds
        of a positive event. Each entry in the `Tensor` parametrizes
        an independent RelaxedBernoulli distribution where the probability of an
        event is sigmoid(logits). Only one of `logits` or `probs` should be
        passed in.
      probs: An N-D `Tensor` representing the probability of a positive event.
        Each entry in the `Tensor` parameterizes an independent Bernoulli
        distribution. Only one of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: If both `probs` and `logits` are passed, or if neither.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([logits, probs, temperature],
                                            tf.float32)

            self._temperature = tensor_util.convert_nonref_to_tensor(
                temperature, name='temperature', dtype=dtype)
            self._probs = tensor_util.convert_nonref_to_tensor(probs,
                                                               name='probs',
                                                               dtype=dtype)
            self._logits = tensor_util.convert_nonref_to_tensor(logits,
                                                                name='logits',
                                                                dtype=dtype)

            if logits is None:
                logits_parameter = tfp_util.DeferredTensor(
                    lambda x: tf.math.log(x) - tf.math.log1p(-x), self._probs)
            else:
                logits_parameter = self._logits

            shape = tf.broadcast_static_shape(logits_parameter.shape,
                                              self._temperature.shape)

            logistic_scale = tfp_util.DeferredTensor(tf.math.reciprocal,
                                                     self._temperature)
            logistic_loc = tfp_util.DeferredTensor(
                lambda x: x * logistic_scale, logits_parameter, shape=shape)

            self._transformed_logistic = (
                transformed_distribution.TransformedDistribution(
                    distribution=logistic.Logistic(
                        logistic_loc,
                        logistic_scale,
                        allow_nan_stats=allow_nan_stats,
                        name=name + '/Logistic'),
                    bijector=sigmoid_bijector.Sigmoid()))

            super(RelaxedBernoulli,
                  self).__init__(dtype=dtype,
                                 reparameterization_type=reparameterization.
                                 FULLY_REPARAMETERIZED,
                                 validate_args=validate_args,
                                 allow_nan_stats=allow_nan_stats,
                                 parameters=parameters,
                                 name=name)
Ejemplo n.º 15
0
    def __init__(self,
                 loc,
                 scale,
                 num_probit_terms_approx=2,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='LogitNormal'):
        """Construct a logit-normal distribution.

    The LogitNormal distribution models random variables between 0 and 1 whose
    logit (i.e., sigmoid_inverse, i.e., `log(p) - log1p(-p)`) is normally
    distributed with mean `loc` and standard deviation `scale`. It is
    constructed as the sigmoid transformation, (i.e., `1 / (1 + exp(-x))`) of a
    Normal distribution.

    Args:
      loc: Floating-point `Tensor`; the mean of the underlying
        Normal distribution(s). Must broadcast with `scale`.
      scale: Floating-point `Tensor`; the stddev of the underlying
        Normal distribution(s). Must broadcast with `loc`.
      num_probit_terms_approx: The `k` used in the approximation,
        `sigmoid(x) approx= sum_i^k p[k,i] Normal(0, c[k, i]).cdf(x)`
        where `sum_i^k p[k,i]=1` and `p[k,i],c[k,i] > 0`
        [(Monahan and Stefanski, 1989)][1] and used in `mean_*_approx` functions
        [(Owen, 1980)][2]. Must be a python scalar integer between `1` and `8`
        (inclusive). Using `num_probit_terms_approx=2` should result in
        `mean_approx` error not exceeding `10**-4`.
        Default value: `2`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    #### References

    [1]: Monahan, John H., and Leonard A. Stefanski. Normal scale mixture
         approximations to the logistic distribution with applications. North
         Carolina State University. Dept. of Statistics, 1989.
         http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.154.5032
    [2]: Owen, Donald Bruce. "A table of normal integrals: A table."
         Communications in Statistics-Simulation and Computation 9.4 (1980):
         389-419.
         https://www.tandfonline.com/doi/abs/10.1080/03610918008812164
    """
        parameters = dict(locals())
        num_probit_terms_approx = int(num_probit_terms_approx)
        if num_probit_terms_approx < 1 or num_probit_terms_approx > 8:
            raise ValueError(
                'Argument `num_probit_terms_approx` must be an integer between '
                '`1` and `8` (inclusive).')
        self._num_probit_terms_approx = num_probit_terms_approx
        with tf.name_scope(name) as name:
            super(LogitNormal,
                  self).__init__(distribution=normal_lib.Normal(loc=loc,
                                                                scale=scale),
                                 bijector=sigmoid_bijector.Sigmoid(),
                                 validate_args=validate_args,
                                 parameters=parameters,
                                 name=name)
Ejemplo n.º 16
0
 def _default_event_space_bijector(self):
     # TODO(b/145620027) Finalize choice of bijector.
     return sigmoid_bijector.Sigmoid(validate_args=self.validate_args)
Ejemplo n.º 17
0
def _asvi_convex_update_for_base_distribution(dist,
                                              mean_field,
                                              initial_prior_weight,
                                              sample_shape=None,
                                              variables=None,
                                              seed=None):
    """Creates a trainable surrogate for a (non-meta, non-joint) distribution."""
    if variables is None:
        variables = {}

    posterior_batch_shape = dist.batch_shape_tensor()
    if sample_shape is not None:
        posterior_batch_shape = ps.concat([
            posterior_batch_shape,
            distribution_util.expand_to_vector(sample_shape)
        ],
                                          axis=0)

    # Create variables backing each parameter, if needed.
    all_parameter_properties = dist.parameter_properties(dtype=dist.dtype)
    for param, prior_value in dist.parameters.items():
        if (param in variables
                or param in (_NON_STATISTICAL_PARAMS + _NON_TRAINABLE_PARAMS)
                or prior_value is None):
            continue

        param_properties = all_parameter_properties[param]
        try:
            bijector = param_properties.default_constraining_bijector_fn()
        except NotImplementedError:
            bijector = identity.Identity()

        param_shape = ps.concat([
            posterior_batch_shape,
            ps.shape(prior_value)[ps.rank(prior_value) -
                                  param_properties.event_ndims:]
        ],
                                axis=0)

        prior_weight = (
            None if mean_field  # pylint: disable=g-long-ternary
            else tfp_util.TransformedVariable(initial_value=tf.fill(
                dims=param_shape,
                value=tf.cast(initial_prior_weight,
                              tf.convert_to_tensor(prior_value).dtype)),
                                              bijector=sigmoid.Sigmoid(),
                                              name='prior_weight/{}/{}'.format(
                                                  _get_name(dist), param)))

        # Initialize the mean-field parameter as a (constrained) standard
        # normal sample.
        seed, param_seed = samplers.split_seed(seed)
        variables[param] = ASVIParameters(
            prior_weight=prior_weight,
            mean_field_parameter=tfp_util.TransformedVariable(
                initial_value=bijector.forward(
                    samplers.normal(
                        shape=bijector.inverse_event_shape(param_shape),
                        seed=param_seed)),
                bijector=bijector,
                name='mean_field_parameter/{}/{}'.format(
                    _get_name(dist), param)))

    temp_params_dict = {'name': _get_name(dist)}
    for param, prior_value in dist.parameters.items():
        if param in (_NON_STATISTICAL_PARAMS +
                     _NON_TRAINABLE_PARAMS) or prior_value is None:
            temp_params_dict[param] = prior_value
        else:
            if mean_field:
                temp_params_dict[param] = variables[param].mean_field_parameter
            else:
                temp_params_dict[param] = (
                    variables[param].prior_weight * prior_value +
                    ((1. - variables[param].prior_weight) *
                     variables[param].mean_field_parameter))
    return type(dist)(**temp_params_dict), variables
Ejemplo n.º 18
0
def _make_asvi_trainable_variables(prior,
                                   mean_field=False,
                                   initial_prior_weight=0.5):
  """Generates parameter dictionaries given a prior distribution and list."""
  with tf.name_scope('make_asvi_trainable_variables'):
    param_dicts = []
    prior_dists = prior._get_single_sample_distributions()  # pylint: disable=protected-access
    for dist in prior_dists:
      original_dist = dist.distribution if isinstance(dist, Root) else dist

      substituted_dist = _as_trainable_family(original_dist)

      # Grab the base distribution if it exists
      try:
        actual_dist = substituted_dist.distribution
      except AttributeError:
        actual_dist = substituted_dist

      new_params_dict = {}

      #  Build trainable ASVI representation for each distribution's parameters.
      parameter_properties = actual_dist.parameter_properties(
          dtype=actual_dist.dtype)

      if isinstance(original_dist, sample.Sample):
        posterior_batch_shape = ps.concat([
            actual_dist.batch_shape_tensor(),
            distribution_util.expand_to_vector(original_dist.sample_shape)
        ], axis=0)
      else:
        posterior_batch_shape = actual_dist.batch_shape_tensor()

      for param, value in actual_dist.parameters.items():

        if param in (_NON_STATISTICAL_PARAMS +
                     _NON_TRAINABLE_PARAMS) or value is None:
          continue

        actual_event_shape = parameter_properties[param].shape_fn(
            actual_dist.event_shape_tensor())
        try:
          bijector = parameter_properties[
              param].default_constraining_bijector_fn()
        except NotImplementedError:
          bijector = identity.Identity()

        if mean_field:
          prior_weight = None
        else:
          unconstrained_ones = tf.ones(
              shape=ps.concat([
                  posterior_batch_shape,
                  bijector.inverse_event_shape_tensor(
                      actual_event_shape)
              ], axis=0),
              dtype=tf.convert_to_tensor(value).dtype)

          prior_weight = tfp_util.TransformedVariable(
              initial_prior_weight * unconstrained_ones,
              bijector=sigmoid.Sigmoid(),
              name='prior_weight/{}/{}'.format(dist.name, param))

        # If the prior distribution was a tfd.Sample wrapping a base
        # distribution, we want to give every single sample in the prior its
        # own lambda and alpha value (rather than having a single lambda and
        # alpha).
        if isinstance(original_dist, sample.Sample):
          value = tf.reshape(
              value,
              ps.concat([
                  actual_dist.batch_shape_tensor(),
                  ps.ones(ps.rank_from_shape(original_dist.sample_shape)),
                  actual_event_shape
              ],
                        axis=0))
          value = tf.broadcast_to(
              value,
              ps.concat([posterior_batch_shape, actual_event_shape], axis=0))
        new_params_dict[param] = ASVIParameters(
            prior_weight=prior_weight,
            mean_field_parameter=tfp_util.TransformedVariable(
                value,
                bijector=bijector,
                name='mean_field_parameter/{}/{}'.format(dist.name, param)))

      param_dicts.append(new_params_dict)
  return param_dicts