def __init__(self,
                 mean,
                 u_tril,
                 v_tril,
                 group_ndims=0,
                 is_reparameterized=True,
                 use_path_derivative=False,
                 check_numerics=False,
                 **kwargs):
        self._check_numerics = check_numerics
        self._mean = tf.convert_to_tensor(mean)
        self._mean = assert_rank_at_least(
            self._mean, 2, 'MatrixVariateNormalCholesky.mean')
        self._n_row = get_shape_at(self._mean, -2)
        self._n_col = get_shape_at(self._mean, -1)
        self._u_tril = tf.convert_to_tensor(u_tril)
        self._u_tril = assert_rank_at_least(
            self._u_tril, 2, 'MatrixVariateNormalCholesky.u_tril')
        self._v_tril = tf.convert_to_tensor(v_tril)
        self._v_tril = assert_rank_at_least(
            self._v_tril, 2, 'MatrixVariateNormalCholesky.v_tril')

        # Static shape check
        expected_u_shape = self._mean.get_shape()[:-1].concatenate(
            [self._n_row if isinstance(self._n_row, int) else None])
        self._u_tril.get_shape().assert_is_compatible_with(expected_u_shape)
        expected_v_shape = self._mean.get_shape()[:-2].concatenate(
            [self._n_col if isinstance(self._n_col, int) else None] * 2)
        self._v_tril.get_shape().assert_is_compatible_with(expected_v_shape)
        # Dynamic
        expected_u_shape = tf.concat(
            [tf.shape(self._mean)[:-1], [self._n_row]], axis=0)
        actual_u_shape = tf.shape(self._u_tril)
        msg = ['MatrixVariateNormalCholesky.u_tril should have compatible '
               'shape with mean. Expected', expected_u_shape, ' got ',
               actual_u_shape]
        assert_u_ops = tf.assert_equal(expected_u_shape, actual_u_shape, msg)
        expected_v_shape = tf.concat(
            [tf.shape(self._mean)[:-2], [self._n_col, self._n_col]], axis=0)
        actual_v_shape = tf.shape(self._v_tril)
        msg = ['MatrixVariateNormalCholesky.v_tril should have compatible '
               'shape with mean. Expected', expected_v_shape, ' got ',
               actual_v_shape]
        assert_v_ops = tf.assert_equal(expected_v_shape, actual_v_shape, msg)
        with tf.control_dependencies([assert_u_ops, assert_v_ops]):
            self._u_tril = tf.identity(self._u_tril)
            self._v_tril = tf.identity(self._v_tril)

        dtype = assert_same_float_dtype(
            [(self._mean, 'MatrixVariateNormalCholesky.mean'),
             (self._u_tril, 'MatrixVariateNormalCholesky.u_tril'),
             (self._v_tril, 'MatrixVariateNormalCholesky.v_tril')])
        super(MatrixVariateNormalCholesky, self).__init__(
            dtype=dtype,
            param_dtype=dtype,
            is_continuous=True,
            is_reparameterized=is_reparameterized,
            use_path_derivative=use_path_derivative,
            group_ndims=group_ndims,
            **kwargs)
Beispiel #2
0
 def test_get_shape_at(self):
     with self.test_session(use_gpu=True):
         ph = tf.placeholder(tf.float32, [2, None])
         # Static
         self.assertEqual(get_shape_at(ph, 0), 2)
         # Dynamic
         fd = {ph: np.ones([2, 9])}
         self.assertEqual(get_shape_at(ph, 1).eval(fd), 9)
Beispiel #3
0
    def __init__(self,
                 temperature,
                 logits,
                 group_ndims=0,
                 is_reparameterized=True,
                 use_path_derivative=False,
                 check_numerics=False,
                 **kwargs):
        self._logits = tf.convert_to_tensor(logits)
        self._temperature = tf.convert_to_tensor(temperature)
        param_dtype = assert_same_float_dtype(
            [(self._logits, 'Concrete.logits'),
             (self._temperature, 'Concrete.temperature')])

        self._logits = assert_rank_at_least_one(
            self._logits, 'Concrete.logits')
        self._n_categories = get_shape_at(self._logits, -1)

        self._temperature = assert_scalar(
            self._temperature, 'Concrete.temperature')

        self._check_numerics = check_numerics
        super(Concrete, self).__init__(
            dtype=param_dtype,
            param_dtype=param_dtype,
            is_continuous=True,
            is_reparameterized=is_reparameterized,
            use_path_derivative=use_path_derivative,
            group_ndims=group_ndims,
            **kwargs)
