Example #1
0
    def __init__(self,
                 mean=0.,
                 logstd=None,
                 std=None,
                 group_event_ndims=0,
                 is_reparameterized=True,
                 check_numerics=False):
        self._mean = tf.convert_to_tensor(mean)
        warnings.warn("Normal: The order of arguments logstd/std will change "
                      "to std/logstd in the coming version.")
        if (logstd is None) == (std is None):
            raise ValueError("Either std or logstd should be passed but not "
                             "both of them.")
        elif logstd is None:
            self._std = tf.convert_to_tensor(std)
            dtype = assert_same_float_dtype([(self._mean, 'Normal.mean'),
                                             (self._std, 'Normal.std')])
            logstd = tf.log(self._std)
            if check_numerics:
                with tf.control_dependencies(
                        [tf.check_numerics(logstd, "log(std)")]):
                    logstd = tf.identity(logstd)
            self._logstd = logstd
        else:
            # std is None
            self._logstd = tf.convert_to_tensor(logstd)
            dtype = assert_same_float_dtype([(self._mean, 'Normal.mean'),
                                             (self._logstd, 'Normal.logstd')])
            std = tf.exp(self._logstd)
            if check_numerics:
                with tf.control_dependencies(
                        [tf.check_numerics(std, "exp(logstd)")]):
                    std = tf.identity(std)
            self._std = std

        try:
            tf.broadcast_static_shape(self._mean.get_shape(),
                                      self._std.get_shape())
        except ValueError:
            raise ValueError(
                "mean and std/logstd should be broadcastable to match each "
                "other. ({} vs. {})".format(
                    self._mean.get_shape(), self._std.get_shape()))
        self._check_numerics = check_numerics
        super(Normal, self).__init__(
            dtype=dtype,
            param_dtype=dtype,
            is_continuous=True,
            is_reparameterized=is_reparameterized,
            group_event_ndims=group_event_ndims)
