def to_default_dtype(x): if get_backend() == "pytorch": return x.type(get_datatype()) else: import tensorflow as tf return tf.cast(x, get_datatype())
def eye(dims): """Identity matrix.""" if get_backend() == "pytorch": import torch return torch.eye(dims, dtype=get_datatype()) else: import tensorflow as tf return tf.eye(dims, dtype=get_datatype())
def randn(shape): """Tensor full of random values drawn from a standard normal.""" if get_backend() == "pytorch": import torch return torch.randn(shape, dtype=get_datatype()) else: import tensorflow as tf return tf.random.normal(shape, dtype=get_datatype())
def full(shape, value): """Tensor full of some value.""" if get_backend() == "pytorch": import torch return torch.full(shape, value, dtype=get_datatype()) else: import tensorflow as tf return tf.cast(tf.fill(shape, value), dtype=get_datatype())
def zeros(shape): """Tensor full of zeros.""" if get_backend() == "pytorch": import torch return torch.zeros(shape, dtype=get_datatype()) else: import tensorflow as tf return tf.zeros(shape, dtype=get_datatype())
def test_datatype(): """Tests get and set_datatype""" assert isinstance(settings.get_datatype(), tf.DType) assert settings.get_datatype() == tf.float32 settings.set_datatype(tf.float64) assert isinstance(settings.get_datatype(), tf.DType) assert settings.get_datatype() == tf.float64 settings.set_datatype(tf.float32) with pytest.raises(TypeError): settings.set_datatype("lala")
def insert_col_of(vals, val): """Add a column of a value to the left side of a tensor""" if get_backend() == "pytorch": import torch shape = [s for s in vals.shape[:-1]] + [1] return torch.cat([val * torch.ones(shape, dtype=get_datatype()), vals], dim=-1) else: import tensorflow as tf shape = tf.concat([vals.shape[:-1], [1]], axis=-1) return tf.concat([val * tf.ones(shape, dtype=get_datatype()), vals], axis=-1)
def xavier(shape): """Xavier initializer""" scale = np.sqrt(2 / sum(shape)) if get_backend() == "pytorch": # TODO: use truncated normal for torch import torch return torch.randn(shape, dtype=get_datatype()) * scale else: import tensorflow as tf return tf.random.truncated_normal(shape, mean=0.0, stddev=scale, dtype=get_datatype())
def log_cholesky_transform(x): r"""Perform the log cholesky transform on a vector of values. This turns a vector of :math:`\frac{N(N+1)}{2}` unconstrained values into a valid :math:`N \times N` covariance matrix. References ---------- - Jose C. Pinheiro & Douglas M. Bates. `Unconstrained Parameterizations for Variance-Covariance Matrices <https://dx.doi.org/10.1007/BF00140873>`_ *Statistics and Computing*, 1996. """ if get_backend() == "pytorch": import numpy as np import torch N = int((np.sqrt(1 + 8 * torch.numel(x)) - 1) / 2) E = torch.zeros((N, N), dtype=get_datatype()) tril_ix = torch.tril_indices(row=N, col=N, offset=0) E[..., tril_ix[0], tril_ix[1]] = x E[..., range(N), range(N)] = torch.exp(torch.diagonal(E)) return E @ torch.transpose(E, -1, -2) else: import tensorflow as tf import tensorflow_probability as tfp E = tfp.math.fill_triangular(x) E = tf.linalg.set_diag(E, tf.exp(tf.linalg.tensor_diag_part(E))) return E @ tf.transpose(E)
def additive_logistic_transform(vals): """The additive logistic transformation""" if get_backend() == "pytorch": import torch ones_shape = [s for s in vals.shape[:-1]] + [1] exp_vals = torch.cat( [torch.exp(vals), torch.ones(ones_shape, dtype=get_datatype())], dim=-1, ) return exp_vals / torch.sum(exp_vals, dim=-1, keepdim=True) else: import tensorflow as tf ones_shape = tf.concat([vals.shape[:-1], [1]], axis=-1) exp_vals = tf.concat( [tf.exp(vals), tf.ones(ones_shape, dtype=get_datatype())], axis=-1, ) return exp_vals / tf.reduce_sum(exp_vals, axis=-1, keepdims=True)
def rand_rademacher(shape): """Tensor full of random 0s or 1s (i.e. drawn from a Rademacher dist).""" if get_backend() == "pytorch": import torch return 2 * torch.randint(0, 1, shape, dtype=get_datatype()) - 1 else: import tensorflow_probability as tfp try: # for older versions of tfp, fall back on older version return tfp.random.rademacher(shape) except AttributeError: # pragma: no cover return tfp.python.math.random_rademacher(shape)