Beispiel #4
0
    def __init__(self,
                 logits,
                 normalize_logits=True,
                 dtype=tf.int32,
                 group_ndims=0,
                 **kwargs):
        self._logits = tf.convert_to_tensor(logits)
        param_dtype = assert_same_float_dtype(
            [(self._logits, 'UnnormalizedMultinomial.logits')])

        assert_dtype_is_int_or_float(dtype)

        self._logits = assert_rank_at_least_one(
            self._logits, 'UnnormalizedMultinomial.logits')
        self._n_categories = get_shape_at(self._logits, -1)

        self.normalize_logits = normalize_logits

        super(UnnormalizedMultinomial, self).__init__(
            dtype=dtype,
            param_dtype=param_dtype,
            is_continuous=False,
            is_reparameterized=False,
            group_ndims=group_ndims,
            **kwargs)
Beispiel #5
0
    def __init__(self,
                 logits,
                 n_experiments,
                 dtype=None,
                 group_ndims=0,
                 **kwargs):
        self._logits = tf.convert_to_tensor(logits)
        param_dtype = assert_same_float_dtype([(self._logits,
                                                'Multinomial.logits')])

        if dtype is None:
            dtype = tf.int32
        assert_same_float_and_int_dtype([], dtype)

        self._logits = assert_rank_at_least_one(self._logits,
                                                'Multinomial.logits')
        self._n_categories = get_shape_at(self._logits, -1)

        self._n_experiments = assert_positive_int32_integer(
            n_experiments, 'Multinomial.n_experiments')

        super(Multinomial, self).__init__(dtype=dtype,
                                          param_dtype=param_dtype,
                                          is_continuous=False,
                                          is_reparameterized=False,
                                          group_ndims=group_ndims,
                                          **kwargs)
    def __init__(self,
                 mean,
                 cov_tril,
                 group_ndims=0,
                 is_reparameterized=True,
                 use_path_derivative=False,
                 check_numerics=False,
                 **kwargs):
        self._check_numerics = check_numerics
        self._mean = tf.convert_to_tensor(mean)
        self._mean = assert_rank_at_least_one(
            self._mean, 'MultivariateNormalCholesky.mean')
        self._n_dim = get_shape_at(self._mean, -1)
        self._cov_tril = tf.convert_to_tensor(cov_tril)
        self._cov_tril = assert_rank_at_least(
            self._cov_tril, 2, 'MultivariateNormalCholesky.cov_tril')

        # Static shape check
        expected_shape = self._mean.get_shape().concatenate(
            [self._n_dim if isinstance(self._n_dim, int) else None])
        self._cov_tril.get_shape().assert_is_compatible_with(expected_shape)
        # Dynamic
        expected_shape = tf.concat([tf.shape(self._mean), [self._n_dim]],
                                   axis=0)
        actual_shape = tf.shape(self._cov_tril)
        msg = [
            'MultivariateNormalCholesky.cov_tril should have compatible '
            'shape with mean. Expected', expected_shape, ' got ', actual_shape
        ]
        assert_ops = [tf.assert_equal(expected_shape, actual_shape, msg)]
        with tf.control_dependencies(assert_ops):
            self._cov_tril = tf.identity(self._cov_tril)

        dtype = assert_same_float_dtype([
            (self._mean, 'MultivariateNormalCholesky.mean'),
            (self._cov_tril, 'MultivariateNormalCholesky.cov_tril')
        ])
        super(MultivariateNormalCholesky,
              self).__init__(dtype=dtype,
                             param_dtype=dtype,
                             is_continuous=True,
                             is_reparameterized=is_reparameterized,
                             use_path_derivative=use_path_derivative,
                             group_ndims=group_ndims,
                             **kwargs)