示例#1
0
  def _covariance(self):
    static_event_ndims = tensorshape_util.rank(self.event_shape)
    if static_event_ndims is not None and static_event_ndims != 1:
      # Covariance is defined only for vector distributions.
      raise NotImplementedError("covariance is not implemented")

    with tf.control_dependencies(self._runtime_assertions):
      # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
      probs = distribution_utils.pad_mixture_dimensions(
          distribution_utils.pad_mixture_dimensions(
              self.mixture_distribution.probs_parameter(),
              self,
              self.mixture_distribution,
              self._event_ndims),
          self, self.mixture_distribution,
          self._event_ndims)                         # [B, k, 1, 1]
      mean_cond_var = tf.reduce_sum(
          probs * self.components_distribution.covariance(),
          axis=-3)  # [B, e, e]
      var_cond_mean = tf.reduce_sum(
          probs * _outer_squared_difference(
              self.components_distribution.mean(),
              self._pad_sample_dims(self._mean())),
          axis=-3)  # [B, e, e]
      return mean_cond_var + var_cond_mean                   # [B, e, e]
 def _mean(self):
     with tf.control_dependencies(self._runtime_assertions):
         probs = distribution_utils.pad_mixture_dimensions(
             self.mixture_distribution.probs, self,
             self.mixture_distribution, self._event_ndims)  # [B, k, [1]*e]
         return tf.reduce_sum(probs * self.components_distribution.mean(),
                              axis=-1 - self._event_ndims)  # [B, E]
 def _mean(self):
   with tf.control_dependencies(self._runtime_assertions):
     probs = distribution_utils.pad_mixture_dimensions(
         self.mixture_distribution.probs, self, self.mixture_distribution,
         self._event_shape().ndims)                         # [B, k, [1]*e]
     return tf.reduce_sum(
         probs * self.components_distribution.mean(),
         axis=-1 - self._event_ndims)  # [B, E]
示例#4
0
    def _sample_n(self, n, seed=None):
        seeds = samplers.split_seed(seed,
                                    n=self.num_components + 1,
                                    salt='Mixture')
        try:
            seed_stream = SeedStream(seed, salt='Mixture')
        except TypeError as e:  # Can happen for Tensor seed.
            seed_stream = None
            seed_stream_err = e

        # This sampling approach is almost the same as the approach used by
        # `MixtureSameFamily`. The differences are due to having a list of
        # `Distribution` objects rather than a single object.
        samples = []
        cat_samples = self.cat.sample(n, seed=seeds[0])

        for c in range(self.num_components):
            try:
                samples.append(self.components[c].sample(n, seed=seeds[c + 1]))
                if seed_stream is not None:
                    seed_stream()
            except TypeError as e:
                if ('Expected int for argument' not in str(e)
                        and TENSOR_SEED_MSG_PREFIX not in str(e)):
                    raise
                if seed_stream is None:
                    raise seed_stream_err
                msg = (
                    'Falling back to stateful sampling for `components[{}]` {} of '
                    'type `{}`. Please update to use `tf.random.stateless_*` RNGs. '
                    'This fallback may be removed after 20-Aug-2020. ({})')
                warnings.warn(
                    msg.format(c, self.components[c].name,
                               type(self.components[c]), str(e)))
                samples.append(self.components[c].sample(n,
                                                         seed=seed_stream()))
        stack_axis = -1 - tensorshape_util.rank(self._static_event_shape)
        x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
        # TODO(b/170730865): Is all this masking stuff really called for?
        npdt = dtype_util.as_numpy_dtype(x.dtype)
        mask = tf.one_hot(
            indices=cat_samples,  # [n, B]
            depth=self._num_components,  # == k
            on_value=npdt(1),
            off_value=npdt(0))  # [n, B, k]
        mask = distribution_util.pad_mixture_dimensions(
            mask, self, self._cat,
            tensorshape_util.rank(
                self._static_event_shape))  # [n, B, k, [1]*e]
        if x.dtype.is_floating:
            masked = tf.math.multiply_no_nan(x, mask)
        else:
            masked = x * mask
        return tf.reduce_sum(masked, axis=stack_axis)  # [n, B, E]
  def test_pad_mixture_dimensions_mixture_same_family(self):
    gm = MixtureSameFamily(
        mixture_distribution=Categorical(probs=[0.3, 0.7]),
        components_distribution=MultivariateNormalDiag(
            loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1.0, 0.5]))

    x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
    x_pad = distribution_util.pad_mixture_dimensions(
        x, gm, gm.mixture_distribution, tensorshape_util.rank(gm.event_shape))
    x_out, x_pad_out = self.evaluate([x, x_pad])

    self.assertAllEqual(x_pad_out.shape, [2, 2, 1])
    self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1]))
