def __init__(self, mean, u_tril, v_tril, group_ndims=0, is_reparameterized=True, use_path_derivative=False, check_numerics=False, **kwargs): self._check_numerics = check_numerics self._mean = tf.convert_to_tensor(mean) self._mean = assert_rank_at_least( self._mean, 2, 'MatrixVariateNormalCholesky.mean') self._n_row = get_shape_at(self._mean, -2) self._n_col = get_shape_at(self._mean, -1) self._u_tril = tf.convert_to_tensor(u_tril) self._u_tril = assert_rank_at_least( self._u_tril, 2, 'MatrixVariateNormalCholesky.u_tril') self._v_tril = tf.convert_to_tensor(v_tril) self._v_tril = assert_rank_at_least( self._v_tril, 2, 'MatrixVariateNormalCholesky.v_tril') # Static shape check expected_u_shape = self._mean.get_shape()[:-1].concatenate( [self._n_row if isinstance(self._n_row, int) else None]) self._u_tril.get_shape().assert_is_compatible_with(expected_u_shape) expected_v_shape = self._mean.get_shape()[:-2].concatenate( [self._n_col if isinstance(self._n_col, int) else None] * 2) self._v_tril.get_shape().assert_is_compatible_with(expected_v_shape) # Dynamic expected_u_shape = tf.concat( [tf.shape(self._mean)[:-1], [self._n_row]], axis=0) actual_u_shape = tf.shape(self._u_tril) msg = ['MatrixVariateNormalCholesky.u_tril should have compatible ' 'shape with mean. Expected', expected_u_shape, ' got ', actual_u_shape] assert_u_ops = tf.assert_equal(expected_u_shape, actual_u_shape, msg) expected_v_shape = tf.concat( [tf.shape(self._mean)[:-2], [self._n_col, self._n_col]], axis=0) actual_v_shape = tf.shape(self._v_tril) msg = ['MatrixVariateNormalCholesky.v_tril should have compatible ' 'shape with mean. Expected', expected_v_shape, ' got ', actual_v_shape] assert_v_ops = tf.assert_equal(expected_v_shape, actual_v_shape, msg) with tf.control_dependencies([assert_u_ops, assert_v_ops]): self._u_tril = tf.identity(self._u_tril) self._v_tril = tf.identity(self._v_tril) dtype = assert_same_float_dtype( [(self._mean, 'MatrixVariateNormalCholesky.mean'), (self._u_tril, 'MatrixVariateNormalCholesky.u_tril'), (self._v_tril, 'MatrixVariateNormalCholesky.v_tril')]) super(MatrixVariateNormalCholesky, self).__init__( dtype=dtype, param_dtype=dtype, is_continuous=True, is_reparameterized=is_reparameterized, use_path_derivative=use_path_derivative, group_ndims=group_ndims, **kwargs)
def test_get_shape_at(self): with self.test_session(use_gpu=True): ph = tf.placeholder(tf.float32, [2, None]) # Static self.assertEqual(get_shape_at(ph, 0), 2) # Dynamic fd = {ph: np.ones([2, 9])} self.assertEqual(get_shape_at(ph, 1).eval(fd), 9)
def __init__(self, temperature, logits, group_ndims=0, is_reparameterized=True, use_path_derivative=False, check_numerics=False, **kwargs): self._logits = tf.convert_to_tensor(logits) self._temperature = tf.convert_to_tensor(temperature) param_dtype = assert_same_float_dtype( [(self._logits, 'Concrete.logits'), (self._temperature, 'Concrete.temperature')]) self._logits = assert_rank_at_least_one( self._logits, 'Concrete.logits') self._n_categories = get_shape_at(self._logits, -1) self._temperature = assert_scalar( self._temperature, 'Concrete.temperature') self._check_numerics = check_numerics super(Concrete, self).__init__( dtype=param_dtype, param_dtype=param_dtype, is_continuous=True, is_reparameterized=is_reparameterized, use_path_derivative=use_path_derivative, group_ndims=group_ndims, **kwargs)
def __init__(self, logits, normalize_logits=True, dtype=tf.int32, group_ndims=0, **kwargs): self._logits = tf.convert_to_tensor(logits) param_dtype = assert_same_float_dtype( [(self._logits, 'UnnormalizedMultinomial.logits')]) assert_dtype_is_int_or_float(dtype) self._logits = assert_rank_at_least_one( self._logits, 'UnnormalizedMultinomial.logits') self._n_categories = get_shape_at(self._logits, -1) self.normalize_logits = normalize_logits super(UnnormalizedMultinomial, self).__init__( dtype=dtype, param_dtype=param_dtype, is_continuous=False, is_reparameterized=False, group_ndims=group_ndims, **kwargs)
def __init__(self, logits, n_experiments, dtype=None, group_ndims=0, **kwargs): self._logits = tf.convert_to_tensor(logits) param_dtype = assert_same_float_dtype([(self._logits, 'Multinomial.logits')]) if dtype is None: dtype = tf.int32 assert_same_float_and_int_dtype([], dtype) self._logits = assert_rank_at_least_one(self._logits, 'Multinomial.logits') self._n_categories = get_shape_at(self._logits, -1) self._n_experiments = assert_positive_int32_integer( n_experiments, 'Multinomial.n_experiments') super(Multinomial, self).__init__(dtype=dtype, param_dtype=param_dtype, is_continuous=False, is_reparameterized=False, group_ndims=group_ndims, **kwargs)
def __init__(self, mean, cov_tril, group_ndims=0, is_reparameterized=True, use_path_derivative=False, check_numerics=False, **kwargs): self._check_numerics = check_numerics self._mean = tf.convert_to_tensor(mean) self._mean = assert_rank_at_least_one( self._mean, 'MultivariateNormalCholesky.mean') self._n_dim = get_shape_at(self._mean, -1) self._cov_tril = tf.convert_to_tensor(cov_tril) self._cov_tril = assert_rank_at_least( self._cov_tril, 2, 'MultivariateNormalCholesky.cov_tril') # Static shape check expected_shape = self._mean.get_shape().concatenate( [self._n_dim if isinstance(self._n_dim, int) else None]) self._cov_tril.get_shape().assert_is_compatible_with(expected_shape) # Dynamic expected_shape = tf.concat([tf.shape(self._mean), [self._n_dim]], axis=0) actual_shape = tf.shape(self._cov_tril) msg = [ 'MultivariateNormalCholesky.cov_tril should have compatible ' 'shape with mean. Expected', expected_shape, ' got ', actual_shape ] assert_ops = [tf.assert_equal(expected_shape, actual_shape, msg)] with tf.control_dependencies(assert_ops): self._cov_tril = tf.identity(self._cov_tril) dtype = assert_same_float_dtype([ (self._mean, 'MultivariateNormalCholesky.mean'), (self._cov_tril, 'MultivariateNormalCholesky.cov_tril') ]) super(MultivariateNormalCholesky, self).__init__(dtype=dtype, param_dtype=dtype, is_continuous=True, is_reparameterized=is_reparameterized, use_path_derivative=use_path_derivative, group_ndims=group_ndims, **kwargs)