Example #2
0
    def __init__(self,
                 temperature,
                 logits,
                 group_ndims=0,
                 is_reparameterized=True,
                 use_path_derivative=False,
                 check_numerics=False,
                 **kwargs):
        self._logits = tf.convert_to_tensor(logits)
        self._temperature = tf.convert_to_tensor(temperature)
        param_dtype = assert_same_float_dtype([
            (self._logits, 'Concrete.logits'),
            (self._temperature, 'Concrete.temperature')
        ])

        self._logits, self._n_categories = assert_rank_at_least_one(
            self._logits, 'Concrete.logits')

        self._temperature = assert_scalar(self._temperature,
                                          'Concrete.temperature')

        self._check_numerics = check_numerics
        super(Concrete, self).__init__(dtype=param_dtype,
                                       param_dtype=param_dtype,
                                       is_continuous=True,
                                       is_reparameterized=is_reparameterized,
                                       use_path_derivative=use_path_derivative,
                                       group_ndims=group_ndims,
                                       **kwargs)
    def __init__(self,
                 mean,
                 u_tril,
                 v_tril,
                 group_ndims=0,
                 is_reparameterized=True,
                 use_path_derivative=False,
                 check_numerics=False,
                 **kwargs):
        self._check_numerics = check_numerics
        self._mean = tf.convert_to_tensor(mean)
        self._mean = assert_rank_at_least(
            self._mean, 2, 'MatrixVariateNormalCholesky.mean')
        self._n_row = get_shape_at(self._mean, -2)
        self._n_col = get_shape_at(self._mean, -1)
        self._u_tril = tf.convert_to_tensor(u_tril)
        self._u_tril = assert_rank_at_least(
            self._u_tril, 2, 'MatrixVariateNormalCholesky.u_tril')
        self._v_tril = tf.convert_to_tensor(v_tril)
        self._v_tril = assert_rank_at_least(
            self._v_tril, 2, 'MatrixVariateNormalCholesky.v_tril')

        # Static shape check
        expected_u_shape = self._mean.get_shape()[:-1].concatenate(
            [self._n_row if isinstance(self._n_row, int) else None])
        self._u_tril.get_shape().assert_is_compatible_with(expected_u_shape)
        expected_v_shape = self._mean.get_shape()[:-2].concatenate(
            [self._n_col if isinstance(self._n_col, int) else None] * 2)
        self._v_tril.get_shape().assert_is_compatible_with(expected_v_shape)
        # Dynamic
        expected_u_shape = tf.concat(
            [tf.shape(self._mean)[:-1], [self._n_row]], axis=0)
        actual_u_shape = tf.shape(self._u_tril)
        msg = ['MatrixVariateNormalCholesky.u_tril should have compatible '
               'shape with mean. Expected', expected_u_shape, ' got ',
               actual_u_shape]
        assert_u_ops = tf.assert_equal(expected_u_shape, actual_u_shape, msg)
        expected_v_shape = tf.concat(
            [tf.shape(self._mean)[:-2], [self._n_col, self._n_col]], axis=0)
        actual_v_shape = tf.shape(self._v_tril)
        msg = ['MatrixVariateNormalCholesky.v_tril should have compatible '
               'shape with mean. Expected', expected_v_shape, ' got ',
               actual_v_shape]
        assert_v_ops = tf.assert_equal(expected_v_shape, actual_v_shape, msg)
        with tf.control_dependencies([assert_u_ops, assert_v_ops]):
            self._u_tril = tf.identity(self._u_tril)
            self._v_tril = tf.identity(self._v_tril)

        dtype = assert_same_float_dtype(
            [(self._mean, 'MatrixVariateNormalCholesky.mean'),
             (self._u_tril, 'MatrixVariateNormalCholesky.u_tril'),
             (self._v_tril, 'MatrixVariateNormalCholesky.v_tril')])
        super(MatrixVariateNormalCholesky, self).__init__(
            dtype=dtype,
            param_dtype=dtype,
            is_continuous=True,
            is_reparameterized=is_reparameterized,
            use_path_derivative=use_path_derivative,
            group_ndims=group_ndims,
            **kwargs)
Example #4
0
    def __init__(self,
                 minval=0.,
                 maxval=1.,
                 group_event_ndims=0,
                 is_reparameterized=True,
                 check_numerics=False):
        self._minval = tf.convert_to_tensor(minval)
        self._maxval = tf.convert_to_tensor(maxval)
        dtype = assert_same_float_dtype([(self._minval, 'Uniform.minval'),
                                         (self._maxval, 'Uniform.maxval')])

        try:
            tf.broadcast_static_shape(self._minval.get_shape(),
                                      self._maxval.get_shape())
        except ValueError:
            raise ValueError(
                "minval and maxval should be broadcastable to match each "
                "other. ({} vs. {})".format(self._minval.get_shape(),
                                            self._maxval.get_shape()))
        self._check_numerics = check_numerics
        super(Uniform, self).__init__(dtype=dtype,
                                      param_dtype=dtype,
                                      is_continuous=True,
                                      is_reparameterized=is_reparameterized,
                                      group_event_ndims=group_event_ndims)
Example #5
0
    def __init__(self,
                 loc,
                 scale,
                 group_event_ndims=0,
                 is_reparameterized=True,
                 check_numerics=False):
        self._loc = tf.convert_to_tensor(loc)
        self._scale = tf.convert_to_tensor(scale)
        dtype = assert_same_float_dtype([(self._loc, 'Laplace.loc'),
                                         (self._scale, 'Laplace.scale')])

        try:
            tf.broadcast_static_shape(self._loc.get_shape(),
                                      self._scale.get_shape())
        except ValueError:
            raise ValueError(
                "loc and scale should be broadcastable to match each "
                "other. ({} vs. {})".format(self._loc.get_shape(),
                                            self._scale.get_shape()))
        self._check_numerics = check_numerics
        super(Laplace, self).__init__(dtype=dtype,
                                      param_dtype=dtype,
                                      is_continuous=True,
                                      is_reparameterized=is_reparameterized,
                                      group_event_ndims=group_event_ndims)
