def __init__(self, mean, u_tril, v_tril, group_ndims=0, is_reparameterized=True, use_path_derivative=False, check_numerics=False, **kwargs): self._check_numerics = check_numerics self._mean = tf.convert_to_tensor(mean) self._mean = assert_rank_at_least( self._mean, 2, 'MatrixVariateNormalCholesky.mean') self._n_row = get_shape_at(self._mean, -2) self._n_col = get_shape_at(self._mean, -1) self._u_tril = tf.convert_to_tensor(u_tril) self._u_tril = assert_rank_at_least( self._u_tril, 2, 'MatrixVariateNormalCholesky.u_tril') self._v_tril = tf.convert_to_tensor(v_tril) self._v_tril = assert_rank_at_least( self._v_tril, 2, 'MatrixVariateNormalCholesky.v_tril') # Static shape check expected_u_shape = self._mean.get_shape()[:-1].concatenate( [self._n_row if isinstance(self._n_row, int) else None]) self._u_tril.get_shape().assert_is_compatible_with(expected_u_shape) expected_v_shape = self._mean.get_shape()[:-2].concatenate( [self._n_col if isinstance(self._n_col, int) else None] * 2) self._v_tril.get_shape().assert_is_compatible_with(expected_v_shape) # Dynamic expected_u_shape = tf.concat( [tf.shape(self._mean)[:-1], [self._n_row]], axis=0) actual_u_shape = tf.shape(self._u_tril) msg = ['MatrixVariateNormalCholesky.u_tril should have compatible ' 'shape with mean. Expected', expected_u_shape, ' got ', actual_u_shape] assert_u_ops = tf.assert_equal(expected_u_shape, actual_u_shape, msg) expected_v_shape = tf.concat( [tf.shape(self._mean)[:-2], [self._n_col, self._n_col]], axis=0) actual_v_shape = tf.shape(self._v_tril) msg = ['MatrixVariateNormalCholesky.v_tril should have compatible ' 'shape with mean. Expected', expected_v_shape, ' got ', actual_v_shape] assert_v_ops = tf.assert_equal(expected_v_shape, actual_v_shape, msg) with tf.control_dependencies([assert_u_ops, assert_v_ops]): self._u_tril = tf.identity(self._u_tril) self._v_tril = tf.identity(self._v_tril) dtype = assert_same_float_dtype( [(self._mean, 'MatrixVariateNormalCholesky.mean'), (self._u_tril, 'MatrixVariateNormalCholesky.u_tril'), (self._v_tril, 'MatrixVariateNormalCholesky.v_tril')]) super(MatrixVariateNormalCholesky, self).__init__( dtype=dtype, param_dtype=dtype, is_continuous=True, is_reparameterized=is_reparameterized, use_path_derivative=use_path_derivative, group_ndims=group_ndims, **kwargs)
def __init__(self, mean, cov_tril, group_ndims=0, is_reparameterized=True, use_path_derivative=False, check_numerics=False, **kwargs): self._check_numerics = check_numerics self._mean = tf.convert_to_tensor(mean) self._mean = assert_rank_at_least_one( self._mean, 'MultivariateNormalCholesky.mean') self._n_dim = get_shape_at(self._mean, -1) self._cov_tril = tf.convert_to_tensor(cov_tril) self._cov_tril = assert_rank_at_least( self._cov_tril, 2, 'MultivariateNormalCholesky.cov_tril') # Static shape check expected_shape = self._mean.get_shape().concatenate( [self._n_dim if isinstance(self._n_dim, int) else None]) self._cov_tril.get_shape().assert_is_compatible_with(expected_shape) # Dynamic expected_shape = tf.concat([tf.shape(self._mean), [self._n_dim]], axis=0) actual_shape = tf.shape(self._cov_tril) msg = [ 'MultivariateNormalCholesky.cov_tril should have compatible ' 'shape with mean. Expected', expected_shape, ' got ', actual_shape ] assert_ops = [tf.assert_equal(expected_shape, actual_shape, msg)] with tf.control_dependencies(assert_ops): self._cov_tril = tf.identity(self._cov_tril) dtype = assert_same_float_dtype([ (self._mean, 'MultivariateNormalCholesky.mean'), (self._cov_tril, 'MultivariateNormalCholesky.cov_tril') ]) super(MultivariateNormalCholesky, self).__init__(dtype=dtype, param_dtype=dtype, is_continuous=True, is_reparameterized=is_reparameterized, use_path_derivative=use_path_derivative, group_ndims=group_ndims, **kwargs)
def test_assert_rank_at_least(self): with self.test_session(use_gpu=True): # Static ph = tf.placeholder(tf.float32, [2, None]) assert_rank_at_least(ph, 2, 'ph') with self.assertRaisesRegexp(ValueError, 'should have rank'): assert_rank_at_least(ph, 3, 'ph') # Dynamic ph = tf.placeholder(tf.float32, None) assert_2 = assert_rank_at_least(ph, 2, 'ph') assert_3 = assert_rank_at_least(ph, 3, 'ph') fd = {ph: np.ones([2, 9])} assert_2.eval(fd) with self.assertRaises(tf.errors.InvalidArgumentError): assert_3.eval(fd)