예제 #1
0
    def __init__(self,
                 mean,
                 u_tril,
                 v_tril,
                 group_ndims=0,
                 is_reparameterized=True,
                 use_path_derivative=False,
                 check_numerics=False,
                 **kwargs):
        self._check_numerics = check_numerics
        self._mean = tf.convert_to_tensor(mean)
        self._mean = assert_rank_at_least(
            self._mean, 2, 'MatrixVariateNormalCholesky.mean')
        self._n_row = get_shape_at(self._mean, -2)
        self._n_col = get_shape_at(self._mean, -1)
        self._u_tril = tf.convert_to_tensor(u_tril)
        self._u_tril = assert_rank_at_least(
            self._u_tril, 2, 'MatrixVariateNormalCholesky.u_tril')
        self._v_tril = tf.convert_to_tensor(v_tril)
        self._v_tril = assert_rank_at_least(
            self._v_tril, 2, 'MatrixVariateNormalCholesky.v_tril')

        # Static shape check
        expected_u_shape = self._mean.get_shape()[:-1].concatenate(
            [self._n_row if isinstance(self._n_row, int) else None])
        self._u_tril.get_shape().assert_is_compatible_with(expected_u_shape)
        expected_v_shape = self._mean.get_shape()[:-2].concatenate(
            [self._n_col if isinstance(self._n_col, int) else None] * 2)
        self._v_tril.get_shape().assert_is_compatible_with(expected_v_shape)
        # Dynamic
        expected_u_shape = tf.concat(
            [tf.shape(self._mean)[:-1], [self._n_row]], axis=0)
        actual_u_shape = tf.shape(self._u_tril)
        msg = ['MatrixVariateNormalCholesky.u_tril should have compatible '
               'shape with mean. Expected', expected_u_shape, ' got ',
               actual_u_shape]
        assert_u_ops = tf.assert_equal(expected_u_shape, actual_u_shape, msg)
        expected_v_shape = tf.concat(
            [tf.shape(self._mean)[:-2], [self._n_col, self._n_col]], axis=0)
        actual_v_shape = tf.shape(self._v_tril)
        msg = ['MatrixVariateNormalCholesky.v_tril should have compatible '
               'shape with mean. Expected', expected_v_shape, ' got ',
               actual_v_shape]
        assert_v_ops = tf.assert_equal(expected_v_shape, actual_v_shape, msg)
        with tf.control_dependencies([assert_u_ops, assert_v_ops]):
            self._u_tril = tf.identity(self._u_tril)
            self._v_tril = tf.identity(self._v_tril)

        dtype = assert_same_float_dtype(
            [(self._mean, 'MatrixVariateNormalCholesky.mean'),
             (self._u_tril, 'MatrixVariateNormalCholesky.u_tril'),
             (self._v_tril, 'MatrixVariateNormalCholesky.v_tril')])
        super(MatrixVariateNormalCholesky, self).__init__(
            dtype=dtype,
            param_dtype=dtype,
            is_continuous=True,
            is_reparameterized=is_reparameterized,
            use_path_derivative=use_path_derivative,
            group_ndims=group_ndims,
            **kwargs)
예제 #2
0
    def __init__(self,
                 mean,
                 cov_tril,
                 group_ndims=0,
                 is_reparameterized=True,
                 use_path_derivative=False,
                 check_numerics=False,
                 **kwargs):
        self._check_numerics = check_numerics
        self._mean = tf.convert_to_tensor(mean)
        self._mean = assert_rank_at_least_one(
            self._mean, 'MultivariateNormalCholesky.mean')
        self._n_dim = get_shape_at(self._mean, -1)
        self._cov_tril = tf.convert_to_tensor(cov_tril)
        self._cov_tril = assert_rank_at_least(
            self._cov_tril, 2, 'MultivariateNormalCholesky.cov_tril')

        # Static shape check
        expected_shape = self._mean.get_shape().concatenate(
            [self._n_dim if isinstance(self._n_dim, int) else None])
        self._cov_tril.get_shape().assert_is_compatible_with(expected_shape)
        # Dynamic
        expected_shape = tf.concat([tf.shape(self._mean), [self._n_dim]],
                                   axis=0)
        actual_shape = tf.shape(self._cov_tril)
        msg = [
            'MultivariateNormalCholesky.cov_tril should have compatible '
            'shape with mean. Expected', expected_shape, ' got ', actual_shape
        ]
        assert_ops = [tf.assert_equal(expected_shape, actual_shape, msg)]
        with tf.control_dependencies(assert_ops):
            self._cov_tril = tf.identity(self._cov_tril)

        dtype = assert_same_float_dtype([
            (self._mean, 'MultivariateNormalCholesky.mean'),
            (self._cov_tril, 'MultivariateNormalCholesky.cov_tril')
        ])
        super(MultivariateNormalCholesky,
              self).__init__(dtype=dtype,
                             param_dtype=dtype,
                             is_continuous=True,
                             is_reparameterized=is_reparameterized,
                             use_path_derivative=use_path_derivative,
                             group_ndims=group_ndims,
                             **kwargs)
예제 #3
0
 def test_assert_rank_at_least(self):
     with self.test_session(use_gpu=True):
         # Static
         ph = tf.placeholder(tf.float32, [2, None])
         assert_rank_at_least(ph, 2, 'ph')
         with self.assertRaisesRegexp(ValueError, 'should have rank'):
             assert_rank_at_least(ph, 3, 'ph')
         # Dynamic
         ph = tf.placeholder(tf.float32, None)
         assert_2 = assert_rank_at_least(ph, 2, 'ph')
         assert_3 = assert_rank_at_least(ph, 3, 'ph')
         fd = {ph: np.ones([2, 9])}
         assert_2.eval(fd)
         with self.assertRaises(tf.errors.InvalidArgumentError):
             assert_3.eval(fd)