Example #6
0
    def __init__(self,
                 logits,
                 normalize_logits=True,
                 dtype=tf.int32,
                 group_ndims=0,
                 **kwargs):
        self._logits = tf.convert_to_tensor(logits)
        param_dtype = assert_same_float_dtype(
            [(self._logits, 'UnnormalizedMultinomial.logits')])

        assert_dtype_is_int_or_float(dtype)

        self._logits = assert_rank_at_least_one(
            self._logits, 'UnnormalizedMultinomial.logits')
        self._n_categories = get_shape_at(self._logits, -1)

        self.normalize_logits = normalize_logits

        super(UnnormalizedMultinomial, self).__init__(
            dtype=dtype,
            param_dtype=param_dtype,
            is_continuous=False,
            is_reparameterized=False,
            group_ndims=group_ndims,
            **kwargs)
Example #7
0
    def __init__(self,
                 alpha,
                 beta,
                 dtype=None,
                 group_event_ndims=0,
                 check_numerics=False):
        self._alpha = tf.convert_to_tensor(alpha)
        self._beta = tf.convert_to_tensor(beta)
        dtype = assert_same_float_dtype([(self._alpha, 'Beta.alpha'),
                                         (self._beta, 'Beta.beta')])

        try:
            tf.broadcast_static_shape(self._alpha.get_shape(),
                                      self._beta.get_shape())
        except ValueError:
            raise ValueError(
                "alpha and beta should be broadcastable to match each "
                "other. ({} vs. {})".format(self._alpha.get_shape(),
                                            self._beta.get_shape()))
        self._check_numerics = check_numerics
        super(Beta, self).__init__(dtype=dtype,
                                   param_dtype=dtype,
                                   is_continuous=True,
                                   is_reparameterized=False,
                                   group_event_ndims=group_event_ndims)
Example #8
0
    def __init__(self, alpha, group_ndims=0, check_numerics=False, **kwargs):
        self._alpha = tf.convert_to_tensor(alpha)
        dtype = assert_same_float_dtype([(self._alpha, 'Dirichlet.alpha')])

        static_alpha_shape = self._alpha.get_shape()
        shape_err_msg = "alpha should have rank >= 1."
        cat_err_msg = "n_categories (length of the last axis " \
                      "of alpha) should be at least 2."
        if static_alpha_shape and (static_alpha_shape.ndims < 1):
            raise ValueError(shape_err_msg)
        elif static_alpha_shape and (static_alpha_shape[-1].value is not None):
            self._n_categories = static_alpha_shape[-1].value
            if self._n_categories < 2:
                raise ValueError(cat_err_msg)
        else:
            _assert_shape_op = tf.assert_rank_at_least(self._alpha,
                                                       1,
                                                       message=shape_err_msg)
            with tf.control_dependencies([_assert_shape_op]):
                self._alpha = tf.identity(self._alpha)
            self._n_categories = tf.shape(self._alpha)[-1]

            _assert_cat_op = tf.assert_greater_equal(self._n_categories,
                                                     2,
                                                     message=cat_err_msg)
            with tf.control_dependencies([_assert_cat_op]):
                self._alpha = tf.identity(self._alpha)
        self._check_numerics = check_numerics

        super(Dirichlet, self).__init__(dtype=dtype,
                                        param_dtype=dtype,
                                        is_continuous=True,
                                        is_reparameterized=False,
                                        group_ndims=group_ndims,
                                        **kwargs)
