def validate_group_ndims(group_ndims, name=None): """ Validate the specified value for `group_ndims` argument. If the specified `group_ndims` is a dynamic :class:`~tf.Tensor`, additional assertion will be added to the graph node of `group_ndims`. Args: group_ndims: Object to be validated. name: TensorFlow name scope of the graph nodes. (default "validate_group_ndims") Returns: tf.Tensor or int: The validated `group_ndims`. Raises: ValueError: If the specified `group_ndims` cannot pass validation. """ @contextlib.contextmanager def gen_name_scope(): if is_tensor_object(group_ndims): with tf.name_scope(name, default_name='validate_group_ndims'): yield else: yield with gen_name_scope(): validator = TensorArgValidator('group_ndims') group_ndims = validator.require_non_negative( validator.require_int32(group_ndims) ) return group_ndims
def validate_n_samples(value, name): """ Validate the `n_samples` argument. Args: value: An int32 value, a int32 :class:`tf.Tensor`, or :obj:`None`. name (str): Name of the argument (in error message). Returns: int or tf.Tensor: The validated `n_samples` argument value. Raises: TypeError or ValueError or None: If the value cannot be validated. """ if is_tensor_object(value): @contextlib.contextmanager def mkcontext(): with tf.name_scope('validate_n_samples'): yield else: @contextlib.contextmanager def mkcontext(): yield if value is not None: with mkcontext(): validator = TensorArgValidator(name=name) value = validator.require_positive(validator.require_int32(value)) return value
def iterative_masked_reconstruct(reconstruct, x, mask, iter_count, back_prop=True, name=None): """ 用“mask”迭代地重构“x”“iter_count”次。 这个方法将调用:func:`masked_reconstruct``iter_count`次,并将前一次迭代的输出作为下一次迭代的输入`x`。将返回最后一次迭代的输出。 Args: reconstruct: 重构x的方法 x: 被方法重构的张量 mask: 32位整型,必须对x进行广播,指示每一个x是否要被覆盖掉 iter_count (int or tf.Tensor):迭代次数 必须大于1 back_prop (bool): 是否在所有迭代中支持反向传播? (default :obj:`True`) name (str): 此操作在TensorFlow图中的名称。 (default "iterative_masked_reconstruct") Returns: tf.Tensor: 迭代重构的x。 """ with tf.name_scope(name, default_name='iterative_masked_reconstruct'): # 校验迭代次数 v = TensorArgValidator('iter_count') iter_count = v.require_positive(v.require_int32(iter_count)) # 覆盖重建 x_r, _ = tf.while_loop( # 条件 lambda x_i, i: i < iter_count, # 赋值 标记覆盖处重构 lambda x_i, i: (masked_reconstruct(reconstruct, x_i, mask), i + 1), [x, tf.constant(0, dtype=tf.int32)], back_prop=back_prop ) return x_r
def __init__(self, distribution, tensor, n_samples=None, group_ndims=0, is_reparameterized=None, flow_origin=None, log_prob=None): """ Construct the :class:`StochasticTensor`. Args: distribution (tfsnippet.distributions.Distribution): The distribution of this :class:`StochasticTensor`. tensor (tf.Tensor or TensorWrapper): The samples or observations of this :class:`StochasticTensor`. n_samples (tf.Tensor or int): The number of samples taken in :class:`Distribution.sample`. If not :obj:`None`, the first dimension of `tensor` should be the sampling dimension. group_ndims (int or tf.Tensor): The number of dimensions to be considered as events group in samples. (default 0) is_reparameterized (bool): Whether or not the samples are re-parameterized? If not specified, will inherit from :attr:`tfsnippet.distributions.Distribution.is_reparameterized`. log_prob (Tensor or None): Pre-computed log-density of `tensor`, given `group_ndims`. flow_origin (StochasticTensor): The original stochastic tensor from the base distribution of a :class:`tfsnippet.FlowDistribution`. """ from tfsnippet.utils import TensorArgValidator, validate_group_ndims_arg if is_reparameterized is None: is_reparameterized = distribution.is_reparameterized if log_prob is not None and not is_tensor_object(log_prob): log_prob = tf.convert_to_tensor(log_prob) n_samples = validate_n_samples_arg(n_samples, 'n_samples') if n_samples is not None: with tf.name_scope('validate_n_samples'): validator = TensorArgValidator('n_samples') n_samples = validator.require_non_negative( validator.require_int32(n_samples)) group_ndims = validate_group_ndims_arg(group_ndims) super(StochasticTensor, self).__init__() self._self_distribution = distribution self._self_tensor = tf.convert_to_tensor(tensor) self._self_n_samples = n_samples self._self_group_ndims = group_ndims self._self_is_reparameterized = is_reparameterized self._self_flow_origin = flow_origin self._self_log_prob = log_prob self._self_prob = None
def __init__(self, distribution, tensor, n_samples=None, group_ndims=0, is_reparameterized=None): """ Construct the :class:`StochasticTensor`. Args: distribution (tfsnippet.distributions.Distribution): The distribution of this :class:`StochasticTensor`. tensor (tf.Tensor or TensorWrapper): The samples or observations of this :class:`StochasticTensor`. n_samples (tf.Tensor or int): The number of samples taken in :class:`Distribution.sample`. If not :obj:`None`, the first dimension of `tensor` should be the sampling dimension. group_ndims (int or tf.Tensor): The number of dimensions to be considered as events group in samples. (default 0) is_reparameterized (bool): Whether or not the samples are re-parameterized? If not specified, will inherit from :attr:`tfsnippet.distributions.Distribution.is_reparameterized`. """ from tfsnippet.distributions import validate_group_ndims if is_reparameterized is None: is_reparameterized = distribution.is_reparameterized n_samples = validate_n_samples(n_samples, 'n_samples') if n_samples is not None: with tf.name_scope('validate_n_samples'): validator = TensorArgValidator('n_samples') n_samples = validator.require_non_negative( validator.require_int32(n_samples)) group_ndims = validate_group_ndims(group_ndims) super(StochasticTensor, self).__init__() self._self_distribution = distribution self._self_tensor = tf.convert_to_tensor(tensor) self._self_n_samples = n_samples self._self_group_ndims = group_ndims self._self_is_reparameterized = is_reparameterized self._self_log_prob = None self._self_prob = None
def iterative_masked_reconstruct(reconstruct, x, mask, iter_count, back_prop=True, name=None): """ Iteratively reconstruct `x` with `mask` for `iter_count` times. This method will call :func:`masked_reconstruct` for `iter_count` times, with the output from previous iteration as the input `x` for the next iteration. The output of the final iteration would be returned. Args: reconstruct: Function for reconstructing `x`. x: The tensor to be reconstructed by `func`. mask: int32 mask, must be broadcastable against `x`. Indicating whether or not to mask each element of `x`. iter_count (int or tf.Tensor): Number of iterations(must be greater than 1). back_prop (bool): Whether or not to support back-propagation through all the iterations? (default :obj:`True`) name (str): Name of this operation in TensorFlow graph. (default "iterative_masked_reconstruct") Returns: tf.Tensor: The iteratively reconstructed `x`. """ with tf.name_scope(name, default_name='iterative_masked_reconstruct'): # validate the iteration count v = TensorArgValidator('iter_count') iter_count = v.require_positive(v.require_int32(iter_count)) # do the masked reconstructions x_r, _ = tf.while_loop( lambda x_i, i: i < iter_count, lambda x_i, i: (masked_reconstruct(reconstruct, x_i, mask), i + 1), [x, tf.constant(0, dtype=tf.int32)], back_prop=back_prop) return x_r
def test_require_int32(self): v = TensorArgValidator('xyz') # test static values for o in [0, 1, -1]: self.assertEqual(v.require_int32(o), o) for o in [object(), None, (), [], 1.2, LONG_MAX]: with pytest.raises(TypeError, match='xyz cannot be converted to int32'): _ = v.require_int32(o) # test dynamic values with self.test_session(): for o in [0, 1, -1]: self.assertEqual( v.require_int32(tf.constant(o, dtype=tf.int32)).eval(), o) for o in [ tf.constant(1.2, dtype=tf.float32), tf.constant(LONG_MAX, dtype=tf.int64) ]: with pytest.raises(TypeError, match='xyz cannot be converted to int32'): _ = v.require_int32(o)
def test_require_non_negative(self): v = TensorArgValidator('xyz') # test static values for o in [0, 0., 1e-7, 1, 1.]: self.assertEqual(v.require_non_negative(o), o) for o in [-1., -1, -1e-7]: with pytest.raises(ValueError, match='xyz must be non-negative'): _ = v.require_non_negative(o) # test dynamic values with self.test_session(): for o, dtype in zip( [0, 0., 1e-7, 1, 1.], [tf.int32, tf.float32, tf.float32, tf.int32, tf.float32]): self.assertAllClose( v.require_non_negative(tf.constant(o, dtype=dtype)).eval(), o) for o, dtype in zip([-1., -1, -1e-7], [tf.float32, tf.int32, tf.float32]): with pytest.raises(Exception, match='xyz must be non-negative'): _ = v.require_non_negative(tf.constant( o, dtype=dtype)).eval()