Beispiel #1
0
def test_random_generators(f, t, dtype_transform, just_single_arg,
                           check_lazy_shapes):
    # Test without specifying data type.
    if not just_single_arg:
        assert B.dtype(f()) is dtype_transform(B.default_dtype)
        assert B.shape(f()) == ()
    assert B.dtype(f(2)) is dtype_transform(B.default_dtype)
    assert B.shape(f(2)) == (2, )
    if not just_single_arg:
        assert B.dtype(f(2, 3)) is dtype_transform(B.default_dtype)
        assert B.shape(f(2, 3)) == (2, 3)

    # Test with specifying data type.
    state = B.create_random_state(t, 0)

    # Test direct specification.
    if not just_single_arg:
        assert B.dtype(f(t)) is dtype_transform(t)
        assert B.shape(f(t)) == ()
    assert B.dtype(f(t, 2)) is dtype_transform(t)
    assert B.shape(f(t, 2)) == (2, )
    if not just_single_arg:
        assert B.dtype(f(t, 2, 3)) is dtype_transform(t)
        assert B.shape(f(t, 2, 3)) == (2, 3)

    # Test state specification.
    if not just_single_arg:
        assert isinstance(f(state, t)[0], B.RandomState)
        assert B.dtype(f(state, t)[1]) is dtype_transform(t)
        assert B.shape(f(state, t)[1]) == ()
    assert isinstance(f(state, t, 2)[0], B.RandomState)
    assert B.dtype(f(state, t, 2)[1]) is dtype_transform(t)
    assert B.shape(f(state, t, 2)[1]) == (2, )
    if not just_single_arg:
        assert isinstance(f(state, t, 2, 3)[0], B.RandomState)
        assert B.dtype(f(state, t, 2, 3)[1]) is dtype_transform(t)
        assert B.shape(f(state, t, 2, 3)[1]) == (2, 3)

    if not just_single_arg:
        # Test reference specification.
        assert B.dtype(f(f(t))) is dtype_transform(t)
        assert B.shape(f(f())) == ()
        assert B.dtype(f(f(t, 2))) is dtype_transform(t)
        assert B.shape(f(f(t, 2))) == (2, )
        assert B.dtype(f(f(t, 2, 3))) is dtype_transform(t)
        assert B.shape(f(f(t, 2, 3))) == (2, 3)

        # Test state and reference specification.
        assert isinstance(f(state, f(t))[0], B.RandomState)
        assert B.dtype(f(state, f(t))[1]) is dtype_transform(t)
        assert B.shape(f(state, f(t))[1]) == ()
        assert isinstance(f(state, f(t, 2))[0], B.RandomState)
        assert B.dtype(f(state, f(t, 2))[1]) is dtype_transform(t)
        assert B.shape(f(state, f(t, 2))[1]) == (2, )
        assert isinstance(f(state, f(t, 2, 3))[0], B.RandomState)
        assert B.dtype(f(state, f(t, 2, 3))[1]) is dtype_transform(t)
        assert B.shape(f(state, f(t, 2, 3))[1]) == (2, 3)
Beispiel #2
0
def diag(a: Zero):
    return B.zeros(B.dtype(a), _diag_len(a))
Beispiel #3
0
def _pad_zero_col(a):
    zeros = B.zeros(B.dtype(a), B.shape(a)[0], 1)
    return B.concat(a, zeros, axis=1)
Beispiel #4
0
def test_properties(dense1):
    approx(dense1.T, B.transpose(dense1))
    assert dense1.shape == B.shape(dense1)
    assert dense1.dtype == B.dtype(dense1)
Beispiel #5
0
def dtype(a: Diagonal):
    return B.dtype(a.diag)