示例#6
0
 def _variance(self):
   with tf.control_dependencies(self._runtime_assertions):
     # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
     probs = distribution_utils.pad_mixture_dimensions(
         self.mixture_distribution.probs, self, self.mixture_distribution,
         self._event_ndims)                         # [B, k, [1]*e]
     mean_cond_var = tf.reduce_sum(
         probs * self.components_distribution.variance(),
         axis=-1 - self._event_ndims)  # [B, E]
     var_cond_mean = tf.reduce_sum(
         probs * tf.squared_difference(self.components_distribution.mean(),
                                       self._pad_sample_dims(self._mean())),
         axis=-1 - self._event_ndims)  # [B, E]
     return mean_cond_var + var_cond_mean                   # [B, E]
  def _covariance(self):
    static_event_ndims = self.event_shape.ndims
    if static_event_ndims != 1:
      # Covariance is defined only for vector distributions.
      raise NotImplementedError("covariance is not implemented")

    with tf.control_dependencies(self._runtime_assertions):
      # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
      probs = distribution_utils.pad_mixture_dimensions(
          distribution_utils.pad_mixture_dimensions(
              self.mixture_distribution.probs, self, self.mixture_distribution,
              self._event_shape().ndims),
          self, self.mixture_distribution,
          self._event_shape().ndims)                         # [B, k, 1, 1]
      mean_cond_var = tf.reduce_sum(
          probs * self.components_distribution.covariance(),
          axis=-3)  # [B, e, e]
      var_cond_mean = tf.reduce_sum(
          probs * _outer_squared_difference(self.components_distribution.mean(),
                                            self._pad_sample_dims(
                                                self._mean())),
          axis=-3)  # [B, e, e]
      return mean_cond_var + var_cond_mean                   # [B, e, e]
 def _variance(self):
   with tf.control_dependencies(self._runtime_assertions):
     # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
     probs = distribution_utils.pad_mixture_dimensions(
         self.mixture_distribution.probs, self, self.mixture_distribution,
         self._event_shape().ndims)                         # [B, k, [1]*e]
     mean_cond_var = tf.reduce_sum(
         probs * self.components_distribution.variance(),
         axis=-1 - self._event_ndims)  # [B, E]
     var_cond_mean = tf.reduce_sum(
         probs * tf.squared_difference(self.components_distribution.mean(),
                                       self._pad_sample_dims(self._mean())),
         axis=-1 - self._event_ndims)  # [B, E]
     return mean_cond_var + var_cond_mean                   # [B, E]
 def _sample_n(self, n, seed):
   with tf.control_dependencies(self._runtime_assertions):
     x = self.components_distribution.sample(n)             # [n, B, k, E]
     # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
     npdt = x.dtype.as_numpy_dtype
     mask = tf.one_hot(
         indices=self.mixture_distribution.sample(n),  # [n, B]
         depth=self._num_components,  # == k
         on_value=np.ones([], dtype=npdt),
         off_value=np.zeros([], dtype=npdt))  # [n, B, k]
     mask = distribution_utils.pad_mixture_dimensions(
         mask, self, self.mixture_distribution,
         self._event_shape().ndims)                         # [n, B, k, [1]*e]
     return tf.reduce_sum(x * mask, axis=-1 - self._event_ndims)  # [n, B, E]
