def test_non_vector_shape(self): dims = 2 new_batch_shape = 2 old_batch_shape = [2] new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape else array_ops.placeholder_with_default( np.int32(new_batch_shape), shape=None)) scale = np.ones(old_batch_shape + [dims], self.dtype) scale_ph = array_ops.placeholder_with_default( scale, shape=scale.shape if self.is_static_shape else None) mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) if self.is_static_shape: with self.assertRaisesRegexp(ValueError, r".*must be a vector.*"): batch_reshape_lib.BatchReshape(distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True) else: with self.cached_session(): with self.assertRaisesOpError(r".*must be a vector.*"): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True).sample().eval()
def test_non_positive_shape(self): dims = 2 old_batch_shape = [4] if self.is_static_shape: # Unknown first dimension does not trigger size check. Note that # any dimension < 0 is treated statically as unknown. new_batch_shape = [-1, 0] else: new_batch_shape = [-2, -2] # -2 * -2 = 4, same size as the old shape. new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape else array_ops.placeholder_with_default( np.int32(new_batch_shape), shape=None)) scale = np.ones(old_batch_shape + [dims], self.dtype) scale_ph = array_ops.placeholder_with_default( scale, shape=scale.shape if self.is_static_shape else None) mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) if self.is_static_shape: with self.assertRaisesRegexp(ValueError, r".*must be >=-1.*"): batch_reshape_lib.BatchReshape(distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True) else: with self.cached_session(): with self.assertRaisesOpError(r".*must be >=-1.*"): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True).sample().eval()
def test_bad_reshape_size(self): dims = 2 new_batch_shape = [2, 3] old_batch_shape = [2] # 2 != 2*3 new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape else array_ops.placeholder_with_default( np.int32(new_batch_shape), shape=None)) scale = np.ones(old_batch_shape + [dims], self.dtype) scale_ph = array_ops.placeholder_with_default( scale, shape=scale.shape if self.is_static_shape else None) mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) if self.is_static_shape: with self.assertRaisesRegexp( ValueError, (r"`batch_shape` size \(6\) must match " r"`distribution\.batch_shape` size \(2\)")): batch_reshape_lib.BatchReshape(distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True) else: with self.cached_session(): with self.assertRaisesOpError(r"Shape sizes do not match."): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True).sample().eval()
def test_broadcasting_explicitly_unsupported(self): old_batch_shape = [4] new_batch_shape = [1, 4, 1] rate_ = self.dtype([1, 10, 2, 20]) rate = array_ops.placeholder_with_default( rate_, shape=old_batch_shape if self.is_static_shape else None) poisson_4 = poisson_lib.Poisson(rate) new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape else array_ops.placeholder_with_default( np.int32(new_batch_shape), shape=None)) poisson_141_reshaped = batch_reshape_lib.BatchReshape( poisson_4, new_batch_shape_ph, validate_args=True) x_4 = self.dtype([2, 12, 3, 23]) x_114 = self.dtype([2, 12, 3, 23]).reshape(1, 1, 4) if self.is_static_shape: with self.assertRaisesRegexp(NotImplementedError, "too few batch and event dims"): poisson_141_reshaped.log_prob(x_4) with self.assertRaisesRegexp(NotImplementedError, "unexpected batch and event shape"): poisson_141_reshaped.log_prob(x_114) return with self.assertRaisesOpError("too few batch and event dims"): with self.cached_session(): poisson_141_reshaped.log_prob(x_4).eval() with self.assertRaisesOpError("unexpected batch and event shape"): with self.cached_session(): poisson_141_reshaped.log_prob(x_114).eval()
def make_mvn(self, dims, new_batch_shape, old_batch_shape): new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape else array_ops.placeholder_with_default( np.int32(new_batch_shape), shape=None)) scale = np.ones(old_batch_shape + [dims], self.dtype) scale_ph = array_ops.placeholder_with_default( scale, shape=scale.shape if self.is_static_shape else None) mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) reshape_mvn = batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, validate_args=True) return mvn, reshape_mvn
def make_normal(self, new_batch_shape, old_batch_shape): new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape else array_ops.placeholder_with_default( np.int32(new_batch_shape), shape=None)) scale = self.dtype( 0.5 + np.arange(np.prod(old_batch_shape)).reshape(old_batch_shape)) scale_ph = array_ops.placeholder_with_default( scale, shape=scale.shape if self.is_static_shape else None) normal = normal_lib.Normal(loc=self.dtype(0), scale=scale_ph) reshape_normal = batch_reshape_lib.BatchReshape( distribution=normal, batch_shape=new_batch_shape_ph, validate_args=True) return normal, reshape_normal
def make_wishart(self, dims, new_batch_shape, old_batch_shape): new_batch_shape_ph = (constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape else array_ops.placeholder_with_default( np.int32(new_batch_shape), shape=None)) scale = self.dtype([ [[1., 0.5], [0.5, 1.]], [[0.5, 0.25], [0.25, 0.75]], ]) scale = np.reshape(np.concatenate([scale, scale], axis=0), old_batch_shape + [dims, dims]) scale_ph = array_ops.placeholder_with_default( scale, shape=scale.shape if self.is_static_shape else None) wishart = wishart_lib.WishartFull(df=5, scale=scale_ph) reshape_wishart = batch_reshape_lib.BatchReshape( distribution=wishart, batch_shape=new_batch_shape_ph, validate_args=True) return wishart, reshape_wishart