Esempio n. 1
0
    def _create_scale_operator(self, identity_multiplier, diag, tril,
                               perturb_diag, perturb_factor, shift,
                               validate_args):
        """Construct `scale` from various components.

    Args:
      identity_multiplier: floating point rank 0 `Tensor` representing a scaling
        done to the identity matrix.
      diag: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape [N1, N2, ...  k], which represents a k x k
        diagonal matrix.
      tril: Floating-point `Tensor` representing the diagonal matrix.
        `scale_tril` has shape [N1, N2, ...  k], which represents a k x k lower
        triangular matrix.
      perturb_diag: Floating-point `Tensor` representing the diagonal matrix of
        the low rank update.
      perturb_factor: Floating-point `Tensor` representing factor matrix.
      shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.

    Returns:
      scale. In the case of scaling by a constant, scale is a
      floating point `Tensor`. Otherwise, scale is a `LinearOperator`.

    Raises:
      ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`.
    """
        identity_multiplier = _as_tensor(identity_multiplier,
                                         "identity_multiplier")
        diag = _as_tensor(diag, "diag")
        tril = _as_tensor(tril, "tril")
        perturb_diag = _as_tensor(perturb_diag, "perturb_diag")
        perturb_factor = _as_tensor(perturb_factor, "perturb_factor")

        # If possible, use the low rank update to infer the shape of
        # the identity matrix, when scale represents a scaled identity matrix
        # with a low rank update.
        shape_hint = None
        if perturb_factor is not None:
            shape_hint = distribution_util.dimension_size(perturb_factor,
                                                          axis=-2)

        if self._is_only_identity_multiplier:
            if validate_args:
                return control_flow_ops.with_dependencies([
                    check_ops.assert_none_equal(
                        identity_multiplier,
                        array_ops.zeros([], identity_multiplier.dtype),
                        ["identity_multiplier should be non-zero."])
                ], identity_multiplier)
            return identity_multiplier

        scale = distribution_util.make_tril_scale(
            loc=shift,
            scale_tril=tril,
            scale_diag=diag,
            scale_identity_multiplier=identity_multiplier,
            validate_args=validate_args,
            assert_positive=False,
            shape_hint=shape_hint)

        if perturb_factor is not None:
            return linalg.LinearOperatorLowRankUpdate(
                scale,
                u=perturb_factor,
                diag_update=perturb_diag,
                is_diag_update_positive=perturb_diag is None,
                is_non_singular=True,  # Implied by is_positive_definite=True.
                is_self_adjoint=True,
                is_positive_definite=True,
                is_square=True)

        return scale
Esempio n. 2
0
  def __init__(self,
               loc,
               scale,
               quadrature_grid_and_probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="PoissonLogNormalQuadratureCompound"):
    """Constructs the PoissonLogNormalQuadratureCompound on `R**k`.

    Args:
      loc: `float`-like (batch of) scalar `Tensor`; the location parameter of
        the LogNormal prior.
      scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
        the LogNormal prior.
      quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
        representing the sample points and the corresponding (possibly
        normalized) weight.  When `None`, defaults to:
        `np.polynomial.hermite.hermgauss(deg=8)`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if `loc.dtype != scale[0].dtype`.
    """
    parameters = locals()
    with ops.name_scope(name, values=[loc, scale]):
      loc = ops.convert_to_tensor(loc, name="loc")
      self._loc = loc

      scale = ops.convert_to_tensor(scale, name="scale")
      self._scale = scale

      dtype = loc.dtype.base_dtype
      if dtype != scale.dtype.base_dtype:
        raise TypeError(
            "loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".format(
                loc.dtype.name, scale.dtype.name))

      grid, probs = distribution_util.process_quadrature_grid_and_probs(
          quadrature_grid_and_probs, dtype, validate_args)
      self._quadrature_grid = grid
      self._quadrature_probs = probs
      self._quadrature_size = distribution_util.dimension_size(probs, axis=0)

      self._mixture_distribution = categorical_lib.Categorical(
          logits=math_ops.log(self._quadrature_probs),
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats)

      # The following maps the broadcast of `loc` and `scale` to each grid
      # point, i.e., we are creating several log-rates that correspond to the
      # different Gauss-Hermite quadrature points and (possible) batches of
      # `loc` and `scale`.
      self._log_rate = (loc[..., array_ops.newaxis]
                        + np.sqrt(2.) * scale[..., array_ops.newaxis] * grid)

      self._distribution = poisson_lib.Poisson(
          log_rate=self._log_rate,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats)

      super(PoissonLogNormalQuadratureCompound, self).__init__(
          dtype=dtype,
          reparameterization_type=distribution_lib.NOT_REPARAMETERIZED,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          graph_parents=[loc, scale],
          name=name)