示例#10
0
 def _sample_n(self, n, seed):
   with tf.control_dependencies(self._runtime_assertions):
     x = self.components_distribution.sample(n)             # [n, B, k, E]
     # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
     npdt = x.dtype.as_numpy_dtype
     mask = tf.one_hot(
         indices=self.mixture_distribution.sample(n),  # [n, B]
         depth=self._num_components,  # == k
         on_value=np.ones([], dtype=npdt),
         off_value=np.zeros([], dtype=npdt))  # [n, B, k]
     mask = distribution_utils.pad_mixture_dimensions(
         mask, self, self.mixture_distribution,
         self._event_ndims)                         # [n, B, k, [1]*e]
     return tf.reduce_sum(x * mask, axis=-1 - self._event_ndims)  # [n, B, E]
示例#11
0
    def test_pad_mixture_dimensions_mixture(self):
        gm = Mixture(cat=Categorical(probs=[[0.3, 0.7]]),
                     components=[
                         Normal(loc=[-1.0], scale=[1.0]),
                         Normal(loc=[1.0], scale=[0.5])
                     ])

        x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
        x_pad = distribution_util.pad_mixture_dimensions(
            x, gm, gm.cat, tensorshape_util.rank(gm.event_shape))
        x_out, x_pad_out = self.evaluate([x, x_pad])

        self.assertAllEqual(x_pad_out.shape, [2, 2])
        self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1]))
    def test_pad_mixture_dimensions_mixture(self):
        with self.test_session() as sess:
            gm = Mixture(cat=tf.distributions.Categorical(probs=[[0.3, 0.7]]),
                         components=[
                             tf.distributions.Normal(loc=[-1.0], scale=[1.0]),
                             tf.distributions.Normal(loc=[1.0], scale=[0.5])
                         ])

            x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
            x_pad = distribution_util.pad_mixture_dimensions(
                x, gm, gm.cat, gm.event_shape.ndims)
            x_out, x_pad_out = sess.run([x, x_pad])

        self.assertAllEqual(x_pad_out.shape, [2, 2])
        self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1]))
示例#13
0
 def _sample_n(self, n, seed):
   with tf.control_dependencies(self._runtime_assertions):
     seed = seed_stream.SeedStream(seed, salt="MixtureSameFamily")
     x = self.components_distribution.sample(n, seed=seed())  # [n, B, k, E]
     # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
     npdt = dtype_util.as_numpy_dtype(x.dtype)
     mask = tf.one_hot(
         indices=self.mixture_distribution.sample(n, seed=seed()),  # [n, B]
         depth=self._num_components,  # == k
         on_value=npdt(1),
         off_value=npdt(0))  # [n, B, k]
     mask = distribution_utils.pad_mixture_dimensions(
         mask, self, self.mixture_distribution,
         self._event_ndims)                         # [n, B, k, [1]*e]
     x = tf.reduce_sum(x * mask, axis=-1 - self._event_ndims)  # [n, B, E]
     if self._reparameterize:
       x = self._reparameterize_sample(x)
     return x
