コード例 #1
0
ファイル: utils.py プロジェクト: 897615138/tfsnippet-jill
def validate_group_ndims(group_ndims, name=None):
    """
    Validate the specified value for `group_ndims` argument.

    If the specified `group_ndims` is a dynamic :class:`~tf.Tensor`,
    additional assertion will be added to the graph node of `group_ndims`.

    Args:
        group_ndims: Object to be validated.
        name: TensorFlow name scope of the graph nodes. (default
            "validate_group_ndims")

    Returns:
        tf.Tensor or int: The validated `group_ndims`.

    Raises:
        ValueError: If the specified `group_ndims` cannot pass validation.
    """
    @contextlib.contextmanager
    def gen_name_scope():
        if is_tensor_object(group_ndims):
            with tf.name_scope(name, default_name='validate_group_ndims'):
                yield
        else:
            yield
    with gen_name_scope():
        validator = TensorArgValidator('group_ndims')
        group_ndims = validator.require_non_negative(
            validator.require_int32(group_ndims)
        )
    return group_ndims
コード例 #2
0
def validate_n_samples(value, name):
    """
    Validate the `n_samples` argument.

    Args:
        value: An int32 value, a int32 :class:`tf.Tensor`, or :obj:`None`.
        name (str): Name of the argument (in error message).

    Returns:
        int or tf.Tensor: The validated `n_samples` argument value.

    Raises:
        TypeError or ValueError or None: If the value cannot be validated.
    """
    if is_tensor_object(value):

        @contextlib.contextmanager
        def mkcontext():
            with tf.name_scope('validate_n_samples'):
                yield
    else:

        @contextlib.contextmanager
        def mkcontext():
            yield

    if value is not None:
        with mkcontext():
            validator = TensorArgValidator(name=name)
            value = validator.require_positive(validator.require_int32(value))
    return value
コード例 #3
0
ファイル: reconstruction.py プロジェクト: 897615138/donut
def iterative_masked_reconstruct(reconstruct, x, mask, iter_count,
                                 back_prop=True, name=None):
    """
    用“mask”迭代地重构“x”“iter_count”次。
    这个方法将调用:func:`masked_reconstruct``iter_count`次,并将前一次迭代的输出作为下一次迭代的输入`x`。将返回最后一次迭代的输出。

    Args:
        reconstruct: 重构x的方法
        x: 被方法重构的张量
        mask: 32位整型,必须对x进行广播,指示每一个x是否要被覆盖掉
        iter_count (int or tf.Tensor):迭代次数 必须大于1
        back_prop (bool): 是否在所有迭代中支持反向传播?
            (default :obj:`True`)
        name (str): 此操作在TensorFlow图中的名称。
            (default "iterative_masked_reconstruct")

    Returns:
        tf.Tensor: 迭代重构的x。
    """
    with tf.name_scope(name, default_name='iterative_masked_reconstruct'):
        # 校验迭代次数
        v = TensorArgValidator('iter_count')
        iter_count = v.require_positive(v.require_int32(iter_count))

        # 覆盖重建
        x_r, _ = tf.while_loop(
            # 条件
            lambda x_i, i: i < iter_count,
            # 赋值 标记覆盖处重构
            lambda x_i, i: (masked_reconstruct(reconstruct, x_i, mask), i + 1),
            [x, tf.constant(0, dtype=tf.int32)],
            back_prop=back_prop
        )
        return x_r
コード例 #4
0
    def __init__(self,
                 distribution,
                 tensor,
                 n_samples=None,
                 group_ndims=0,
                 is_reparameterized=None,
                 flow_origin=None,
                 log_prob=None):
        """
        Construct the :class:`StochasticTensor`.

        Args:
            distribution (tfsnippet.distributions.Distribution): The
                distribution of this :class:`StochasticTensor`.
            tensor (tf.Tensor or TensorWrapper): The samples or observations
                of this :class:`StochasticTensor`.
            n_samples (tf.Tensor or int): The number of samples taken in
                :class:`Distribution.sample`.  If not :obj:`None`, the first
                dimension of `tensor` should be the sampling dimension.
            group_ndims (int or tf.Tensor): The number of dimensions to be
                considered as events group in samples. (default 0)
            is_reparameterized (bool): Whether or not the samples are
                re-parameterized?  If not specified, will inherit from
                :attr:`tfsnippet.distributions.Distribution.is_reparameterized`.
            log_prob (Tensor or None): Pre-computed log-density of `tensor`,
                given `group_ndims`.
            flow_origin (StochasticTensor): The original stochastic tensor
                from the base distribution of a
                :class:`tfsnippet.FlowDistribution`.
        """
        from tfsnippet.utils import TensorArgValidator, validate_group_ndims_arg

        if is_reparameterized is None:
            is_reparameterized = distribution.is_reparameterized
        if log_prob is not None and not is_tensor_object(log_prob):
            log_prob = tf.convert_to_tensor(log_prob)

        n_samples = validate_n_samples_arg(n_samples, 'n_samples')
        if n_samples is not None:
            with tf.name_scope('validate_n_samples'):
                validator = TensorArgValidator('n_samples')
                n_samples = validator.require_non_negative(
                    validator.require_int32(n_samples))

        group_ndims = validate_group_ndims_arg(group_ndims)

        super(StochasticTensor, self).__init__()
        self._self_distribution = distribution
        self._self_tensor = tf.convert_to_tensor(tensor)
        self._self_n_samples = n_samples
        self._self_group_ndims = group_ndims
        self._self_is_reparameterized = is_reparameterized
        self._self_flow_origin = flow_origin
        self._self_log_prob = log_prob
        self._self_prob = None