Example #9
0
    def __init__(self, logits, dtype=None, group_event_ndims=0):
        self._logits = tf.convert_to_tensor(logits)
        param_dtype = assert_same_float_dtype(
            [(self._logits, 'Categorical.logits')])

        if dtype is None:
            dtype = tf.int32
        assert_same_float_and_int_dtype([], dtype)

        static_logits_shape = self._logits.get_shape()
        shape_err_msg = "logits should have rank >= 1."
        if static_logits_shape and (static_logits_shape.ndims < 1):
            raise ValueError(shape_err_msg)
        elif static_logits_shape and (
                static_logits_shape[-1].value is not None):
            self._n_categories = static_logits_shape[-1].value
        else:
            _assert_shape_op = tf.assert_rank_at_least(
                self._logits, 1, message=shape_err_msg)
            with tf.control_dependencies([_assert_shape_op]):
                self._logits = tf.identity(self._logits)
            self._n_categories = tf.shape(self._logits)[-1]

        super(Categorical, self).__init__(
            dtype=dtype,
            param_dtype=param_dtype,
            is_continuous=False,
            is_reparameterized=False,
            group_event_ndims=group_event_ndims)
Example #10
0
    def __init__(self,
                 logits,
                 n_experiments,
                 dtype=None,
                 group_ndims=0,
                 **kwargs):
        self._logits = tf.convert_to_tensor(logits)
        param_dtype = assert_same_float_dtype([(self._logits,
                                                'Multinomial.logits')])

        if dtype is None:
            dtype = tf.int32
        assert_same_float_and_int_dtype([], dtype)

        self._logits, self._n_categories = assert_rank_at_least_one(
            self._logits, 'Multinomial.logits')

        self._n_experiments = assert_positive_int32_integer(
            n_experiments, 'Multinomial.n_experiments')

        super(Multinomial, self).__init__(dtype=dtype,
                                          param_dtype=param_dtype,
                                          is_continuous=False,
                                          is_reparameterized=False,
                                          group_ndims=group_ndims,
                                          **kwargs)
Example #11
0
    def __init__(self,
                 temperature,
                 logits,
                 group_ndims=0,
                 is_reparameterized=True,
                 check_numerics=False,
                 **kwargs):
        self._logits = tf.convert_to_tensor(logits)
        self._temperature = tf.convert_to_tensor(temperature)
        param_dtype = assert_same_float_dtype([
            (self._logits, 'BinConcrete.logits'),
            (self._temperature, 'BinConcrete.temperature')
        ])

        self._temperature = assert_scalar(self._temperature,
                                          'BinConcrete.temperature')

        self._check_numerics = check_numerics
        super(BinConcrete,
              self).__init__(dtype=param_dtype,
                             param_dtype=param_dtype,
                             is_continuous=True,
                             is_reparameterized=is_reparameterized,
                             group_ndims=group_ndims,
                             **kwargs)
Example #12
0
    def __init__(self,
                 mean=0.,
                 logstd=0.,
                 group_event_ndims=0,
                 is_reparameterized=True,
                 check_numerics=False):
        self._mean = tf.convert_to_tensor(mean)
        self._logstd = tf.convert_to_tensor(logstd)
        dtype = assert_same_float_dtype([(self._mean, 'Normal.mean'),
                                         (self._logstd, 'Normal.logstd')])

        try:
            tf.broadcast_static_shape(self._mean.get_shape(),
                                      self._logstd.get_shape())
        except ValueError:
            raise ValueError(
                "mean and logstd should be broadcastable to match each "
                "other. ({} vs. {})".format(self._mean.get_shape(),
                                            self._logstd.get_shape()))
        self._check_numerics = check_numerics
        super(Normal, self).__init__(dtype=dtype,
                                     param_dtype=dtype,
                                     is_continuous=True,
                                     is_reparameterized=is_reparameterized,
                                     group_event_ndims=group_event_ndims)
