def __init__(self,
                 loc=None,
                 scale_diag=None,
                 scale_identity_multiplier=None,
                 skewness=None,
                 tailweight=None,
                 distribution=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="MultivariateNormalLinearOperator"):
        """Construct VectorSinhArcsinhDiag distribution on `R^k`.

    The arguments `scale_diag` and `scale_identity_multiplier` combine to
    define the diagonal `scale` referred to in this class docstring:

    ```none
    scale = diag(scale_diag + scale_identity_multiplier * ones(k))
    ```

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale_diag: Non-zero, floating-point `Tensor` representing a diagonal
        matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`,
        and characterizes `b`-batches of `k x k` diagonal matrices added to
        `scale`. When both `scale_identity_multiplier` and `scale_diag` are
        `None` then `scale` is the `Identity`.
      scale_identity_multiplier: Non-zero, floating-point `Tensor` representing
        a scale-identity-matrix added to `scale`. May have shape
        `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scale
        `k x k` identity matrices added to `scale`. When both
        `scale_identity_multiplier` and `scale_diag` are `None` then `scale`
        is the `Identity`.
      skewness:  Skewness parameter.  floating-point `Tensor` with shape
        broadcastable with `event_shape`.
      tailweight:  Tailweight parameter.  floating-point `Tensor` with shape
        broadcastable with `event_shape`.
      distribution: `tf.Distribution`-like instance. Distribution from which `k`
        iid samples are used as input to transformation `F`.  Default is
        `tf.distributions.Normal(loc=0., scale=1.)`.
        Must be a scalar-batch, scalar-event distribution.  Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a VectorSinhArcsinhDiag sample and `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if at most `scale_identity_multiplier` is specified.
    """
        parameters = dict(locals())

        with tf.name_scope(name,
                           values=[
                               loc, scale_diag, scale_identity_multiplier,
                               skewness, tailweight
                           ]) as name:
            loc = tf.convert_to_tensor(loc,
                                       name="loc") if loc is not None else loc
            tailweight = 1. if tailweight is None else tailweight
            has_default_skewness = skewness is None
            skewness = 0. if skewness is None else skewness

            # Recall, with Z a random variable,
            #   Y := loc + C * F(Z),
            #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight )
            #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
            #   C := 2 * scale / F_0(2)

            # Construct shapes and 'scale' out of the scale_* and loc kwargs.
            # scale_linop is only an intermediary to:
            #  1. get shapes from looking at loc and the two scale args.
            #  2. combine scale_diag with scale_identity_multiplier, which gives us
            #     'scale', which in turn gives us 'C'.
            scale_linop = distribution_util.make_diag_scale(
                loc=loc,
                scale_diag=scale_diag,
                scale_identity_multiplier=scale_identity_multiplier,
                validate_args=False,
                assert_positive=False)
            batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
                loc, scale_linop)
            # scale_linop.diag_part() is efficient since it is a diag type linop.
            scale_diag_part = scale_linop.diag_part()
            dtype = scale_diag_part.dtype

            if distribution is None:
                distribution = tf.distributions.Normal(
                    loc=tf.zeros([], dtype=dtype),
                    scale=tf.ones([], dtype=dtype),
                    allow_nan_stats=allow_nan_stats)
            else:
                asserts = distribution_util.maybe_check_scalar_distribution(
                    distribution, dtype, validate_args)
                if asserts:
                    scale_diag_part = control_flow_ops.with_dependencies(
                        asserts, scale_diag_part)

            # Make the SAS bijector, 'F'.
            skewness = tf.convert_to_tensor(skewness,
                                            dtype=dtype,
                                            name="skewness")
            tailweight = tf.convert_to_tensor(tailweight,
                                              dtype=dtype,
                                              name="tailweight")
            f = bijectors.SinhArcsinh(skewness=skewness, tailweight=tailweight)
            if has_default_skewness:
                f_noskew = f
            else:
                f_noskew = bijectors.SinhArcsinh(
                    skewness=skewness.dtype.as_numpy_dtype(0.),
                    tailweight=tailweight)

            # Make the Affine bijector, Z --> loc + C * Z.
            c = 2 * scale_diag_part / f_noskew.forward(
                tf.convert_to_tensor(2, dtype=dtype))
            affine = bijectors.Affine(shift=loc,
                                      scale_diag=c,
                                      validate_args=validate_args)

            bijector = bijectors.Chain([affine, f])

            super(VectorSinhArcsinhDiag,
                  self).__init__(distribution=distribution,
                                 bijector=bijector,
                                 batch_shape=batch_shape,
                                 event_shape=event_shape,
                                 validate_args=validate_args,
                                 name=name)
        self._parameters = parameters
        self._loc = loc
        self._scale = scale_linop
        self._tailweight = tailweight
        self._skewness = skewness
    def __init__(self,
                 mix_loc,
                 temperature,
                 distribution,
                 loc=None,
                 scale=None,
                 quadrature_size=8,
                 quadrature_fn=quadrature_scheme_softmaxnormal_quantiles,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="VectorDiffeomixture"):
        """Constructs the VectorDiffeomixture on `R^d`.

    The vector diffeomixture (VDM) approximates the compound distribution

    ```none
    p(x) = int p(x | z) p(z) dz,
    where z is in the K-simplex, and
    p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k])
    ```

    Args:
      mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`.
        In terms of samples, larger `mix_loc[..., k]` ==>
        `Z` is more likely to put more weight on its `kth` component.
      temperature: `float`-like `Tensor`. Broadcastable with `mix_loc`.
        In terms of samples, smaller `temperature` means one component is more
        likely to dominate.  I.e., smaller `temperature` makes the VDM look more
        like a standard mixture of `K` components.
      distribution: `tfp.distributions.Distribution`-like instance. Distribution
        from which `d` iid samples are used as input to the selected affine
        transformation. Must be a scalar-batch, scalar-event distribution.
        Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED`
        or it is a function of non-trainable parameters. WARNING: If you
        backprop through a VectorDiffeomixture sample and the `distribution`
        is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables,
        then the gradient will be incorrect!
      loc: Length-`K` list of `float`-type `Tensor`s. The `k`-th element
        represents the `shift` used for the `k`-th affine transformation.  If
        the `k`-th item is `None`, `loc` is implicitly `0`.  When specified,
        must have shape `[B1, ..., Bb, d]` where `b >= 0` and `d` is the event
        size.
      scale: Length-`K` list of `LinearOperator`s. Each should be
        positive-definite and operate on a `d`-dimensional vector space. The
        `k`-th element represents the `scale` used for the `k`-th affine
        transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`,
        `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices
      quadrature_size: Python `int` scalar representing number of
        quadrature points.  Larger `quadrature_size` means `q_N(x)` better
        approximates `p(x)`.
      quadrature_fn: Python callable taking `normal_loc`, `normal_scale`,
        `quadrature_size`, `validate_args` and returning `tuple(grid, probs)`
        representing the SoftmaxNormal grid and corresponding normalized weight.
        normalized) weight.
        Default value: `quadrature_scheme_softmaxnormal_quantiles`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if `not scale or len(scale) < 2`.
      ValueError: if `len(loc) != len(scale)`
      ValueError: if `quadrature_grid_and_probs is not None` and
        `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
      ValueError: if `validate_args` and any not scale.is_positive_definite.
      TypeError: if any scale.dtype != scale[0].dtype.
      TypeError: if any loc.dtype != scale[0].dtype.
      NotImplementedError: if `len(scale) != 2`.
      ValueError: if `not distribution.is_scalar_batch`.
      ValueError: if `not distribution.is_scalar_event`.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            if not scale or len(scale) < 2:
                raise ValueError(
                    "Must specify list (or list-like object) of scale "
                    "LinearOperators, one for each component with "
                    "num_component >= 2.")

            if loc is None:
                loc = [None] * len(scale)

            if len(loc) != len(scale):
                raise ValueError("loc/scale must be same-length lists "
                                 "(or same-length list-like objects).")

            dtype = dtype_util.base_dtype(scale[0].dtype)

            loc = [
                tf.convert_to_tensor(
                    value=loc_, dtype=dtype, name="loc{}".format(k))
                if loc_ is not None else None for k, loc_ in enumerate(loc)
            ]

            for k, scale_ in enumerate(scale):
                if validate_args and not scale_.is_positive_definite:
                    raise ValueError(
                        "scale[{}].is_positive_definite = {} != True".format(
                            k, scale_.is_positive_definite))
                if dtype_util.base_dtype(scale_.dtype) != dtype:
                    raise TypeError(
                        "dtype mismatch; scale[{}].base_dtype=\"{}\" != \"{}\""
                        .format(k, dtype_util.name(scale_.dtype),
                                dtype_util.name(dtype)))

            self._endpoint_affine = [
                affine_linear_operator_bijector.AffineLinearOperator(  # pylint: disable=g-complex-comprehension
                    shift=loc_,
                    scale=scale_,
                    validate_args=validate_args,
                    name="endpoint_affine_{}".format(k))
                for k, (loc_, scale_) in enumerate(zip(loc, scale))
            ]

            # TODO(jvdillon): Remove once we support k-mixtures.
            # We make this assertion here because otherwise `grid` would need to be a
            # vector not a scalar.
            if len(scale) != 2:
                raise NotImplementedError(
                    "Currently only bimixtures are supported; "
                    "len(scale)={} is not 2.".format(len(scale)))

            mix_loc = tf.convert_to_tensor(value=mix_loc,
                                           dtype=dtype,
                                           name="mix_loc")
            temperature = tf.convert_to_tensor(value=temperature,
                                               dtype=dtype,
                                               name="temperature")
            self._grid, probs = tuple(
                quadrature_fn(mix_loc / temperature, 1. / temperature,
                              quadrature_size, validate_args))

            # Note: by creating the logits as `log(prob)` we ensure that
            # `self.mixture_distribution.logits` is equivalent to
            # `math_ops.log(self.mixture_distribution.probs)`.
            self._mixture_distribution = categorical.Categorical(
                logits=tf.math.log(probs),
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats)

            asserts = distribution_util.maybe_check_scalar_distribution(
                distribution, dtype, validate_args)
            if asserts:
                self._grid = distribution_util.with_dependencies(
                    asserts, self._grid)
            self._distribution = distribution

            self._interpolated_affine = [
                affine_linear_operator_bijector.AffineLinearOperator(  # pylint: disable=g-complex-comprehension
                    shift=loc_,
                    scale=scale_,
                    validate_args=validate_args,
                    name="interpolated_affine_{}".format(k))
                for k, (loc_, scale_) in enumerate(
                    zip(interpolate_loc(self._grid, loc),
                        interpolate_scale(self._grid, scale)))
            ]

            [
                self._batch_shape_,
                self._batch_shape_tensor_,
                self._event_shape_,
                self._event_shape_tensor_,
            ] = determine_batch_event_shapes(self._grid, self._endpoint_affine)

            super(VectorDiffeomixture, self).__init__(
                dtype=dtype,
                # We hard-code `FULLY_REPARAMETERIZED` because when
                # `validate_args=True` we verify that indeed
                # `distribution.reparameterization_type == FULLY_REPARAMETERIZED`. A
                # distribution which is a function of only non-trainable parameters
                # also implies we can use `FULLY_REPARAMETERIZED`. However, we cannot
                # easily test for that possibility thus we use `validate_args=False`
                # as a "back-door" to allow users a way to use non
                # `FULLY_REPARAMETERIZED` distribution. In such cases IT IS THE USERS
                # RESPONSIBILITY to verify that the base distribution is a function of
                # non-trainable parameters.
                reparameterization_type=reparameterization.
                FULLY_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                graph_parents=(
                    distribution._graph_parents  # pylint: disable=protected-access
                    + [loc_ for loc_ in loc if loc_ is not None] +
                    [p for scale_ in scale for p in scale_.graph_parents]),  # pylint: disable=g-complex-comprehension
                name=name)
Example #3
0
    def __init__(self,
                 loc,
                 scale,
                 skewness=None,
                 tailweight=None,
                 distribution=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="SinhArcsinh"):
        """Construct SinhArcsinh distribution on `(-inf, inf)`.

    Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape
    (indexing batch dimensions).  They must all have the same `dtype`.

    Args:
      loc: Floating-point `Tensor`.
      scale:  `Tensor` of same `dtype` as `loc`.
      skewness:  Skewness parameter.  Default is `0.0` (no skew).
      tailweight:  Tailweight parameter. Default is `1.0` (unchanged tailweight)
      distribution: `tf.Distribution`-like instance. Distribution that is
        transformed to produce this distribution.
        Default is `tfd.Normal(0., 1.)`.
        Must be a scalar-batch, scalar-event distribution.  Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a `SinhArcsinh` sample and `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())

        with tf.compat.v2.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale, skewness, tailweight],
                                            tf.float32)
            loc = tf.convert_to_tensor(value=loc, name="loc", dtype=dtype)
            scale = tf.convert_to_tensor(value=scale,
                                         name="scale",
                                         dtype=dtype)
            tailweight = 1. if tailweight is None else tailweight
            has_default_skewness = skewness is None
            skewness = 0. if skewness is None else skewness
            tailweight = tf.convert_to_tensor(value=tailweight,
                                              name="tailweight",
                                              dtype=dtype)
            skewness = tf.convert_to_tensor(value=skewness,
                                            name="skewness",
                                            dtype=dtype)

            batch_shape = distribution_util.get_broadcast_shape(
                loc, scale, tailweight, skewness)

            # Recall, with Z a random variable,
            #   Y := loc + C * F(Z),
            #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight )
            #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
            #   C := 2 * scale / F_0(2)
            if distribution is None:
                distribution = normal.Normal(loc=tf.zeros([], dtype=dtype),
                                             scale=tf.ones([], dtype=dtype),
                                             allow_nan_stats=allow_nan_stats)
            else:
                asserts = distribution_util.maybe_check_scalar_distribution(
                    distribution, dtype, validate_args)
                if asserts:
                    loc = distribution_util.with_dependencies(asserts, loc)

            # Make the SAS bijector, 'F'.
            f = sinh_arcsinh_bijector.SinhArcsinh(skewness=skewness,
                                                  tailweight=tailweight)
            if has_default_skewness:
                f_noskew = f
            else:
                f_noskew = sinh_arcsinh_bijector.SinhArcsinh(
                    skewness=skewness.dtype.as_numpy_dtype(0.),
                    tailweight=tailweight)

            # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2))
            c = 2 * scale / f_noskew.forward(
                tf.convert_to_tensor(value=2, dtype=dtype))
            affine = affine_scalar_bijector.AffineScalar(
                shift=loc, scale=c, validate_args=validate_args)

            bijector = chain_bijector.Chain([affine, f])

            super(SinhArcsinh, self).__init__(distribution=distribution,
                                              bijector=bijector,
                                              batch_shape=batch_shape,
                                              validate_args=validate_args,
                                              name=name)
        self._parameters = parameters
        self._loc = loc
        self._scale = scale
        self._tailweight = tailweight
        self._skewness = skewness
Example #4
0
  def __init__(self,
               loc,
               scale,
               skewness=None,
               tailweight=None,
               distribution=None,
               validate_args=False,
               allow_nan_stats=True,
               name="SinhArcsinh"):
    """Construct SinhArcsinh distribution on `(-inf, inf)`.

    Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape
    (indexing batch dimensions).  They must all have the same `dtype`.

    Args:
      loc: Floating-point `Tensor`.
      scale:  `Tensor` of same `dtype` as `loc`.
      skewness:  Skewness parameter.  Default is `0.0` (no skew).
      tailweight:  Tailweight parameter. Default is `1.0` (unchanged tailweight)
      distribution: `tf.Distribution`-like instance. Distribution that is
        transformed to produce this distribution.
        Default is `tf.distributions.Normal(0., 1.)`.
        Must be a scalar-batch, scalar-event distribution.  Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a `SinhArcsinh` sample and `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = dict(locals())

    with tf.name_scope(name, values=[loc, scale, skewness, tailweight]) as name:
      loc = tf.convert_to_tensor(loc, name="loc")
      dtype = loc.dtype
      scale = tf.convert_to_tensor(scale, name="scale", dtype=dtype)
      tailweight = 1. if tailweight is None else tailweight
      has_default_skewness = skewness is None
      skewness = 0. if skewness is None else skewness
      tailweight = tf.convert_to_tensor(
          tailweight, name="tailweight", dtype=dtype)
      skewness = tf.convert_to_tensor(skewness, name="skewness", dtype=dtype)

      batch_shape = distribution_util.get_broadcast_shape(
          loc, scale, tailweight, skewness)

      # Recall, with Z a random variable,
      #   Y := loc + C * F(Z),
      #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight )
      #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
      #   C := 2 * scale / F_0(2)
      if distribution is None:
        distribution = tf.distributions.Normal(
            loc=tf.zeros([], dtype=dtype),
            scale=tf.ones([], dtype=dtype),
            allow_nan_stats=allow_nan_stats)
      else:
        asserts = distribution_util.maybe_check_scalar_distribution(
            distribution, dtype, validate_args)
        if asserts:
          loc = control_flow_ops.with_dependencies(asserts, loc)

      # Make the SAS bijector, 'F'.
      f = bijectors.SinhArcsinh(
          skewness=skewness, tailweight=tailweight)
      if has_default_skewness:
        f_noskew = f
      else:
        f_noskew = bijectors.SinhArcsinh(
            skewness=skewness.dtype.as_numpy_dtype(0.),
            tailweight=tailweight)

      # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2))
      c = 2 * scale / f_noskew.forward(tf.convert_to_tensor(2, dtype=dtype))
      affine = bijectors.AffineScalar(
          shift=loc,
          scale=c,
          validate_args=validate_args)

      bijector = bijectors.Chain([affine, f])

      super(SinhArcsinh, self).__init__(
          distribution=distribution,
          bijector=bijector,
          batch_shape=batch_shape,
          validate_args=validate_args,
          name=name)
    self._parameters = parameters
    self._loc = loc
    self._scale = scale
    self._tailweight = tailweight
    self._skewness = skewness
  def __init__(self,
               loc=None,
               scale_diag=None,
               scale_identity_multiplier=None,
               skewness=None,
               tailweight=None,
               distribution=None,
               validate_args=False,
               allow_nan_stats=True,
               name="MultivariateNormalLinearOperator"):
    """Construct VectorSinhArcsinhDiag distribution on `R^k`.

    The arguments `scale_diag` and `scale_identity_multiplier` combine to
    define the diagonal `scale` referred to in this class docstring:

    ```none
    scale = diag(scale_diag + scale_identity_multiplier * ones(k))
    ```

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale_diag: Non-zero, floating-point `Tensor` representing a diagonal
        matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`,
        and characterizes `b`-batches of `k x k` diagonal matrices added to
        `scale`. When both `scale_identity_multiplier` and `scale_diag` are
        `None` then `scale` is the `Identity`.
      scale_identity_multiplier: Non-zero, floating-point `Tensor` representing
        a scale-identity-matrix added to `scale`. May have shape
        `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scale
        `k x k` identity matrices added to `scale`. When both
        `scale_identity_multiplier` and `scale_diag` are `None` then `scale`
        is the `Identity`.
      skewness:  Skewness parameter.  floating-point `Tensor` with shape
        broadcastable with `event_shape`.
      tailweight:  Tailweight parameter.  floating-point `Tensor` with shape
        broadcastable with `event_shape`.
      distribution: `tf.Distribution`-like instance. Distribution from which `k`
        iid samples are used as input to transformation `F`.  Default is
        `tfd.Normal(loc=0., scale=1.)`.
        Must be a scalar-batch, scalar-event distribution.  Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a VectorSinhArcsinhDiag sample and `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if at most `scale_identity_multiplier` is specified.
    """
    parameters = dict(locals())

    with tf.name_scope(
        name,
        values=[
            loc, scale_diag, scale_identity_multiplier, skewness, tailweight
        ]) as name:
      dtype = dtype_util.common_dtype(
          [loc, scale_diag, scale_identity_multiplier, skewness, tailweight],
          tf.float32)
      loc = loc if loc is None else tf.convert_to_tensor(
          loc, name="loc", dtype=dtype)
      tailweight = 1. if tailweight is None else tailweight
      has_default_skewness = skewness is None
      skewness = 0. if skewness is None else skewness

      # Recall, with Z a random variable,
      #   Y := loc + C * F(Z),
      #   F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight )
      #   F_0(Z) := Sinh( Arcsinh(Z) * tailweight )
      #   C := 2 * scale / F_0(2)

      # Construct shapes and 'scale' out of the scale_* and loc kwargs.
      # scale_linop is only an intermediary to:
      #  1. get shapes from looking at loc and the two scale args.
      #  2. combine scale_diag with scale_identity_multiplier, which gives us
      #     'scale', which in turn gives us 'C'.
      scale_linop = distribution_util.make_diag_scale(
          loc=loc,
          scale_diag=scale_diag,
          scale_identity_multiplier=scale_identity_multiplier,
          validate_args=False,
          assert_positive=False,
          dtype=dtype)
      batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
          loc, scale_linop)
      # scale_linop.diag_part() is efficient since it is a diag type linop.
      scale_diag_part = scale_linop.diag_part()
      dtype = scale_diag_part.dtype

      if distribution is None:
        distribution = normal.Normal(
            loc=tf.zeros([], dtype=dtype),
            scale=tf.ones([], dtype=dtype),
            allow_nan_stats=allow_nan_stats)
      else:
        asserts = distribution_util.maybe_check_scalar_distribution(
            distribution, dtype, validate_args)
        if asserts:
          scale_diag_part = control_flow_ops.with_dependencies(
              asserts, scale_diag_part)

      # Make the SAS bijector, 'F'.
      skewness = tf.convert_to_tensor(skewness, dtype=dtype, name="skewness")
      tailweight = tf.convert_to_tensor(
          tailweight, dtype=dtype, name="tailweight")
      f = sinh_arcsinh_bijector.SinhArcsinh(
          skewness=skewness, tailweight=tailweight)
      if has_default_skewness:
        f_noskew = f
      else:
        f_noskew = sinh_arcsinh_bijector.SinhArcsinh(
            skewness=skewness.dtype.as_numpy_dtype(0.),
            tailweight=tailweight)

      # Make the Affine bijector, Z --> loc + C * Z.
      c = 2 * scale_diag_part / f_noskew.forward(
          tf.convert_to_tensor(2, dtype=dtype))
      affine = affine_bijector.Affine(
          shift=loc, scale_diag=c, validate_args=validate_args)

      bijector = chain_bijector.Chain([affine, f])

      super(VectorSinhArcsinhDiag, self).__init__(
          distribution=distribution,
          bijector=bijector,
          batch_shape=batch_shape,
          event_shape=event_shape,
          validate_args=validate_args,
          name=name)
    self._parameters = parameters
    self._loc = loc
    self._scale = scale_linop
    self._tailweight = tailweight
    self._skewness = skewness
  def __init__(self,
               mix_loc,
               temperature,
               distribution,
               loc=None,
               scale=None,
               quadrature_size=8,
               quadrature_fn=quadrature_scheme_softmaxnormal_quantiles,
               validate_args=False,
               allow_nan_stats=True,
               name="VectorDiffeomixture"):
    """Constructs the VectorDiffeomixture on `R^d`.

    The vector diffeomixture (VDM) approximates the compound distribution

    ```none
    p(x) = int p(x | z) p(z) dz,
    where z is in the K-simplex, and
    p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k])
    ```

    Args:
      mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`.
        In terms of samples, larger `mix_loc[..., k]` ==>
        `Z` is more likely to put more weight on its `kth` component.
      temperature: `float`-like `Tensor`. Broadcastable with `mix_loc`.
        In terms of samples, smaller `temperature` means one component is more
        likely to dominate.  I.e., smaller `temperature` makes the VDM look more
        like a standard mixture of `K` components.
      distribution: `tfp.distributions.Distribution`-like instance. Distribution
        from which `d` iid samples are used as input to the selected affine
        transformation. Must be a scalar-batch, scalar-event distribution.
        Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED`
        or it is a function of non-trainable parameters. WARNING: If you
        backprop through a VectorDiffeomixture sample and the `distribution`
        is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables,
        then the gradient will be incorrect!
      loc: Length-`K` list of `float`-type `Tensor`s. The `k`-th element
        represents the `shift` used for the `k`-th affine transformation.  If
        the `k`-th item is `None`, `loc` is implicitly `0`.  When specified,
        must have shape `[B1, ..., Bb, d]` where `b >= 0` and `d` is the event
        size.
      scale: Length-`K` list of `LinearOperator`s. Each should be
        positive-definite and operate on a `d`-dimensional vector space. The
        `k`-th element represents the `scale` used for the `k`-th affine
        transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`,
        `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices
      quadrature_size: Python `int` scalar representing number of
        quadrature points.  Larger `quadrature_size` means `q_N(x)` better
        approximates `p(x)`.
      quadrature_fn: Python callable taking `normal_loc`, `normal_scale`,
        `quadrature_size`, `validate_args` and returning `tuple(grid, probs)`
        representing the SoftmaxNormal grid and corresponding normalized weight.
        normalized) weight.
        Default value: `quadrature_scheme_softmaxnormal_quantiles`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if `not scale or len(scale) < 2`.
      ValueError: if `len(loc) != len(scale)`
      ValueError: if `quadrature_grid_and_probs is not None` and
        `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
      ValueError: if `validate_args` and any not scale.is_positive_definite.
      TypeError: if any scale.dtype != scale[0].dtype.
      TypeError: if any loc.dtype != scale[0].dtype.
      NotImplementedError: if `len(scale) != 2`.
      ValueError: if `not distribution.is_scalar_batch`.
      ValueError: if `not distribution.is_scalar_event`.
    """
    parameters = dict(locals())
    with tf.name_scope(name, values=[mix_loc, temperature]) as name:
      if not scale or len(scale) < 2:
        raise ValueError("Must specify list (or list-like object) of scale "
                         "LinearOperators, one for each component with "
                         "num_component >= 2.")

      if loc is None:
        loc = [None]*len(scale)

      if len(loc) != len(scale):
        raise ValueError("loc/scale must be same-length lists "
                         "(or same-length list-like objects).")

      dtype = scale[0].dtype.base_dtype

      loc = [
          tf.convert_to_tensor(loc_, dtype=dtype, name="loc{}".format(k))
          if loc_ is not None else None for k, loc_ in enumerate(loc)
      ]

      for k, scale_ in enumerate(scale):
        if validate_args and not scale_.is_positive_definite:
          raise ValueError("scale[{}].is_positive_definite = {} != True".format(
              k, scale_.is_positive_definite))
        if scale_.dtype.base_dtype != dtype:
          raise TypeError(
              "dtype mismatch; scale[{}].base_dtype=\"{}\" != \"{}\"".format(
                  k, scale_.dtype.base_dtype.name, dtype.name))

      self._endpoint_affine = [
          affine_linear_operator_bijector.AffineLinearOperator(
              shift=loc_, scale=scale_,
              validate_args=validate_args,
              name="endpoint_affine_{}".format(k))
          for k, (loc_, scale_) in enumerate(zip(loc, scale))]

      # TODO(jvdillon): Remove once we support k-mixtures.
      # We make this assertion here because otherwise `grid` would need to be a
      # vector not a scalar.
      if len(scale) != 2:
        raise NotImplementedError("Currently only bimixtures are supported; "
                                  "len(scale)={} is not 2.".format(len(scale)))

      mix_loc = tf.convert_to_tensor(mix_loc, dtype=dtype, name="mix_loc")
      temperature = tf.convert_to_tensor(
          temperature, dtype=dtype, name="temperature")
      self._grid, probs = tuple(quadrature_fn(
          mix_loc / temperature,
          1. / temperature,
          quadrature_size,
          validate_args))

      # Note: by creating the logits as `log(prob)` we ensure that
      # `self.mixture_distribution.logits` is equivalent to
      # `math_ops.log(self.mixture_distribution.probs)`.
      self._mixture_distribution = categorical.Categorical(
          logits=tf.log(probs),
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats)

      asserts = distribution_util.maybe_check_scalar_distribution(
          distribution, dtype, validate_args)
      if asserts:
        self._grid = control_flow_ops.with_dependencies(
            asserts, self._grid)
      self._distribution = distribution

      self._interpolated_affine = [
          affine_linear_operator_bijector.AffineLinearOperator(
              shift=loc_, scale=scale_,
              validate_args=validate_args,
              name="interpolated_affine_{}".format(k))
          for k, (loc_, scale_) in enumerate(zip(
              interpolate_loc(self._grid, loc),
              interpolate_scale(self._grid, scale)))]

      [
          self._batch_shape_,
          self._batch_shape_tensor_,
          self._event_shape_,
          self._event_shape_tensor_,
      ] = determine_batch_event_shapes(self._grid,
                                       self._endpoint_affine)

      super(VectorDiffeomixture, self).__init__(
          dtype=dtype,
          # We hard-code `FULLY_REPARAMETERIZED` because when
          # `validate_args=True` we verify that indeed
          # `distribution.reparameterization_type == FULLY_REPARAMETERIZED`. A
          # distribution which is a function of only non-trainable parameters
          # also implies we can use `FULLY_REPARAMETERIZED`. However, we cannot
          # easily test for that possibility thus we use `validate_args=False`
          # as a "back-door" to allow users a way to use non
          # `FULLY_REPARAMETERIZED` distribution. In such cases IT IS THE USERS
          # RESPONSIBILITY to verify that the base distribution is a function of
          # non-trainable parameters.
          reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          graph_parents=(
              distribution._graph_parents  # pylint: disable=protected-access
              + [loc_ for loc_ in loc if loc_ is not None] +
              [p for scale_ in scale for p in scale_.graph_parents]),
          name=name)