示例#14
0
  def _sample_n(self, n, seed=None):
    if self._use_static_graph:
      # This sampling approach is almost the same as the approach used by
      # `MixtureSameFamily`. The differences are due to having a list of
      # `Distribution` objects rather than a single object, and maintaining
      # random seed management that is consistent with the non-static code path.
      samples = []
      cat_samples = self.cat.sample(n, seed=seed)
      for c in range(self.num_components):
        seed = distribution_util.gen_new_seed(seed, "mixture")
        samples.append(self.components[c].sample(n, seed=seed))
      x = tf.stack(samples, -self._static_event_shape.ndims - 1)  # [n, B, k, E]
      npdt = x.dtype.as_numpy_dtype
      mask = tf.one_hot(
          indices=cat_samples,  # [n, B]
          depth=self._num_components,  # == k
          on_value=np.ones([], dtype=npdt),
          off_value=np.zeros([], dtype=npdt))  # [n, B, k]
      mask = distribution_utils.pad_mixture_dimensions(
          mask, self, self._cat,
          self._static_event_shape.ndims)                   # [n, B, k, [1]*e]
      return tf.reduce_sum(
          x * mask, axis=-1 - self._static_event_shape.ndims)  # [n, B, E]

    with tf.control_dependencies(self._assertions):
      n = tf.convert_to_tensor(n, name="n")
      static_n = tensor_util.constant_value(n)
      n = int(static_n) if static_n is not None else n
      cat_samples = self.cat.sample(n, seed=seed)

      static_samples_shape = cat_samples.get_shape()
      if static_samples_shape.is_fully_defined():
        samples_shape = static_samples_shape.as_list()
        samples_size = static_samples_shape.num_elements()
      else:
        samples_shape = tf.shape(cat_samples)
        samples_size = tf.size(cat_samples)
      static_batch_shape = self.batch_shape
      if static_batch_shape.is_fully_defined():
        batch_shape = static_batch_shape.as_list()
        batch_size = static_batch_shape.num_elements()
      else:
        batch_shape = self.batch_shape_tensor()
        batch_size = tf.reduce_prod(batch_shape)
      static_event_shape = self.event_shape
      if static_event_shape.is_fully_defined():
        event_shape = np.array(static_event_shape.as_list(), dtype=np.int32)
      else:
        event_shape = self.event_shape_tensor()

      # Get indices into the raw cat sampling tensor. We will
      # need these to stitch sample values back out after sampling
      # within the component partitions.
      samples_raw_indices = tf.reshape(tf.range(0, samples_size), samples_shape)

      # Partition the raw indices so that we can use
      # dynamic_stitch later to reconstruct the samples from the
      # known partitions.
      partitioned_samples_indices = tf.dynamic_partition(
          data=samples_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)

      # Copy the batch indices n times, as we will need to know
      # these to pull out the appropriate rows within the
      # component partitions.
      batch_raw_indices = tf.reshape(
          tf.tile(tf.range(0, batch_size), [n]), samples_shape)

      # Explanation of the dynamic partitioning below:
      #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
      # Suppose partitions are:
      #     [1 1 0 0 1 1]
      # After partitioning, batch indices are cut as:
      #     [batch_indices[x] for x in 2, 3]
      #     [batch_indices[x] for x in 0, 1, 4, 5]
      # i.e.
      #     [1 1] and [0 0 0 0]
      # Now we sample n=2 from part 0 and n=4 from part 1.
      # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
      # and for part 1 we want samples from batch entries 0, 0, 0, 0
      #   (samples 0, 1, 2, 3).
      partitioned_batch_indices = tf.dynamic_partition(
          data=batch_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)
      samples_class = [None for _ in range(self.num_components)]

      for c in range(self.num_components):
        n_class = tf.size(partitioned_samples_indices[c])
        seed = distribution_util.gen_new_seed(seed, "mixture")
        samples_class_c = self.components[c].sample(n_class, seed=seed)

        # Pull out the correct batch entries from each index.
        # To do this, we may have to flatten the batch shape.

        # For sample s, batch element b of component c, we get the
        # partitioned batch indices from
        # partitioned_batch_indices[c]; and shift each element by
        # the sample index. The final lookup can be thought of as
        # a matrix gather along locations (s, b) in
        # samples_class_c where the n_class rows correspond to
        # samples within this component and the batch_size columns
        # correspond to batch elements within the component.
        #
        # Thus the lookup index is
        #   lookup[c, i] = batch_size * s[i] + b[c, i]
        # for i = 0 ... n_class[c] - 1.
        lookup_partitioned_batch_indices = (
            batch_size * tf.range(n_class) + partitioned_batch_indices[c])
        samples_class_c = tf.reshape(
            samples_class_c, tf.concat([[n_class * batch_size], event_shape],
                                       0))
        samples_class_c = tf.gather(
            samples_class_c,
            lookup_partitioned_batch_indices,
            name="samples_class_c_gather")
        samples_class[c] = samples_class_c

      # Stitch back together the samples across the components.
      lhs_flat_ret = tf.dynamic_stitch(
          indices=partitioned_samples_indices, data=samples_class)
      # Reshape back to proper sample, batch, and event shape.
      ret = tf.reshape(
          lhs_flat_ret, tf.concat(
              [samples_shape, self.event_shape_tensor()], 0))
      ret.set_shape(
          tf.TensorShape(static_samples_shape).concatenate(self.event_shape))
      return ret