Example #13
0
    def __init__(self,
                 logits,
                 n_experiments,
                 dtype=None,
                 group_event_ndims=0):
        self._logits = tf.convert_to_tensor(logits)
        param_dtype = assert_same_float_dtype(
            [(self._logits, 'Multinomial.logits')])

        if dtype is None:
            dtype = tf.int32
        assert_same_float_and_int_dtype([], dtype)

        static_logits_shape = self._logits.get_shape()
        shape_err_msg = "logits should have rank >= 1."
        if static_logits_shape and (static_logits_shape.ndims < 1):
            raise ValueError(shape_err_msg)
        elif static_logits_shape and (
                static_logits_shape[-1].value is not None):
            self._n_categories = static_logits_shape[-1].value
        else:
            _assert_shape_op = tf.assert_rank_at_least(
                self._logits, 1, message=shape_err_msg)
            with tf.control_dependencies([_assert_shape_op]):
                self._logits = tf.identity(self._logits)
            self._n_categories = tf.shape(self._logits)[-1]

        sign_err_msg = "n_experiments must be positive"
        if isinstance(n_experiments, int):
            if n_experiments <= 0:
                raise ValueError(sign_err_msg)
            self._n_experiments = n_experiments
        else:
            try:
                n_experiments = tf.convert_to_tensor(n_experiments, tf.int32)
            except ValueError:
                raise TypeError('n_experiments must be int32')
            _assert_rank_op = tf.assert_rank(
                n_experiments, 0,
                message="n_experiments should be a scalar (0-D Tensor).")
            _assert_positive_op = tf.assert_greater(
                n_experiments, 0, message=sign_err_msg)
            with tf.control_dependencies([_assert_rank_op,
                                          _assert_positive_op]):
                self._n_experiments = tf.identity(n_experiments)

        super(Multinomial, self).__init__(
            dtype=dtype,
            param_dtype=param_dtype,
            is_continuous=False,
            is_reparameterized=False,
            group_event_ndims=group_event_ndims)
Example #14
0
    def __init__(self, logits, dtype=None, group_event_ndims=0):
        self._logits = tf.convert_to_tensor(logits)
        param_dtype = assert_same_float_dtype([(self._logits,
                                                'Bernoulli.logits')])

        if dtype is None:
            dtype = tf.int32
        assert_same_float_and_int_dtype([], dtype)

        super(Bernoulli, self).__init__(dtype=dtype,
                                        param_dtype=param_dtype,
                                        is_continuous=False,
                                        is_reparameterized=False,
                                        group_event_ndims=group_event_ndims)
    def __init__(self,
                 mean,
                 cov_tril,
                 group_ndims=0,
                 is_reparameterized=True,
                 use_path_derivative=False,
                 check_numerics=False,
                 **kwargs):
        self._check_numerics = check_numerics
        self._mean = tf.convert_to_tensor(mean)
        self._mean = assert_rank_at_least_one(
            self._mean, 'MultivariateNormalCholesky.mean')
        self._n_dim = get_shape_at(self._mean, -1)
        self._cov_tril = tf.convert_to_tensor(cov_tril)
        self._cov_tril = assert_rank_at_least(
            self._cov_tril, 2, 'MultivariateNormalCholesky.cov_tril')

        # Static shape check
        expected_shape = self._mean.get_shape().concatenate(
            [self._n_dim if isinstance(self._n_dim, int) else None])
        self._cov_tril.get_shape().assert_is_compatible_with(expected_shape)
        # Dynamic
        expected_shape = tf.concat([tf.shape(self._mean), [self._n_dim]],
                                   axis=0)
        actual_shape = tf.shape(self._cov_tril)
        msg = [
            'MultivariateNormalCholesky.cov_tril should have compatible '
            'shape with mean. Expected', expected_shape, ' got ', actual_shape
        ]
        assert_ops = [tf.assert_equal(expected_shape, actual_shape, msg)]
        with tf.control_dependencies(assert_ops):
            self._cov_tril = tf.identity(self._cov_tril)

        dtype = assert_same_float_dtype([
            (self._mean, 'MultivariateNormalCholesky.mean'),
            (self._cov_tril, 'MultivariateNormalCholesky.cov_tril')
        ])
        super(MultivariateNormalCholesky,
              self).__init__(dtype=dtype,
                             param_dtype=dtype,
                             is_continuous=True,
                             is_reparameterized=is_reparameterized,
                             use_path_derivative=use_path_derivative,
                             group_ndims=group_ndims,
                             **kwargs)
