示例#1
0
    def __init__(self,
                 temperature,
                 logits,
                 group_ndims=0,
                 is_reparameterized=True,
                 use_path_derivative=False,
                 check_numerics=False,
                 **kwargs):
        self._logits = tf.convert_to_tensor(logits)
        self._temperature = tf.convert_to_tensor(temperature)
        param_dtype = assert_same_float_dtype([
            (self._logits, 'Concrete.logits'),
            (self._temperature, 'Concrete.temperature')
        ])

        self._logits, self._n_categories = assert_rank_at_least_one(
            self._logits, 'Concrete.logits')

        self._temperature = assert_scalar(self._temperature,
                                          'Concrete.temperature')

        self._check_numerics = check_numerics
        super(Concrete, self).__init__(dtype=param_dtype,
                                       param_dtype=param_dtype,
                                       is_continuous=True,
                                       is_reparameterized=is_reparameterized,
                                       use_path_derivative=use_path_derivative,
                                       group_ndims=group_ndims,
                                       **kwargs)
示例#2
0
    def __init__(self,
                 logits,
                 n_experiments,
                 dtype=None,
                 group_ndims=0,
                 **kwargs):
        self._logits = tf.convert_to_tensor(logits)
        param_dtype = assert_same_float_dtype([(self._logits,
                                                'Multinomial.logits')])

        if dtype is None:
            dtype = tf.int32
        assert_same_float_and_int_dtype([], dtype)

        self._logits, self._n_categories = assert_rank_at_least_one(
            self._logits, 'Multinomial.logits')

        self._n_experiments = assert_positive_int32_integer(
            n_experiments, 'Multinomial.n_experiments')

        super(Multinomial, self).__init__(dtype=dtype,
                                          param_dtype=param_dtype,
                                          is_continuous=False,
                                          is_reparameterized=False,
                                          group_ndims=group_ndims,
                                          **kwargs)
示例#3
0
    def __init__(self,
                 logits,
                 normalize_logits=True,
                 dtype=tf.int32,
                 group_ndims=0,
                 **kwargs):
        self._logits = tf.convert_to_tensor(logits)
        param_dtype = assert_same_float_dtype(
            [(self._logits, 'UnnormalizedMultinomial.logits')])

        assert_dtype_is_int_or_float(dtype)

        self._logits = assert_rank_at_least_one(
            self._logits, 'UnnormalizedMultinomial.logits')
        self._n_categories = get_shape_at(self._logits, -1)

        self.normalize_logits = normalize_logits

        super(UnnormalizedMultinomial, self).__init__(
            dtype=dtype,
            param_dtype=param_dtype,
            is_continuous=False,
            is_reparameterized=False,
            group_ndims=group_ndims,
            **kwargs)
    def __init__(self,
                 mean,
                 cov_tril,
                 group_ndims=0,
                 is_reparameterized=True,
                 use_path_derivative=False,
                 check_numerics=False,
                 **kwargs):
        self._check_numerics = check_numerics
        self._mean = tf.convert_to_tensor(mean)
        self._mean = assert_rank_at_least_one(
            self._mean, 'MultivariateNormalCholesky.mean')
        self._n_dim = get_shape_at(self._mean, -1)
        self._cov_tril = tf.convert_to_tensor(cov_tril)
        self._cov_tril = assert_rank_at_least(
            self._cov_tril, 2, 'MultivariateNormalCholesky.cov_tril')

        # Static shape check
        expected_shape = self._mean.get_shape().concatenate(
            [self._n_dim if isinstance(self._n_dim, int) else None])
        self._cov_tril.get_shape().assert_is_compatible_with(expected_shape)
        # Dynamic
        expected_shape = tf.concat([tf.shape(self._mean), [self._n_dim]],
                                   axis=0)
        actual_shape = tf.shape(self._cov_tril)
        msg = [
            'MultivariateNormalCholesky.cov_tril should have compatible '
            'shape with mean. Expected', expected_shape, ' got ', actual_shape
        ]
        assert_ops = [tf.assert_equal(expected_shape, actual_shape, msg)]
        with tf.control_dependencies(assert_ops):
            self._cov_tril = tf.identity(self._cov_tril)

        dtype = assert_same_float_dtype([
            (self._mean, 'MultivariateNormalCholesky.mean'),
            (self._cov_tril, 'MultivariateNormalCholesky.cov_tril')
        ])
        super(MultivariateNormalCholesky,
              self).__init__(dtype=dtype,
                             param_dtype=dtype,
                             is_continuous=True,
                             is_reparameterized=is_reparameterized,
                             use_path_derivative=use_path_derivative,
                             group_ndims=group_ndims,
                             **kwargs)
示例#5
0
    def __init__(self, logits, dtype=None, group_event_ndims=0):
        self._logits = tf.convert_to_tensor(logits)
        param_dtype = assert_same_float_dtype([(self._logits,
                                                'Categorical.logits')])

        if dtype is None:
            dtype = tf.int32
        assert_same_float_and_int_dtype([], dtype)

        self._logits, self._n_categories = assert_rank_at_least_one(
            self._logits, 'Categorical.logits')
        static_logits_shape = self._logits.get_shape()

        super(Categorical, self).__init__(dtype=dtype,
                                          param_dtype=param_dtype,
                                          is_continuous=False,
                                          is_reparameterized=False,
                                          group_event_ndims=group_event_ndims)
示例#6
0
    def __init__(self, logits, dtype=None, group_ndims=0, **kwargs):
        self._logits = tf.convert_to_tensor(logits)
        param_dtype = assert_same_float_dtype([(self._logits,
                                                'OnehotCategorical.logits')])

        if dtype is None:
            dtype = tf.int32
        assert_same_float_and_int_dtype([], dtype)

        self._logits = assert_rank_at_least_one(self._logits,
                                                'OnehotCategorical.logits')
        self._n_categories = get_shape_at(self._logits, -1)

        super(OnehotCategorical, self).__init__(dtype=dtype,
                                                param_dtype=param_dtype,
                                                is_continuous=False,
                                                is_reparameterized=False,
                                                group_ndims=group_ndims,
                                                **kwargs)