示例#15
0
    def _sample_n(self, n, seed=None):
        if self._use_static_graph:
            with tf.control_dependencies(self._assertions):
                # This sampling approach is almost the same as the approach used by
                # `MixtureSameFamily`. The differences are due to having a list of
                # `Distribution` objects rather than a single object, and maintaining
                # random seed management that is consistent with the non-static code
                # path.
                samples = []
                cat_samples = self.cat.sample(n, seed=seed)
                stream = seed_stream.SeedStream(seed, salt="Mixture")

                for c in range(self.num_components):
                    samples.append(self.components[c].sample(n, seed=stream()))
                x = tf.stack(samples, -self._static_event_shape.ndims -
                             1)  # [n, B, k, E]
                npdt = x.dtype.as_numpy_dtype
                mask = tf.one_hot(
                    indices=cat_samples,  # [n, B]
                    depth=self._num_components,  # == k
                    on_value=np.ones([], dtype=npdt),
                    off_value=np.zeros([], dtype=npdt))  # [n, B, k]
                mask = distribution_util.pad_mixture_dimensions(
                    mask, self, self._cat,
                    self._static_event_shape.ndims)  # [n, B, k, [1]*e]
                return tf.reduce_sum(
                    input_tensor=x * mask,
                    axis=-1 - self._static_event_shape.ndims)  # [n, B, E]

        with tf.control_dependencies(self._assertions):
            n = tf.convert_to_tensor(value=n, name="n")
            static_n = tf.get_static_value(n)
            n = int(static_n) if static_n is not None else n
            cat_samples = self.cat.sample(n, seed=seed)

            static_samples_shape = cat_samples.shape
            if static_samples_shape.is_fully_defined():
                samples_shape = static_samples_shape.as_list()
                samples_size = static_samples_shape.num_elements()
            else:
                samples_shape = tf.shape(input=cat_samples)
                samples_size = tf.size(input=cat_samples)
            static_batch_shape = self.batch_shape
            if static_batch_shape.is_fully_defined():
                batch_shape = static_batch_shape.as_list()
                batch_size = static_batch_shape.num_elements()
            else:
                batch_shape = self.batch_shape_tensor()
                batch_size = tf.reduce_prod(input_tensor=batch_shape)
            static_event_shape = self.event_shape
            if static_event_shape.is_fully_defined():
                event_shape = np.array(static_event_shape.as_list(),
                                       dtype=np.int32)
            else:
                event_shape = self.event_shape_tensor()

            # Get indices into the raw cat sampling tensor. We will
            # need these to stitch sample values back out after sampling
            # within the component partitions.
            samples_raw_indices = tf.reshape(tf.range(0, samples_size),
                                             samples_shape)

            # Partition the raw indices so that we can use
            # dynamic_stitch later to reconstruct the samples from the
            # known partitions.
            partitioned_samples_indices = tf.dynamic_partition(
                data=samples_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)

            # Copy the batch indices n times, as we will need to know
            # these to pull out the appropriate rows within the
            # component partitions.
            batch_raw_indices = tf.reshape(
                tf.tile(tf.range(0, batch_size), [n]), samples_shape)

            # Explanation of the dynamic partitioning below:
            #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
            # Suppose partitions are:
            #     [1 1 0 0 1 1]
            # After partitioning, batch indices are cut as:
            #     [batch_indices[x] for x in 2, 3]
            #     [batch_indices[x] for x in 0, 1, 4, 5]
            # i.e.
            #     [1 1] and [0 0 0 0]
            # Now we sample n=2 from part 0 and n=4 from part 1.
            # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
            # and for part 1 we want samples from batch entries 0, 0, 0, 0
            #   (samples 0, 1, 2, 3).
            partitioned_batch_indices = tf.dynamic_partition(
                data=batch_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)
            samples_class = [None for _ in range(self.num_components)]

            stream = seed_stream.SeedStream(seed, salt="Mixture")

            for c in range(self.num_components):
                n_class = tf.size(input=partitioned_samples_indices[c])
                samples_class_c = self.components[c].sample(n_class,
                                                            seed=stream())

                # Pull out the correct batch entries from each index.
                # To do this, we may have to flatten the batch shape.

                # For sample s, batch element b of component c, we get the
                # partitioned batch indices from
                # partitioned_batch_indices[c]; and shift each element by
                # the sample index. The final lookup can be thought of as
                # a matrix gather along locations (s, b) in
                # samples_class_c where the n_class rows correspond to
                # samples within this component and the batch_size columns
                # correspond to batch elements within the component.
                #
                # Thus the lookup index is
                #   lookup[c, i] = batch_size * s[i] + b[c, i]
                # for i = 0 ... n_class[c] - 1.
                lookup_partitioned_batch_indices = (
                    batch_size * tf.range(n_class) +
                    partitioned_batch_indices[c])
                samples_class_c = tf.reshape(
                    samples_class_c,
                    tf.concat([[n_class * batch_size], event_shape], 0))
                samples_class_c = tf.gather(samples_class_c,
                                            lookup_partitioned_batch_indices,
                                            name="samples_class_c_gather")
                samples_class[c] = samples_class_c

            # Stitch back together the samples across the components.
            lhs_flat_ret = tf.dynamic_stitch(
                indices=partitioned_samples_indices, data=samples_class)
            # Reshape back to proper sample, batch, and event shape.
            ret = tf.reshape(
                lhs_flat_ret,
                tf.concat(
                    [samples_shape, self.event_shape_tensor()], 0))
            ret.set_shape(
                tf.TensorShape(static_samples_shape).concatenate(
                    self.event_shape))
            return ret
