Exemple #1
0
 def _mean(self):
     with ops.control_dependencies(self._runtime_assertions):
         probs = distribution_utils.pad_mixture_dimensions(
             self.mixture_distribution.probs, self,
             self.mixture_distribution,
             self._event_shape().ndims)  # [B, k, [1]*e]
         return math_ops.reduce_sum(probs *
                                    self.components_distribution.mean(),
                                    axis=-1 - self._event_ndims)  # [B, E]
Exemple #2
0
    def test_pad_mixture_dimensions_mixture_same_family(self):
        with self.cached_session() as sess:
            gm = mixture_same_family.MixtureSameFamily(
                mixture_distribution=categorical.Categorical(probs=[0.3, 0.7]),
                components_distribution=mvn_diag.MultivariateNormalDiag(
                    loc=[[-1., 1], [1, -1]],
                    scale_identity_multiplier=[1.0, 0.5]))

            x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]])
            x_pad = distribution_util.pad_mixture_dimensions(
                x, gm, gm.mixture_distribution, gm.event_shape.ndims)
            x_out, x_pad_out = sess.run([x, x_pad])

        self.assertAllEqual(x_pad_out.shape, [2, 2, 1])
        self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1]))
Exemple #3
0
    def _covariance(self):
        static_event_ndims = self.event_shape.ndims
        if static_event_ndims != 1:
            # Covariance is defined only for vector distributions.
            raise NotImplementedError("covariance is not implemented")

        with ops.control_dependencies(self._runtime_assertions):
            # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
            probs = distribution_utils.pad_mixture_dimensions(
                distribution_utils.pad_mixture_dimensions(
                    self.mixture_distribution.probs, self,
                    self.mixture_distribution,
                    self._event_shape().ndims), self,
                self.mixture_distribution,
                self._event_shape().ndims)  # [B, k, 1, 1]
            mean_cond_var = math_ops.reduce_sum(
                probs * self.components_distribution.covariance(),
                axis=-3)  # [B, e, e]
            var_cond_mean = math_ops.reduce_sum(
                probs *
                _outer_squared_difference(self.components_distribution.mean(),
                                          self._pad_sample_dims(self._mean())),
                axis=-3)  # [B, e, e]
            return mean_cond_var + var_cond_mean  # [B, e, e]
Exemple #4
0
 def _sample_n(self, n, seed):
     with ops.control_dependencies(self._runtime_assertions):
         x = self.components_distribution.sample(n)  # [n, B, k, E]
         # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
         npdt = x.dtype.as_numpy_dtype
         mask = array_ops.one_hot(
             indices=self.mixture_distribution.sample(n),  # [n, B]
             depth=self._num_components,  # == k
             on_value=np.ones([], dtype=npdt),
             off_value=np.zeros([], dtype=npdt))  # [n, B, k]
         mask = distribution_utils.pad_mixture_dimensions(
             mask, self, self.mixture_distribution,
             self._event_shape().ndims)  # [n, B, k, [1]*e]
         return math_ops.reduce_sum(x * mask, axis=-1 -
                                    self._event_ndims)  # [n, B, E]
Exemple #5
0
    def test_pad_mixture_dimensions_mixture(self):
        with self.cached_session() as sess:
            gm = mixture.Mixture(
                cat=categorical.Categorical(probs=[[0.3, 0.7]]),
                components=[
                    normal.Normal(loc=[-1.0], scale=[1.0]),
                    normal.Normal(loc=[1.0], scale=[0.5])
                ])

            x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]])
            x_pad = distribution_util.pad_mixture_dimensions(
                x, gm, gm.cat, gm.event_shape.ndims)
            x_out, x_pad_out = sess.run([x, x_pad])

        self.assertAllEqual(x_pad_out.shape, [2, 2])
        self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1]))
Exemple #6
0
 def _variance(self):
     with ops.control_dependencies(self._runtime_assertions):
         # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
         probs = distribution_utils.pad_mixture_dimensions(
             self.mixture_distribution.probs, self,
             self.mixture_distribution,
             self._event_shape().ndims)  # [B, k, [1]*e]
         mean_cond_var = math_ops.reduce_sum(
             probs * self.components_distribution.variance(),
             axis=-1 - self._event_ndims)  # [B, E]
         var_cond_mean = math_ops.reduce_sum(
             probs * math_ops.squared_difference(
                 self.components_distribution.mean(),
                 self._pad_sample_dims(self._mean())),
             axis=-1 - self._event_ndims)  # [B, E]
         return mean_cond_var + var_cond_mean  # [B, E]