Esempio n. 3
0
  def __init__(self,
               loc=None,
               scale_tril=None,
               validate_args=False,
               allow_nan_stats=True,
               name="MultivariateNormalTriL"):
    """Construct Multivariate Normal distribution on `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by the last dimension of `loc` or the last
    dimension of the matrix implied by `scale`.

    Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is:

    ```none
    scale = scale_tril
    ```

    where `scale_tril` is lower-triangular `k x k` matrix with non-zero
    diagonal, i.e., `tf.diag_part(scale_tril) != 0`.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale_tril: Floating-point, lower-triangular `Tensor` with non-zero
        diagonal elements. `scale_tril` has shape `[B1, ..., Bb, k, k]` where
        `b >= 0` and `k` is the event size.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if neither `loc` nor `scale_tril` are specified.
    """
    parameters = locals()
    def _convert_to_tensor(x, name):
      return None if x is None else ops.convert_to_tensor(x, name=name)
    if loc is None and scale_tril is None:
      raise ValueError("Must specify one or both of `loc`, `scale_tril`.")
    with ops.name_scope(name):
      with ops.name_scope("init", values=[loc, scale_tril]):
        loc = _convert_to_tensor(loc, name="loc")
        scale_tril = _convert_to_tensor(scale_tril, name="scale_tril")
        if scale_tril is None:
          scale = linalg.LinearOperatorIdentity(
              num_rows=distribution_util.dimension_size(loc, -1),
              dtype=loc.dtype,
              is_self_adjoint=True,
              is_positive_definite=True,
              assert_proper_shapes=validate_args)
        else:
          if validate_args:
            scale_tril = control_flow_ops.with_dependencies([
                # TODO(b/35157376): Use `assert_none_equal` once it exists.
                check_ops.assert_greater(
                    math_ops.abs(array_ops.matrix_diag_part(scale_tril)),
                    array_ops.zeros([], scale_tril.dtype),
                    message="`scale_tril` must have non-zero diagonal"),
            ], scale_tril)
          scale = linalg.LinearOperatorTriL(
              scale_tril,
              is_non_singular=True,
              is_self_adjoint=False,
              is_positive_definite=False)
    super(MultivariateNormalTriL, self).__init__(
        loc=loc,
        scale=scale,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        name=name)
    self._parameters = parameters
    def __init__(self,
                 loc,
                 scale,
                 quadrature_grid_and_probs=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="PoissonLogNormalQuadratureCompound"):
        """Constructs the PoissonLogNormalQuadratureCompound on `R**k`.

    Args:
      loc: `float`-like (batch of) scalar `Tensor`; the location parameter of
        the LogNormal prior.
      scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
        the LogNormal prior.
      quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
        representing the sample points and the corresponding (possibly
        normalized) weight.  When `None`, defaults to:
        `np.polynomial.hermite.hermgauss(deg=8)`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if `loc.dtype != scale[0].dtype`.
    """
        parameters = locals()
        with ops.name_scope(name, values=[loc, scale]):
            loc = ops.convert_to_tensor(loc, name="loc")
            self._loc = loc

            scale = ops.convert_to_tensor(scale, name="scale")
            self._scale = scale

            dtype = loc.dtype.base_dtype
            if dtype != scale.dtype.base_dtype:
                raise TypeError(
                    "loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".
                    format(loc.dtype.name, scale.dtype.name))

            grid, probs = distribution_util.process_quadrature_grid_and_probs(
                quadrature_grid_and_probs, dtype, validate_args)
            self._quadrature_grid = grid
            self._quadrature_probs = probs
            self._quadrature_size = distribution_util.dimension_size(probs,
                                                                     axis=0)

            self._mixture_distribution = categorical_lib.Categorical(
                logits=math_ops.log(self._quadrature_probs),
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats)

            # The following maps the broadcast of `loc` and `scale` to each grid
            # point, i.e., we are creating several log-rates that correspond to the
            # different Gauss-Hermite quadrature points and (possible) batches of
            # `loc` and `scale`.
            self._log_rate = (
                loc[..., array_ops.newaxis] +
                np.sqrt(2.) * scale[..., array_ops.newaxis] * grid)

            self._distribution = poisson_lib.Poisson(
                log_rate=self._log_rate,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats)

            super(PoissonLogNormalQuadratureCompound, self).__init__(
                dtype=dtype,
                reparameterization_type=distribution_lib.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                graph_parents=[loc, scale],
                name=name)