示例#16
0
    def _distributional_transform(self, x):
        """Performs distributional transform of the mixture samples.

    Distributional transform removes the parameters from samples of a
    multivariate distribution by applying conditional CDFs:
      (F(x_1), F(x_2 | x1_), ..., F(x_d | x_1, ..., x_d-1))
    (the indexing is over the "flattened" event dimensions).
    The result is a sample of product of Uniform[0, 1] distributions.

    We assume that the components are factorized, so the conditional CDFs become
      F(x_i | x_1, ..., x_i-1) = sum_k w_i^k F_k (x_i),
    where w_i^k is the posterior mixture weight: for i > 0
      w_i^k = w_k prob_k(x_1, ..., x_i-1) / sum_k' w_k' prob_k'(x_1, ..., x_i-1)
    and w_0^k = w_k is the mixture probability of the k-th component.

    Arguments:
      x: Sample of mixture distribution

    Returns:
      Result of the distributional transform
    """

        if x.shape.ndims is None:
            # tf.nn.softmax raises an error when applied to inputs of undefined rank.
            raise ValueError(
                "Distributional transform does not support inputs of "
                "undefined rank.")

        # Obtain factorized components distribution and assert that it's
        # a scalar distribution.
        if isinstance(self._components_distribution, independent.Independent):
            univariate_components = self._components_distribution.distribution
        else:
            univariate_components = self._components_distribution

        with tf.control_dependencies([
                assert_util.assert_equal(
                    univariate_components.is_scalar_event(),
                    True,
                    message="`univariate_components` must have scalar event")
        ]):
            x_padded = self._pad_sample_dims(x)  # [S, B, 1, E]
            log_prob_x = univariate_components.log_prob(
                x_padded)  # [S, B, k, E]
            cdf_x = univariate_components.cdf(x_padded)  # [S, B, k, E]

            # log prob_k (x_1, ..., x_i-1)
            cumsum_log_prob_x = tf.reshape(
                tf.math.cumsum(
                    # [S*prod(B)*k, prod(E)]
                    tf.reshape(log_prob_x, [-1, self._event_size]),
                    exclusive=True,
                    axis=-1),
                tf.shape(input=log_prob_x))  # [S, B, k, E]

            logits_mix_prob = distribution_utils.pad_mixture_dimensions(
                self.mixture_distribution.logits, self,
                self.mixture_distribution, self._event_ndims)  # [B, k, 1]

            # Logits of the posterior weights: log w_k + log prob_k (x_1, ..., x_i-1)
            log_posterior_weights_x = logits_mix_prob + cumsum_log_prob_x

            component_axis = x.shape.ndims - self._event_ndims
            posterior_weights_x = tf.nn.softmax(log_posterior_weights_x,
                                                axis=component_axis)
            return tf.reduce_sum(input_tensor=posterior_weights_x * cdf_x,
                                 axis=component_axis)
