def test_with_rank_ndarray(self):
        x = np.array([2], dtype=np.int32)
        with self.assertRaises(ValueError):
            tensorshape_util.with_rank(x, 2)

        x = np.array([2, 3, 4], dtype=np.int32)
        y = tensorshape_util.with_rank(x, 3)
        self.assertAllEqual(x, y)

        x = np.array([2, 3, 4, 5], dtype=np.int32)
        y = tensorshape_util.with_rank_at_least(x, 3)
        self.assertAllEqual(x, y)
    def test_with_rank_list_tuple(self):
        with self.assertRaises(ValueError):
            tensorshape_util.with_rank([2], 2)

        with self.assertRaises(ValueError):
            tensorshape_util.with_rank((2, ), 2)

        self.assertAllEqual((2, 1), tensorshape_util.with_rank((2, 1), 2))
        self.assertAllEqual([2, 1], tensorshape_util.with_rank([2, 1], 2))

        self.assertAllEqual((2, 3, 4),
                            tensorshape_util.with_rank_at_least((2, 3, 4), 2))
        self.assertAllEqual([2, 3, 4],
                            tensorshape_util.with_rank_at_least([2, 3, 4], 2))
Ejemplo n.º 3
0
 def _event_shape(self):
     # We will never broadcast the num_categories with total_count.
     return tensorshape_util.with_rank(
         (self._probs if self._logits is None else self._logits).shape[-1:],
         rank=1)
Ejemplo n.º 4
0
 def _event_shape(self):
     return tensorshape_util.with_rank(self.concentration.shape[-1:],
                                       rank=1)
Ejemplo n.º 5
0
 def _event_shape(self):
     param = self._logits if self._logits is not None else self._probs
     return tensorshape_util.with_rank(param.shape[-1:], rank=1)
Ejemplo n.º 6
0
 def _event_shape(self):
   return tensorshape_util.with_rank(self.mean_direction.shape[-1:], rank=1)
Ejemplo n.º 7
0
 def _event_shape(self):
     # Event shape depends only on concentration, not total_count.
     return tensorshape_util.with_rank(self.concentration.shape[-1:],
                                       rank=1)
Ejemplo n.º 8
0
    def _sample_n(self, n, seed=None):
        stream = SeedStream(seed, salt="VectorDiffeomixture")
        x = self.distribution.sample(sample_shape=concat_vectors(
            [n], self.batch_shape_tensor(), self.event_shape_tensor()),
                                     seed=stream())  # shape: [n, B, e]
        x = [aff.forward(x) for aff in self.endpoint_affine]

        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        batch_size = tensorshape_util.num_elements(self.batch_shape)
        if batch_size is None:
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
        mix_batch_size = tensorshape_util.num_elements(
            self.mixture_distribution.batch_shape)
        if mix_batch_size is None:
            mix_batch_size = tf.reduce_prod(
                self.mixture_distribution.batch_shape_tensor())
        ids = self.mixture_distribution.sample(sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]),
                                          [batch_size // mix_batch_size])),
                                               seed=stream())
        # We need to flatten batch dims in case mixture_distribution has its own
        # batch dims.
        ids = tf.reshape(ids,
                         shape=concat_vectors([n],
                                              distribution_util.pick_vector(
                                                  self.is_scalar_batch(),
                                                  np.int32([]),
                                                  np.int32([-1]))))

        # Stride `components * quadrature_size` for `batch_size` number of times.
        stride = tensorshape_util.num_elements(
            tensorshape_util.with_rank(self.grid.shape[-2:], rank=2))
        if stride is None:
            stride = tf.reduce_prod(tf.shape(self.grid)[-2:])
        offset = tf.range(start=0,
                          limit=batch_size * stride,
                          delta=stride,
                          dtype=ids.dtype)

        weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset)
        # At this point, weight flattened all batch dims into one.
        # We also need to append a singleton to broadcast with event dims.
        if tensorshape_util.is_fully_defined(self.batch_shape):
            new_shape = [-1] + tensorshape_util.as_list(self.batch_shape) + [1]
        else:
            new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]),
                                  axis=0)
        weight = tf.reshape(weight, shape=new_shape)

        if len(x) != 2:
            # We actually should have already triggered this exception. However as a
            # policy we're putting this exception wherever we exploit the bimixture
            # assumption.
            raise NotImplementedError(
                "Currently only bimixtures are supported; "
                "len(scale)={} is not 2.".format(len(x)))

        # Alternatively:
        # x = weight * x[0] + (1. - weight) * x[1]
        x = weight * (x[0] - x[1]) + x[1]

        return x