Exemplo n.º 1
0
 def _default_event_space_bijector(self):
     # TODO(b/145620027) Finalize choice of bijector.
     return chain_bijector.Chain([
         invert_bijector.Invert(
             square_bijector.Square(validate_args=self.validate_args),
             validate_args=self.validate_args),
         softmax_centered_bijector.SoftmaxCentered(
             validate_args=self.validate_args)
     ],
                                 validate_args=self.validate_args)
Exemplo n.º 2
0
    def __init__(self, input_shape, num_steps, coupling_bijector_fn,
                 use_actnorm, seedstream):
        parameters = dict(locals())
        rnvp_block = [identity.Identity()]
        this_nchan = input_shape[-1]

        for j in range(num_steps):  # pylint: disable=unused-variable

            this_layer_input_shape = input_shape[:-1] + (input_shape[-1] //
                                                         2, )
            this_layer = coupling_bijector_fn(this_layer_input_shape)
            bijector_fn = self.make_bijector_fn(this_layer)

            # For each step in the block, we do (optional) actnorm, followed
            # by an invertible 1x1 convolution, then affine coupling.
            this_rnvp = invert.Invert(
                real_nvp.RealNVP(this_nchan // 2, bijector_fn=bijector_fn))

            # Append the layer to the realNVP bijector for variable tracking.
            this_rnvp.coupling_bijector_layer = this_layer
            rnvp_block.append(this_rnvp)

            rnvp_block.append(
                invert.Invert(
                    OneByOneConv(this_nchan,
                                 seed=seedstream(),
                                 dtype=dtype_util.common_dtype(
                                     this_rnvp.variables,
                                     dtype_hint=tf.float32))))

            if use_actnorm:
                rnvp_block.append(
                    ActivationNormalization(this_nchan,
                                            dtype=dtype_util.common_dtype(
                                                this_rnvp.variables,
                                                dtype_hint=tf.float32)))

        # Note that we reverse the list since Chain applies bijectors in reverse
        # order.
        super(GlowBlock, self).__init__(chain.Chain(rnvp_block[::-1]),
                                        parameters=parameters,
                                        name='glow_block')
Exemplo n.º 3
0
  def __init__(self,
               concentration,
               scale,
               validate_args=False,
               allow_nan_stats=True,
               name='Weibull'):
    """Construct Weibull distributions.

    The parameters `concentration` and `scale` must be shaped in a way that
    supports broadcasting (e.g. `concentration + scale` is a valid operation).

    Args:
     concentration: Positive Float-type `Tensor`, the concentration param of the
       distribution. Must contain only positive values.
     scale: Positive Float-type `Tensor`, the scale param of the distribution.
       Must contain only positive values.
     validate_args: Python `bool` indicating whether arguments should be checked
       for correctness.
     allow_nan_stats: Python `bool` indicating whether nan values should be
       allowed.
     name: Python `str` name given to ops managed by this class.
       Default value: `'Weibull'`.

    Raises:
      TypeError: if concentration and scale are different dtypes.

    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([concentration, scale],
                                      dtype_hint=tf.float32)
      concentration = tensor_util.convert_nonref_to_tensor(
          concentration, name='concentration', dtype=dtype)
      scale = tensor_util.convert_nonref_to_tensor(
          scale, name='scale', dtype=dtype)
      # Positive scale and concentration is asserted by the incorporated
      # Weibull bijector.
      self._weibull_bijector = weibull_cdf_bijector.WeibullCDF(
          scale=scale, concentration=concentration, validate_args=validate_args)

      batch_shape = distribution_util.get_broadcast_shape(concentration, scale)
      super(Weibull, self).__init__(
          distribution=uniform.Uniform(
              # TODO(b/137665504): Use batch-adding meta-distribution to set the
              # batch shape instead of tf.ones.
              low=tf.zeros(batch_shape, dtype=dtype),
              high=tf.ones(batch_shape, dtype=dtype),
              allow_nan_stats=allow_nan_stats),
          # The Weibull bijector encodes the CDF function as the forward,
          # and hence needs to be inverted.
          bijector=invert_bijector.Invert(
              self._weibull_bijector, validate_args=validate_args),
          parameters=parameters,
          name=name)
Exemplo n.º 4
0
    def __init__(self,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Gumbel"):
        """Construct Gumbel distributions with location and scale `loc` and `scale`.

    The parameters `loc` and `scale` must be shaped in a way that supports
    broadcasting (e.g. `loc + scale` is a valid operation).

    Args:
      loc: Floating point tensor, the means of the distribution(s).
      scale: Floating point tensor, the scales of the distribution(s).
        scale must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `'Gumbel'`.

    Raises:
      TypeError: if loc and scale are different dtypes.
    """
        with tf.name_scope(name, values=[loc, scale]) as name:
            dtype = dtype_util.common_dtype([loc, scale],
                                            preferred_dtype=tf.float32)
            loc = tf.convert_to_tensor(loc, name="loc", dtype=dtype)
            scale = tf.convert_to_tensor(scale, name="scale", dtype=dtype)
            with tf.control_dependencies(
                [tf.assert_positive(scale)] if validate_args else []):
                loc = tf.identity(loc, name="loc")
                scale = tf.identity(scale, name="scale")
                tf.assert_same_float_dtype([loc, scale])
                self._gumbel_bijector = gumbel_bijector.Gumbel(
                    loc=loc, scale=scale, validate_args=validate_args)

            super(Gumbel, self).__init__(
                distribution=uniform.Uniform(low=tf.zeros([], dtype=loc.dtype),
                                             high=tf.ones([], dtype=loc.dtype),
                                             allow_nan_stats=allow_nan_stats),
                # The Gumbel bijector encodes the quantile
                # function as the forward, and hence needs to
                # be inverted.
                bijector=invert_bijector.Invert(self._gumbel_bijector),
                batch_shape=distribution_util.get_broadcast_shape(loc, scale),
                name=name)
Exemplo n.º 5
0
  def __init__(self,
               base_kernel,
               fixed_inputs,
               diag_shift=None,
               validate_args=False,
               name='SchurComplement'):
    """Construct a SchurComplement kernel instance.

    Args:
      base_kernel: A `PositiveSemidefiniteKernel` instance, the kernel used to
        build the block matrices of which this kernel computes the Schur
        complement.
      fixed_inputs: A Tensor, representing a collection of inputs. The Schur
        complement that this kernel computes comes from a block matrix, whose
        bottom-right corner is derived from `base_kernel.matrix(fixed_inputs,
        fixed_inputs)`, and whose top-right and bottom-left pieces are
        constructed by computing the base_kernel at pairs of input locations
        together with these `fixed_inputs`. `fixed_inputs` is allowed to be an
        empty collection (either `None` or having a zero shape entry), in which
        case the kernel falls back to the trivial application of `base_kernel`
        to inputs. See class-level docstring for more details on the exact
        computation this does; `fixed_inputs` correspond to the `Z` structure
        discussed there. `fixed_inputs` is assumed to have shape `[b1, ..., bB,
        N, f1, ..., fF]` where the `b`'s are batch shape entries, the `f`'s are
        feature_shape entries, and `N` is the number of fixed inputs. Use of
        this kernel entails a 1-time O(N^3) cost of computing the Cholesky
        decomposition of the k(Z, Z) matrix. The batch shape elements of
        `fixed_inputs` must be broadcast compatible with
        `base_kernel.batch_shape`.
      diag_shift: A floating point scalar to be added to the diagonal of the
        divisor_matrix before computing its Cholesky.
      validate_args: If `True`, parameters are checked for validity despite
        possibly degrading runtime performance.
        Default value: `False`
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `"SchurComplement"`
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype(
          [base_kernel, fixed_inputs, diag_shift], tf.float32)
      self._base_kernel = base_kernel
      self._diag_shift = tensor_util.convert_nonref_to_tensor(
          diag_shift, dtype=dtype, name='diag_shift')
      self._fixed_inputs = tensor_util.convert_nonref_to_tensor(
          fixed_inputs, dtype=dtype, name='fixed_inputs')
      self._cholesky_bijector = invert.Invert(
          cholesky_outer_product.CholeskyOuterProduct())
      super(SchurComplement, self).__init__(
          base_kernel.feature_ndims,
          dtype=dtype,
          name=name,
          parameters=parameters)
    def __init__(self,
                 shift,
                 scale,
                 tailweight,
                 validate_args=False,
                 name="lambertw_tail"):
        """Construct a location scale heavy-tail Lambert W bijector.

    The parameters `shift`, `scale`, and `tail` must be shaped in a way that
    supports broadcasting (e.g. `shift + scale + tail` is a valid operation).

    Args:
      shift: Floating point tensor; the shift for centering (uncentering) the
        input (output) random variable(s).
      scale: Floating point tensor; the scaling (unscaling) of the input
        (output) random variable(s). Must contain only positive values.
      tailweight: Floating point tensor; the tail behaviors of the output random
        variable(s).  Must contain only non-negative values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if `shift` and `scale` and `tail` have different `dtype`.
    """
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([tailweight, shift, scale],
                                            tf.float32)
            self._tailweight = tensor_util.convert_nonref_to_tensor(
                tailweight, name="tailweight", dtype=dtype)
            self._shift = tensor_util.convert_nonref_to_tensor(shift,
                                                               name="shift",
                                                               dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name="scale",
                                                               dtype=dtype)
            dtype_util.assert_same_float_dtype(
                (self._tailweight, self._shift, self._scale))

            self._shift_and_scale = chain.Chain(
                [tfb_shift.Shift(self._shift),
                 tfb_scale.Scale(self._scale)])
            # 'bijectors' argument in tfb.Chain super class are executed in reverse(!)
            # order.  Hence the ordering in the list must be (3,2,1), not (1,2,3).
            super(LambertWTail, self).__init__(bijectors=[
                self._shift_and_scale,
                _HeavyTailOnly(tailweight=self._tailweight),
                invert.Invert(self._shift_and_scale)
            ],
                                               validate_args=validate_args)
Exemplo n.º 7
0
def _get_flat_unconstraining_bijector(jd_model):
    """Create a bijector from a joint distribution that flattens and unconstrains.

  The intention is (loosely) to go from a model joint distribution supported on

  U_1 x U_2 x ... U_n, with U_j a subset of R^{n_j}

  to a model supported on R^N, with N = sum(n_j). (This is "loose" in the sense
  of base measures: some distribution may be supported on an m-dimensional
  subset of R^n, and the default transform for that distribution may then
  have support on R^m. See [1] for details.

  Args:
    jd_model: subclass of `tfd.JointDistribution` A JointDistribution for a
      model.

  Returns:
    Two `tfb.Bijector`s where the `.forward` method flattens and unconstrains
    points, and the second may be used to initialize a step size.
  """
    # TODO(b/180396233): This bijector is in general point-dependent.
    event_space_bij = jd_model.experimental_default_event_space_bijector()
    flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor())

    unconstrained_shapes = event_space_bij(
        flat_bijector).inverse_event_shape_tensor(
            jd_model.event_shape_tensor())

    # this reshaping is required as as split can produce a tensor of shape [1]
    # when the distribution event shape is []
    unsplit = joint_map.JointMap(
        tf.nest.map_structure(
            lambda x: reshape.Reshape(event_shape_out=x, event_shape_in=[-1]),
            unconstrained_shapes))

    bij = invert.Invert(chain.Chain([event_space_bij, flat_bijector, unsplit]))
    step_size_bij = invert.Invert(flat_bijector)

    return bij, step_size_bij
Exemplo n.º 8
0
  def __init__(self,
               concentration1=1.,
               concentration0=1.,
               validate_args=False,
               allow_nan_stats=True,
               name='Kumaraswamy'):
    """Initialize a batch of Kumaraswamy distributions.

    Args:
      concentration1: Positive floating-point `Tensor` indicating mean
        number of successes; aka 'alpha'. Implies `self.dtype` and
        `self.batch_shape`, i.e.,
        `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`.
      concentration0: Positive floating-point `Tensor` indicating mean
        number of failures; aka 'beta'. Otherwise has same semantics as
        `concentration1`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value '`NaN`' to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([concentration1, concentration0],
                                      dtype_hint=tf.float32)
      concentration1 = tensor_util.convert_nonref_to_tensor(
          concentration1, name='concentration1', dtype=dtype)
      concentration0 = tensor_util.convert_nonref_to_tensor(
          concentration0, name='concentration0', dtype=dtype)
      self._kumaraswamy_cdf = kumaraswamy_cdf.KumaraswamyCDF(
          concentration1=concentration1,
          concentration0=concentration0,
          validate_args=validate_args)
      batch_shape = distribution_util.get_broadcast_shape(
          concentration1, concentration0)
      super(Kumaraswamy, self).__init__(
          # TODO(b/137665504): Use batch-adding meta-distribution to set the
          # batch shape instead of tf.zeros.
          distribution=uniform.Uniform(
              low=tf.zeros(batch_shape, dtype=dtype),
              high=tf.ones([], dtype=dtype),
              allow_nan_stats=allow_nan_stats),
          bijector=invert.Invert(
              self._kumaraswamy_cdf, validate_args=validate_args),
          parameters=parameters,
          name=name)
Exemplo n.º 9
0
  def __init__(self, nchan, dtype=tf.float32, validate_args=False, name=None):
    parameters = dict(locals())

    self._initialized = tf.Variable(False, trainable=False)
    self._m = tf.Variable(tf.zeros(nchan, dtype))
    self._s = TransformedVariable(tf.ones(nchan, dtype), exp.Exp())
    self._bijector = invert.Invert(
        chain.Chain([
            scale.Scale(self._s),
            shift.Shift(self._m),
        ]))
    super(ActivationNormalization, self).__init__(
        validate_args=validate_args,
        forward_min_event_ndims=1,
        parameters=parameters,
        name=name or 'ActivationNormalization')
Exemplo n.º 10
0
def joint_prior_on_parameters_and_state(parameter_prior,
                                        parameterized_initial_state_prior_fn,
                                        parameter_constraining_bijector,
                                        prior_is_constrained=True):
    """Constructs a joint dist. from p(parameters) and p(state | parameters)."""
    if prior_is_constrained:
        parameter_prior = transformed_distribution.TransformedDistribution(
            parameter_prior,
            invert.Invert(parameter_constraining_bijector),
            name='unconstrained_parameter_prior')

    return joint_distribution_named.JointDistributionNamed(
        ParametersAndState(
            unconstrained_parameters=parameter_prior,
            state=lambda unconstrained_parameters: (  # pylint: disable=g-long-lambda
                parameterized_initial_state_prior_fn(
                    parameter_constraining_bijector.forward(
                        unconstrained_parameters)))))
Exemplo n.º 11
0
def _get_flat_unconstraining_bijector(jd_model):
    """Create a bijector from a joint distribution that flattens and unconstrains.

  The intention is (loosely) to go from a model joint distribution supported on

  U_1 x U_2 x ... U_n, with U_j a subset of R^{n_j}

  to a model supported on R^N, with N = sum(n_j). (This is "loose" in the sense
  of base measures: some distribution may be supported on an m-dimensional
  subset of R^n, and the default transform for that distribution may then
  have support on R^m. See [1] for details.

  Args:
    jd_model: subclass of `tfd.JointDistribution` A JointDistribution for a
      model.

  Returns:
    A `tfb.Bijector` where the `.forward` method flattens and unconstrains
    points.
  """
    # TODO(b/180396233): This bijector is in general point-dependent.
    to_chain = [jd_model.experimental_default_event_space_bijector()]
    flat_bijector = restructure.pack_sequence_as(jd_model.event_shape_tensor())
    to_chain.append(flat_bijector)

    unconstrained_shapes = flat_bijector.inverse_event_shape_tensor(
        jd_model.event_shape_tensor())

    # this reshaping is required as as split can produce a tensor of shape [1]
    # when the distribution event shape is []
    reshapers = [
        reshape.Reshape(event_shape_out=x, event_shape_in=[-1])
        for x in unconstrained_shapes
    ]
    to_chain.append(joint_map.JointMap(bijectors=reshapers))

    size_splits = [ps.reduce_prod(x) for x in unconstrained_shapes]
    to_chain.append(split.Split(num_or_size_splits=size_splits))

    return invert.Invert(chain.Chain(to_chain))
Exemplo n.º 12
0
    def __init__(self,
                 df,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Chi"):
        """Construct Chi distributions with parameter `df`.

    Args:
      df: Floating point tensor, the degrees of freedom of the
        distribution(s). `df` must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value `NaN` to indicate the result
        is undefined. When `False`, an exception is raised if one or more of the
        statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `'Chi'`.
    """
        parameters = dict(locals())
        with tf.compat.v2.name_scope(name) as name:
            df = tf.convert_to_tensor(value=df,
                                      name="df",
                                      dtype=dtype_util.common_dtype(
                                          [df], preferred_dtype=tf.float32))
            validation_assertions = ([assert_util.assert_positive(df)]
                                     if validate_args else [])
            with tf.control_dependencies(validation_assertions):
                self._df = tf.identity(df, name="df")

            super(Chi, self).__init__(
                distribution=chi2.Chi2(df=self._df,
                                       validate_args=validate_args,
                                       allow_nan_stats=allow_nan_stats),
                bijector=invert_bijector.Invert(square_bijector.Square()),
                parameters=parameters,
                name=name)
Exemplo n.º 13
0
def tree_flatten(example, name='restructure'):
    """Returns a Bijector variant of tf.nest.flatten.

  To make it a Bijector, it has to know how to "unflatten" as
  well---unlike the real `tf.nest.flatten`, this can only flatten or
  unflatten a specific structure.  The `example` argument defines the
  structure.

  See also the `Restructure` bijector for general rearrangements.

  Args:
    example: A Tensor or (potentially nested) collection of Tensors.
    name: An optional Python string, inserted into names of TF ops
      created by this bijector.

  Returns:
    flatten: A Bijector whose `forward` method flattens structures
      parallel to `example` into a list of Tensors, and whose
      `inverse` method packs a list of Tensors of the right length
      into a structure parallel to `example`.

  #### Example

  ```python
  x = tf.constant(1)
  example = collections.OrderedDict([
      ('a', [x, x, x]),
      ('b', x)])
  bij = tfb.tree_flatten(example)
  ys = collections.OrderedDict([
      ('a', [1, 2, 3]),
      ('b', 4.)])
  bij.forward(ys)
  # Returns [1, 2, 3, 4.]
  ```

  """
    return invert.Invert(pack_sequence_as(example, name))
Exemplo n.º 14
0
    def __init__(self,
                 skewness,
                 tailweight,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name=None):
        """Construct Johnson's SU distributions.

    The distributions have shape parameteres `tailweight` and `skewness`,
    mean `loc`, and scale `scale`.

    The parameters `tailweight`, `skewness`, `loc`, and `scale` must be shaped
    in a way that supports broadcasting
    (e.g. `skewness + tailweight + loc + scale` is a valid operation).

    Args:
      skewness: Floating-point `Tensor`. Skewness of the distribution(s).
      tailweight: Floating-point `Tensor`. Tail weight of the
        distribution(s). `tailweight` must contain only positive values.
      loc: Floating-point `Tensor`. The mean(s) of the distribution(s).
      scale: Floating-point `Tensor`. The scaling factor(s) for the
        distribution(s). Note that `scale` is not technically the standard
        deviation of this distribution but has semantics more similar to
        standard deviation than variance.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value '`NaN`' to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if any of skewness, tailweight, loc and scale are different
        dtypes.
    """
        parameters = dict(locals())
        with tf.name_scope(name or 'JohnsonSU') as name:
            dtype = dtype_util.common_dtype([skewness, tailweight, loc, scale],
                                            tf.float32)
            self._skewness = tensor_util.convert_nonref_to_tensor(
                skewness, name='skewness', dtype=dtype)
            self._tailweight = tensor_util.convert_nonref_to_tensor(
                tailweight, name='tailweight', dtype=dtype)
            self._loc = tensor_util.convert_nonref_to_tensor(loc,
                                                             name='loc',
                                                             dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name='scale',
                                                               dtype=dtype)

            norm_shift = invert_bijector.Invert(
                shift_bijector.Shift(shift=self._skewness,
                                     validate_args=validate_args))

            norm_scale = invert_bijector.Invert(
                scale_bijector.Scale(scale=self._tailweight,
                                     validate_args=validate_args))

            sinh = sinh_bijector.Sinh(validate_args=validate_args)

            scale = scale_bijector.Scale(scale=self._scale,
                                         validate_args=validate_args)

            shift = shift_bijector.Shift(shift=self._loc,
                                         validate_args=validate_args)

            bijector = shift(scale(sinh(norm_scale(norm_shift))))

            batch_rank = ps.reduce_max([
                distribution_util.prefer_static_rank(x)
                for x in (self._skewness, self._tailweight, self._loc,
                          self._scale)
            ])

            super(JohnsonSU, self).__init__(
                # TODO(b/160730249): Make `loc` a scalar `0.` and remove overridden
                # `batch_shape` and `batch_shape_tensor` when
                # TransformedDistribution's bijector can modify its `batch_shape`.
                distribution=normal.Normal(loc=tf.zeros(ps.ones(
                    batch_rank, tf.int32),
                                                        dtype=dtype),
                                           scale=tf.ones([], dtype=dtype),
                                           validate_args=validate_args,
                                           allow_nan_stats=allow_nan_stats),
                bijector=bijector,
                validate_args=validate_args,
                parameters=parameters,
                name=name)
Exemplo n.º 15
0
  def __init__(self,
               loc,
               scale,
               concentration,
               validate_args=False,
               allow_nan_stats=True,
               name='GeneralizedExtremeValue'):
    """Construct generalized extreme value distribution.

    The parameters `loc`, `scale`, and `concentration` must be shaped in a way
    that supports broadcasting (e.g. `loc + scale` + `concentration` is valid).

    Args:
      loc: Floating point tensor, the location parameter of the distribution(s).
      scale: Floating point tensor, the scales of the distribution(s).
        scale must contain only positive values.
      concentration: Floating point tensor, the concentration of
        the distribution(s).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value `NaN` to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `'GeneralizedExtremeValue'`.

    Raises:
      TypeError: if loc and scale are different dtypes.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([loc, scale, concentration],
                                      dtype_hint=tf.float32)
      loc = tensor_util.convert_nonref_to_tensor(
          loc, name='loc', dtype=dtype)
      scale = tensor_util.convert_nonref_to_tensor(
          scale, name='scale', dtype=dtype)
      concentration = tensor_util.convert_nonref_to_tensor(
          concentration, name='concentration', dtype=dtype)
      dtype_util.assert_same_float_dtype([loc, scale, concentration])
      # Positive scale is asserted by the incorporated GEV bijector.
      self._gev_bijector = gev_cdf_bijector.GeneralizedExtremeValueCDF(
          loc=loc, scale=scale, concentration=concentration,
          validate_args=validate_args)

      batch_shape = distribution_util.get_broadcast_shape(loc, scale,
                                                          concentration)
      # Because the uniform sampler generates samples in `[0, 1)` this would
      # cause samples to lie in `(inf, -inf]` instead of `(inf, -inf)`. To fix
      # this, we use `np.finfo(dtype_util.as_numpy_dtype(self.dtype).tiny`
      # because it is the smallest, positive, 'normal' number.
      super(GeneralizedExtremeValue, self).__init__(
          # TODO(b/137665504): Use batch-adding meta-distribution to set the
          # batch shape instead of tf.ones.
          distribution=uniform.Uniform(
              low=np.finfo(dtype_util.as_numpy_dtype(dtype)).tiny,
              high=tf.ones(batch_shape, dtype=dtype),
              allow_nan_stats=allow_nan_stats),
          # The GEV bijector encodes the CDF function as the forward,
          # and hence needs to be inverted.
          bijector=invert_bijector.Invert(
              self._gev_bijector, validate_args=validate_args),
          parameters=parameters,
          name=name)
Exemplo n.º 16
0
    def __init__(self,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='Moyal'):
        """Construct Moyal distributions with location and scale `loc` and `scale`.

    The parameters `loc` and `scale` must be shaped in a way that supports
    broadcasting (e.g. `loc + scale` is a valid operation).

    Args:
      loc: Floating point tensor, the means of the distribution(s).
      scale: Floating point tensor, the scales of the distribution(s).
        scale must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value `NaN` to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `'Moyal'`.

    Raises:
      TypeError: if loc and scale are different dtypes.


    #### References

    [1] J.E. Moyal, "XXX. Theory of ionization fluctuations",
       The London, Edinburgh, and Dublin Philosophical Magazine
       and Journal of Science.
       https://www.tandfonline.com/doi/abs/10.1080/14786440308521076
    [2] G. Cordeiro, J. Nobre, R. Pescim, E. Ortega,
        "The beta Moyal: a useful skew distribution",
        https://www.arpapress.com/Volumes/Vol10Issue2/IJRRAS_10_2_02.pdf
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale],
                                            dtype_hint=tf.float32)
            loc = tensor_util.convert_nonref_to_tensor(loc,
                                                       name='loc',
                                                       dtype=dtype)
            scale = tensor_util.convert_nonref_to_tensor(scale,
                                                         name='scale',
                                                         dtype=dtype)
            dtype_util.assert_same_float_dtype([loc, scale])
            # Positive scale is asserted by the incorporated Moyal bijector.
            self._moyal_bijector = moyal_cdf_bijector.MoyalCDF(
                loc=loc, scale=scale, validate_args=validate_args)

            # Because the uniform sampler generates samples in `[0, 1)` this would
            # cause samples to lie in `(inf, -inf]` instead of `(inf, -inf)`. To fix
            # this, we use `np.finfo(dtype_util.as_numpy_dtype(self.dtype).tiny`
            # because it is the smallest, positive, 'normal' number.
            batch_shape = distribution_util.get_broadcast_shape(loc, scale)
            super(Moyal, self).__init__(
                # TODO(b/137665504): Use batch-adding meta-distribution to set the
                # batch shape instead of tf.ones.
                distribution=uniform.Uniform(low=np.finfo(
                    dtype_util.as_numpy_dtype(dtype)).tiny,
                                             high=tf.ones(batch_shape,
                                                          dtype=dtype),
                                             allow_nan_stats=allow_nan_stats),
                # The Moyal bijector encodes the CDF function as the forward,
                # and hence needs to be inverted.
                bijector=invert_bijector.Invert(self._moyal_bijector,
                                                validate_args=validate_args),
                parameters=parameters,
                name=name)
Exemplo n.º 17
0
def _affine_surrogate_posterior_from_base_distribution(
        base_distribution,
        operators='diag',
        bijector=None,
        initial_unconstrained_loc_fn=_sample_uniform_initial_loc,
        validate_args=False,
        name=None):
    """Builds a variational posterior by linearly transforming base distributions.

  This function builds a surrogate posterior by applying a trainable
  transformation to a base distribution (typically a `tfd.JointDistribution`) or
  nested structure of base distributions, and constraining the samples with
  `bijector`. Note that the distributions must have event shapes corresponding
  to the *pretransformed* surrogate posterior -- that is, if `bijector` contains
  a shape-changing bijector, then the corresponding base distribution event
  shape is the inverse event shape of the bijector applied to the desired
  surrogate posterior shape. The surrogate posterior is constucted as follows:

  1. Flatten the base distribution event shapes to vectors, and pack the base
     distributions into a `tfd.JointDistribution`.
  2. Apply a trainable blockwise LinearOperator bijector to the joint base
     distribution.
  3. Apply the constraining bijectors and return the resulting trainable
     `tfd.TransformedDistribution` instance.

  Args:
    base_distribution: `tfd.Distribution` instance (typically a
      `tfd.JointDistribution`), or a nested structure of `tfd.Distribution`
      instances.
    operators: Either a string or a list/tuple containing `LinearOperator`
      subclasses, `LinearOperator` instances, or callables returning
      `LinearOperator` instances. Supported string values are "diag" (to create
      a mean-field surrogate posterior) and "tril" (to create a full-covariance
      surrogate posterior). A list/tuple may be passed to induce other
      posterior covariance structures. If the list is flat, a
      `tf.linalg.LinearOperatorBlockDiag` instance will be created and applied
      to the base distribution. Otherwise the list must be singly-nested and
      have a first element of length 1, second element of length 2, etc.; the
      elements of the outer list are interpreted as rows of a lower-triangular
      block structure, and a `tf.linalg.LinearOperatorBlockLowerTriangular`
      instance is created. For complete documentation and examples, see
      `tfp.experimental.vi.util.build_trainable_linear_operator_block`, which
      receives the `operators` arg if it is list-like.
      Default value: `"diag"`.
    bijector: `tfb.Bijector` instance, or nested structure of `tfb.Bijector`
      instances, that maps (nested) values in R^n to the support of the
      posterior. (This can be the `experimental_default_event_space_bijector` of
      the distribution over the prior latent variables.)
      Default value: `None` (i.e., the posterior is over R^n).
    initial_unconstrained_loc_fn: Optional Python `callable` with signature
      `initial_loc = initial_unconstrained_loc_fn(shape, dtype, seed)` used to
      sample real-valued initializations for the unconstrained location of
      each variable.
      Default value: `functools.partial(tf.random.stateless_uniform,
        minval=-2., maxval=2., dtype=tf.float32)`.
    validate_args: Python `bool`. Whether to validate input with asserts. This
      imposes a runtime cost. If `validate_args` is `False`, and the inputs are
      invalid, correct behavior is not guaranteed.
      Default value: `False`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` (i.e.,
      'build_affine_surrogate_posterior_from_base_distribution').
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.
  Raises:
    NotImplementedError: Base distributions with mixed dtypes are not supported.

  #### Examples
  ```python
  tfd = tfp.distributions
  tfb = tfp.bijectors

  # Fit a multivariate Normal surrogate posterior on the Eight Schools model
  # [1].

  treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
  treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]

  def model_fn():
    avg_effect = yield tfd.Normal(loc=0., scale=10., name='avg_effect')
    log_stddev = yield tfd.Normal(loc=5., scale=1., name='log_stddev')
    school_effects = yield tfd.Sample(
        tfd.Normal(loc=avg_effect, scale=tf.exp(log_stddev)),
        sample_shape=[8],
        name='school_effects')
    treatment_effects = yield tfd.Independent(
        tfd.Normal(loc=school_effects, scale=treatment_stddevs),
        reinterpreted_batch_ndims=1,
        name='treatment_effects')
  model = tfd.JointDistributionCoroutineAutoBatched(model_fn)

  # Pin the observed values in the model.
  target_model = model.experimental_pin(treatment_effects=treatment_effects)

  # Define a lower triangular structure of `LinearOperator` subclasses that
  # models full covariance among latent variables except for the 8 dimensions
  # of `school_effect`, which are modeled as independent (using
  # `LinearOperatorDiag`).
  operators = [
    [tf.linalg.LinearOperatorLowerTriangular],
    [tf.linalg.LinearOperatorFullMatrix, LinearOperatorLowerTriangular],
    [tf.linalg.LinearOperatorFullMatrix, LinearOperatorFullMatrix,
     tf.linalg.LinearOperatorDiag]]


  # Constrain the posterior values to the support of the prior.
  bijector = target_model.experimental_default_event_space_bijector()

  # Build a full-covariance surrogate posterior.
  surrogate_posterior = (
    tfp.experimental.vi.build_affine_surrogate_posterior_from_base_distribution(
        base_distribution=base_distribution,
        operators=operators,
        bijector=bijector))

  # Fit the model.
  losses = tfp.vi.fit_surrogate_posterior(
      target_model.unnormalized_log_prob,
      surrogate_posterior,
      num_steps=100,
      optimizer=tf.optimizers.Adam(0.1),
      sample_size=10)
  ```

  #### References

  [1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and
      Donald Rubin. Bayesian Data Analysis, Third Edition.
      Chapman and Hall/CRC, 2013.

  """
    with tf.name_scope(name
                       or 'affine_surrogate_posterior_from_base_distribution'):

        if nest.is_nested(base_distribution):
            base_distribution = (joint_distribution_util.
                                 independent_joint_distribution_from_structure(
                                     base_distribution,
                                     validate_args=validate_args))

        if nest.is_nested(bijector):
            bijector = joint_map.JointMap(nest.map_structure(
                lambda b: identity.Identity() if b is None else b, bijector),
                                          validate_args=validate_args)

        batch_shape = base_distribution.batch_shape_tensor()
        if tf.nest.is_nested(
                batch_shape):  # Base is a classic JointDistribution.
            batch_shape = functools.reduce(ps.broadcast_shape,
                                           tf.nest.flatten(batch_shape))
        event_shape = base_distribution.event_shape_tensor()
        flat_event_size = nest.flatten(
            nest.map_structure(ps.reduce_prod, event_shape))

        base_dtypes = set([
            dtype_util.base_dtype(d)
            for d in nest.flatten(base_distribution.dtype)
        ])
        if len(base_dtypes) > 1:
            raise NotImplementedError(
                'Base distributions with mixed dtype are not supported. Saw '
                'components of dtype {}'.format(base_dtypes))
        base_dtype = list(base_dtypes)[0]

        num_components = len(flat_event_size)
        if operators == 'diag':
            operators = [tf.linalg.LinearOperatorDiag] * num_components
        elif operators == 'tril':
            operators = [[tf.linalg.LinearOperatorFullMatrix] * i +
                         [tf.linalg.LinearOperatorLowerTriangular]
                         for i in range(num_components)]
        elif isinstance(operators, str):
            raise ValueError(
                'Unrecognized operator type {}. Valid operators are "diag", "tril", '
                'or a structure that can be passed to '
                '`tfp.experimental.vi.util.build_trainable_linear_operator_block` as '
                'the `operators` arg.'.format(operators))

        if nest.is_nested(operators):
            operators = yield from trainable_linear_operators._trainable_linear_operator_block(  # pylint: disable=protected-access
                operators,
                block_dims=flat_event_size,
                dtype=base_dtype,
                batch_shape=batch_shape)

        linop_bijector = (
            scale_matvec_linear_operator.ScaleMatvecLinearOperatorBlock(
                scale=operators, validate_args=validate_args))

        def generate_shift_bijector(s):
            x = yield trainable_state_util.Parameter(
                functools.partial(initial_unconstrained_loc_fn,
                                  ps.concat([batch_shape, [s]], axis=0),
                                  dtype=base_dtype))
            return shift.Shift(x)

        loc_bijectors = yield from nest_util.map_structure_coroutine(
            generate_shift_bijector, flat_event_size)
        loc_bijector = joint_map.JointMap(loc_bijectors,
                                          validate_args=validate_args)

        unflatten_and_reshape = chain.Chain([
            joint_map.JointMap(nest.map_structure(reshape.Reshape,
                                                  event_shape),
                               validate_args=validate_args),
            restructure.Restructure(
                nest.pack_sequence_as(event_shape, range(num_components)))
        ],
                                            validate_args=validate_args)

        bijectors = [] if bijector is None else [bijector]
        bijectors.extend([
            unflatten_and_reshape,
            loc_bijector,  # Allow the mean of the standard dist to shift from 0.
            linop_bijector
        ])  # Apply LinOp to scale the standard dist.
        bijector = chain.Chain(bijectors, validate_args=validate_args)

        flat_base_distribution = invert.Invert(unflatten_and_reshape)(
            base_distribution)

        return transformed_distribution.TransformedDistribution(
            flat_base_distribution,
            bijector=bijector,
            validate_args=validate_args)
Exemplo n.º 18
0
    def __init__(self,
                 base_kernel,
                 fixed_inputs,
                 fixed_inputs_mask=None,
                 diag_shift=None,
                 cholesky_fn=None,
                 validate_args=False,
                 name='SchurComplement',
                 _precomputed_divisor_matrix_cholesky=None):
        """Construct a SchurComplement kernel instance.

    Args:
      base_kernel: A `PositiveSemidefiniteKernel` instance, the kernel used to
        build the block matrices of which this kernel computes the Schur
        complement.
      fixed_inputs: A Tensor, representing a collection of inputs. The Schur
        complement that this kernel computes comes from a block matrix, whose
        bottom-right corner is derived from `base_kernel.matrix(fixed_inputs,
        fixed_inputs)`, and whose top-right and bottom-left pieces are
        constructed by computing the base_kernel at pairs of input locations
        together with these `fixed_inputs`. `fixed_inputs` is allowed to be an
        empty collection (either `None` or having a zero shape entry), in which
        case the kernel falls back to the trivial application of `base_kernel`
        to inputs. See class-level docstring for more details on the exact
        computation this does; `fixed_inputs` correspond to the `Z` structure
        discussed there. `fixed_inputs` is assumed to have shape `[b1, ..., bB,
        N, f1, ..., fF]` where the `b`'s are batch shape entries, the `f`'s are
        feature_shape entries, and `N` is the number of fixed inputs. Use of
        this kernel entails a 1-time O(N^3) cost of computing the Cholesky
        decomposition of the k(Z, Z) matrix. The batch shape elements of
        `fixed_inputs` must be broadcast compatible with
        `base_kernel.batch_shape`.
      fixed_inputs_mask: A boolean Tensor of shape `[..., N]`.  When `mask` is
        not None and an element of `mask` is `False`, this kernel will return
        values computed as if the divisor matrix did not contain the
        corresponding row or column.
      diag_shift: A floating point scalar to be added to the diagonal of the
        divisor_matrix before computing its Cholesky.
      cholesky_fn: Callable which takes a single (batch) matrix argument and
        returns a Cholesky-like lower triangular factor.  Default value: `None`,
        in which case `make_cholesky_with_jitter_fn` is used with the `jitter`
        parameter.
      validate_args: If `True`, parameters are checked for validity despite
        possibly degrading runtime performance.
        Default value: `False`
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `"SchurComplement"`
      _precomputed_divisor_matrix_cholesky: Internal parameter -- do not use.
    """
        parameters = dict(locals())

        # Delayed import to avoid circular dependency between `tfp.bijectors` and
        # `tfp.math`
        # pylint: disable=g-import-not-at-top
        from tensorflow_probability.python.bijectors import cholesky_outer_product
        from tensorflow_probability.python.bijectors import invert
        # pylint: enable=g-import-not-at-top
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([
                base_kernel, fixed_inputs, diag_shift,
                _precomputed_divisor_matrix_cholesky
            ], tf.float32)
            self._base_kernel = base_kernel
            self._diag_shift = tensor_util.convert_nonref_to_tensor(
                diag_shift, dtype=dtype, name='diag_shift')
            self._fixed_inputs = tensor_util.convert_nonref_to_tensor(
                fixed_inputs, dtype=dtype, name='fixed_inputs')
            self._fixed_inputs_mask = tensor_util.convert_nonref_to_tensor(
                fixed_inputs_mask, dtype=tf.bool, name='fixed_inputs_mask')
            self._cholesky_bijector = invert.Invert(
                cholesky_outer_product.CholeskyOuterProduct())
            self._precomputed_divisor_matrix_cholesky = _precomputed_divisor_matrix_cholesky
            if self._precomputed_divisor_matrix_cholesky is not None:
                self._precomputed_divisor_matrix_cholesky = tf.convert_to_tensor(
                    self._precomputed_divisor_matrix_cholesky, dtype)
            if cholesky_fn is None:
                from tensorflow_probability.python.distributions import cholesky_util  # pylint:disable=g-import-not-at-top
                cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn()
            self._cholesky_fn = cholesky_fn
            self._cholesky_bijector = invert.Invert(
                cholesky_outer_product.CholeskyOuterProduct(
                    cholesky_fn=cholesky_fn))

            super(SchurComplement, self).__init__(base_kernel.feature_ndims,
                                                  dtype=dtype,
                                                  name=name,
                                                  parameters=parameters)
Exemplo n.º 19
0
  def __init__(self,
               base_kernel,
               fixed_inputs,
               diag_shift=None,
               validate_args=False,
               name='SchurComplement'):
    """Construct a SchurComplement kernel instance.

    Args:
      base_kernel: A `PositiveSemidefiniteKernel` instance, the kernel used to
        build the block matrices of which this kernel computes the  Schur
        complement.
      fixed_inputs: A Tensor, representing a collection of inputs. The Schur
        complement that this kernel computes comes from a block matrix, whose
        bottom-right corner is derived from `base_kernel.matrix(fixed_inputs,
        fixed_inputs)`, and whose top-right and bottom-left pieces are
        constructed by computing the base_kernel at pairs of input locations
        together with these `fixed_inputs`. `fixed_inputs` is allowed to be an
        empty collection (either `None` or having a zero shape entry), in which
        case the kernel falls back to the trivial application of `base_kernel`
        to inputs. See class-level docstring for more details on the exact
        computation this does; `fixed_inputs` correspond to the `Z` structure
        discussed there. `fixed_inputs` is assumed to have shape `[b1, ..., bB,
        N, f1, ..., fF]` where the `b`'s are batch shape entries, the `f`'s are
        feature_shape entries, and `N` is the number of fixed inputs. Use of
        this kernel entails a 1-time O(N^3) cost of computing the Cholesky
        decomposition of the k(Z, Z) matrix. The batch shape elements of
        `fixed_inputs` must be broadcast compatible with
        `base_kernel.batch_shape`.
      diag_shift: A floating point scalar to be added to the diagonal of the
        divisor_matrix before computing its Cholesky.
      validate_args: If `True`, parameters are checked for validity despite
        possibly degrading runtime performance.
        Default value: `False`
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `"SchurComplement"`
    """
    with tf.compat.v1.name_scope(
        name, values=[base_kernel, fixed_inputs]) as name:
      # If the base_kernel doesn't have a specified dtype, we can't pass it off
      # to common_dtype, which always expects `tf.as_dtype(dtype)` to work (and
      # it doesn't if the given `dtype` is None.
      # TODO(b/130421035): Consider changing common_dtype to allow Nones, and
      # clean this up after.
      #
      # Thus, we spell out the logic
      # here: use the dtype of `fixed_inputs` if possible. If base_kernel.dtype
      # is not None, use the usual logic.
      if base_kernel.dtype is None:
        dtype = None if fixed_inputs is None else fixed_inputs.dtype
      else:
        dtype = dtype_util.common_dtype([base_kernel, fixed_inputs], tf.float32)
      self._base_kernel = base_kernel
      self._fixed_inputs = (None if fixed_inputs is None else
                            tf.convert_to_tensor(value=fixed_inputs,
                                                 dtype=dtype))
      if not self._is_empty_fixed_inputs():
        # We create and store this matrix here, so that we get the caching
        # benefit when we later access its cholesky. If we computed the matrix
        # every time we needed the cholesky, the bijector cache wouldn't be hit.
        self._divisor_matrix = base_kernel.matrix(fixed_inputs, fixed_inputs)
        if diag_shift is not None:
          self._divisor_matrix = _add_diagonal_shift(
              self._divisor_matrix, diag_shift)

      self._cholesky_bijector = invert.Invert(
          cholesky_outer_product.CholeskyOuterProduct())
    super(SchurComplement, self).__init__(
        base_kernel.feature_ndims, dtype=dtype, name=name)
Exemplo n.º 20
0
  def __init__(self,
               loc=None,
               precision_factor=None,
               precision=None,
               validate_args=False,
               allow_nan_stats=True,
               name='MultivariateNormalPrecisionFactorLinearOperator'):
    """Initialize distribution.

    Precision is the inverse of the covariance matrix, and
    `precision_factor @ precision_factor.T = precision`.

    The `batch_shape` of this distribution is the broadcast of
    `loc.shape[:-1]` and `precision_factor.batch_shape`.

    The `event_shape` of this distribution is determined by `loc.shape[-1:]`,
    OR `precision_factor.shape[-1:]`, which must match.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      precision_factor: Required nonsingular `tf.linalg.LinearOperator` instance
        with same `dtype` and shape compatible with `loc`.
      precision: Optional square `tf.linalg.LinearOperator` instance with same
        `dtype` and shape compatible with `loc` and `precision_factor`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      if precision_factor is None:
        raise ValueError(
            'Argument `precision_factor` must be provided. Found `None`')

      dtype = dtype_util.common_dtype([loc, precision_factor, precision],
                                      dtype_hint=tf.float32)
      loc = tensor_util.convert_nonref_to_tensor(loc, dtype=dtype, name='loc')

      self._loc = loc
      self._precision_factor = precision_factor
      self._precision = precision

      batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
          loc, precision_factor)

      # Proof of factors (used throughout code):
      # Let,
      #   C = covariance,
      #   P = inv(covariance) = precision
      #   P = F @ F.T  (so F is the `precision_factor`).
      #
      # Then, the log prob term is
      #  x.T @ inv(C) @ x
      #  = x.T @ P @ x
      #  = x.T @ F @ F.T @ x
      #  = || F.T @ x ||**2
      # notice it involves F.T, which is why we set adjoint=True in various
      # places.
      #
      # Also, if w ~ Normal(0, I), then we can sample by setting
      #  x = inv(F.T) @ w + loc,
      # since then
      #  E[(x - loc) @ (x - loc).T]
      #  = E[inv(F.T) @ w @ w.T @ inv(F)]
      #  = inv(F.T) @ inv(F)
      #  = inv(F @ F.T)
      #  = inv(P)
      #  = C.

      if precision is not None:
        precision.shape.assert_is_compatible_with(precision_factor.shape)

      bijector = invert.Invert(
          scale_matvec_linear_operator.ScaleMatvecLinearOperator(
              scale=precision_factor,
              validate_args=validate_args,
              adjoint=True)
      )
      if loc is not None:
        shift = shift_bijector.Shift(shift=loc, validate_args=validate_args)
        bijector = shift(bijector)

      super(MultivariateNormalPrecisionFactorLinearOperator, self).__init__(
          distribution=mvn_diag.MultivariateNormalDiag(
              loc=tf.zeros(
                  ps.concat([batch_shape, event_shape], axis=0), dtype=dtype)),
          bijector=bijector,
          validate_args=validate_args,
          name=name)
      self._parameters = parameters
Exemplo n.º 21
0
  def __init__(self,
               bijectors,
               block_sizes=None,
               validate_args=False,
               maybe_changes_size=True,
               name=None):
    """Creates the bijector.

    Args:
      bijectors: A non-empty list of bijectors.
      block_sizes: A 1-D integer `Tensor` with each element signifying the
        length of the block of the input vector to pass to the corresponding
        bijector. The length of `block_sizes` must be be equal to the length of
        `bijectors`. If left as None, a vector of 1's is used.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      maybe_changes_size: Python `bool` indicating that this bijector might
        change the event size. If this is known to be false and set
        appropriately, then this will lead to improved static shape inference
        when the block sizes are not statically known.
      name: Python `str`, name given to ops managed by this object. Default:
        E.g., `Blockwise([Exp(), Softplus()]).name ==
        'blockwise_of_exp_and_softplus'`.

    Raises:
      NotImplementedError: If there is a bijector with `event_ndims` > 1.
      ValueError: If `bijectors` list is empty.
      ValueError: If size of `block_sizes` does not equal to the length of
        bijectors or is not a vector.
    """
    parameters = dict(locals())
    if not name:
      name = 'blockwise_of_' + '_and_'.join([b.name for b in bijectors])
      name = name.replace('/', '')

    with tf.name_scope(name) as name:
      for b in bijectors:
        if (nest.is_nested(b.forward_min_event_ndims)
            or nest.is_nested(b.inverse_min_event_ndims)):
          raise ValueError('Bijectors must all be single-part.')
        elif isinstance(b.forward_min_event_ndims, int):
          if b.forward_min_event_ndims != b.inverse_min_event_ndims:
            raise ValueError('Rank-changing bijectors are not supported.')
          elif b.forward_min_event_ndims > 1:
            raise ValueError('Only scalar and vector event-shape '
                             'bijectors are supported at this time.')

      b_joint = joint_map.JointMap(list(bijectors), name='jointmap')

      block_sizes = (
          np.ones(len(bijectors), dtype=np.int32)
          if block_sizes is None else
          _validate_block_sizes(block_sizes, bijectors, validate_args))
      b_split = split.Split(
          block_sizes, name='split', validate_args=validate_args)

      if maybe_changes_size:
        i_block_sizes = _validate_block_sizes(
            ps.concat(b_joint.forward_event_shape_tensor(
                ps.split(block_sizes, len(bijectors))), axis=0),
            bijectors, validate_args)
        maybe_changes_size = not tf.get_static_value(
            ps.reduce_all(block_sizes == i_block_sizes))
      b_concat = invert.Invert(
          (split.Split(i_block_sizes, name='isplit')
           if maybe_changes_size else b_split),
          name='concat')

      self._maybe_changes_size = maybe_changes_size
      super(Blockwise, self).__init__(
          bijectors=[b_concat, b_joint, b_split],
          validate_args=validate_args,
          parameters=parameters,
          name=name)
Exemplo n.º 22
0
  def __init__(self,
               output_shape=(32, 32, 3),
               num_glow_blocks=3,
               num_steps_per_block=32,
               coupling_bijector_fn=None,
               exit_bijector_fn=None,
               grab_after_block=None,
               use_actnorm=True,
               seed=None,
               validate_args=False,
               name='glow'):
    """Creates the Glow bijector.

    Args:
      output_shape: A list of integers, specifying the event shape of the
        output, of the bijectors forward pass (the image).  Specified as
        [H, W, C].
        Default Value: (32, 32, 3)
      num_glow_blocks: An integer, specifying how many downsampling levels to
        include in the model. This must divide equally into both H and W,
        otherwise the bijector would not be invertible.
        Default Value: 3
      num_steps_per_block: An integer specifying how many Affine Coupling and
        1x1 convolution layers to include at each level of the spatial
        hierarchy.
        Default Value: 32 (i.e. the value used in the original glow paper).
      coupling_bijector_fn: A function which takes the argument `input_shape`
        and returns a callable neural network (e.g. a keras.Sequential). The
        network should either return a tensor with the same event shape as
        `input_shape` (this will employ additive coupling), a tensor with the
        same height and width as `input_shape` but twice the number of channels
        (this will employ affine coupling), or a bijector which takes in a
        tensor with event shape `input_shape`, and returns a tensor with shape
        `input_shape`.
      exit_bijector_fn: Similar to coupling_bijector_fn, exit_bijector_fn is
        a function which takes the argument `input_shape` and `output_chan`
        and returns a callable neural network. The neural network it returns
        should take a tensor of shape `input_shape` as the input, and return
        one of three options: A tensor with `output_chan` channels, a tensor
        with `2 * output_chan` channels, or a bijector. Additional details can
        be found in the documentation for ExitBijector.
      grab_after_block: A tuple of floats, specifying what fraction of the
        remaining channels to remove following each glow block. Glow will take
        the integer floor of this number multiplied by the remaining number of
        channels. The default is half at each spatial hierarchy.
        Default value: None (this will take out half of the channels after each
          block.
      use_actnorm: A bool deciding whether or not to use actnorm. Data-dependent
        initialization is used to initialize this layer.
        Default value: `False`
      seed: A seed to control randomness in the 1x1 convolution initialization.
        Default value: `None` (i.e., non-reproducible sampling).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
        Default value: `False`
      name: Python `str`, name given to ops managed by this object.
        Default value: `'glow'`.
    """
    # Make sure that the input shape is fully defined.
    if not tensorshape_util.is_fully_defined(output_shape):
      raise ValueError('Shape must be fully defined.')
    if tensorshape_util.rank(output_shape) != 3:
      raise ValueError('Shape ndims must be 3 for images.  Your shape is'
                       '{}'.format(tensorshape_util.rank(output_shape)))

    num_glow_blocks_ = tf.get_static_value(num_glow_blocks)
    if (num_glow_blocks_ is None or
        int(num_glow_blocks_) != num_glow_blocks_ or
        num_glow_blocks_ < 1):
      raise ValueError('Argument `num_glow_blocks` must be a statically known'
                       'positive `int` (saw: {}).'.format(num_glow_blocks))
    num_glow_blocks = int(num_glow_blocks_)

    output_shape = tensorshape_util.as_list(output_shape)
    h, w, c = output_shape
    n = num_glow_blocks
    nsteps = num_steps_per_block

    # Default Glow: Half of the channels are split off after each block,
    # and after the final block, no channels are split off.
    if grab_after_block is None:
      grab_after_block = tuple([0.5] * (n - 1) + [0.])

    # Thing we know must be true: h and w are evenly divisible by 2, n times.
    # Otherwise, the squeeze bijector will not work.
    if w % 2**n != 0:
      raise ValueError('Width must be divisible by 2 at least n times.'
                       'Saw: {} % {} != 0'.format(w, 2**n))
    if h % 2**n != 0:
      raise ValueError('Height should be divisible by 2 at least n times.')
    if h // 2**n < 1:
      raise ValueError('num_glow_blocks ({0}) is too large. The image height '
                       '({1}) must be divisible by 2 no more than {2} '
                       'times.'.format(num_glow_blocks, h,
                                       int(np.log(h) / np.log(2.))))
    if w // 2**n < 1:
      raise ValueError('num_glow_blocks ({0}) is too large. The image width '
                       '({1}) must be divisible by 2 no more than {2} '
                       'times.'.format(num_glow_blocks, w,
                                       int(np.log(h) / np.log(2.))))

    # Other things we want to be true:
    # - The number of times we take must be equal to the number of glow blocks.
    if len(grab_after_block) != num_glow_blocks:
      raise ValueError('Length of grab_after_block ({0}) must match the number'
                       'of blocks ({1}).'.format(len(grab_after_block),
                                                 num_glow_blocks))

    self._blockwise_splits = self._get_blockwise_splits(output_shape,
                                                        grab_after_block[::-1])

    # Now check on the values of blockwise splits
    if any([bs[0] < 1 for bs in self._blockwise_splits]):
      first_offender = [bs[0] for bs in self._blockwise_splits].index(True)
      raise ValueError('At at least one exit, you are taking out all of your '
                       'channels, and therefore have no inputs to later blocks.'
                       ' Try setting grab_after_block to a lower value at index'
                       '{}.'.format(first_offender))

    if any(np.isclose(gab, 0) for gab in grab_after_block):
      # Special case: if specifically exiting no channels, then the exit is
      # just an identity bijector.
      pass
    elif any([bs[1] < 1 for bs in self._blockwise_splits]):
      first_offender = [bs[1] for bs in self._blockwise_splits].index(True)
      raise ValueError('At least one of your layers has < 1 output channels. '
                       'This means you set grab_at_block too small. '
                       'Try setting grab_after_block to a larger value at index'
                       '{}.'.format(first_offender))

    # Lets start to build our bijector. We assume that the distribution is 1
    # dimensional. First, lets reshape it to an image.
    glow_chain = [
        reshape.Reshape(
            event_shape_out=[h // 2**n, w // 2**n, c * 4**n],
            event_shape_in=[h * w * c])
    ]

    seedstream = SeedStream(seed=seed, salt='random_beta')

    for i in range(n):

      # This is the shape of the current tensor
      current_shape = (h // 2**n * 2**i, w // 2**n * 2**i, c * 4**(i + 1))

      # This is the shape of the input to both the glow block and exit bijector.
      this_nchan = sum(self._blockwise_splits[i][0:2])
      this_input_shape = (h // 2**n * 2**i, w // 2**n * 2**i, this_nchan)

      glow_chain.append(invert.Invert(ExitBijector(current_shape,
                                                   self._blockwise_splits[i],
                                                   exit_bijector_fn)))

      glow_block = GlowBlock(input_shape=this_input_shape,
                             num_steps=nsteps,
                             coupling_bijector_fn=coupling_bijector_fn,
                             use_actnorm=use_actnorm,
                             seedstream=seedstream)

      if self._blockwise_splits[i][2] == 0:
        # All channels are passed to the RealNVP
        glow_chain.append(glow_block)
      else:
        # Some channels are passed around the block.
        # This is done with the Blockwise bijector.
        glow_chain.append(
            blockwise.Blockwise(
                [glow_block, identity.Identity()],
                [sum(self._blockwise_splits[i][0:2]),
                 self._blockwise_splits[i][2]]))

      # Finally, lets expand the channels into spatial features.
      glow_chain.append(
          Expand(input_shape=[
              h // 2**n * 2**i,
              w // 2**n * 2**i,
              c * 4**n // 4**i,
          ]))

    glow_chain = glow_chain[::-1]
    # To finish off, we initialize the bijector with the chain we've built
    # This way, the rest of the model attributes are taken care of for us.
    super(Glow, self).__init__(
        bijectors=glow_chain, validate_args=validate_args, name=name)