Example #16
0
    def __init__(self, logits, dtype=None, group_event_ndims=0):
        self._logits = tf.convert_to_tensor(logits)
        param_dtype = assert_same_float_dtype([(self._logits,
                                                'Categorical.logits')])

        if dtype is None:
            dtype = tf.int32
        assert_same_float_and_int_dtype([], dtype)

        self._logits, self._n_categories = assert_rank_at_least_one(
            self._logits, 'Categorical.logits')
        static_logits_shape = self._logits.get_shape()

        super(Categorical, self).__init__(dtype=dtype,
                                          param_dtype=param_dtype,
                                          is_continuous=False,
                                          is_reparameterized=False,
                                          group_event_ndims=group_event_ndims)
Example #17
0
    def __init__(self,
                 logits,
                 n_experiments,
                 dtype=None,
                 group_ndims=0,
                 check_numerics=False,
                 **kwargs):
        self._logits = tf.convert_to_tensor(logits)
        param_dtype = assert_same_float_dtype([(self._logits,
                                                'Binomial.logits')])

        if dtype is None:
            dtype = tf.int32
        assert_same_float_and_int_dtype([], dtype)

        sign_err_msg = "n_experiments must be positive"
        if isinstance(n_experiments, int):
            if n_experiments <= 0:
                raise ValueError(sign_err_msg)
            self._n_experiments = n_experiments
        else:
            try:
                n_experiments = tf.convert_to_tensor(n_experiments, tf.int32)
            except ValueError:
                raise TypeError('n_experiments must be int32')
            _assert_rank_op = tf.assert_rank(
                n_experiments,
                0,
                message="n_experiments should be a scalar (0-D Tensor).")
            _assert_positive_op = tf.assert_greater(n_experiments,
                                                    0,
                                                    message=sign_err_msg)
            with tf.control_dependencies(
                [_assert_rank_op, _assert_positive_op]):
                self._n_experiments = tf.identity(n_experiments)

        self._check_numerics = check_numerics
        super(Binomial, self).__init__(dtype=dtype,
                                       param_dtype=param_dtype,
                                       is_continuous=False,
                                       is_reparameterized=False,
                                       group_ndims=group_ndims,
                                       **kwargs)
Example #18
0
    def __init__(self,
                 rate,
                 dtype=None,
                 group_event_ndims=0,
                 check_numerics=False):
        self._rate = tf.convert_to_tensor(rate)
        param_dtype = assert_same_float_dtype([(self._rate, 'Poisson.rate')])

        if dtype is None:
            dtype = tf.int32
        assert_same_float_and_int_dtype([], dtype)

        self._check_numerics = check_numerics

        super(Poisson, self).__init__(dtype=dtype,
                                      param_dtype=param_dtype,
                                      is_continuous=False,
                                      is_reparameterized=False,
                                      group_event_ndims=group_event_ndims)