コード例 #5
0
    def __init__(self,
                 distribution,
                 tensor,
                 n_samples=None,
                 group_ndims=0,
                 is_reparameterized=None):
        """
        Construct the :class:`StochasticTensor`.

        Args:
            distribution (tfsnippet.distributions.Distribution): The
                distribution of this :class:`StochasticTensor`.
            tensor (tf.Tensor or TensorWrapper): The samples or observations
                of this :class:`StochasticTensor`.
            n_samples (tf.Tensor or int): The number of samples taken in
                :class:`Distribution.sample`.  If not :obj:`None`, the first
                dimension of `tensor` should be the sampling dimension.
            group_ndims (int or tf.Tensor): The number of dimensions to be
                considered as events group in samples. (default 0)
            is_reparameterized (bool): Whether or not the samples are
                re-parameterized?  If not specified, will inherit from
                :attr:`tfsnippet.distributions.Distribution.is_reparameterized`.
        """
        from tfsnippet.distributions import validate_group_ndims

        if is_reparameterized is None:
            is_reparameterized = distribution.is_reparameterized

        n_samples = validate_n_samples(n_samples, 'n_samples')
        if n_samples is not None:
            with tf.name_scope('validate_n_samples'):
                validator = TensorArgValidator('n_samples')
                n_samples = validator.require_non_negative(
                    validator.require_int32(n_samples))

        group_ndims = validate_group_ndims(group_ndims)

        super(StochasticTensor, self).__init__()
        self._self_distribution = distribution
        self._self_tensor = tf.convert_to_tensor(tensor)
        self._self_n_samples = n_samples
        self._self_group_ndims = group_ndims
        self._self_is_reparameterized = is_reparameterized
        self._self_log_prob = None
        self._self_prob = None
コード例 #6
0
ファイル: reconstruction.py プロジェクト: WenweiGu/DONUT-SMD
def iterative_masked_reconstruct(reconstruct,
                                 x,
                                 mask,
                                 iter_count,
                                 back_prop=True,
                                 name=None):
    """
    Iteratively reconstruct `x` with `mask` for `iter_count` times.

    This method will call :func:`masked_reconstruct` for `iter_count` times,
    with the output from previous iteration as the input `x` for the next
    iteration.  The output of the final iteration would be returned.

    Args:
        reconstruct: Function for reconstructing `x`.
        x: The tensor to be reconstructed by `func`.
        mask: int32 mask, must be broadcastable against `x`.
            Indicating whether or not to mask each element of `x`.
        iter_count (int or tf.Tensor):
            Number of iterations(must be greater than 1).
        back_prop (bool): Whether or not to support back-propagation through
            all the iterations? (default :obj:`True`)
        name (str): Name of this operation in TensorFlow graph.
            (default "iterative_masked_reconstruct")

    Returns:
        tf.Tensor: The iteratively reconstructed `x`.
    """
    with tf.name_scope(name, default_name='iterative_masked_reconstruct'):
        # validate the iteration count
        v = TensorArgValidator('iter_count')
        iter_count = v.require_positive(v.require_int32(iter_count))

        # do the masked reconstructions
        x_r, _ = tf.while_loop(
            lambda x_i, i: i < iter_count,
            lambda x_i, i: (masked_reconstruct(reconstruct, x_i, mask), i + 1),
            [x, tf.constant(0, dtype=tf.int32)],
            back_prop=back_prop)

        return x_r
コード例 #7
0
ファイル: test_typeutils.py プロジェクト: shliujing/tfsnippet
    def test_require_int32(self):
        v = TensorArgValidator('xyz')

        # test static values
        for o in [0, 1, -1]:
            self.assertEqual(v.require_int32(o), o)

        for o in [object(), None, (), [], 1.2, LONG_MAX]:
            with pytest.raises(TypeError,
                               match='xyz cannot be converted to int32'):
                _ = v.require_int32(o)

        # test dynamic values
        with self.test_session():
            for o in [0, 1, -1]:
                self.assertEqual(
                    v.require_int32(tf.constant(o, dtype=tf.int32)).eval(), o)

            for o in [
                    tf.constant(1.2, dtype=tf.float32),
                    tf.constant(LONG_MAX, dtype=tf.int64)
            ]:
                with pytest.raises(TypeError,
                                   match='xyz cannot be converted to int32'):
                    _ = v.require_int32(o)
コード例 #8
0
ファイル: test_typeutils.py プロジェクト: shliujing/tfsnippet
    def test_require_non_negative(self):
        v = TensorArgValidator('xyz')

        # test static values
        for o in [0, 0., 1e-7, 1, 1.]:
            self.assertEqual(v.require_non_negative(o), o)

        for o in [-1., -1, -1e-7]:
            with pytest.raises(ValueError, match='xyz must be non-negative'):
                _ = v.require_non_negative(o)

        # test dynamic values
        with self.test_session():
            for o, dtype in zip(
                [0, 0., 1e-7, 1, 1.],
                [tf.int32, tf.float32, tf.float32, tf.int32, tf.float32]):
                self.assertAllClose(
                    v.require_non_negative(tf.constant(o, dtype=dtype)).eval(),
                    o)

            for o, dtype in zip([-1., -1, -1e-7],
                                [tf.float32, tf.int32, tf.float32]):
                with pytest.raises(Exception,
                                   match='xyz must be non-negative'):
                    _ = v.require_non_negative(tf.constant(
                        o, dtype=dtype)).eval()