Esempio n. 5
0
  def _create_scale_operator(self, identity_multiplier, diag, tril,
                             perturb_diag, perturb_factor, shift,
                             validate_args):
    """Construct `scale` from various components.

    Args:
      identity_multiplier: floating point rank 0 `Tensor` representing a scaling
        done to the identity matrix.
      diag: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape [N1, N2, ...  k], which represents a k x k
        diagonal matrix.
      tril: Floating-point `Tensor` representing the diagonal matrix.
        `scale_tril` has shape [N1, N2, ...  k], which represents a k x k lower
        triangular matrix.
      perturb_diag: Floating-point `Tensor` representing the diagonal matrix of
        the low rank update.
      perturb_factor: Floating-point `Tensor` representing factor matrix.
      shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.

    Returns:
      scale. In the case of scaling by a constant, scale is a
      floating point `Tensor`. Otherwise, scale is a `LinearOperator`.

    Raises:
      ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`.
    """
    identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier")
    diag = _as_tensor(diag, "diag")
    tril = _as_tensor(tril, "tril")
    perturb_diag = _as_tensor(perturb_diag, "perturb_diag")
    perturb_factor = _as_tensor(perturb_factor, "perturb_factor")

    # If possible, use the low rank update to infer the shape of
    # the identity matrix, when scale represents a scaled identity matrix
    # with a low rank update.
    shape_hint = None
    if perturb_factor is not None:
      shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2)

    if self._is_only_identity_multiplier:
      if validate_args:
        return control_flow_ops.with_dependencies(
            [check_ops.assert_none_equal(
                identity_multiplier,
                array_ops.zeros([], identity_multiplier.dtype),
                ["identity_multiplier should be non-zero."])],
            identity_multiplier)
      return identity_multiplier

    scale = distribution_util.make_tril_scale(
        loc=shift,
        scale_tril=tril,
        scale_diag=diag,
        scale_identity_multiplier=identity_multiplier,
        validate_args=validate_args,
        assert_positive=False,
        shape_hint=shape_hint)

    if perturb_factor is not None:
      return linalg.LinearOperatorUDVHUpdate(
          scale,
          u=perturb_factor,
          diag_update=perturb_diag,
          is_diag_update_positive=perturb_diag is None,
          is_non_singular=True,  # Implied by is_positive_definite=True.
          is_self_adjoint=True,
          is_positive_definite=True,
          is_square=True)

    return scale