Example #19
0
    def __init__(self, logits, dtype=None, group_ndims=0, **kwargs):
        self._logits = tf.convert_to_tensor(logits)
        param_dtype = assert_same_float_dtype([(self._logits,
                                                'OnehotCategorical.logits')])

        if dtype is None:
            dtype = tf.int32
        assert_same_float_and_int_dtype([], dtype)

        self._logits = assert_rank_at_least_one(self._logits,
                                                'OnehotCategorical.logits')
        self._n_categories = get_shape_at(self._logits, -1)

        super(OnehotCategorical, self).__init__(dtype=dtype,
                                                param_dtype=param_dtype,
                                                is_continuous=False,
                                                is_reparameterized=False,
                                                group_ndims=group_ndims,
                                                **kwargs)
    def __init__(self,
                 mean,
                 u=None,
                 v=None,
                 u_c=None,
                 v_c=None,
                 u_c_logdet=None,
                 v_c_logdet=None,
                 group_event_ndims=0,
                 is_reparameterized=True,
                 check_numerics=False):

        mean = tf.convert_to_tensor(mean)
        _assert_rank_op = tf.assert_greater_equal(
            tf.rank(mean), 2, message="mean should be at least a 2-D tensor.")
        with tf.control_dependencies([_assert_rank_op]):
            self._mean = mean

        def _eig_decomp(mat):
            mat_t = transpose_last2dims(mat)
            e, v = tf.self_adjoint_eig((mat + mat_t) / 2 +
                                       tf.eye(tf.shape(mat)[-1]) * 1e-8)
            e = tf.maximum(e, 1e-10)**0.5
            return tf.matmul(v,
                             tf.matrix_diag(e)), tf.reduce_sum(tf.log(e), -1)

        if u is not None and v is not None:
            # assert_same_rank([(self._mean, 'MatrixVariateNormal.mean'),
            #                   (u, 'MatrixVariateNormal.u'),
            #                   (v, 'MatrixVariateNormal.v')])
            u = tf.convert_to_tensor(u)
            _assert_shape_op_1 = tf.assert_equal(
                tf.shape(mean)[-2],
                tf.shape(u)[-1],
                message='second last dimension of mean should be the same \
                         as the last dimension of U matrix')
            _assert_shape_op_2 = tf.assert_equal(
                tf.shape(u)[-1],
                tf.shape(u)[-2],
                message='second last dimension of U should be the same \
                         as the last dimension of U matrix')
            with tf.control_dependencies([
                    _assert_shape_op_1, _assert_shape_op_2,
                    tf.check_numerics(u, 'U matrix')
            ]):
                self._u = u
            v = tf.convert_to_tensor(v)
            _assert_shape_op_1 = tf.assert_equal(
                tf.shape(mean)[-1],
                tf.shape(v)[-1],
                message='last dimension of mean should be the same \
                         as last dimension of V matrix')
            _assert_shape_op_2 = tf.assert_equal(
                tf.shape(v)[-1],
                tf.shape(v)[-2],
                message='second last dimension of V should be the same \
                         as last dimension of V matrix')
            with tf.control_dependencies([
                    _assert_shape_op_1, _assert_shape_op_2,
                    tf.check_numerics(v, 'V matrix')
            ]):
                self._v = v
            dtype = assert_same_float_dtype([
                (self._mean, 'MatrixVariateNormal.mean'),
                (self._u, 'MatrixVariateNormal.u'),
                (self._v, 'MatrixVariateNormal.v')
            ])

            self._u_c, self._u_c_log_determinant = _eig_decomp(self._u)
            self._v_c, self._v_c_log_determinant = _eig_decomp(self._v)

        elif u_c is not None and v_c is not None:
            # assert_same_rank([(self._mean, 'MatrixVariateNormal.mean'),
            #                   (u_c, 'MatrixVariateNormal.u_c'),
            #                   (v_c, 'MatrixVariateNormal.v_c')])
            dtype = assert_same_float_dtype([(self._mean,
                                              'MatrixVariateNormal.mean'),
                                             (u_c, 'MatrixVariateNormal.u_c'),
                                             (v_c, 'MatrixVariateNormal.v_c')])
            self._u_c = u_c
            self._v_c = v_c
            self._u = tf.matmul(self._u_c, transpose_last2dims(self._u_c))
            self._v = tf.matmul(self._v_c, transpose_last2dims(self._v_c))
            if u_c_logdet is not None:
                self._u_c_log_determinant = u_c_logdet
            else:
                _, self.u_c_log_determinant = _eig_decomp(self._u)
            if v_c_logdet is not None:
                self._v_c_log_determinant = v_c_logdet
            else:
                _, self._v_c_log_determinant = _eig_decomp(self._v)

        super(DMatrixVariateNormal,
              self).__init__(dtype=dtype,
                             param_dtype=dtype,
                             is_continuous=True,
                             is_reparameterized=is_reparameterized,
                             group_ndims=group_event_ndims)
    def __init__(self,
                 mean,
                 u_b=None,
                 v_b=None,
                 r=None,
                 group_event_ndims=0,
                 is_reparameterized=True,
                 check_numerics=False):

        mean = tf.convert_to_tensor(mean)
        _assert_rank_op = tf.assert_greater_equal(
            tf.rank(mean), 2, message="mean should be at least a 2-D tensor.")
        with tf.control_dependencies([_assert_rank_op]):
            self._mean = mean

        # assert_same_rank([(self._mean, 'EigenMatrixNormal.mean'),
        #                   (u_b, 'EigenMatrixNormal.u_b'),
        #                   (v_b, 'EigenMatrixNormal.v_b'),
        #                   (r, 'EigenMatrixNormal.r')])
        u_b = tf.convert_to_tensor(u_b)
        self._u_b = u_b

        # _assert_shape_op_1 = tf.assert_equal(
        #     tf.shape(mean)[-2], tf.shape(u)[-1],
        #     message='second last dimension of mean should be the same \
        #              as the last dimension of U matrix')
        # _assert_shape_op_2 = tf.assert_equal(
        #     tf.shape(u)[-1], tf.shape(u)[-2],
        #     message='second last dimension of U should be the same \
        #              as the last dimension of U matrix')
        # with tf.control_dependencies([
        #     _assert_shape_op_1, _assert_shape_op_2,
        #     tf.check_numerics(u, 'U matrix')]):

        v_b = tf.convert_to_tensor(v_b)
        self._v_b = v_b

        # _assert_shape_op_1 = tf.assert_equal(
        #     tf.shape(mean)[-1], tf.shape(v)[-1],
        #     message='last dimension of mean should be the same \
        #              as last dimension of V matrix')
        # _assert_shape_op_2 = tf.assert_equal(
        #     tf.shape(v)[-1], tf.shape(v)[-2],
        #     message='second last dimension of V should be the same \
        #              as last dimension of V matrix')
        # with tf.control_dependencies([
        #     _assert_shape_op_1, _assert_shape_op_2,
        #     tf.check_numerics(v, 'V matrix')]):

        r = tf.convert_to_tensor(r)
        self._r = r
        # _assert_shape_op_1 = tf.assert_equal(
        #     tf.shape(mean)[-1], tf.shape(r)[-1],
        #     message='second last dimension of mean should be the same \
        #                          as the last dimension of U matrix')
        # _assert_shape_op_2 = tf.assert_equal(
        #     tf.shape(mean)[-2], tf.shape(r)[-2],
        #     message='second last dimension of U should be the same \
        #                          as the last dimension of U matrix')
        # with tf.control_dependencies([
        #     _assert_shape_op_1, _assert_shape_op_2,
        #     tf.check_numerics(r, 'R matrix')]):
        #     self._r = r

        dtype = assert_same_float_dtype([
            (self._mean, 'MatrixVariateNormal.mean'),
            (self._u_b, 'MatrixVariateNormal.u_b'),
            (self._v_b, 'MatrixVariateNormal.v_b'),
            (self._r, 'MatrixVariateNormal.r')
        ])

        # R should have been damped before. Sqrt for sampling.
        # self._r_c = tf.sqrt(self._r)
        self.log_std = 0.5 * tf.log(self._r)
        self.std = tf.exp(self.log_std)

        super(EigenMultivariateNormal,
              self).__init__(dtype=dtype,
                             param_dtype=dtype,
                             is_continuous=True,
                             is_reparameterized=is_reparameterized,
                             group_ndims=group_event_ndims)