Пример #1
0
    def test_require_non_negative(self):
        v = TensorArgValidator('xyz')

        # test static values
        for o in [0, 0., 1e-7, 1, 1.]:
            self.assertEqual(v.require_non_negative(o), o)

        for o in [-1., -1, -1e-7]:
            with pytest.raises(ValueError, match='xyz must be non-negative'):
                _ = v.require_non_negative(o)

        # test dynamic values
        with self.test_session():
            for o, dtype in zip(
                [0, 0., 1e-7, 1, 1.],
                [tf.int32, tf.float32, tf.float32, tf.int32, tf.float32]):
                self.assertAllClose(
                    v.require_non_negative(tf.constant(o, dtype=dtype)).eval(),
                    o)

            for o, dtype in zip([-1., -1, -1e-7],
                                [tf.float32, tf.int32, tf.float32]):
                with pytest.raises(Exception,
                                   match='xyz must be non-negative'):
                    _ = v.require_non_negative(tf.constant(
                        o, dtype=dtype)).eval()
Пример #2
0
def validate_group_ndims(group_ndims, name=None):
    """
    Validate the specified value for `group_ndims` argument.

    If the specified `group_ndims` is a dynamic :class:`~tf.Tensor`,
    additional assertion will be added to the graph node of `group_ndims`.

    Args:
        group_ndims: Object to be validated.
        name: TensorFlow name scope of the graph nodes. (default
            "validate_group_ndims")

    Returns:
        tf.Tensor or int: The validated `group_ndims`.

    Raises:
        ValueError: If the specified `group_ndims` cannot pass validation.
    """
    @contextlib.contextmanager
    def gen_name_scope():
        if is_tensor_object(group_ndims):
            with tf.name_scope(name, default_name='validate_group_ndims'):
                yield
        else:
            yield
    with gen_name_scope():
        validator = TensorArgValidator('group_ndims')
        group_ndims = validator.require_non_negative(
            validator.require_int32(group_ndims)
        )
    return group_ndims
Пример #3
0
    def __init__(self,
                 distribution,
                 tensor,
                 n_samples=None,
                 group_ndims=0,
                 is_reparameterized=None,
                 flow_origin=None,
                 log_prob=None):
        """
        Construct the :class:`StochasticTensor`.

        Args:
            distribution (tfsnippet.distributions.Distribution): The
                distribution of this :class:`StochasticTensor`.
            tensor (tf.Tensor or TensorWrapper): The samples or observations
                of this :class:`StochasticTensor`.
            n_samples (tf.Tensor or int): The number of samples taken in
                :class:`Distribution.sample`.  If not :obj:`None`, the first
                dimension of `tensor` should be the sampling dimension.
            group_ndims (int or tf.Tensor): The number of dimensions to be
                considered as events group in samples. (default 0)
            is_reparameterized (bool): Whether or not the samples are
                re-parameterized?  If not specified, will inherit from
                :attr:`tfsnippet.distributions.Distribution.is_reparameterized`.
            log_prob (Tensor or None): Pre-computed log-density of `tensor`,
                given `group_ndims`.
            flow_origin (StochasticTensor): The original stochastic tensor
                from the base distribution of a
                :class:`tfsnippet.FlowDistribution`.
        """
        from tfsnippet.utils import TensorArgValidator, validate_group_ndims_arg

        if is_reparameterized is None:
            is_reparameterized = distribution.is_reparameterized
        if log_prob is not None and not is_tensor_object(log_prob):
            log_prob = tf.convert_to_tensor(log_prob)

        n_samples = validate_n_samples_arg(n_samples, 'n_samples')
        if n_samples is not None:
            with tf.name_scope('validate_n_samples'):
                validator = TensorArgValidator('n_samples')
                n_samples = validator.require_non_negative(
                    validator.require_int32(n_samples))

        group_ndims = validate_group_ndims_arg(group_ndims)

        super(StochasticTensor, self).__init__()
        self._self_distribution = distribution
        self._self_tensor = tf.convert_to_tensor(tensor)
        self._self_n_samples = n_samples
        self._self_group_ndims = group_ndims
        self._self_is_reparameterized = is_reparameterized
        self._self_flow_origin = flow_origin
        self._self_log_prob = log_prob
        self._self_prob = None
Пример #4
0
    def __init__(self,
                 distribution,
                 tensor,
                 n_samples=None,
                 group_ndims=0,
                 is_reparameterized=None):
        """
        Construct the :class:`StochasticTensor`.

        Args:
            distribution (tfsnippet.distributions.Distribution): The
                distribution of this :class:`StochasticTensor`.
            tensor (tf.Tensor or TensorWrapper): The samples or observations
                of this :class:`StochasticTensor`.
            n_samples (tf.Tensor or int): The number of samples taken in
                :class:`Distribution.sample`.  If not :obj:`None`, the first
                dimension of `tensor` should be the sampling dimension.
            group_ndims (int or tf.Tensor): The number of dimensions to be
                considered as events group in samples. (default 0)
            is_reparameterized (bool): Whether or not the samples are
                re-parameterized?  If not specified, will inherit from
                :attr:`tfsnippet.distributions.Distribution.is_reparameterized`.
        """
        from tfsnippet.distributions import validate_group_ndims

        if is_reparameterized is None:
            is_reparameterized = distribution.is_reparameterized

        n_samples = validate_n_samples(n_samples, 'n_samples')
        if n_samples is not None:
            with tf.name_scope('validate_n_samples'):
                validator = TensorArgValidator('n_samples')
                n_samples = validator.require_non_negative(
                    validator.require_int32(n_samples))

        group_ndims = validate_group_ndims(group_ndims)

        super(StochasticTensor, self).__init__()
        self._self_distribution = distribution
        self._self_tensor = tf.convert_to_tensor(tensor)
        self._self_n_samples = n_samples
        self._self_group_ndims = group_ndims
        self._self_is_reparameterized = is_reparameterized
        self._self_log_prob = None
        self._self_prob = None