Esempio n. 6
0
  def __init__(self,
               mix_loc,
               mix_scale,
               distribution,
               loc=None,
               scale=None,
               quadrature_grid_and_probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="VectorDiffeomixture"):
    """Constructs the VectorDiffeomixture on `R**k`.

    Args:
      mix_loc: `float`-like `Tensor`. Represents the `location` parameter of the
        SoftmaxNormal used for selecting one of the `K` affine transformations.
      mix_scale: `float`-like `Tensor`. Represents the `scale` parameter of the
        SoftmaxNormal used for selecting one of the `K` affine transformations.
      distribution: `tf.Distribution`-like instance. Distribution from which `d`
        iid samples are used as input to the selected affine transformation.
        Must be a scalar-batch, scalar-event distribution.  Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a VectorDiffeomixture sample and the `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      loc: Length-`K` list of `float`-type `Tensor`s. The `k`-th element
        represents the `shift` used for the `k`-th affine transformation.  If
        the `k`-th item is `None`, `loc` is implicitly `0`.  When specified,
        must have shape `[B1, ..., Bb, d]` where `b >= 0` and `d` is the event
        size.
      scale: Length-`K` list of `LinearOperator`s. Each should be
        positive-definite and operate on a `d`-dimensional vector space. The
        `k`-th element represents the `scale` used for the `k`-th affine
        transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`,
        `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices
      quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
        representing the sample points and the corresponding (possibly
        normalized) weight.  When `None`, defaults to:
        `np.polynomial.hermite.hermgauss(deg=8)`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if `not scale or len(scale) < 2`.
      ValueError: if `len(loc) != len(scale)`
      ValueError: if `quadrature_grid_and_probs is not None` and
        `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
      ValueError: if `validate_args` and any not scale.is_positive_definite.
      TypeError: if any scale.dtype != scale[0].dtype.
      TypeError: if any loc.dtype != scale[0].dtype.
      NotImplementedError: if `len(scale) != 2`.
      ValueError: if `not distribution.is_scalar_batch`.
      ValueError: if `not distribution.is_scalar_event`.
    """
    parameters = locals()
    with ops.name_scope(name, values=[mix_loc, mix_scale]):
      if not scale or len(scale) < 2:
        raise ValueError("Must specify list (or list-like object) of scale "
                         "LinearOperators, one for each component with "
                         "num_component >= 2.")

      if loc is None:
        loc = [None]*len(scale)

      if len(loc) != len(scale):
        raise ValueError("loc/scale must be same-length lists "
                         "(or same-length list-like objects).")

      dtype = scale[0].dtype.base_dtype

      loc = [ops.convert_to_tensor(loc_, dtype=dtype, name="loc{}".format(k))
             if loc_ is not None else None
             for k, loc_ in enumerate(loc)]

      for k, scale_ in enumerate(scale):
        if validate_args and not scale_.is_positive_definite:
          raise ValueError("scale[{}].is_positive_definite = {} != True".format(
              k, scale_.is_positive_definite))
        if scale_.dtype.base_dtype != dtype:
          raise TypeError(
              "dtype mismatch; scale[{}].base_dtype=\"{}\" != \"{}\"".format(
                  k, scale_.dtype.base_dtype.name, dtype.name))

      self._endpoint_affine = [
          AffineLinearOperator(shift=loc_,
                               scale=scale_,
                               event_ndims=1,
                               validate_args=validate_args,
                               name="endpoint_affine_{}".format(k))
          for k, (loc_, scale_) in enumerate(zip(loc, scale))]

      # TODO(jvdillon): Remove once we support k-mixtures.
      # We make this assertion here because otherwise `grid` would need to be a
      # vector not a scalar.
      if len(scale) != 2:
        raise NotImplementedError("Currently only bimixtures are supported; "
                                  "len(scale)={} is not 2.".format(len(scale)))

      grid, probs = distribution_util.process_quadrature_grid_and_probs(
          quadrature_grid_and_probs, dtype, validate_args)
      self._quadrature_grid = grid
      self._quadrature_probs = probs
      self._quadrature_size = distribution_util.dimension_size(probs, axis=0)

      # Note: by creating the logits as `log(prob)` we ensure that
      # `self.mixture_distribution.logits` is equivalent to
      # `math_ops.log(self.mixture_distribution.probs)`.
      self._mixture_distribution = categorical_lib.Categorical(
          logits=math_ops.log(probs),
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats)

      mix_loc = maybe_check_mix_param(
          mix_loc, "mix_loc", dtype, validate_args)
      mix_scale = maybe_check_mix_param(
          mix_scale, "mix_scale", dtype, validate_args)

      asserts = distribution_util.maybe_check_scalar_distribution(
          distribution, dtype, validate_args)
      if asserts:
        mix_loc = control_flow_ops.with_dependencies(asserts, mix_loc)
      self._distribution = distribution

      # shape: [B, deg]
      self._interpolate_weight = math_ops.sigmoid(
          mix_loc
          + np.sqrt(2.) * mix_scale * grid)

      self._interpolated_affine = [
          AffineLinearOperator(shift=loc_,
                               scale=scale_,
                               event_ndims=1,
                               validate_args=validate_args,
                               name="interpolated_affine_{}".format(k))
          for k, (loc_, scale_) in enumerate(zip(
              interpolate_loc(self._quadrature_size,
                              self._interpolate_weight,
                              loc),
              interpolate_scale(self._quadrature_size,
                                self._interpolate_weight,
                                scale)))]

      self._batch_shape_, self._event_shape_ = determine_batch_event_shapes(
          mix_loc, mix_scale, self._endpoint_affine)

      super(VectorDiffeomixture, self).__init__(
          dtype=dtype,
          # We hard-code `FULLY_REPARAMETERIZED` because when
          # `validate_args=True` we verify that indeed
          # `distribution.reparameterization_type == FULLY_REPARAMETERIZED`. A
          # distribution which is a function of only non-trainable parameters
          # also implies we can use `FULLY_REPARAMETERIZED`. However, we cannot
          # easily test for that possibility thus we use `validate_args=False`
          # as a "back-door" to allow users a way to use non
          # `FULLY_REPARAMETERIZED` distribution. In such cases IT IS THE USERS
          # RESPONSIBILITY to verify that the base distribution is a function of
          # non-trainable parameters.
          reparameterization_type=distribution_lib.FULLY_REPARAMETERIZED,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          graph_parents=(
              [mix_loc, mix_scale]
              + distribution._graph_parents  # pylint: disable=protected-access
              + [loc_ for loc_ in loc if loc_ is not None]
              + [p for scale_ in scale for p in scale_.graph_parents]),
          name=name)
