Exemplo n.º 1
0
def test_sampling():
    """Tests the Sampling context manager"""

    # Defaults before sampling
    assert settings.get_backend() == "tensorflow"
    assert settings.get_samples() is None
    assert settings.get_flipout() is False

    # Default should be samples=1 and flipout=False
    with settings.Sampling():
        assert settings.get_samples() == 1
        assert settings.get_flipout() is False

    # Should return to defaults after sampling
    assert settings.get_backend() == "tensorflow"
    assert settings.get_samples() is None
    assert settings.get_flipout() is False

    # Should be able to set samples and flipout via kwargs
    with settings.Sampling(n=100, flipout=True):
        assert settings.get_samples() == 100
        assert settings.get_flipout() is True

    # Again should return to defaults after __exit__
    assert settings.get_backend() == "tensorflow"
    assert settings.get_samples() is None
    assert settings.get_flipout() is False
Exemplo n.º 2
0
    def __call__(self, x):
        """Perform the forward pass"""

        # Using the Flipout estimator
        if get_flipout():

            # With PyTorch
            if get_backend() == "pytorch":
                raise NotImplementedError

            # With Tensorflow
            else:

                import tensorflow as tf
                import tensorflow_probability as tfp

                # Flipout-estimated weight samples
                s = tfp.python.math.random_rademacher(tf.shape(x))
                r = tfp.python.math.random_rademacher([x.shape[0], self.d_out])
                norm_samples = tf.random.normal([self.d_in, self.d_out])
                w_samples = self.weights.variables["scale"] * norm_samples
                w_noise = r * ((x * s) @ w_samples)
                w_outputs = x @ self.weights.variables["loc"] + w_noise

                # Flipout-estimated bias samples
                r = tfp.python.math.random_rademacher([x.shape[0], self.d_out])
                norm_samples = tf.random.normal([self.d_out])
                b_samples = self.bias.variables["scale"] * norm_samples
                b_outputs = self.bias.variables["loc"] + r * b_samples

                return w_outputs + b_outputs

        # Without Flipout
        else:
            return x @ self.weights() + self.bias()
Exemplo n.º 3
0
    def __call__(self, x):
        """Perform the forward pass"""

        x = to_tensor(x)

        # Using the Flipout estimator
        if (get_flipout() and self.flipout and self.probabilistic
                and get_samples() is not None and get_samples() == 1):

            # Flipout-estimated weight samples
            s = O.rand_rademacher(O.shape(x))
            r = O.rand_rademacher([O.shape(x)[0], self.d_out])
            norm_samples = O.randn([self.d_in, self.d_out])
            w_samples = self.weights.variables["scale"] * norm_samples
            w_noise = r * ((x * s) @ w_samples)
            w_outputs = x @ self.weights.variables["loc"] + w_noise

            # Flipout-estimated bias samples
            r = O.rand_rademacher([O.shape(x)[0], self.d_out])
            norm_samples = O.randn([self.d_out])
            b_samples = self.bias.variables["scale"] * norm_samples
            b_outputs = self.bias.variables["loc"] + r * b_samples

            return w_outputs + b_outputs

        # Without Flipout
        else:
            return x @ self.weights() + self.bias()
Exemplo n.º 4
0
def test_flipout():
    """Tests setting and getting the flipout setting"""

    # Default should be False
    assert settings.get_flipout() is False

    # Should be able to change to True or False
    settings.set_flipout(True)
    assert settings.get_flipout() is True
    settings.set_flipout(False)
    assert settings.get_flipout() is False

    # But only bool
    with pytest.raises(TypeError):
        settings.set_flipout(3.14)
    with pytest.raises(TypeError):
        settings.set_flipout(1)
    with pytest.raises(TypeError):
        settings.set_flipout("lalala")
Exemplo n.º 5
0
def test_sampling():
    """Tests the Sampling context manager"""

    # Defaults before sampling
    assert settings.get_backend() == "tensorflow"
    assert settings.get_samples() is None
    assert settings.get_flipout() is False
    assert settings.get_static_sampling_uuid() is None

    # Default should be Not to change anything
    with settings.Sampling():
        assert settings.get_backend() == "tensorflow"
        assert settings.get_samples() is None
        assert settings.get_flipout() is False
        assert settings.get_static_sampling_uuid() is None

    # Should be able to set samples and flipout via kwargs
    with settings.Sampling(n=100, flipout=True):
        assert settings.get_backend() == "tensorflow"
        assert settings.get_samples() == 100
        assert settings.get_flipout() is True
        assert settings.get_static_sampling_uuid() is None

    # Should return to defaults after sampling
    assert settings.get_backend() == "tensorflow"
    assert settings.get_samples() is None
    assert settings.get_flipout() is False
    assert settings.get_static_sampling_uuid() is None

    # Should be able to set static sampling uuid
    with settings.Sampling(static=True):
        assert settings.get_backend() == "tensorflow"
        assert settings.get_samples() is None
        assert settings.get_flipout() is False
        assert settings.get_static_sampling_uuid() is not None
        assert isinstance(settings.get_static_sampling_uuid(), uuid.UUID)

    # Should return to defaults after sampling
    assert settings.get_backend() == "tensorflow"
    assert settings.get_samples() is None
    assert settings.get_flipout() is False
    assert settings.get_static_sampling_uuid() is None

    # Should be able to nest sampling context managers
    with settings.Sampling(static=True):
        assert settings.get_backend() == "tensorflow"
        assert settings.get_samples() is None
        assert settings.get_flipout() is False
        assert settings.get_static_sampling_uuid() is not None
        assert isinstance(settings.get_static_sampling_uuid(), uuid.UUID)
        with settings.Sampling(n=100, flipout=True):
            assert settings.get_backend() == "tensorflow"
            assert settings.get_samples() == 100
            assert settings.get_flipout() is True
            assert settings.get_static_sampling_uuid() is not None
            assert isinstance(settings.get_static_sampling_uuid(), uuid.UUID)

    # Should return to defaults after sampling
    assert settings.get_backend() == "tensorflow"
    assert settings.get_samples() is None
    assert settings.get_flipout() is False
    assert settings.get_static_sampling_uuid() is None