Exemple #7
0
    def _sample_n(self, n, seed=None):
        if self._use_static_graph:
            # This sampling approach is almost the same as the approach used by
            # `MixtureSameFamily`. The differences are due to having a list of
            # `Distribution` objects rather than a single object, and maintaining
            # random seed management that is consistent with the non-static code path.
            samples = []
            cat_samples = self.cat.sample(n, seed=seed)
            for c in range(self.num_components):
                seed = distribution_util.gen_new_seed(seed, "mixture")
                samples.append(self.components[c].sample(n, seed=seed))
            x = array_ops.stack(samples, -self._static_event_shape.ndims -
                                1)  # [n, B, k, E]
            npdt = x.dtype.as_numpy_dtype
            mask = array_ops.one_hot(
                indices=cat_samples,  # [n, B]
                depth=self._num_components,  # == k
                on_value=np.ones([], dtype=npdt),
                off_value=np.zeros([], dtype=npdt))  # [n, B, k]
            mask = distribution_utils.pad_mixture_dimensions(
                mask, self, self._cat,
                self._static_event_shape.ndims)  # [n, B, k, [1]*e]
            return math_ops.reduce_sum(
                x * mask,
                axis=-1 - self._static_event_shape.ndims)  # [n, B, E]

        with ops.control_dependencies(self._assertions):
            n = ops.convert_to_tensor(n, name="n")
            static_n = tensor_util.constant_value(n)
            n = int(static_n) if static_n is not None else n
            cat_samples = self.cat.sample(n, seed=seed)

            static_samples_shape = cat_samples.get_shape()
            if static_samples_shape.is_fully_defined():
                samples_shape = static_samples_shape.as_list()
                samples_size = static_samples_shape.num_elements()
            else:
                samples_shape = array_ops.shape(cat_samples)
                samples_size = array_ops.size(cat_samples)
            static_batch_shape = self.batch_shape
            if static_batch_shape.is_fully_defined():
                batch_shape = static_batch_shape.as_list()
                batch_size = static_batch_shape.num_elements()
            else:
                batch_shape = self.batch_shape_tensor()
                batch_size = math_ops.reduce_prod(batch_shape)
            static_event_shape = self.event_shape
            if static_event_shape.is_fully_defined():
                event_shape = np.array(static_event_shape.as_list(),
                                       dtype=np.int32)
            else:
                event_shape = self.event_shape_tensor()

            # Get indices into the raw cat sampling tensor. We will
            # need these to stitch sample values back out after sampling
            # within the component partitions.
            samples_raw_indices = array_ops.reshape(
                math_ops.range(0, samples_size), samples_shape)

            # Partition the raw indices so that we can use
            # dynamic_stitch later to reconstruct the samples from the
            # known partitions.
            partitioned_samples_indices = data_flow_ops.dynamic_partition(
                data=samples_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)

            # Copy the batch indices n times, as we will need to know
            # these to pull out the appropriate rows within the
            # component partitions.
            batch_raw_indices = array_ops.reshape(
                array_ops.tile(math_ops.range(0, batch_size), [n]),
                samples_shape)

            # Explanation of the dynamic partitioning below:
            #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
            # Suppose partitions are:
            #     [1 1 0 0 1 1]
            # After partitioning, batch indices are cut as:
            #     [batch_indices[x] for x in 2, 3]
            #     [batch_indices[x] for x in 0, 1, 4, 5]
            # i.e.
            #     [1 1] and [0 0 0 0]
            # Now we sample n=2 from part 0 and n=4 from part 1.
            # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
            # and for part 1 we want samples from batch entries 0, 0, 0, 0
            #   (samples 0, 1, 2, 3).
            partitioned_batch_indices = data_flow_ops.dynamic_partition(
                data=batch_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)
            samples_class = [None for _ in range(self.num_components)]

            for c in range(self.num_components):
                n_class = array_ops.size(partitioned_samples_indices[c])
                seed = distribution_util.gen_new_seed(seed, "mixture")
                samples_class_c = self.components[c].sample(n_class, seed=seed)

                # Pull out the correct batch entries from each index.
                # To do this, we may have to flatten the batch shape.

                # For sample s, batch element b of component c, we get the
                # partitioned batch indices from
                # partitioned_batch_indices[c]; and shift each element by
                # the sample index. The final lookup can be thought of as
                # a matrix gather along locations (s, b) in
                # samples_class_c where the n_class rows correspond to
                # samples within this component and the batch_size columns
                # correspond to batch elements within the component.
                #
                # Thus the lookup index is
                #   lookup[c, i] = batch_size * s[i] + b[c, i]
                # for i = 0 ... n_class[c] - 1.
                lookup_partitioned_batch_indices = (
                    batch_size * math_ops.range(n_class) +
                    partitioned_batch_indices[c])
                samples_class_c = array_ops.reshape(
                    samples_class_c,
                    array_ops.concat([[n_class * batch_size], event_shape], 0))
                samples_class_c = array_ops.gather(
                    samples_class_c,
                    lookup_partitioned_batch_indices,
                    name="samples_class_c_gather")
                samples_class[c] = samples_class_c

            # Stitch back together the samples across the components.
            lhs_flat_ret = data_flow_ops.dynamic_stitch(
                indices=partitioned_samples_indices, data=samples_class)
            # Reshape back to proper sample, batch, and event shape.
            ret = array_ops.reshape(
                lhs_flat_ret,
                array_ops.concat(
                    [samples_shape, self.event_shape_tensor()], 0))
            ret.set_shape(
                tensor_shape.TensorShape(static_samples_shape).concatenate(
                    self.event_shape))
            return ret