def __init__(self, mean=0., logstd=None, std=None, group_event_ndims=0, is_reparameterized=True, check_numerics=False): self._mean = tf.convert_to_tensor(mean) warnings.warn("Normal: The order of arguments logstd/std will change " "to std/logstd in the coming version.") if (logstd is None) == (std is None): raise ValueError("Either std or logstd should be passed but not " "both of them.") elif logstd is None: self._std = tf.convert_to_tensor(std) dtype = assert_same_float_dtype([(self._mean, 'Normal.mean'), (self._std, 'Normal.std')]) logstd = tf.log(self._std) if check_numerics: with tf.control_dependencies( [tf.check_numerics(logstd, "log(std)")]): logstd = tf.identity(logstd) self._logstd = logstd else: # std is None self._logstd = tf.convert_to_tensor(logstd) dtype = assert_same_float_dtype([(self._mean, 'Normal.mean'), (self._logstd, 'Normal.logstd')]) std = tf.exp(self._logstd) if check_numerics: with tf.control_dependencies( [tf.check_numerics(std, "exp(logstd)")]): std = tf.identity(std) self._std = std try: tf.broadcast_static_shape(self._mean.get_shape(), self._std.get_shape()) except ValueError: raise ValueError( "mean and std/logstd should be broadcastable to match each " "other. ({} vs. {})".format( self._mean.get_shape(), self._std.get_shape())) self._check_numerics = check_numerics super(Normal, self).__init__( dtype=dtype, param_dtype=dtype, is_continuous=True, is_reparameterized=is_reparameterized, group_event_ndims=group_event_ndims)
def __init__(self, temperature, logits, group_ndims=0, is_reparameterized=True, use_path_derivative=False, check_numerics=False, **kwargs): self._logits = tf.convert_to_tensor(logits) self._temperature = tf.convert_to_tensor(temperature) param_dtype = assert_same_float_dtype([ (self._logits, 'Concrete.logits'), (self._temperature, 'Concrete.temperature') ]) self._logits, self._n_categories = assert_rank_at_least_one( self._logits, 'Concrete.logits') self._temperature = assert_scalar(self._temperature, 'Concrete.temperature') self._check_numerics = check_numerics super(Concrete, self).__init__(dtype=param_dtype, param_dtype=param_dtype, is_continuous=True, is_reparameterized=is_reparameterized, use_path_derivative=use_path_derivative, group_ndims=group_ndims, **kwargs)
def __init__(self, mean, u_tril, v_tril, group_ndims=0, is_reparameterized=True, use_path_derivative=False, check_numerics=False, **kwargs): self._check_numerics = check_numerics self._mean = tf.convert_to_tensor(mean) self._mean = assert_rank_at_least( self._mean, 2, 'MatrixVariateNormalCholesky.mean') self._n_row = get_shape_at(self._mean, -2) self._n_col = get_shape_at(self._mean, -1) self._u_tril = tf.convert_to_tensor(u_tril) self._u_tril = assert_rank_at_least( self._u_tril, 2, 'MatrixVariateNormalCholesky.u_tril') self._v_tril = tf.convert_to_tensor(v_tril) self._v_tril = assert_rank_at_least( self._v_tril, 2, 'MatrixVariateNormalCholesky.v_tril') # Static shape check expected_u_shape = self._mean.get_shape()[:-1].concatenate( [self._n_row if isinstance(self._n_row, int) else None]) self._u_tril.get_shape().assert_is_compatible_with(expected_u_shape) expected_v_shape = self._mean.get_shape()[:-2].concatenate( [self._n_col if isinstance(self._n_col, int) else None] * 2) self._v_tril.get_shape().assert_is_compatible_with(expected_v_shape) # Dynamic expected_u_shape = tf.concat( [tf.shape(self._mean)[:-1], [self._n_row]], axis=0) actual_u_shape = tf.shape(self._u_tril) msg = ['MatrixVariateNormalCholesky.u_tril should have compatible ' 'shape with mean. Expected', expected_u_shape, ' got ', actual_u_shape] assert_u_ops = tf.assert_equal(expected_u_shape, actual_u_shape, msg) expected_v_shape = tf.concat( [tf.shape(self._mean)[:-2], [self._n_col, self._n_col]], axis=0) actual_v_shape = tf.shape(self._v_tril) msg = ['MatrixVariateNormalCholesky.v_tril should have compatible ' 'shape with mean. Expected', expected_v_shape, ' got ', actual_v_shape] assert_v_ops = tf.assert_equal(expected_v_shape, actual_v_shape, msg) with tf.control_dependencies([assert_u_ops, assert_v_ops]): self._u_tril = tf.identity(self._u_tril) self._v_tril = tf.identity(self._v_tril) dtype = assert_same_float_dtype( [(self._mean, 'MatrixVariateNormalCholesky.mean'), (self._u_tril, 'MatrixVariateNormalCholesky.u_tril'), (self._v_tril, 'MatrixVariateNormalCholesky.v_tril')]) super(MatrixVariateNormalCholesky, self).__init__( dtype=dtype, param_dtype=dtype, is_continuous=True, is_reparameterized=is_reparameterized, use_path_derivative=use_path_derivative, group_ndims=group_ndims, **kwargs)
def __init__(self, minval=0., maxval=1., group_event_ndims=0, is_reparameterized=True, check_numerics=False): self._minval = tf.convert_to_tensor(minval) self._maxval = tf.convert_to_tensor(maxval) dtype = assert_same_float_dtype([(self._minval, 'Uniform.minval'), (self._maxval, 'Uniform.maxval')]) try: tf.broadcast_static_shape(self._minval.get_shape(), self._maxval.get_shape()) except ValueError: raise ValueError( "minval and maxval should be broadcastable to match each " "other. ({} vs. {})".format(self._minval.get_shape(), self._maxval.get_shape())) self._check_numerics = check_numerics super(Uniform, self).__init__(dtype=dtype, param_dtype=dtype, is_continuous=True, is_reparameterized=is_reparameterized, group_event_ndims=group_event_ndims)
def __init__(self, loc, scale, group_event_ndims=0, is_reparameterized=True, check_numerics=False): self._loc = tf.convert_to_tensor(loc) self._scale = tf.convert_to_tensor(scale) dtype = assert_same_float_dtype([(self._loc, 'Laplace.loc'), (self._scale, 'Laplace.scale')]) try: tf.broadcast_static_shape(self._loc.get_shape(), self._scale.get_shape()) except ValueError: raise ValueError( "loc and scale should be broadcastable to match each " "other. ({} vs. {})".format(self._loc.get_shape(), self._scale.get_shape())) self._check_numerics = check_numerics super(Laplace, self).__init__(dtype=dtype, param_dtype=dtype, is_continuous=True, is_reparameterized=is_reparameterized, group_event_ndims=group_event_ndims)
def __init__(self, logits, normalize_logits=True, dtype=tf.int32, group_ndims=0, **kwargs): self._logits = tf.convert_to_tensor(logits) param_dtype = assert_same_float_dtype( [(self._logits, 'UnnormalizedMultinomial.logits')]) assert_dtype_is_int_or_float(dtype) self._logits = assert_rank_at_least_one( self._logits, 'UnnormalizedMultinomial.logits') self._n_categories = get_shape_at(self._logits, -1) self.normalize_logits = normalize_logits super(UnnormalizedMultinomial, self).__init__( dtype=dtype, param_dtype=param_dtype, is_continuous=False, is_reparameterized=False, group_ndims=group_ndims, **kwargs)
def __init__(self, alpha, beta, dtype=None, group_event_ndims=0, check_numerics=False): self._alpha = tf.convert_to_tensor(alpha) self._beta = tf.convert_to_tensor(beta) dtype = assert_same_float_dtype([(self._alpha, 'Beta.alpha'), (self._beta, 'Beta.beta')]) try: tf.broadcast_static_shape(self._alpha.get_shape(), self._beta.get_shape()) except ValueError: raise ValueError( "alpha and beta should be broadcastable to match each " "other. ({} vs. {})".format(self._alpha.get_shape(), self._beta.get_shape())) self._check_numerics = check_numerics super(Beta, self).__init__(dtype=dtype, param_dtype=dtype, is_continuous=True, is_reparameterized=False, group_event_ndims=group_event_ndims)
def __init__(self, alpha, group_ndims=0, check_numerics=False, **kwargs): self._alpha = tf.convert_to_tensor(alpha) dtype = assert_same_float_dtype([(self._alpha, 'Dirichlet.alpha')]) static_alpha_shape = self._alpha.get_shape() shape_err_msg = "alpha should have rank >= 1." cat_err_msg = "n_categories (length of the last axis " \ "of alpha) should be at least 2." if static_alpha_shape and (static_alpha_shape.ndims < 1): raise ValueError(shape_err_msg) elif static_alpha_shape and (static_alpha_shape[-1].value is not None): self._n_categories = static_alpha_shape[-1].value if self._n_categories < 2: raise ValueError(cat_err_msg) else: _assert_shape_op = tf.assert_rank_at_least(self._alpha, 1, message=shape_err_msg) with tf.control_dependencies([_assert_shape_op]): self._alpha = tf.identity(self._alpha) self._n_categories = tf.shape(self._alpha)[-1] _assert_cat_op = tf.assert_greater_equal(self._n_categories, 2, message=cat_err_msg) with tf.control_dependencies([_assert_cat_op]): self._alpha = tf.identity(self._alpha) self._check_numerics = check_numerics super(Dirichlet, self).__init__(dtype=dtype, param_dtype=dtype, is_continuous=True, is_reparameterized=False, group_ndims=group_ndims, **kwargs)
def __init__(self, logits, dtype=None, group_event_ndims=0): self._logits = tf.convert_to_tensor(logits) param_dtype = assert_same_float_dtype( [(self._logits, 'Categorical.logits')]) if dtype is None: dtype = tf.int32 assert_same_float_and_int_dtype([], dtype) static_logits_shape = self._logits.get_shape() shape_err_msg = "logits should have rank >= 1." if static_logits_shape and (static_logits_shape.ndims < 1): raise ValueError(shape_err_msg) elif static_logits_shape and ( static_logits_shape[-1].value is not None): self._n_categories = static_logits_shape[-1].value else: _assert_shape_op = tf.assert_rank_at_least( self._logits, 1, message=shape_err_msg) with tf.control_dependencies([_assert_shape_op]): self._logits = tf.identity(self._logits) self._n_categories = tf.shape(self._logits)[-1] super(Categorical, self).__init__( dtype=dtype, param_dtype=param_dtype, is_continuous=False, is_reparameterized=False, group_event_ndims=group_event_ndims)
def __init__(self, logits, n_experiments, dtype=None, group_ndims=0, **kwargs): self._logits = tf.convert_to_tensor(logits) param_dtype = assert_same_float_dtype([(self._logits, 'Multinomial.logits')]) if dtype is None: dtype = tf.int32 assert_same_float_and_int_dtype([], dtype) self._logits, self._n_categories = assert_rank_at_least_one( self._logits, 'Multinomial.logits') self._n_experiments = assert_positive_int32_integer( n_experiments, 'Multinomial.n_experiments') super(Multinomial, self).__init__(dtype=dtype, param_dtype=param_dtype, is_continuous=False, is_reparameterized=False, group_ndims=group_ndims, **kwargs)
def __init__(self, temperature, logits, group_ndims=0, is_reparameterized=True, check_numerics=False, **kwargs): self._logits = tf.convert_to_tensor(logits) self._temperature = tf.convert_to_tensor(temperature) param_dtype = assert_same_float_dtype([ (self._logits, 'BinConcrete.logits'), (self._temperature, 'BinConcrete.temperature') ]) self._temperature = assert_scalar(self._temperature, 'BinConcrete.temperature') self._check_numerics = check_numerics super(BinConcrete, self).__init__(dtype=param_dtype, param_dtype=param_dtype, is_continuous=True, is_reparameterized=is_reparameterized, group_ndims=group_ndims, **kwargs)
def __init__(self, mean=0., logstd=0., group_event_ndims=0, is_reparameterized=True, check_numerics=False): self._mean = tf.convert_to_tensor(mean) self._logstd = tf.convert_to_tensor(logstd) dtype = assert_same_float_dtype([(self._mean, 'Normal.mean'), (self._logstd, 'Normal.logstd')]) try: tf.broadcast_static_shape(self._mean.get_shape(), self._logstd.get_shape()) except ValueError: raise ValueError( "mean and logstd should be broadcastable to match each " "other. ({} vs. {})".format(self._mean.get_shape(), self._logstd.get_shape())) self._check_numerics = check_numerics super(Normal, self).__init__(dtype=dtype, param_dtype=dtype, is_continuous=True, is_reparameterized=is_reparameterized, group_event_ndims=group_event_ndims)
def __init__(self, logits, n_experiments, dtype=None, group_event_ndims=0): self._logits = tf.convert_to_tensor(logits) param_dtype = assert_same_float_dtype( [(self._logits, 'Multinomial.logits')]) if dtype is None: dtype = tf.int32 assert_same_float_and_int_dtype([], dtype) static_logits_shape = self._logits.get_shape() shape_err_msg = "logits should have rank >= 1." if static_logits_shape and (static_logits_shape.ndims < 1): raise ValueError(shape_err_msg) elif static_logits_shape and ( static_logits_shape[-1].value is not None): self._n_categories = static_logits_shape[-1].value else: _assert_shape_op = tf.assert_rank_at_least( self._logits, 1, message=shape_err_msg) with tf.control_dependencies([_assert_shape_op]): self._logits = tf.identity(self._logits) self._n_categories = tf.shape(self._logits)[-1] sign_err_msg = "n_experiments must be positive" if isinstance(n_experiments, int): if n_experiments <= 0: raise ValueError(sign_err_msg) self._n_experiments = n_experiments else: try: n_experiments = tf.convert_to_tensor(n_experiments, tf.int32) except ValueError: raise TypeError('n_experiments must be int32') _assert_rank_op = tf.assert_rank( n_experiments, 0, message="n_experiments should be a scalar (0-D Tensor).") _assert_positive_op = tf.assert_greater( n_experiments, 0, message=sign_err_msg) with tf.control_dependencies([_assert_rank_op, _assert_positive_op]): self._n_experiments = tf.identity(n_experiments) super(Multinomial, self).__init__( dtype=dtype, param_dtype=param_dtype, is_continuous=False, is_reparameterized=False, group_event_ndims=group_event_ndims)
def __init__(self, logits, dtype=None, group_event_ndims=0): self._logits = tf.convert_to_tensor(logits) param_dtype = assert_same_float_dtype([(self._logits, 'Bernoulli.logits')]) if dtype is None: dtype = tf.int32 assert_same_float_and_int_dtype([], dtype) super(Bernoulli, self).__init__(dtype=dtype, param_dtype=param_dtype, is_continuous=False, is_reparameterized=False, group_event_ndims=group_event_ndims)
def __init__(self, mean, cov_tril, group_ndims=0, is_reparameterized=True, use_path_derivative=False, check_numerics=False, **kwargs): self._check_numerics = check_numerics self._mean = tf.convert_to_tensor(mean) self._mean = assert_rank_at_least_one( self._mean, 'MultivariateNormalCholesky.mean') self._n_dim = get_shape_at(self._mean, -1) self._cov_tril = tf.convert_to_tensor(cov_tril) self._cov_tril = assert_rank_at_least( self._cov_tril, 2, 'MultivariateNormalCholesky.cov_tril') # Static shape check expected_shape = self._mean.get_shape().concatenate( [self._n_dim if isinstance(self._n_dim, int) else None]) self._cov_tril.get_shape().assert_is_compatible_with(expected_shape) # Dynamic expected_shape = tf.concat([tf.shape(self._mean), [self._n_dim]], axis=0) actual_shape = tf.shape(self._cov_tril) msg = [ 'MultivariateNormalCholesky.cov_tril should have compatible ' 'shape with mean. Expected', expected_shape, ' got ', actual_shape ] assert_ops = [tf.assert_equal(expected_shape, actual_shape, msg)] with tf.control_dependencies(assert_ops): self._cov_tril = tf.identity(self._cov_tril) dtype = assert_same_float_dtype([ (self._mean, 'MultivariateNormalCholesky.mean'), (self._cov_tril, 'MultivariateNormalCholesky.cov_tril') ]) super(MultivariateNormalCholesky, self).__init__(dtype=dtype, param_dtype=dtype, is_continuous=True, is_reparameterized=is_reparameterized, use_path_derivative=use_path_derivative, group_ndims=group_ndims, **kwargs)
def __init__(self, logits, dtype=None, group_event_ndims=0): self._logits = tf.convert_to_tensor(logits) param_dtype = assert_same_float_dtype([(self._logits, 'Categorical.logits')]) if dtype is None: dtype = tf.int32 assert_same_float_and_int_dtype([], dtype) self._logits, self._n_categories = assert_rank_at_least_one( self._logits, 'Categorical.logits') static_logits_shape = self._logits.get_shape() super(Categorical, self).__init__(dtype=dtype, param_dtype=param_dtype, is_continuous=False, is_reparameterized=False, group_event_ndims=group_event_ndims)
def __init__(self, logits, n_experiments, dtype=None, group_ndims=0, check_numerics=False, **kwargs): self._logits = tf.convert_to_tensor(logits) param_dtype = assert_same_float_dtype([(self._logits, 'Binomial.logits')]) if dtype is None: dtype = tf.int32 assert_same_float_and_int_dtype([], dtype) sign_err_msg = "n_experiments must be positive" if isinstance(n_experiments, int): if n_experiments <= 0: raise ValueError(sign_err_msg) self._n_experiments = n_experiments else: try: n_experiments = tf.convert_to_tensor(n_experiments, tf.int32) except ValueError: raise TypeError('n_experiments must be int32') _assert_rank_op = tf.assert_rank( n_experiments, 0, message="n_experiments should be a scalar (0-D Tensor).") _assert_positive_op = tf.assert_greater(n_experiments, 0, message=sign_err_msg) with tf.control_dependencies( [_assert_rank_op, _assert_positive_op]): self._n_experiments = tf.identity(n_experiments) self._check_numerics = check_numerics super(Binomial, self).__init__(dtype=dtype, param_dtype=param_dtype, is_continuous=False, is_reparameterized=False, group_ndims=group_ndims, **kwargs)
def __init__(self, rate, dtype=None, group_event_ndims=0, check_numerics=False): self._rate = tf.convert_to_tensor(rate) param_dtype = assert_same_float_dtype([(self._rate, 'Poisson.rate')]) if dtype is None: dtype = tf.int32 assert_same_float_and_int_dtype([], dtype) self._check_numerics = check_numerics super(Poisson, self).__init__(dtype=dtype, param_dtype=param_dtype, is_continuous=False, is_reparameterized=False, group_event_ndims=group_event_ndims)
def __init__(self, logits, dtype=None, group_ndims=0, **kwargs): self._logits = tf.convert_to_tensor(logits) param_dtype = assert_same_float_dtype([(self._logits, 'OnehotCategorical.logits')]) if dtype is None: dtype = tf.int32 assert_same_float_and_int_dtype([], dtype) self._logits = assert_rank_at_least_one(self._logits, 'OnehotCategorical.logits') self._n_categories = get_shape_at(self._logits, -1) super(OnehotCategorical, self).__init__(dtype=dtype, param_dtype=param_dtype, is_continuous=False, is_reparameterized=False, group_ndims=group_ndims, **kwargs)
def __init__(self, mean, u=None, v=None, u_c=None, v_c=None, u_c_logdet=None, v_c_logdet=None, group_event_ndims=0, is_reparameterized=True, check_numerics=False): mean = tf.convert_to_tensor(mean) _assert_rank_op = tf.assert_greater_equal( tf.rank(mean), 2, message="mean should be at least a 2-D tensor.") with tf.control_dependencies([_assert_rank_op]): self._mean = mean def _eig_decomp(mat): mat_t = transpose_last2dims(mat) e, v = tf.self_adjoint_eig((mat + mat_t) / 2 + tf.eye(tf.shape(mat)[-1]) * 1e-8) e = tf.maximum(e, 1e-10)**0.5 return tf.matmul(v, tf.matrix_diag(e)), tf.reduce_sum(tf.log(e), -1) if u is not None and v is not None: # assert_same_rank([(self._mean, 'MatrixVariateNormal.mean'), # (u, 'MatrixVariateNormal.u'), # (v, 'MatrixVariateNormal.v')]) u = tf.convert_to_tensor(u) _assert_shape_op_1 = tf.assert_equal( tf.shape(mean)[-2], tf.shape(u)[-1], message='second last dimension of mean should be the same \ as the last dimension of U matrix') _assert_shape_op_2 = tf.assert_equal( tf.shape(u)[-1], tf.shape(u)[-2], message='second last dimension of U should be the same \ as the last dimension of U matrix') with tf.control_dependencies([ _assert_shape_op_1, _assert_shape_op_2, tf.check_numerics(u, 'U matrix') ]): self._u = u v = tf.convert_to_tensor(v) _assert_shape_op_1 = tf.assert_equal( tf.shape(mean)[-1], tf.shape(v)[-1], message='last dimension of mean should be the same \ as last dimension of V matrix') _assert_shape_op_2 = tf.assert_equal( tf.shape(v)[-1], tf.shape(v)[-2], message='second last dimension of V should be the same \ as last dimension of V matrix') with tf.control_dependencies([ _assert_shape_op_1, _assert_shape_op_2, tf.check_numerics(v, 'V matrix') ]): self._v = v dtype = assert_same_float_dtype([ (self._mean, 'MatrixVariateNormal.mean'), (self._u, 'MatrixVariateNormal.u'), (self._v, 'MatrixVariateNormal.v') ]) self._u_c, self._u_c_log_determinant = _eig_decomp(self._u) self._v_c, self._v_c_log_determinant = _eig_decomp(self._v) elif u_c is not None and v_c is not None: # assert_same_rank([(self._mean, 'MatrixVariateNormal.mean'), # (u_c, 'MatrixVariateNormal.u_c'), # (v_c, 'MatrixVariateNormal.v_c')]) dtype = assert_same_float_dtype([(self._mean, 'MatrixVariateNormal.mean'), (u_c, 'MatrixVariateNormal.u_c'), (v_c, 'MatrixVariateNormal.v_c')]) self._u_c = u_c self._v_c = v_c self._u = tf.matmul(self._u_c, transpose_last2dims(self._u_c)) self._v = tf.matmul(self._v_c, transpose_last2dims(self._v_c)) if u_c_logdet is not None: self._u_c_log_determinant = u_c_logdet else: _, self.u_c_log_determinant = _eig_decomp(self._u) if v_c_logdet is not None: self._v_c_log_determinant = v_c_logdet else: _, self._v_c_log_determinant = _eig_decomp(self._v) super(DMatrixVariateNormal, self).__init__(dtype=dtype, param_dtype=dtype, is_continuous=True, is_reparameterized=is_reparameterized, group_ndims=group_event_ndims)
def __init__(self, mean, u_b=None, v_b=None, r=None, group_event_ndims=0, is_reparameterized=True, check_numerics=False): mean = tf.convert_to_tensor(mean) _assert_rank_op = tf.assert_greater_equal( tf.rank(mean), 2, message="mean should be at least a 2-D tensor.") with tf.control_dependencies([_assert_rank_op]): self._mean = mean # assert_same_rank([(self._mean, 'EigenMatrixNormal.mean'), # (u_b, 'EigenMatrixNormal.u_b'), # (v_b, 'EigenMatrixNormal.v_b'), # (r, 'EigenMatrixNormal.r')]) u_b = tf.convert_to_tensor(u_b) self._u_b = u_b # _assert_shape_op_1 = tf.assert_equal( # tf.shape(mean)[-2], tf.shape(u)[-1], # message='second last dimension of mean should be the same \ # as the last dimension of U matrix') # _assert_shape_op_2 = tf.assert_equal( # tf.shape(u)[-1], tf.shape(u)[-2], # message='second last dimension of U should be the same \ # as the last dimension of U matrix') # with tf.control_dependencies([ # _assert_shape_op_1, _assert_shape_op_2, # tf.check_numerics(u, 'U matrix')]): v_b = tf.convert_to_tensor(v_b) self._v_b = v_b # _assert_shape_op_1 = tf.assert_equal( # tf.shape(mean)[-1], tf.shape(v)[-1], # message='last dimension of mean should be the same \ # as last dimension of V matrix') # _assert_shape_op_2 = tf.assert_equal( # tf.shape(v)[-1], tf.shape(v)[-2], # message='second last dimension of V should be the same \ # as last dimension of V matrix') # with tf.control_dependencies([ # _assert_shape_op_1, _assert_shape_op_2, # tf.check_numerics(v, 'V matrix')]): r = tf.convert_to_tensor(r) self._r = r # _assert_shape_op_1 = tf.assert_equal( # tf.shape(mean)[-1], tf.shape(r)[-1], # message='second last dimension of mean should be the same \ # as the last dimension of U matrix') # _assert_shape_op_2 = tf.assert_equal( # tf.shape(mean)[-2], tf.shape(r)[-2], # message='second last dimension of U should be the same \ # as the last dimension of U matrix') # with tf.control_dependencies([ # _assert_shape_op_1, _assert_shape_op_2, # tf.check_numerics(r, 'R matrix')]): # self._r = r dtype = assert_same_float_dtype([ (self._mean, 'MatrixVariateNormal.mean'), (self._u_b, 'MatrixVariateNormal.u_b'), (self._v_b, 'MatrixVariateNormal.v_b'), (self._r, 'MatrixVariateNormal.r') ]) # R should have been damped before. Sqrt for sampling. # self._r_c = tf.sqrt(self._r) self.log_std = 0.5 * tf.log(self._r) self.std = tf.exp(self.log_std) super(EigenMultivariateNormal, self).__init__(dtype=dtype, param_dtype=dtype, is_continuous=True, is_reparameterized=is_reparameterized, group_ndims=group_event_ndims)