コード例 #1
0
    def __init__(self,
                 mixture_distribution,
                 components_distribution,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="MixtureSameFamily"):
        """Construct a `MixtureSameFamily` distribution.

    Args:
      mixture_distribution: `tfp.distributions.Categorical`-like instance.
        Manages the probability of selecting components. The number of
        categories must match the rightmost batch dimension of the
        `components_distribution`. Must have either scalar `batch_shape` or
        `batch_shape` matching `components_distribution.batch_shape[:-1]`.
      components_distribution: `tfp.distributions.Distribution`-like instance.
        Right-most batch dimension indexes components.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: `if not mixture_distribution.dtype.is_integer`.
      ValueError: if mixture_distribution does not have scalar `event_shape`.
      ValueError: if `mixture_distribution.batch_shape` and
        `components_distribution.batch_shape[:-1]` are both fully defined and
        the former is neither scalar nor equal to the latter.
      ValueError: if `mixture_distribution` categories does not equal
        `components_distribution` rightmost batch shape.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._mixture_distribution = mixture_distribution
            self._components_distribution = components_distribution
            self._runtime_assertions = []

            s = components_distribution.event_shape_tensor()
            self._event_ndims = (s.shape[0].value
                                 if s.shape.with_rank_at_least(1)[0].value
                                 is not None else tf.shape(s)[0])

            if not mixture_distribution.dtype.is_integer:
                raise ValueError(
                    "`mixture_distribution.dtype` ({}) is not over integers".
                    format(mixture_distribution.dtype.name))

            if (mixture_distribution.event_shape.ndims is not None
                    and mixture_distribution.event_shape.ndims != 0):
                raise ValueError(
                    "`mixture_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    control_flow_ops.assert_has_rank(
                        mixture_distribution.event_shape_tensor(),
                        0,
                        message=
                        "`mixture_distribution` must have scalar `event_dim`s"
                    ),
                ]

            mdbs = mixture_distribution.batch_shape
            cdbs = components_distribution.batch_shape.with_rank_at_least(
                1)[:-1]
            if mdbs.is_fully_defined() and cdbs.is_fully_defined():
                if mdbs.ndims != 0 and mdbs != cdbs:
                    raise ValueError(
                        "`mixture_distribution.batch_shape` (`{}`) is not "
                        "compatible with `components_distribution.batch_shape` "
                        "(`{}`)".format(mdbs.as_list(), cdbs.as_list()))
            elif validate_args:
                mdbs = mixture_distribution.batch_shape_tensor()
                cdbs = components_distribution.batch_shape_tensor()[:-1]
                self._runtime_assertions += [
                    control_flow_ops.assert_equal(
                        distribution_utils.pick_vector(
                            mixture_distribution.is_scalar_batch(), cdbs,
                            mdbs),
                        cdbs,
                        message=
                        ("`mixture_distribution.batch_shape` is not "
                         "compatible with `components_distribution.batch_shape`"
                         ))
                ]

            km = mixture_distribution.logits.shape.with_rank_at_least(
                1)[-1].value
            kc = components_distribution.batch_shape.with_rank_at_least(
                1)[-1].value
            if km is not None and kc is not None and km != kc:
                raise ValueError(
                    "`mixture_distribution components` ({}) does not "
                    "equal `components_distribution.batch_shape[-1]` "
                    "({})".format(km, kc))
            elif validate_args:
                km = tf.shape(mixture_distribution.logits)[-1]
                kc = components_distribution.batch_shape_tensor()[-1]
                self._runtime_assertions += [
                    control_flow_ops.assert_equal(
                        km,
                        kc,
                        message=(
                            "`mixture_distribution components` does not equal "
                            "`components_distribution.batch_shape[-1:]`")),
                ]
            elif km is None:
                km = tf.shape(mixture_distribution.logits)[-1]

            self._num_components = km

            super(MixtureSameFamily, self).__init__(
                dtype=self._components_distribution.dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                graph_parents=(
                    self._mixture_distribution._graph_parents  # pylint: disable=protected-access
                    + self._components_distribution._graph_parents),  # pylint: disable=protected-access
                name=name)
コード例 #2
0
  def __init__(self,
               mixture_distribution,
               components_distribution,
               validate_args=False,
               allow_nan_stats=True,
               name="MixtureSameFamily"):
    """Construct a `MixtureSameFamily` distribution.

    Args:
      mixture_distribution: `tf.distributions.Categorical`-like instance.
        Manages the probability of selecting components. The number of
        categories must match the rightmost batch dimension of the
        `components_distribution`. Must have either scalar `batch_shape` or
        `batch_shape` matching `components_distribution.batch_shape[:-1]`.
      components_distribution: `tf.distributions.Distribution`-like instance.
        Right-most batch dimension indexes components.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: `if not mixture_distribution.dtype.is_integer`.
      ValueError: if mixture_distribution does not have scalar `event_shape`.
      ValueError: if `mixture_distribution.batch_shape` and
        `components_distribution.batch_shape[:-1]` are both fully defined and
        the former is neither scalar nor equal to the latter.
      ValueError: if `mixture_distribution` categories does not equal
        `components_distribution` rightmost batch shape.
    """
    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name) as name:
      self._mixture_distribution = mixture_distribution
      self._components_distribution = components_distribution
      self._runtime_assertions = []

      s = components_distribution.event_shape_tensor()
      self._event_ndims = (s.shape[0].value
                           if s.shape.with_rank_at_least(1)[0].value is not None
                           else array_ops.shape(s)[0])

      if not mixture_distribution.dtype.is_integer:
        raise ValueError(
            "`mixture_distribution.dtype` ({}) is not over integers".format(
                mixture_distribution.dtype.name))

      if (mixture_distribution.event_shape.ndims is not None
          and mixture_distribution.event_shape.ndims != 0):
        raise ValueError("`mixture_distribution` must have scalar `event_dim`s")
      elif validate_args:
        self._runtime_assertions += [
            control_flow_ops.assert_has_rank(
                mixture_distribution.event_shape_tensor(), 0,
                message="`mixture_distribution` must have scalar `event_dim`s"),
        ]

      mdbs = mixture_distribution.batch_shape
      cdbs = components_distribution.batch_shape.with_rank_at_least(1)[:-1]
      if mdbs.is_fully_defined() and cdbs.is_fully_defined():
        if mdbs.ndims != 0 and mdbs != cdbs:
          raise ValueError(
              "`mixture_distribution.batch_shape` (`{}`) is not "
              "compatible with `components_distribution.batch_shape` "
              "(`{}`)".format(mdbs.as_list(), cdbs.as_list()))
      elif validate_args:
        mdbs = mixture_distribution.batch_shape_tensor()
        cdbs = components_distribution.batch_shape_tensor()[:-1]
        self._runtime_assertions += [
            control_flow_ops.assert_equal(
                distribution_util.pick_vector(
                    mixture_distribution.is_scalar_batch(), cdbs, mdbs),
                cdbs,
                message=(
                    "`mixture_distribution.batch_shape` is not "
                    "compatible with `components_distribution.batch_shape`"))]

      km = mixture_distribution.logits.shape.with_rank_at_least(1)[-1].value
      kc = components_distribution.batch_shape.with_rank_at_least(1)[-1].value
      if km is not None and kc is not None and km != kc:
        raise ValueError("`mixture_distribution components` ({}) does not "
                         "equal `components_distribution.batch_shape[-1]` "
                         "({})".format(km, kc))
      elif validate_args:
        km = array_ops.shape(mixture_distribution.logits)[-1]
        kc = components_distribution.batch_shape_tensor()[-1]
        self._runtime_assertions += [
            control_flow_ops.assert_equal(
                km, kc,
                message=("`mixture_distribution components` does not equal "
                         "`components_distribution.batch_shape[-1:]`")),
        ]
      elif km is None:
        km = array_ops.shape(mixture_distribution.logits)[-1]

      self._num_components = km

      super(MixtureSameFamily, self).__init__(
          dtype=self._components_distribution.dtype,
          reparameterization_type=distribution.NOT_REPARAMETERIZED,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          graph_parents=(
              self._mixture_distribution._graph_parents  # pylint: disable=protected-access
              + self._components_distribution._graph_parents),  # pylint: disable=protected-access
          name=name)