Esempio n. 7
0
    def __init__(self,
                 loc=None,
                 scale_tril=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="MultivariateNormalTriL"):
        """Construct Multivariate Normal distribution on `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by the last dimension of `loc` or the last
    dimension of the matrix implied by `scale`.

    Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is:

    ```none
    scale = scale_tril
    ```

    where `scale_tril` is lower-triangular `k x k` matrix with non-zero
    diagonal, i.e., `tf.diag_part(scale_tril) != 0`.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale_tril: Floating-point, lower-triangular `Tensor` with non-zero
        diagonal elements. `scale_tril` has shape `[B1, ..., Bb, k, k]` where
        `b >= 0` and `k` is the event size.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if neither `loc` nor `scale_tril` are specified.
    """
        parameters = locals()

        def _convert_to_tensor(x, name):
            return None if x is None else ops.convert_to_tensor(x, name=name)

        if loc is None and scale_tril is None:
            raise ValueError(
                "Must specify one or both of `loc`, `scale_tril`.")
        with ops.name_scope(name):
            with ops.name_scope("init", values=[loc, scale_tril]):
                loc = _convert_to_tensor(loc, name="loc")
                scale_tril = _convert_to_tensor(scale_tril, name="scale_tril")
                if scale_tril is None:
                    scale = linalg.LinearOperatorIdentity(
                        num_rows=distribution_util.dimension_size(loc, -1),
                        dtype=loc.dtype,
                        is_self_adjoint=True,
                        is_positive_definite=True,
                        assert_proper_shapes=validate_args)
                else:
                    if validate_args:
                        scale_tril = control_flow_ops.with_dependencies(
                            [
                                # TODO(b/35157376): Use `assert_none_equal` once it exists.
                                check_ops.assert_greater(
                                    math_ops.abs(
                                        array_ops.matrix_diag_part(
                                            scale_tril)),
                                    array_ops.zeros([], scale_tril.dtype),
                                    message=
                                    "`scale_tril` must have non-zero diagonal"
                                ),
                            ],
                            scale_tril)
                    scale = linalg.LinearOperatorTriL(
                        scale_tril,
                        is_non_singular=True,
                        is_self_adjoint=False,
                        is_positive_definite=False)
        super(MultivariateNormalTriL,
              self).__init__(loc=loc,
                             scale=scale,
                             validate_args=validate_args,
                             allow_nan_stats=allow_nan_stats,
                             name=name)
        self._parameters = parameters
    def __init__(self,
                 mix_loc,
                 mix_scale,
                 distribution,
                 loc=None,
                 scale=None,
                 quadrature_grid_and_probs=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="VectorDiffeomixture"):
        """Constructs the VectorDiffeomixture on `R**k`.

    Args:
      mix_loc: `float`-like `Tensor`. Represents the `location` parameter of the
        SoftmaxNormal used for selecting one of the `K` affine transformations.
      mix_scale: `float`-like `Tensor`. Represents the `scale` parameter of the
        SoftmaxNormal used for selecting one of the `K` affine transformations.
      distribution: `tf.Distribution`-like instance. Distribution from which `d`
        iid samples are used as input to the selected affine transformation.
        Must be a scalar-batch, scalar-event distribution.  Typically
        `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
        a function of non-trainable parameters. WARNING: If you backprop through
        a VectorDiffeomixture sample and the `distribution` is not
        `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then
        the gradient will be incorrect!
      loc: Length-`K` list of `float`-type `Tensor`s. The `k`-th element
        represents the `shift` used for the `k`-th affine transformation.  If
        the `k`-th item is `None`, `loc` is implicitly `0`.  When specified,
        must have shape `[B1, ..., Bb, d]` where `b >= 0` and `d` is the event
        size.
      scale: Length-`K` list of `LinearOperator`s. Each should be
        positive-definite and operate on a `d`-dimensional vector space. The
        `k`-th element represents the `scale` used for the `k`-th affine
        transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`,
        `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices
      quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
        representing the sample points and the corresponding (possibly
        normalized) weight.  When `None`, defaults to:
        `np.polynomial.hermite.hermgauss(deg=8)`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if `not scale or len(scale) < 2`.
      ValueError: if `len(loc) != len(scale)`
      ValueError: if `quadrature_grid_and_probs is not None` and
        `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
      ValueError: if `validate_args` and any not scale.is_positive_definite.
      TypeError: if any scale.dtype != scale[0].dtype.
      TypeError: if any loc.dtype != scale[0].dtype.
      NotImplementedError: if `len(scale) != 2`.
      ValueError: if `not distribution.is_scalar_batch`.
      ValueError: if `not distribution.is_scalar_event`.
    """
        parameters = locals()
        with ops.name_scope(name, values=[mix_loc, mix_scale]):
            if not scale or len(scale) < 2:
                raise ValueError(
                    "Must specify list (or list-like object) of scale "
                    "LinearOperators, one for each component with "
                    "num_component >= 2.")

            if loc is None:
                loc = [None] * len(scale)

            if len(loc) != len(scale):
                raise ValueError("loc/scale must be same-length lists "
                                 "(or same-length list-like objects).")

            dtype = scale[0].dtype.base_dtype

            loc = [
                ops.convert_to_tensor(
                    loc_, dtype=dtype, name="loc{}".format(k))
                if loc_ is not None else None for k, loc_ in enumerate(loc)
            ]

            for k, scale_ in enumerate(scale):
                if validate_args and not scale_.is_positive_definite:
                    raise ValueError(
                        "scale[{}].is_positive_definite = {} != True".format(
                            k, scale_.is_positive_definite))
                if scale_.dtype.base_dtype != dtype:
                    raise TypeError(
                        "dtype mismatch; scale[{}].base_dtype=\"{}\" != \"{}\""
                        .format(k, scale_.dtype.base_dtype.name, dtype.name))

            self._endpoint_affine = [
                AffineLinearOperator(shift=loc_,
                                     scale=scale_,
                                     event_ndims=1,
                                     validate_args=validate_args,
                                     name="endpoint_affine_{}".format(k))
                for k, (loc_, scale_) in enumerate(zip(loc, scale))
            ]

            # TODO(jvdillon): Remove once we support k-mixtures.
            # We make this assertion here because otherwise `grid` would need to be a
            # vector not a scalar.
            if len(scale) != 2:
                raise NotImplementedError(
                    "Currently only bimixtures are supported; "
                    "len(scale)={} is not 2.".format(len(scale)))

            grid, probs = distribution_util.process_quadrature_grid_and_probs(
                quadrature_grid_and_probs, dtype, validate_args)
            self._quadrature_grid = grid
            self._quadrature_probs = probs
            self._quadrature_size = distribution_util.dimension_size(probs,
                                                                     axis=0)

            # Note: by creating the logits as `log(prob)` we ensure that
            # `self.mixture_distribution.logits` is equivalent to
            # `math_ops.log(self.mixture_distribution.probs)`.
            self._mixture_distribution = categorical_lib.Categorical(
                logits=math_ops.log(probs),
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats)

            mix_loc = maybe_check_mix_param(mix_loc, "mix_loc", dtype,
                                            validate_args)
            mix_scale = maybe_check_mix_param(mix_scale, "mix_scale", dtype,
                                              validate_args)

            asserts = distribution_util.maybe_check_scalar_distribution(
                distribution, dtype, validate_args)
            if asserts:
                mix_loc = control_flow_ops.with_dependencies(asserts, mix_loc)
            self._distribution = distribution

            # shape: [B, deg]
            self._interpolate_weight = math_ops.sigmoid(mix_loc + np.sqrt(2.) *
                                                        mix_scale * grid)

            self._interpolated_affine = [
                AffineLinearOperator(shift=loc_,
                                     scale=scale_,
                                     event_ndims=1,
                                     validate_args=validate_args,
                                     name="interpolated_affine_{}".format(k))
                for k, (loc_, scale_) in enumerate(
                    zip(
                        interpolate_loc(self._quadrature_size,
                                        self._interpolate_weight, loc),
                        interpolate_scale(self._quadrature_size,
                                          self._interpolate_weight, scale)))
            ]

            self._batch_shape_, self._event_shape_ = determine_batch_event_shapes(
                mix_loc, mix_scale, self._endpoint_affine)

            super(VectorDiffeomixture, self).__init__(
                dtype=dtype,
                # We hard-code `FULLY_REPARAMETERIZED` because when
                # `validate_args=True` we verify that indeed
                # `distribution.reparameterization_type == FULLY_REPARAMETERIZED`. A
                # distribution which is a function of only non-trainable parameters
                # also implies we can use `FULLY_REPARAMETERIZED`. However, we cannot
                # easily test for that possibility thus we use `validate_args=False`
                # as a "back-door" to allow users a way to use non
                # `FULLY_REPARAMETERIZED` distribution. In such cases IT IS THE USERS
                # RESPONSIBILITY to verify that the base distribution is a function of
                # non-trainable parameters.
                reparameterization_type=distribution_lib.FULLY_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                graph_parents=(
                    [mix_loc, mix_scale] + distribution._graph_parents  # pylint: disable=protected-access
                    + [loc_ for loc_ in loc if loc_ is not None] +
                    [p for scale_ in scale for p in scale_.graph_parents]),
                name=name)