示例#17
0
    def _sample_n(self, n, seed=None):
        seeds = samplers.split_seed(seed,
                                    n=self.num_components + 1,
                                    salt='Mixture')
        try:
            seed_stream = SeedStream(seed, salt='Mixture')
        except TypeError as e:  # Can happen for Tensor seed.
            seed_stream = None
            seed_stream_err = e
        if self._use_static_graph:
            # This sampling approach is almost the same as the approach used by
            # `MixtureSameFamily`. The differences are due to having a list of
            # `Distribution` objects rather than a single object, and maintaining
            # random seed management that is consistent with the non-static code
            # path.
            samples = []
            cat_samples = self.cat.sample(n, seed=seeds[0])

            for c in range(self.num_components):
                try:
                    samples.append(self.components[c].sample(n,
                                                             seed=seeds[c +
                                                                        1]))
                    if seed_stream is not None:
                        seed_stream()
                except TypeError as e:
                    if ('Expected int for argument' not in str(e)
                            and TENSOR_SEED_MSG_PREFIX not in str(e)):
                        raise
                    if seed_stream is None:
                        raise seed_stream_err
                    msg = (
                        'Falling back to stateful sampling for `components[{}]` {} of '
                        'type `{}`. Please update to use `tf.random.stateless_*` RNGs. '
                        'This fallback may be removed after 20-Aug-2020. ({})')
                    warnings.warn(
                        msg.format(c, self.components[c].name,
                                   type(self.components[c]), str(e)))
                    samples.append(self.components[c].sample(
                        n, seed=seed_stream()))
            stack_axis = -1 - tensorshape_util.rank(self._static_event_shape)
            x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
            npdt = dtype_util.as_numpy_dtype(x.dtype)
            mask = tf.one_hot(
                indices=cat_samples,  # [n, B]
                depth=self._num_components,  # == k
                on_value=npdt(1),
                off_value=npdt(0))  # [n, B, k]
            mask = distribution_util.pad_mixture_dimensions(
                mask, self, self._cat,
                tensorshape_util.rank(
                    self._static_event_shape))  # [n, B, k, [1]*e]
            if x.dtype.is_floating:
                masked = tf.math.multiply_no_nan(x, mask)
            else:
                masked = x * mask
            return tf.reduce_sum(masked, axis=stack_axis)  # [n, B, E]

        n = tf.convert_to_tensor(n, name='n')
        static_n = tf.get_static_value(n)
        n = int(static_n) if static_n is not None else n
        cat_samples = self.cat.sample(n, seed=seeds[0])

        static_samples_shape = cat_samples.shape
        if tensorshape_util.is_fully_defined(static_samples_shape):
            samples_shape = tensorshape_util.as_list(static_samples_shape)
            samples_size = tensorshape_util.num_elements(static_samples_shape)
        else:
            samples_shape = tf.shape(cat_samples)
            samples_size = tf.size(cat_samples)
        static_batch_shape = self.batch_shape
        if tensorshape_util.is_fully_defined(static_batch_shape):
            batch_shape = tensorshape_util.as_list(static_batch_shape)
            batch_size = tensorshape_util.num_elements(static_batch_shape)
        else:
            batch_shape = tf.shape(cat_samples)[1:]
            batch_size = tf.reduce_prod(batch_shape)
        static_event_shape = self.event_shape
        if tensorshape_util.is_fully_defined(static_event_shape):
            event_shape = np.array(
                tensorshape_util.as_list(static_event_shape), dtype=np.int32)
        else:
            event_shape = None

        # Get indices into the raw cat sampling tensor. We will
        # need these to stitch sample values back out after sampling
        # within the component partitions.
        samples_raw_indices = tf.reshape(tf.range(0, samples_size),
                                         samples_shape)

        # Partition the raw indices so that we can use
        # dynamic_stitch later to reconstruct the samples from the
        # known partitions.
        partitioned_samples_indices = tf.dynamic_partition(
            data=samples_raw_indices,
            partitions=cat_samples,
            num_partitions=self.num_components)

        # Copy the batch indices n times, as we will need to know
        # these to pull out the appropriate rows within the
        # component partitions.
        batch_raw_indices = tf.reshape(tf.tile(tf.range(0, batch_size), [n]),
                                       samples_shape)

        # Explanation of the dynamic partitioning below:
        #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
        # Suppose partitions are:
        #     [1 1 0 0 1 1]
        # After partitioning, batch indices are cut as:
        #     [batch_indices[x] for x in 2, 3]
        #     [batch_indices[x] for x in 0, 1, 4, 5]
        # i.e.
        #     [1 1] and [0 0 0 0]
        # Now we sample n=2 from part 0 and n=4 from part 1.
        # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
        # and for part 1 we want samples from batch entries 0, 0, 0, 0
        #   (samples 0, 1, 2, 3).
        partitioned_batch_indices = tf.dynamic_partition(
            data=batch_raw_indices,
            partitions=cat_samples,
            num_partitions=self.num_components)
        samples_class = [None for _ in range(self.num_components)]

        for c in range(self.num_components):
            n_class = tf.size(partitioned_samples_indices[c])
            try:
                samples_class_c = self.components[c].sample(n_class,
                                                            seed=seeds[c + 1])
                if seed_stream is not None:
                    seed_stream()
            except TypeError as e:
                if ('Expected int for argument' not in str(e)
                        and TENSOR_SEED_MSG_PREFIX not in str(e)):
                    raise
                if seed_stream is None:
                    raise seed_stream_err
                msg = (
                    'Falling back to stateful sampling for `components[{}]` {} of '
                    'type `{}`. Please update to use `tf.random.stateless_*` RNGs. '
                    'This fallback may be removed after 20-Aug-2020. ({})')
                warnings.warn(
                    msg.format(c, self.components[c].name,
                               type(self.components[c]), str(e)))
                samples_class_c = self.components[c].sample(n_class,
                                                            seed=seed_stream())

            if event_shape is None:
                batch_ndims = prefer_static.rank_from_shape(batch_shape)
                event_shape = tf.shape(samples_class_c)[1 + batch_ndims:]

            # Pull out the correct batch entries from each index.
            # To do this, we may have to flatten the batch shape.

            # For sample s, batch element b of component c, we get the
            # partitioned batch indices from
            # partitioned_batch_indices[c]; and shift each element by
            # the sample index. The final lookup can be thought of as
            # a matrix gather along locations (s, b) in
            # samples_class_c where the n_class rows correspond to
            # samples within this component and the batch_size columns
            # correspond to batch elements within the component.
            #
            # Thus the lookup index is
            #   lookup[c, i] = batch_size * s[i] + b[c, i]
            # for i = 0 ... n_class[c] - 1.
            lookup_partitioned_batch_indices = (
                batch_size * tf.range(n_class) + partitioned_batch_indices[c])
            samples_class_c = tf.reshape(
                samples_class_c,
                tf.concat([[n_class * batch_size], event_shape], 0))
            samples_class_c = tf.gather(samples_class_c,
                                        lookup_partitioned_batch_indices,
                                        name='samples_class_c_gather')
            samples_class[c] = samples_class_c

        # Stitch back together the samples across the components.
        lhs_flat_ret = tf.dynamic_stitch(indices=partitioned_samples_indices,
                                         data=samples_class)
        # Reshape back to proper sample, batch, and event shape.
        ret = tf.reshape(lhs_flat_ret,
                         tf.concat([samples_shape, event_shape], 0))
        tensorshape_util.set_shape(
            ret,
            tensorshape_util.concatenate(static_samples_shape,
                                         self.event_shape))
        return ret