Beispiel #6
0
def test_cast(check_lazy_shapes):
    # Test casting to a given data type.
    assert B.dtype(B.cast(np.float64, 1)) is np.float64
    assert B.dtype(B.cast(np.float64, np.array(1))) is np.float64
    assert B.dtype(B.cast(np.float64, autograd_box(
        np.float32(1)))) is np.float64

    assert B.dtype(B.cast(tf.float64, 1)) is tf.float64
    assert B.dtype(B.cast(tf.float64, np.array(1))) is tf.float64
    assert B.dtype(B.cast(tf.float64, tf.constant(1))) is tf.float64

    assert B.dtype(B.cast(torch.float64, 1)) is torch.float64
    assert B.dtype(B.cast(torch.float64, np.array(1))) is torch.float64
    assert B.dtype(B.cast(torch.float64, torch.tensor(1))) is torch.float64

    assert B.dtype(B.cast(jnp.float64, 1)) is jnp.float64
    assert B.dtype(B.cast(jnp.float64, np.array(1))) is jnp.float64
    assert B.dtype(B.cast(jnp.float64, jnp.array(1))) is jnp.float64
Beispiel #7
0
 def dtype(self):
     return B.dtype(self)
Beispiel #8
0
 def dtype(self):
     """dtype: Data type."""
     return B.dtype(self.lam, self.prec)
Beispiel #9
0
def test_normal_dtype(normal1):
    assert B.dtype(Normal(0, B.eye(3))) == np.float64
    assert B.dtype(Normal(B.ones(3), B.zeros(int, 3))) == np.float64
    assert B.dtype(Normal(B.ones(int, 3), B.zeros(int, 3))) == np.int64
Beispiel #10
0
def test_normal_cast(normal1):
    assert B.dtype(normal1) == np.float64
    assert B.dtype(B.cast(np.float32, normal1)) == np.float32
Beispiel #11
0
def check_function(
    f,
    args_spec,
    kw_args_spec=None,
    assert_dtype=True,
    skip=None,
    contains_nans=None,
):
    """Check that a function produces consistent output. Moreover, if the first
    argument is a data type, check that the result is exactly of that type."""
    skip = [] if skip is None else skip

    if kw_args_spec is None:
        kw_args_spec = {}

    # Construct product of keyword arguments.
    kw_args_prod = list(
        product(*[[(k, v) for v in vs.forms()]
                  for k, vs in kw_args_spec.items()]))
    kw_args_prod = [{k: v for k, v in kw_args} for kw_args in kw_args_prod]

    # Add default call.
    kw_args_prod += [{}]

    # Construct product of arguments.
    args_prod = list(product(*[arg.forms() for arg in args_spec]))

    # Construct framework types to skip mixes of.
    fw_types = [
        plum.Union(t, plum.List(t), plum.Tuple(t))
        for t in [B.AGNumeric, B.TorchNumeric, B.TFNumeric, B.JAXNumeric]
    ]

    # Construct other types to skip entirely.
    skip_types = [plum.Union(t, plum.List(t), plum.Tuple(t)) for t in skip]

    # Check consistency of results.
    for kw_args in kw_args_prod:
        # Compare everything against the first result.
        first_result = f(*args_prod[0], **kw_args)

        # If first argument is a data type, then check that.
        if isinstance(args_prod[0][0], B.DType):
            assert B.dtype(first_result) is args_prod[0][0]

        for args in args_prod:
            # Skip mixes of FW types.
            fw_count = sum(
                [any(isinstance(arg, t) for arg in args) for t in fw_types])

            # Skip all skips.
            skip_count = sum(
                [any(isinstance(arg, t) for arg in args) for t in skip_types])

            if fw_count >= 2 or skip_count >= 1:
                log.debug(f"Skipping call with arguments {args} and keyword "
                          f"arguments {kw_args}.")
                continue

            # Check consistency.
            log.debug(
                f"Call with arguments {args} and keyword arguments {kw_args}.")
            result = f(*args, **kw_args)
            approx(first_result, result, assert_dtype=assert_dtype)

            # If first argument is a data type, then again check that.
            if isinstance(args[0], B.DType):
                assert B.dtype(result) is args[0]

            # Check NaNs.
            if contains_nans is not None:
                assert B.any(B.isnan(result)) == contains_nans
