예제 #1
0
 def test_register_non_tensor_wrapper_class(self):
     with pytest.raises(TypeError,
                        match='`.*_NonTensorWrapperClass.*` is not a type, '
                        'or not a subclass of `TensorWrapper`'):
         register_tensor_wrapper_class(_NonTensorWrapperClass)
     with pytest.raises(TypeError,
                        match='`123` is not a type, or not a subclass of '
                        '`TensorWrapper`'):
         register_tensor_wrapper_class(123)
예제 #2
0
    def __repr__(self):
        return 'ZeroLogDet({},{})'.format(self._self_shape, self.dtype.name)

    @property
    def dtype(self):
        """Get the data type of the log-det."""
        return self._self_dtype

    @property
    def log_det_shape(self):
        """Get the shape of the log-det."""
        return self._self_shape

    @property
    def tensor(self):
        if self._self_tensor is None:
            self._self_tensor = tf.zeros(self.log_det_shape, dtype=self.dtype)
        return self._self_tensor

    def __neg__(self):
        return self

    def __add__(self, other):
        return broadcast_to_shape(other, self.log_det_shape)

    def __sub__(self, other):
        return -broadcast_to_shape(other, self.log_det_shape)


register_tensor_wrapper_class(ZeroLogDet)
예제 #3
0
from mock import Mock

from tfsnippet.stochastic import StochasticTensor
from tfsnippet.utils import TensorWrapper, register_tensor_wrapper_class


class _MyTensorWrapper(TensorWrapper):
    def __init__(self, wrapped):
        self._self_wrapped = wrapped

    @property
    def tensor(self):
        return self._self_wrapped


register_tensor_wrapper_class(_MyTensorWrapper)


class StochasticTensorTestCase(tf.test.TestCase):
    def test_equality(self):
        distrib = Mock(is_reparameterized=False)
        samples = tf.constant(0.)
        t = StochasticTensor(distrib, samples)
        self.assertEqual(t, t)
        self.assertEqual(hash(t), hash(t))
        self.assertNotEqual(StochasticTensor(distrib, samples), t)

    def test_construction(self):
        distrib = Mock(is_reparameterized=True, is_continuous=True)
        samples = tf.constant(12345678., dtype=tf.float32)
예제 #4
0

class _SimpleTensor(TensorWrapper):

    def __init__(self, wrapped, flag=None):
        self._self_flag_ = flag
        super(_SimpleTensor, self).__init__(wrapped)

    @property
    def flag(self):
        return self._self_flag_

    def get_flag(self):
        return self._self_flag_

register_tensor_wrapper_class(_SimpleTensor)


class TensorWrapperArithTestCase(TestCase):

    def test_prerequisite(self):
        if six.PY2:
            self.assertAlmostEqual(regular_div(3, 2), 1)
            self.assertAlmostEqual(regular_div(3.3, 1.6), 2.0625)
        else:
            self.assertAlmostEqual(regular_div(3, 2), 1.5)
            self.assertAlmostEqual(regular_div(3.3, 1.6), 2.0625)
        self.assertAlmostEqual(true_div(3, 2), 1.5)
        self.assertAlmostEqual(true_div(3.3, 1.6), 2.0625)
        self.assertAlmostEqual(floor_div(3, 2), 1)
        self.assertAlmostEqual(floor_div(3.3, 1.6), 2.0)
예제 #5
0
        """
        if group_ndims is None or group_ndims == self.group_ndims:
            if self._self_log_prob is None:
                self._self_log_prob = \
                    self.distribution.log_prob(self.tensor, self.group_ndims)
            return self._self_log_prob
        else:
            return self.distribution.log_prob(self.tensor, group_ndims)

    def prob(self, group_ndims=None):
        """
        Compute the probability densities of this :class:`StochasticTensor`.

        Args:
            group_ndims (int or tf.Tensor): If specified, overriding the
                configured `group_ndims`.

        Returns:
            tf.Tensor: The probability densities.
        """
        if group_ndims is None or group_ndims == self.group_ndims:
            if self._self_prob is None:
                self._self_prob = \
                    self.distribution.prob(self.tensor, self.group_ndims)
            return self._self_prob
        else:
            return self.distribution.prob(self.tensor, group_ndims)


register_tensor_wrapper_class(StochasticTensor)
예제 #6
0
    def _init(self, name, initial_value, dtype, collections, ratio, min_value):
        ratio = tf.convert_to_tensor(ratio)
        if ratio.dtype != dtype:
            ratio = tf.cast(ratio, dtype=dtype)

        if min_value is not None:
            min_value = tf.convert_to_tensor(min_value)
            if min_value.dtype != dtype:
                min_value = tf.cast(min_value, dtype=dtype)
            initial_value = tf.maximum(initial_value, min_value)

        super(AnnealingVariable, self)._init(name, initial_value, dtype,
                                             collections)

        with tf.name_scope('anneal_op'):
            if min_value is not None:
                self._self_anneal_op = tf.assign(
                    self._self_var,
                    tf.maximum(min_value, self._self_var * ratio))
            else:
                self._self_anneal_op = tf.assign(self._self_var,
                                                 self._self_var * ratio)

    def anneal(self):
        """Anneal the value."""
        get_default_session_or_error().run(self._self_anneal_op)


register_tensor_wrapper_class(ScheduledVariable)
예제 #7
0
    @property
    def flow_origin(self):
        """
        Get the original stochastic tensor from the base distribution.

        Returns:
            StochasticTensor: The original stochastic tensor.
        """
        return self._self_flow_origin

    @property
    def tensor(self):
        return self._self_tensor


register_tensor_wrapper_class(FlowDistributionDerivedTensor)


class FlowDistribution(Distribution):
    """
    Transform a :class:`Distribution` by a :class:`BaseFlow`, as a new
    distribution.
    """
    def __init__(self, distribution, flow):
        """
        Construct a new :class:`FlowDistribution` from the given `distribution`.

        Args:
            distribution (Distribution): The distribution to transform from.
                It must be continuous,
            flow (BaseFlow): A normalizing flow to transform the `distribution`.