Пример #1
0
def iterative_masked_reconstruct(reconstruct, x, mask, iter_count,
                                 back_prop=True, name=None):
    """
    用“mask”迭代地重构“x”“iter_count”次。
    这个方法将调用:func:`masked_reconstruct``iter_count`次,并将前一次迭代的输出作为下一次迭代的输入`x`。将返回最后一次迭代的输出。

    Args:
        reconstruct: 重构x的方法
        x: 被方法重构的张量
        mask: 32位整型,必须对x进行广播,指示每一个x是否要被覆盖掉
        iter_count (int or tf.Tensor):迭代次数 必须大于1
        back_prop (bool): 是否在所有迭代中支持反向传播?
            (default :obj:`True`)
        name (str): 此操作在TensorFlow图中的名称。
            (default "iterative_masked_reconstruct")

    Returns:
        tf.Tensor: 迭代重构的x。
    """
    with tf.name_scope(name, default_name='iterative_masked_reconstruct'):
        # 校验迭代次数
        v = TensorArgValidator('iter_count')
        iter_count = v.require_positive(v.require_int32(iter_count))

        # 覆盖重建
        x_r, _ = tf.while_loop(
            # 条件
            lambda x_i, i: i < iter_count,
            # 赋值 标记覆盖处重构
            lambda x_i, i: (masked_reconstruct(reconstruct, x_i, mask), i + 1),
            [x, tf.constant(0, dtype=tf.int32)],
            back_prop=back_prop
        )
        return x_r
Пример #2
0
def validate_n_samples(value, name):
    """
    Validate the `n_samples` argument.

    Args:
        value: An int32 value, a int32 :class:`tf.Tensor`, or :obj:`None`.
        name (str): Name of the argument (in error message).

    Returns:
        int or tf.Tensor: The validated `n_samples` argument value.

    Raises:
        TypeError or ValueError or None: If the value cannot be validated.
    """
    if is_tensor_object(value):

        @contextlib.contextmanager
        def mkcontext():
            with tf.name_scope('validate_n_samples'):
                yield
    else:

        @contextlib.contextmanager
        def mkcontext():
            yield

    if value is not None:
        with mkcontext():
            validator = TensorArgValidator(name=name)
            value = validator.require_positive(validator.require_int32(value))
    return value
Пример #3
0
    def test_require_positive(self):
        v = TensorArgValidator('xyz')

        # test static values
        for o in [1e-7, 1, 1.]:
            self.assertEqual(v.require_positive(o), o)

        for o in [-1., -1, -1e-7, 0., 0]:
            with pytest.raises(ValueError, match='xyz must be positive'):
                _ = v.require_positive(o)

        # test dynamic values
        with self.test_session():
            for o, dtype in zip([1e-7, 1, 1.],
                                [tf.float32, tf.int32, tf.float32]):
                self.assertAllClose(
                    v.require_positive(tf.constant(o, dtype=dtype)).eval(), o)

            for o, dtype in zip(
                [-1., -1, -1e-7, 0., 0],
                [tf.float32, tf.int32, tf.float32, tf.float32, tf.int32]):
                with pytest.raises(Exception, match='xyz must be positive'):
                    _ = v.require_positive(tf.constant(o, dtype=dtype)).eval()
Пример #4
0
def iterative_masked_reconstruct(reconstruct,
                                 x,
                                 mask,
                                 iter_count,
                                 back_prop=True,
                                 name=None):
    """
    Iteratively reconstruct `x` with `mask` for `iter_count` times.

    This method will call :func:`masked_reconstruct` for `iter_count` times,
    with the output from previous iteration as the input `x` for the next
    iteration.  The output of the final iteration would be returned.

    Args:
        reconstruct: Function for reconstructing `x`.
        x: The tensor to be reconstructed by `func`.
        mask: int32 mask, must be broadcastable against `x`.
            Indicating whether or not to mask each element of `x`.
        iter_count (int or tf.Tensor):
            Number of iterations(must be greater than 1).
        back_prop (bool): Whether or not to support back-propagation through
            all the iterations? (default :obj:`True`)
        name (str): Name of this operation in TensorFlow graph.
            (default "iterative_masked_reconstruct")

    Returns:
        tf.Tensor: The iteratively reconstructed `x`.
    """
    with tf.name_scope(name, default_name='iterative_masked_reconstruct'):
        # validate the iteration count
        v = TensorArgValidator('iter_count')
        iter_count = v.require_positive(v.require_int32(iter_count))

        # do the masked reconstructions
        x_r, _ = tf.while_loop(
            lambda x_i, i: i < iter_count,
            lambda x_i, i: (masked_reconstruct(reconstruct, x_i, mask), i + 1),
            [x, tf.constant(0, dtype=tf.int32)],
            back_prop=back_prop)

        return x_r