Beispiel #12
0
def test_cast_shape_element(dtype, check_lazy_shapes):
    assert B.dtype(B.cast(dtype, B.shape(B.ones(dtype, 1))[0])) is dtype
Beispiel #13
0
def test_cast_own_dtype(x, check_lazy_shapes):
    # Test that casting to its own data type does nothing.
    assert x is B.cast(B.dtype(x), x)
Beispiel #14
0
def diag(a: Constant):
    return a.const * B.ones(B.dtype(a), _diag_len(a))
Beispiel #15
0
def kron(a: AbstractMatrix, b: Zero):
    return Zero(B.dtype(b), *_product_shape(a, b))
Beispiel #16
0
def dense(a: Constant):
    if a.dense is None:
        a.dense = a.const * B.ones(B.dtype(a.const), a.rows, a.cols)
    return a.dense
Beispiel #17
0
 def check(noise, dtype, n, asserted_type):
     noise = _noise_as_matrix(noise, dtype, n)
     assert isinstance(noise, asserted_type)
     assert B.dtype(noise) == dtype
     assert B.shape(noise) == (n, n)
Beispiel #18
0
def dtype(a: Union[Dense, LowerTriangular, UpperTriangular]):
    return B.dtype(a.mat)
Beispiel #19
0
def _pad_zero_row(a):
    zeros = B.zeros(B.dtype(a), 1, B.shape(a)[1])
    return B.concat(a, zeros, axis=0)
Beispiel #20
0
def dtype(a: Constant):
    return B.dtype(a.const)
Beispiel #21
0
    def _get_var(
        self,
        transform,
        inverse_transform,
        init,
        generate_init,
        shape,
        shape_latent,
        dtype,
        name,
    ):
        # If the name already exists, return that variable.
        try:
            return self[name]
        except KeyError:
            pass

        # A new variable will be added. Clear lookup cache.
        self._get_latent_vars_cache.clear()

        # Resolve data type.
        dtype = self._resolve_dtype(dtype)

        # If no source is provided, get the latent from from the provided
        # initialiser.
        if self.source is None:
            # Resolve initialisation.
            if init is None:
                init = generate_init(shape=shape, dtype=dtype)
            else:
                init = B.cast(dtype, init)

            # Ensure that the initialisation is on the right device.
            init = B.to_active_device(init)

            # Allow broadcasting in the initialisation.
            if shape is not None:
                init = init * B.ones(B.dtype(init), *shape)

            # Double check the shape of the initialisation.
            if shape is not None and Shape(*shape) != Shape(*B.shape(init)):
                raise ValueError(
                    f"Shape of initial value {B.shape(init)} is not equal to the "
                    f"desired shape {shape}.")

            # Construct optimisable variable.
            latent = inverse_transform(init)
            if isinstance(self.dtype, B.TFDType):
                latent = tf.Variable(latent)
            elif isinstance(self.dtype, B.TorchDType):
                pass  # All is good in this case.
            elif isinstance(self.dtype, B.JAXDType):
                latent = jnp.array(latent)
            else:
                # Must be a NumPy data type.
                assert isinstance(self.dtype, B.NPDType)
                latent = np.array(latent)
        else:
            # Get the latent variable from the source.
            length = reduce(mul, shape_latent, 1)
            latent_flat = self.source[self.source_index:self.source_index +
                                      length]
            self.source_index += length

            # Cast to the right data type.
            latent = B.cast(dtype, B.reshape(latent_flat, *shape_latent))

        # Store transforms.
        self.vars.append(latent)
        self.transforms.append(transform)
        self.inverse_transforms.append(inverse_transform)

        # Get index of the variable.
        index = len(self.vars) - 1

        # Store name if given.
        if name is not None:
            self.name_to_index[name] = index

        # Generate the variable and return.
        return transform(latent)