Exemple #1
0
    def test_non_vector_shape(self):
        dims = 2
        new_batch_shape = 2
        old_batch_shape = [2]

        new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape))
                              if self.is_static_shape else
                              array_ops.placeholder_with_default(
                                  np.int32(new_batch_shape), shape=None))

        scale = np.ones(old_batch_shape + [dims], self.dtype)
        scale_ph = array_ops.placeholder_with_default(
            scale, shape=scale.shape if self.is_static_shape else None)
        mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)

        if self.is_static_shape:
            with self.assertRaisesRegexp(ValueError, r".*must be a vector.*"):
                batch_reshape_lib.BatchReshape(distribution=mvn,
                                               batch_shape=new_batch_shape_ph,
                                               validate_args=True)

        else:
            with self.cached_session():
                with self.assertRaisesOpError(r".*must be a vector.*"):
                    batch_reshape_lib.BatchReshape(
                        distribution=mvn,
                        batch_shape=new_batch_shape_ph,
                        validate_args=True).sample().eval()
Exemple #2
0
    def test_non_positive_shape(self):
        dims = 2
        old_batch_shape = [4]
        if self.is_static_shape:
            # Unknown first dimension does not trigger size check. Note that
            # any dimension < 0 is treated statically as unknown.
            new_batch_shape = [-1, 0]
        else:
            new_batch_shape = [-2,
                               -2]  # -2 * -2 = 4, same size as the old shape.

        new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape))
                              if self.is_static_shape else
                              array_ops.placeholder_with_default(
                                  np.int32(new_batch_shape), shape=None))

        scale = np.ones(old_batch_shape + [dims], self.dtype)
        scale_ph = array_ops.placeholder_with_default(
            scale, shape=scale.shape if self.is_static_shape else None)
        mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)

        if self.is_static_shape:
            with self.assertRaisesRegexp(ValueError, r".*must be >=-1.*"):
                batch_reshape_lib.BatchReshape(distribution=mvn,
                                               batch_shape=new_batch_shape_ph,
                                               validate_args=True)

        else:
            with self.cached_session():
                with self.assertRaisesOpError(r".*must be >=-1.*"):
                    batch_reshape_lib.BatchReshape(
                        distribution=mvn,
                        batch_shape=new_batch_shape_ph,
                        validate_args=True).sample().eval()
Exemple #3
0
    def test_bad_reshape_size(self):
        dims = 2
        new_batch_shape = [2, 3]
        old_batch_shape = [2]  # 2 != 2*3

        new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape))
                              if self.is_static_shape else
                              array_ops.placeholder_with_default(
                                  np.int32(new_batch_shape), shape=None))

        scale = np.ones(old_batch_shape + [dims], self.dtype)
        scale_ph = array_ops.placeholder_with_default(
            scale, shape=scale.shape if self.is_static_shape else None)
        mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)

        if self.is_static_shape:
            with self.assertRaisesRegexp(
                    ValueError, (r"`batch_shape` size \(6\) must match "
                                 r"`distribution\.batch_shape` size \(2\)")):
                batch_reshape_lib.BatchReshape(distribution=mvn,
                                               batch_shape=new_batch_shape_ph,
                                               validate_args=True)

        else:
            with self.cached_session():
                with self.assertRaisesOpError(r"Shape sizes do not match."):
                    batch_reshape_lib.BatchReshape(
                        distribution=mvn,
                        batch_shape=new_batch_shape_ph,
                        validate_args=True).sample().eval()
Exemple #4
0
    def test_broadcasting_explicitly_unsupported(self):
        old_batch_shape = [4]
        new_batch_shape = [1, 4, 1]
        rate_ = self.dtype([1, 10, 2, 20])

        rate = array_ops.placeholder_with_default(
            rate_, shape=old_batch_shape if self.is_static_shape else None)
        poisson_4 = poisson_lib.Poisson(rate)
        new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape))
                              if self.is_static_shape else
                              array_ops.placeholder_with_default(
                                  np.int32(new_batch_shape), shape=None))
        poisson_141_reshaped = batch_reshape_lib.BatchReshape(
            poisson_4, new_batch_shape_ph, validate_args=True)

        x_4 = self.dtype([2, 12, 3, 23])
        x_114 = self.dtype([2, 12, 3, 23]).reshape(1, 1, 4)

        if self.is_static_shape:
            with self.assertRaisesRegexp(NotImplementedError,
                                         "too few batch and event dims"):
                poisson_141_reshaped.log_prob(x_4)
            with self.assertRaisesRegexp(NotImplementedError,
                                         "unexpected batch and event shape"):
                poisson_141_reshaped.log_prob(x_114)
            return

        with self.assertRaisesOpError("too few batch and event dims"):
            with self.cached_session():
                poisson_141_reshaped.log_prob(x_4).eval()

        with self.assertRaisesOpError("unexpected batch and event shape"):
            with self.cached_session():
                poisson_141_reshaped.log_prob(x_114).eval()
Exemple #5
0
    def make_mvn(self, dims, new_batch_shape, old_batch_shape):
        new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape))
                              if self.is_static_shape else
                              array_ops.placeholder_with_default(
                                  np.int32(new_batch_shape), shape=None))

        scale = np.ones(old_batch_shape + [dims], self.dtype)
        scale_ph = array_ops.placeholder_with_default(
            scale, shape=scale.shape if self.is_static_shape else None)
        mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph)
        reshape_mvn = batch_reshape_lib.BatchReshape(
            distribution=mvn,
            batch_shape=new_batch_shape_ph,
            validate_args=True)
        return mvn, reshape_mvn
Exemple #6
0
    def make_normal(self, new_batch_shape, old_batch_shape):
        new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape))
                              if self.is_static_shape else
                              array_ops.placeholder_with_default(
                                  np.int32(new_batch_shape), shape=None))

        scale = self.dtype(
            0.5 + np.arange(np.prod(old_batch_shape)).reshape(old_batch_shape))
        scale_ph = array_ops.placeholder_with_default(
            scale, shape=scale.shape if self.is_static_shape else None)
        normal = normal_lib.Normal(loc=self.dtype(0), scale=scale_ph)
        reshape_normal = batch_reshape_lib.BatchReshape(
            distribution=normal,
            batch_shape=new_batch_shape_ph,
            validate_args=True)
        return normal, reshape_normal
Exemple #7
0
    def make_wishart(self, dims, new_batch_shape, old_batch_shape):
        new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape))
                              if self.is_static_shape else
                              array_ops.placeholder_with_default(
                                  np.int32(new_batch_shape), shape=None))

        scale = self.dtype([
            [[1., 0.5], [0.5, 1.]],
            [[0.5, 0.25], [0.25, 0.75]],
        ])
        scale = np.reshape(np.concatenate([scale, scale], axis=0),
                           old_batch_shape + [dims, dims])
        scale_ph = array_ops.placeholder_with_default(
            scale, shape=scale.shape if self.is_static_shape else None)
        wishart = wishart_lib.WishartFull(df=5, scale=scale_ph)
        reshape_wishart = batch_reshape_lib.BatchReshape(
            distribution=wishart,
            batch_shape=new_batch_shape_ph,
            validate_args=True)

        return wishart, reshape_wishart