Пример #1
0
 def f():
     return basic.Linear(output_size=2,
                         name=linear_name)(jnp.zeros([6]))
Пример #2
0
 def log_abs_det_jacobian(self, x, y, intermediates=None):
     return jnp.zeros(jnp.shape(x)[:-1])
Пример #3
0
def sincos_nonnegative_softmax_kernel_feature_creator0(
    data,
    projection_matrix,
    batch_dims_t,
    precision,
    is_query,
    normalize_data=False,
    eps=0.0001,
):
    """Constructs nonnegative kernel features for fast softmax attention.

    Args:
    data: input for which features are computes
    projection_matrix: random matrix used to compute features
    batch_dims_t: tuple of batch dimensions
    precision: precision parameter
    is_query: predicate indicating whether input data corresponds to queries or
      keys
    normalize_data: predicate indicating whether data should be normalized,
    eps: numerical stabilizer.

    Returns:
    Random features for fast softmax attention.
    """
    if normalize_data:
        # We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where
        # w_norm = w * data_normalizer for w in {q,k}.
        data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
    else:
        data_normalizer = 1.0
    #ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0])
    ratio = 1.0
    data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape
    data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix

    #"""
    data_dash = lax.dot_general(data_normalizer * data,
                                data_thick_random_matrix,
                                (((data.ndim - 1, ),
                                  (data_thick_random_matrix.ndim - 1, )),
                                 (batch_dims_t, batch_dims_t)),
                                precision=precision)
    #"""
    #data_dash = jnp.einsum("...bd,...fd->...bf", data_normalizer * data, projection_matrix)
    data_dash_cos = ratio * jnp.cos(data_dash)
    data_dash_sin = ratio * jnp.sin(data_dash)
    data_dash = jnp.concatenate((data_dash_cos, data_dash_sin), axis=-1)

    # Constructing D_data and data^{'}
    diag_data = jnp.square(data)
    diag_data = jnp.sum(diag_data, axis=data.ndim - 1)
    diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
    diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1)
    # Additional renormalization for numerical stability
    # which one?
    #data_renormalizer = jnp.max(diag_data, -1, keepdims=True)
    data_renormalizer = jnp.max(diag_data)
    diag_data -= data_renormalizer
    diag_data = jnp.exp(diag_data)
    data_prime = data_dash * diag_data

    return data_prime + eps
Пример #4
0
import jax
import jax.numpy as np
import tigercontrol as tc
import numpy.random as random
import matplotlib.pyplot as plt
from scipy.linalg import solve_discrete_are as dare
from tigercontrol.controllers import Controller

T, H, M, lr = 200, 10, 10, 0.001
n, m, A, B = 2, 1, np.array([[1., 1.], [0., 1.]]), np.array([[0.], [1.]])
Q, R = np.eye(N=n), np.eye(N=m)
x0 = np.zeros((n, 1))

Wproc = lambda n, x, u, w, t: random.normal(size=(n, 1))
Wproc = lambda n, x, u, w, t: np.sin(t / (2 * np.pi)) * np.ones((2, 1))

env = tc.environment('LDS')


class LQR(Controller):
    def __init__(self, A, B, Q, R):
        P = dare(A, B, Q, R)
        self.K = np.linalg.inv(R + B.T @ P @ B) @ (B.T @ P @ A)

    def plan(self, x):
        return -self.K @ x


x = env.initialize(n,
                   m,
                   noise_distribution=Wproc,
Пример #5
0
    def initialize(self, n, m, h=64):
        """
        Description: Randomly initialize the RNN.
        Args:
            n (int): Input dimension.
            m (int): Observation/output dimension.
            h (int): Default value 64. Hidden dimension of RNN.
        Returns:
            The first value in the time-series
        """
        self.T = 0
        self.initialized = True
        self.n, self.m, self.h = n, m, h

        glorot_init = stax.glorot(
        )  # returns a function that initializes weights
        self.W_h = glorot_init(generate_key(), (h, h))
        self.W_u = glorot_init(generate_key(), (h, n))
        self.W_out = glorot_init(generate_key(), (m, h))
        self.b_h = np.zeros(h)
        self.hid = np.zeros(h)

        self.rollout_controller = None
        self.target = jax.random.uniform(generate_key(),
                                         shape=(self.m, ),
                                         minval=-1,
                                         maxval=1)
        '''
        def _step(x, hid):
            next_hid = np.tanh(np.dot(self.W_h, hid) + np.dot(self.W_x, x) + self.b_h)
            y = np.dot(self.W_out, next_hid)
            return (next_hid, y)'''
        def _dynamics(hid, u):
            next_hid = np.tanh(
                np.dot(self.W_h, hid) + np.dot(self.W_u, u) + self.b_h)
            y = np.dot(self.W_out, next_hid)
            return (next_hid, y)

        # self._step = jax.jit(_step)
        self._dynamics = jax.jit(_dynamics)

        self._loss = lambda x, u: (self.target - self._dynamics(x, u))**2

        # stack the jacobians of environment dynamics gradient
        jacobian = jax.jacrev(self._dynamics, argnums=(0, 1))
        self._dynamics_jacobian = jax.jit(
            lambda x, u: np.hstack(jacobian(x, u)))

        # stack the gradients of environment loss
        loss_grad = jax.grad(self._loss, argnums=(0, 1))
        self._loss_grad = jax.jit(lambda x, u: np.hstack(loss_grad(x, u)))

        # block the hessian of environment loss
        block_hessian = lambda A: np.vstack(
            [np.hstack([A[0][0], A[0][1]]),
             np.hstack([A[1][0], A[1][1]])])
        hessian = jax.hessian(self._loss, argnums=(0, 1))
        self._loss_hessian = jax.jit(lambda x, u: block_hessian(hessian(x, u)))

        def _rollout(act, dyn, x_0, T):
            def f(x, i):
                u = act(x)
                x_next = dyn(x, u)
                return x_next, np.hstack((x, u))

            _, trajectory = jax.lax.scan(f, x_0, np.arange(T))
            return trajectory

        self._rollout = jax.jit(_rollout, static_argnums=(0, 1, 3))
        return np.dot(self.W_out, self.hid)
Пример #6
0
    def __init__(self, x, y, k=3, endpoints="not-a-knot", coefficients=None):
        """JAX implementation of kth-order spline interpolation.

        This class aims to reproduce scipy's InterpolatedUnivariateSpline
        functionality using JAX. Not all of the original class's features
        have been implemented yet, notably
        - `w`    : no weights are used in the spline fitting.
        - `bbox` : we assume the boundary to always be [x[0], x[-1]].
        - `ext`  : extrapolation is always active, i.e., `ext` = 0.
        - `k`    : orders `k` > 3 are not available.
        - `check_finite` : no such check is performed.

        (The relevant lines from the original docstring have been included
        in the following.)

        Fits a spline y = spl(x) of degree `k` to the provided `x`, `y` data.
        Spline function passes through all provided points. Equivalent to
        `UnivariateSpline` with s = 0.

        Parameters
        ----------
        x : (N,) array_like
            Input dimension of data points -- must be strictly increasing
        y : (N,) array_like
            input dimension of data points
        k : int, optional
            Degree of the smoothing spline.  Must be 1 <= `k` <= 3.
        endpoints : str, optional, one of {'natural', 'not-a-knot'}
            Endpoint condition for cubic splines, i.e., `k` = 3.
            'natural' endpoints enforce a vanishing second derivative
            of the spline at the two endpoints, while 'not-a-knot'
            ensures that the third derivatives are equal for the two
            left-most `x` of the domain, as well as for the two
            right-most `x`. The original scipy implementation uses
            'not-a-knot'.
        coefficients: list, optional
            Precomputed parameters for spline interpolation. Shouldn't be set
            manually.

        See Also
        --------
        UnivariateSpline : Superclass -- allows knots to be selected by a
            smoothing condition
        LSQUnivariateSpline : spline for which knots are user-selected
        splrep : An older, non object-oriented wrapping of FITPACK
        splev, sproot, splint, spalde
        BivariateSpline : A similar class for two-dimensional spline interpolation

        Notes
        -----
        The number of data points must be larger than the spline degree `k`.

        The general form of the spline can be written as
          f[i](x) = a[i] + b[i](x - x[i]) + c[i](x - x[i])^2 + d[i](x - x[i])^3,
          i = 0, ..., n-1,
        where d = 0 for `k` = 2, and c = d = 0 for `k` = 1.

        The unknown coefficients (a, b, c, d) define a symmetric, diagonal
        linear system of equations, Az = s, where z = b for `k` = 1 and `k` = 2,
        and z = c for `k` = 3. In each case, the coefficients defining each
        spline piece can be expressed in terms of only z[i], z[i+1],
        y[i], and y[i+1]. The coefficients are solved for using
        `np.linalg.solve` when `k` = 2 and `k` = 3.

        """
        # Verify inputs
        k = int(k)
        assert k in (1, 2, 3), "Order k must be in {1, 2, 3}."
        x = np.atleast_1d(x)
        y = np.atleast_1d(y)
        assert len(x) == len(y), "Input arrays must be the same length."
        assert x.ndim == 1 and y.ndim == 1, "Input arrays must be 1D."
        n_data = len(x)

        # Difference vectors
        h = np.diff(x)  # x[i+1] - x[i] for i=0,...,n-1
        p = np.diff(y)  # y[i+1] - y[i]

        if coefficients is None:
            # Build the linear system of equations depending on k
            # (No matrix necessary for k=1)
            if k == 1:
                assert n_data > 1, "Not enough input points for linear spline."
                coefficients = p / h

            if k == 2:
                assert n_data > 2, "Not enough input points for quadratic spline."
                assert endpoints == "not-a-knot"  # I have only validated this
                # And actually I think it's probably the best choice of border condition

                # The knots are actually in between data points
                knots = (x[1:] + x[:-1]) / 2.0
                # We add 2 artificial knots before and after
                knots = np.concatenate([
                    np.array([x[0] - (x[1] - x[0]) / 2.0]),
                    knots,
                    np.array([x[-1] + (x[-1] - x[-2]) / 2.0]),
                ])
                n = len(knots)
                # Compute interval lenghts for these new knots
                h = np.diff(knots)
                # postition of data point inside the interval
                dt = x - knots[:-1]

                # Now we build the system natrix
                A = np.diag(
                    np.concatenate([
                        np.ones(1),
                        (2 * dt[1:] - dt[1:]**2 / h[1:] - dt[:-1]**2 / h[:-1] +
                         h[:-1]),
                        np.ones(1),
                    ]))

                A += np.diag(
                    np.concatenate(
                        [-np.array([1 + h[0] / h[1]]), dt[1:]**2 / h[1:]]),
                    k=1,
                )
                A += np.diag(np.concatenate(
                    [np.atleast_1d(h[0] / h[1]),
                     np.zeros(n - 3)]),
                             k=2)

                A += np.diag(
                    np.concatenate([
                        h[:-1] - 2 * dt[:-1] + dt[:-1]**2 / h[:-1],
                        -np.array([1 + h[-1] / h[-2]]),
                    ]),
                    k=-1,
                )
                A += np.diag(
                    np.concatenate(
                        [np.zeros(n - 3),
                         np.atleast_1d(h[-1] / h[-2])]),
                    k=-2,
                )

                # And now we build the RHS vector
                s = np.concatenate([np.zeros(1), 2 * p, np.zeros(1)])

                # Compute spline coefficients by solving the system
                coefficients = np.linalg.solve(A, s)

            if k == 3:
                assert n_data > 3, "Not enough input points for cubic spline."
                if endpoints not in ("natural", "not-a-knot"):
                    print("Warning : endpoints not recognized. Using natural.")
                    endpoints = "natural"

                # Special values for the first and last equations
                zero = array([0.0])
                one = array([1.0])
                A00 = one if endpoints == "natural" else array([h[1]])
                A01 = zero if endpoints == "natural" else array(
                    [-(h[0] + h[1])])
                A02 = zero if endpoints == "natural" else array([h[0]])
                ANN = one if endpoints == "natural" else array([h[-2]])
                AN1 = (-one if endpoints == "natural" else array(
                    [-(h[-2] + h[-1])]))  # A[N, N-1]
                AN2 = zero if endpoints == "natural" else array(
                    [h[-1]])  # A[N, N-2]

                # Construct the tri-diagonal matrix A
                A = np.diag(concatenate((A00, 2 * (h[:-1] + h[1:]), ANN)))
                upper_diag1 = np.diag(concatenate((A01, h[1:])), k=1)
                upper_diag2 = np.diag(concatenate((A02, zeros(n_data - 3))),
                                      k=2)
                lower_diag1 = np.diag(concatenate((h[:-1], AN1)), k=-1)
                lower_diag2 = np.diag(concatenate((zeros(n_data - 3), AN2)),
                                      k=-2)
                A += upper_diag1 + upper_diag2 + lower_diag1 + lower_diag2

                # Construct RHS vector s
                center = 3 * (p[1:] / h[1:] - p[:-1] / h[:-1])
                s = concatenate((zero, center, zero))
                # Compute spline coefficients by solving the system
                coefficients = np.linalg.solve(A, s)

        # Saving spline parameters for evaluation later
        self.k = k
        self._x = x
        self._y = y
        self._coefficients = coefficients
Пример #7
0
def test_get_Q():
    Q = get_polynomial_form()
    import pylab as plt
    from jax import jit

    @jit
    def tomo_weight_ref(gamma, x1, x2, p1, p2):
        return tomographic_weight_function(gamma, x1, x2, p1, p2, S=150)

    @jit
    def cumulative_tomo_weight_function_dimensionless(gamma, x1, x2, p1, p2):
        x12 = x1 - x2
        h = jnp.linalg.norm(x12)
        n = x12 / h
        w1 = p1 / h
        w2 = p2 / h
        gamma_prime = gamma / h**2
        return cumulative_tomographic_weight_dimensionless_function(
            gamma_prime, n, w1, w2, S=150)

    @jit
    def cumulative_tomo_weight_polynomial_dimensionless(gamma, x1, x2, p1, p2):
        x12 = x1 - x2
        h = jnp.linalg.norm(x12)
        n = x12 / h
        w1 = p1 / h
        w2 = p2 / h
        gamma_prime = gamma / h**2
        return vmap(lambda gamma_prime:
                    cumulative_tomographic_weight_dimensionless_polynomial(
                        Q, gamma_prime, n, w1, w2))(gamma_prime)
        # return jnp.exp(log_tomographic_weight_dimensionless_function(gamma_prime, n, w1, w2, S=150)) / h ** 2

    for i in range(10):
        keys = random.split(random.PRNGKey(i), 6)
        x1 = jnp.concatenate(
            [10. * random.uniform(keys[0], shape=(2, )),
             jnp.zeros((1, ))],
            axis=-1)
        p1 = jnp.concatenate([
            4. * jnp.pi / 180. *
            random.uniform(keys[1], shape=(2, ), minval=-1, maxval=1),
            jnp.ones((1, ))
        ],
                             axis=-1)
        p1 = 4 * p1 / jnp.linalg.norm(p1, axis=-1, keepdims=True)

        x2 = jnp.concatenate(
            [4. * random.uniform(keys[2], shape=(2, )),
             jnp.zeros((1, ))],
            axis=-1)
        p2 = jnp.concatenate([
            4. * jnp.pi / 180. *
            random.uniform(keys[3], shape=(2, ), minval=-1, maxval=1),
            jnp.ones((1, ))
        ],
                             axis=-1)
        p2 = 4 * p2 / jnp.linalg.norm(p2, axis=-1, keepdims=True)

        t1 = random.uniform(keys[4], shape=(10000, ))
        t2 = random.uniform(keys[5], shape=(10000, ))
        u1 = x1 + t1[:, None] * p1
        u2 = x2 + t2[:, None] * p2
        gamma = jnp.linalg.norm(u1 - u2, axis=1)**2
        plt.hist(gamma.flatten(), bins=100, density=True, label='histogram')
        hist, bins = jnp.histogram(gamma.flatten(), density=True, bins=100)
        gamma = 0.5 * (bins[:-1] + bins[1:])
        w_ref = tomo_weight_ref(bins, x1, x2, p1, p2)
        plt.plot(gamma, w_ref, label='analytic ref')
        plt.legend()
        plt.show()
        cdf_ref = cumulative_tomo_weight_function_dimensionless(
            gamma, x1, x2, p1, p2)
        cdf_poly = cumulative_tomo_weight_polynomial_dimensionless(
            gamma, x1, x2, p1, p2)
        gamma_prime = gamma / jnp.linalg.norm(x1 - x2)
        plt.plot(gamma_prime, cdf_ref, label='ref')
        plt.plot(gamma_prime, cdf_poly, label='poly')
        plt.legend()
        plt.show()
Пример #8
0
 def significance_map(self):
     return np.zeros(1, dtype=np.int32)
Пример #9
0
 def significance_map(self):
     return np.zeros(self.representation_length, dtype=np.int32)
Пример #10
0
    def dot_product_attention(self,
                              query,
                              key,
                              value,
                              dtype=jnp.float32,
                              bias=None,
                              axis=None,
                              broadcast_dropout=True,
                              dropout_rng=None,
                              dropout_rate=0.,
                              deterministic=False,
                              precision=None):

        assert key.shape[:-1] == value.shape[:-1]
        assert (query.shape[0:1] == key.shape[0:1]
                and query.shape[-1] == key.shape[-1])
        if axis is None:
            axis = tuple(range(1, key.ndim - 2))
        if not isinstance(axis, Iterable):
            axis = (axis, )
        assert key.ndim == query.ndim
        assert key.ndim == value.ndim
        for ax in axis:
            if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
                raise ValueError('Attention axis must be between the batch '
                                 'axis and the last-two axes.')
        n = key.ndim

        # Constructing projection tensor.
        if self.redraw_features:
            # TODO(kchoro): Get rid of the constant below.
            query_seed = lax.convert_element_type(
                jnp.ceil(jnp.sum(query) * 10000000.0), jnp.int32)
            rng = random.PRNGKey(query_seed)
            self.projection_matrix = self.draw_weights(rng)

        # batch_dims is  <bs, <non-attention dims>, num_heads>
        batch_dims = tuple(onp.delete(range(n), axis + (n - 1, )))
        # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
        qk_perm = batch_dims + axis + (n - 1, )
        k_extra_perm = axis + batch_dims + (n - 1, )
        key_extra = key.transpose(k_extra_perm)
        key = key.transpose(qk_perm)
        query = query.transpose(qk_perm)
        # v -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
        v_perm = batch_dims + axis + (n - 1, )
        value = value.transpose(v_perm)
        batch_dims_t = tuple(range(len(batch_dims)))
        attention_dims_t = tuple(
            range(len(batch_dims),
                  len(batch_dims) + len(axis)))

        # Constructing tensors Q^{'} and K^{'}.
        query_prime = self.kernel_feature_creator(query,
                                                  self.projection_matrix,
                                                  attention_dims_t,
                                                  batch_dims_t, precision,
                                                  True)
        key_prime = self.kernel_feature_creator(key, self.projection_matrix,
                                                attention_dims_t, batch_dims_t,
                                                precision, False)

        if self.unidirectional:
            index = attention_dims_t[0]
            z_slice_shape = key_prime.shape[0:len(batch_dims_t)] + (
                key_prime.shape[-1], ) + (value.shape[-1], )

            numerator_fn = _numerator(z_slice_shape, precision,
                                      self.lax_scan_unroll)
            W = numerator_fn(jnp.moveaxis(query_prime, index, 0),
                             jnp.moveaxis(key_prime, index, 0),
                             jnp.moveaxis(value, index, 0))

            # Constructing W = (Q^{'}(K^{'})^{T})_{masked}V
            W = jnp.moveaxis(W, 0, index)

            if not self.renormalize_attention:
                # Unidirectional, not-normalized attention.
                perm_inv = _invert_perm(qk_perm)
                result = W.transpose(perm_inv)
                return result
            else:
                # Unidirectional, normalized attention.
                thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(
                    key_extra.shape[0:len(axis)])

                index = attention_dims_t[0]
                t_slice_shape = key_prime.shape[0:len(batch_dims_t)] + (
                    key_prime.shape[-1], )
                denominator_fn = _denominator(t_slice_shape, precision,
                                              self.lax_scan_unroll)
                R = denominator_fn(jnp.moveaxis(query_prime, index, 0),
                                   jnp.moveaxis(key_prime, index, 0))

                R = jnp.moveaxis(R, 0, index)
        else:
            contract_query = tuple(
                range(
                    len(batch_dims) + len(axis),
                    len(batch_dims) + len(axis) + 1))
            contract_z = tuple(range(len(batch_dims), len(batch_dims) + 1))
            # Constructing Z = (K^{'})^{T}V
            # Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
            Z = lax.dot_general(key_prime,
                                value, ((attention_dims_t, attention_dims_t),
                                        (batch_dims_t, batch_dims_t)),
                                precision=precision)
            # Constructing W = Q^{'}Z = Q^{'}(K^{'})^{T}V
            # q (bs, <non-attention dims>, num_heads, <attention dims>, channels_m)
            # Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
            # W (bs,  <non-attention dims>, num_heads, <attention dims>, channels_v)
            W = lax.dot_general(query_prime,
                                Z, ((contract_query, contract_z),
                                    (batch_dims_t, batch_dims_t)),
                                precision=precision)
            if not self.renormalize_attention:
                # Bidirectional, not-normalized attention.
                perm_inv = _invert_perm(qk_perm)
                result = W.transpose(perm_inv)
                return result
            else:
                # Bidirectional, normalized attention.
                thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(
                    key_extra.shape[0:len(axis)])
                contract_key = tuple(
                    range(len(batch_dims),
                          len(batch_dims) + len(axis)))
                contract_thick_all_ones = tuple(
                    range(thick_all_ones.ndim - len(axis),
                          thick_all_ones.ndim))
                # Construct T = (K^{'})^{T} 1_L
                # k (bs, <non-attention dims>, num_heads, <attention dims>, channels)
                T = lax.dot_general(key_prime,
                                    thick_all_ones,
                                    ((contract_key, contract_thick_all_ones),
                                     (batch_dims_t, batch_dims_t)),
                                    precision=precision)

                # Construct partition function: R = Q^{'} T = Q^{'}(K^{'})^{T} 1_L
                # q_p (bs, <non-attention dims>, num_heads, <attention dims>, channs_m)
                # T   (bs, <non-attention dims>, num_heads, channels_m)
                R = lax.dot_general(query_prime,
                                    T,
                                    (((query_prime.ndim - 1, ),
                                      (T.ndim - 1, )),
                                     (batch_dims_t, range(0,
                                                          len(T.shape) - 1))),
                                    precision=precision)

        R = R + 2 * self.numerical_stabilizer * (jnp.abs(R) <=
                                                 self.numerical_stabilizer)
        R = jnp.reciprocal(R)
        R = jnp.expand_dims(R, len(R.shape))
        # W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
        # R (bs, <non-attention dims>, num_heads, <attention dims>, extra_channel)
        result = W * R
        # back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
        perm_inv = _invert_perm(qk_perm)
        result = result.transpose(perm_inv)
        return result
Пример #11
0
def fori_collect(lower,
                 upper,
                 body_fun,
                 init_val,
                 transform=identity,
                 progbar=True,
                 **progbar_opts):
    """
    This looping construct works like :func:`~jax.lax.fori_loop` but with the additional
    effect of collecting values from the loop body. In addition, this allows for
    post-processing of these samples via `transform`, and progress bar updates.
    Note that, `progbar=False` will be faster, especially when collecting a
    lot of samples. Refer to example usage in :func:`~numpyro.mcmc.hmc`.

    :param int lower: the index to start the collective work. In other words,
        we will skip collecting the first `lower` values.
    :param int upper: number of times to run the loop body.
    :param body_fun: a callable that takes a collection of
        `np.ndarray` and returns a collection with the same shape and
        `dtype`.
    :param init_val: initial value to pass as argument to `body_fun`. Can
        be any Python collection type containing `np.ndarray` objects.
    :param transform: a callable to post-process the values returned by `body_fn`.
    :param progbar: whether to post progress bar updates.
    :param `**progbar_opts`: optional additional progress bar arguments. A
        `diagnostics_fn` can be supplied which when passed the current value
        from `body_fun` returns a string that is used to update the progress
        bar postfix. Also a `progbar_desc` keyword argument can be supplied
        which is used to label the progress bar.
    :return: collection with the same type as `init_val` with values
        collected along the leading axis of `np.ndarray` objects.
    """
    assert lower < upper
    init_val_flat, unravel_fn = ravel_pytree(transform(init_val))
    ravel_fn = lambda x: ravel_pytree(transform(x))[0]  # noqa: E731

    if not progbar:
        collection = np.zeros((upper - lower, ) + init_val_flat.shape)

        def _body_fn(i, vals):
            val, collection = vals
            val = body_fun(val)
            i = np.where(i >= lower, i - lower, 0)
            collection = ops.index_update(collection, i, ravel_fn(val))
            return val, collection

        _, collection = jit(fori_loop,
                            static_argnums=(2, ))(0, upper, _body_fn,
                                                  (init_val, collection))
    else:
        diagnostics_fn = progbar_opts.pop('diagnostics_fn', None)
        progbar_desc = progbar_opts.pop('progbar_desc', '')
        collection = []

        val = init_val
        with tqdm.trange(upper, desc=progbar_desc) as t:
            for i in t:
                val = body_fun(val)
                if i >= lower:
                    collection.append(jit(ravel_fn)(val))
                if diagnostics_fn:
                    t.set_postfix_str(diagnostics_fn(val), refresh=False)

        # XXX: jax.numpy.stack/concatenate is currently slow
        collection = onp.stack(collection)

    return vmap(unravel_fn)(collection)
Пример #12
0
 def onerow(i):
     n = m.shape[0]
     row = jnp.zeros(n + _W - 1)
     row = lax.dynamic_update_slice(row, m[i], (i, ))[_W // 2:-(_W // 2)]
     return row
Пример #13
0
def coo2sparse(row: Array, col: Array, data: Array, n: i64) -> Array:

    disp = jnp.clip(col - row + _W // 2, 0, _W - 1)
    sparse = jnp.zeros((n, _W)).at[row, disp].set(data)
    return sparse
Пример #14
0
 def f():
     return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))
Пример #15
0
    def train(self,
              bs,
              solutions=[None],
              retrain=False,
              tensorboard_writer=None,
              work_unit=None):

        if not retrain and not self.flaxd:
            opt_state = self.opt_init(self.net_params)
        if retrain:
            opt_state = self.opt_init(self.opt_params)
        loss = onp.zeros(self.training_iter // 10 + 1)
        gradients = onp.zeros(self.training_iter // 10 + 1)
        if not self.flaxd:
            param = self.get_params(opt_state)
        else:
            param = self.optimizer.target
            opt_state = self.optimizer
        og_loss = self.test_loss(
            self.preconditioner, self.n_test, self.mesh, param,
            np.zeros((bs.shape[1], self.n * self.n)), bs[0].reshape(
                bs.shape[1], self.n * self.n), 0, self.k) / 10000000
        print(og_loss)
        if work_unit is not None:
            work_unit.get_measurement_series(
                label='train/loss').create_measurement(objective_value=og_loss,
                                                       step=0)
        for i in range(self.training_iter):
            m = bs.shape[0]
            order = random.shuffle(random.PRNGKey(i), np.arange(m))
            for _ in range(50):
                for b in bs[order]:
                    current_loss, grad, opt_state = self.step(
                        i, opt_state, np.zeros((b.shape[0], self.n * self.n)),
                        b, solutions[min(m,
                                         len(solutions) - 1)])

            if i % 10 == 0:
                if not self.flaxd:
                    param = self.get_params(opt_state)
                else:
                    param = opt_state.target
                current_loss_test = self.test_loss(
                    self.preconditioner, self.n_test, self.mesh, param,
                    np.zeros((b.shape[0], self.n * self.n)), b, 0,
                    self.k) / 10000000
                current_loss = current_loss / 10000000
                avg_grad = onp.mean(onp.abs(onp_utils.flatten(grad)[-1]))
                print(
                    f'step{i: 5d}: loss { current_loss :1.5f} : avg_gradient \
              { avg_grad :1.5f} : current_loss_test { current_loss_test :1.5f}'
                )
                logging.info(
                    f'step{i: 5d}: loss { current_loss :1.5f} : avg_gradient \
              { avg_grad :1.5f} : current_loss_test { current_loss_test :1.5f}'
                )
                loss[i // 10] = current_loss
                gradients[i // 10] = avg_grad
                if work_unit is not None:
                    work_unit.get_measurement_series(
                        label='train/loss').create_measurement(
                            objective_value=current_loss_test, step=i)
                    tensorboard_writer.scalar('train/loss',
                                              current_loss_test,
                                              step=i + 1)
                    work_unit.get_measurement_series(
                        label='train/loss ' +
                        str(self.iter_gmres(i))).create_measurement(
                            objective_value=current_loss, step=i + 1)
                    tensorboard_writer.scalar('train/loss ' +
                                              str(self.iter_gmres(i)),
                                              current_loss,
                                              step=i + 1)
            if i % 50 == 0:
                if self.flaxd:
                    self.opt_params = opt_state.target.params
                else:
                    self.opt_params = self.get_params(opt_state)
                self.save(str(i))
        if self.flaxd:
            self.optimizer = opt_state
        else:
            self.opt_params = self.get_params(opt_state)
            self.opt_state = opt_state
        if self.model_dir is None:
            self.model_dir = ''

        with open(os.path.join(self.model_dir, 'train_loss.np'), 'wb') as f:
            onp.save(f, loss)
        with open(os.path.join(self.model_dir, 'train_gradients.np'),
                  'wb') as f:
            onp.save(f, gradients)
        self.save()
        if work_unit is not None:
            tensorboard_writer.close()
Пример #16
0
# Here submodules are explicitly defined during init, but still materialized
# lazily only once a first input is passed through and shapes are known.
class MLP(Module):
    def setup(self):
        self.dense1 = Dense(features=2)
        self.dense2 = Dense(features=1)

        # shapes aren't yet known, so variables aren't materialized
        print(self.dense2.variables)
        # FrozenDict({})

    def __call__(self, x):
        return self.dense2(nn.relu(self.dense1(x)))


# Return an initialized instance of MLP by calling `__call__` with an input batch,
# initializing all variables.
#
# Variable shapes depend on the input shape passed in.
rngkey = jax.random.PRNGKey(10)
mlp_variables = MLP().init(rngkey, jnp.zeros((1, 3)))

pprint(mlp_variables)
# {'param': {'dense1': {'bias': DeviceArray([0., 0.], dtype=float32),
#                       'kernel': DeviceArray([[ 0.18307537, -0.38739476],
#              [-0.902451  , -0.5190721 ],
#              [ 0.51552075,  1.1169153 ]], dtype=float32)},
#            'dense2': {'bias': DeviceArray([0.], dtype=float32),
#                       'kernel': DeviceArray([[ 0.6704609 ],
#              [-0.90477365]], dtype=float32)}}}
 def total_potential(xt):
     sum_potential = np.zeros(())
     for i in range(n - 1):
         sum_potential = sum_potential + G * vp(xt[i], xt[i + 1:]).sum()
     print(sum_potential)
     return sum_potential
Пример #18
0
def build_tree(verlet_update,
               kinetic_fn,
               verlet_state,
               inverse_mass_matrix,
               step_size,
               rng,
               max_delta_energy=1000.,
               max_tree_depth=10):
    """
    Builds a binary tree from the `verlet_state`. This is used in NUTS sampler.

    **References:**

    1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,
       Matthew D. Hoffman, Andrew Gelman
    2. *A Conceptual Introduction to Hamiltonian Monte Carlo*,
       Michael Betancourt

    :param verlet_update: A callable to get a new integrator state given a current
        integrator state.
    :param kinetic_fn: A callable to compute kinetic energy.
    :param verlet_state: Initial integrator state.
    :param inverse_mass_matrix: Inverse of the mass matrix.
    :param float step_size: Step size for the current trajectory.
    :param jax.random.PRNGKey rng: random key to be used as the source of
        randomness.
    :param float max_delta_energy: A threshold to decide if the new state diverges
        (based on the energy difference) too much from the initial integrator state.
    :return: information of the tree.
    :rtype: :data:`TreeInfo`
    """
    z, r, potential_energy, z_grad = verlet_state
    energy_current = potential_energy + kinetic_fn(inverse_mass_matrix, r)
    r_ckpts = np.zeros((max_tree_depth, inverse_mass_matrix.shape[-1]))
    r_sum_ckpts = np.zeros((max_tree_depth, inverse_mass_matrix.shape[-1]))

    tree = TreeInfo(z,
                    r,
                    z_grad,
                    z,
                    r,
                    z_grad,
                    z,
                    potential_energy,
                    z_grad,
                    depth=0,
                    weight=0.,
                    r_sum=r,
                    turning=False,
                    diverging=False,
                    sum_accept_probs=0.,
                    num_proposals=0)

    def _cond_fn(state):
        tree, _ = state
        return (tree.depth < max_tree_depth) & ~tree.turning & ~tree.diverging

    def _body_fn(state):
        tree, key = state
        key, direction_key, doubling_key = random.split(key, 3)
        going_right = random.bernoulli(direction_key)
        tree = _double_tree(tree, verlet_update, kinetic_fn,
                            inverse_mass_matrix, step_size, going_right,
                            doubling_key, energy_current, max_delta_energy,
                            r_ckpts, r_sum_ckpts)
        return tree, key

    state = (tree, rng)
    tree, _ = while_loop(_cond_fn, _body_fn, state)
    return tree
Пример #19
0
def test_tomographic_weight_rel_err():
    import pylab as plt
    from jax import jit

    for S in range(5, 30, 5):

        @jit
        def tomo_weight(gamma, x1, x2, p1, p2):
            return tomographic_weight_function(gamma, x1, x2, p1, p2, S=S)

        @jit
        def tomo_weight_ref(gamma, x1, x2, p1, p2):
            return tomographic_weight_function(gamma, x1, x2, p1, p2, S=150)

        rel_error = []
        for i in range(400):
            keys = random.split(random.PRNGKey(i), 6)
            x1 = jnp.concatenate(
                [4. * random.uniform(keys[0], shape=(2, )),
                 jnp.zeros((1, ))],
                axis=-1)
            p1 = jnp.concatenate([
                4. * jnp.pi / 180. *
                random.uniform(keys[1], shape=(2, ), minval=-1, maxval=1),
                jnp.ones((1, ))
            ],
                                 axis=-1)
            p1 = 4 * p1 / jnp.linalg.norm(p1, axis=-1, keepdims=True)

            x2 = jnp.concatenate(
                [4. * random.uniform(keys[2], shape=(2, )),
                 jnp.zeros((1, ))],
                axis=-1)
            p2 = jnp.concatenate([
                4. * jnp.pi / 180. *
                random.uniform(keys[3], shape=(2, ), minval=-1, maxval=1),
                jnp.ones((1, ))
            ],
                                 axis=-1)
            p2 = 4 * p2 / jnp.linalg.norm(p2, axis=-1, keepdims=True)

            # x1 = random.normal(keys[0], shape_dict=(2,))
            # p1 = random.normal(keys[1], shape_dict=(2,))
            # x2 = random.normal(keys[2], shape_dict=(2,))
            # p2 = random.normal(keys[3], shape_dict=(2,))

            t1 = random.uniform(keys[4], shape=(10000, ))
            t2 = random.uniform(keys[5], shape=(10000, ))
            u1 = x1 + t1[:, None] * p1
            u2 = x2 + t2[:, None] * p2
            gamma = jnp.linalg.norm(u1 - u2, axis=1)**2
            hist, bins = jnp.histogram(gamma.flatten(), density=True, bins=100)
            bins = jnp.linspace(bins.min(), bins.max(), 20)
            w = tomo_weight(bins, x1, x2, p1, p2)
            w_ref = tomo_weight_ref(bins, x1, x2, p1, p2)
            rel_error.append(jnp.max(jnp.abs(w - w_ref)) / jnp.max(w_ref))
        rel_error = jnp.array(rel_error)
        plt.hist(rel_error, bins='auto')
        plt.title("{} : {:.2f}|{:.2f}|{:.2f}".format(
            S, *jnp.percentile(rel_error, [5, 50, 95])))
        plt.show()
Пример #20
0
 def initial_state():
   return PopArtState(
       jnp.zeros([num_outputs]), jnp.ones([num_outputs]),
       jnp.ones([num_outputs]))
Пример #21
0
def test_tomographic_weight():
    import pylab as plt
    from jax import jit

    # @jit
    def tomo_weight(gamma, x1, x2, p1, p2):
        return tomographic_weight_function(gamma, x1, x2, p1, p2, S=10)

    @jit
    def tomo_weight_ref(gamma, x1, x2, p1, p2):
        return tomographic_weight_function(gamma, x1, x2, p1, p2, S=150)

    @jit
    def _tomo_weight_ref(gamma, x1, x2, p1, p2):
        return _tomographic_weight_function(gamma, x1, x2, p1, p2, S=150)

    @jit
    def tomo_weight_dimensionless_ref(gamma, x1, x2, p1, p2):
        x12 = x1 - x2
        h = jnp.linalg.norm(x12)
        n = x12 / h
        w1 = p1 / h
        w2 = p2 / h
        gamma_prime = gamma / h**2
        return jnp.exp(
            log_tomographic_weight_dimensionless_function(
                gamma_prime, n, w1, w2, S=150)) / h**2

    for i in range(100):
        keys = random.split(random.PRNGKey(i), 6)
        x1 = jnp.concatenate(
            [10. * random.uniform(keys[0], shape=(2, )),
             jnp.zeros((1, ))],
            axis=-1)
        p1 = jnp.concatenate([
            4. * jnp.pi / 180. *
            random.uniform(keys[1], shape=(2, ), minval=-1, maxval=1),
            jnp.ones((1, ))
        ],
                             axis=-1)
        p1 = 4 * p1 / jnp.linalg.norm(p1, axis=-1, keepdims=True)

        x2 = jnp.concatenate(
            [4. * random.uniform(keys[2], shape=(2, )),
             jnp.zeros((1, ))],
            axis=-1)
        p2 = jnp.concatenate([
            4. * jnp.pi / 180. *
            random.uniform(keys[3], shape=(2, ), minval=-1, maxval=1),
            jnp.ones((1, ))
        ],
                             axis=-1)
        p2 = 4 * p2 / jnp.linalg.norm(p2, axis=-1, keepdims=True)

        t1 = random.uniform(keys[4], shape=(10000, ))
        t2 = random.uniform(keys[5], shape=(10000, ))
        u1 = x1 + t1[:, None] * p1
        u2 = x2 + t2[:, None] * p2
        gamma = jnp.linalg.norm(u1 - u2, axis=1)**2
        plt.hist(gamma.flatten(), bins=100, density=True, label='histogram')
        hist, bins = jnp.histogram(gamma.flatten(), density=True, bins=100)
        bins = jnp.linspace(bins.min(), bins.max(), 50)
        gamma = 0.5 * (bins[:-1] + bins[1:])
        w = tomo_weight(bins, x1, x2, p1, p2)
        plt.plot(gamma, w, label='analytic')
        w_ref = tomo_weight_ref(bins, x1, x2, p1, p2)
        _w_ref = _tomo_weight_ref(gamma, x1, x2, p1, p2)
        # w_ref = tomo_weight_dimensionless_ref(bins, x1,x2,p1,p2)
        plt.plot(gamma, w_ref, label='analytic ref')
        plt.legend()
        plt.savefig(
            '/home/albert/git/jaxns/debug_figs/pdf_fig{:03d}.png'.format(i))
        plt.close('all')

        plt.plot(gamma, jnp.cumsum(w), label='analytic')
        # w_ref = tomo_weight_dimensionless_ref(bins, x1,x2,p1,p2)
        plt.plot(gamma, jnp.cumsum(w_ref), label='analytic ref')
        plt.legend()
        plt.savefig(
            '/home/albert/git/jaxns/debug_figs/cdf_fig{:03d}.png'.format(i))
        plt.close('all')
        self.reduced_kwargs["prefetch"] = tf.data.AUTOTUNE
        self.kwargs["cache"] = True
        self.reduced_kwargs["cache"] = True


test = aggregatedGradientTests(imnn=AggregatedGradientIMNN,
                               filename="aggregated_gradient")


@pytest.mark.parametrize("kwargs", [test.kwargs, test.reduced_kwargs])
@pytest.mark.parametrize("state", [True, False])
@pytest.mark.parametrize("validate", [True, False])
@pytest.mark.parametrize("input_variable", [
    None,
    list(), 1., 1,
    np.zeros((1, )), test.rng,
    tuple(), (0, 0), (test.model[0], 0), test.bad_model, test.state
])
@pytest.mark.parametrize("variable", test.kwargs.keys())
def test_initialisation_parameters_(variable, kwargs, input_variable, state,
                                    validate):
    test.initialise_parameters(variable,
                               kwargs,
                               input_variable,
                               state=state,
                               validate=validate)


@pytest.mark.parametrize("validate", [True, False])
@pytest.mark.parametrize("state", [False, True])
@pytest.mark.parametrize("variable", ["n_s", "n_d", "same"])
Пример #23
0
 def f():
     data = jnp.zeros(input_shape)
     net = conv.ConvND(n, output_channels=3, kernel_shape=3, stride=3)
     return net(data)
Пример #24
0
 def init(self, x):
     vs = [np.zeros(sz, dtype=x.dtype) for sz in x.shape]
     return (np.zeros_like(x), vs)
Пример #25
0
 def _inverse(self, y):
     size = self.permutation.size
     permutation_inv = ops.index_update(
         jnp.zeros(size, dtype=canonicalize_dtype(jnp.int64)),
         self.permutation, jnp.arange(size))
     return y[..., permutation_inv]
Пример #26
0
 def model():
     with handlers.mask(mask=jnp.zeros(10, dtype=bool)):
         numpyro.factor("inf", -jnp.inf)
Пример #27
0
 def init(init_state) -> DivergencesState:
     """Initialize the divergence counters."""
     num_chains, _ = init_state.position.shape
     num_divergences = jnp.zeros(num_chains)
     return DivergencesState(num_divergences, 0)
Пример #28
0
 def guide(subsample):
     scale = numpyro.param("scale", 1.0)
     with handlers.substitute(data={"data": subsample}):
         with numpyro.plate("data", len(data), subsample_size):
             loc = numpyro.param("loc", jnp.zeros(len(data)), event_dim=0)
             numpyro.sample("z", dist.Normal(loc, scale))
Пример #29
0
def class_wise_nms(boxes: Tensor,
                   scores: Tensor,
                   classes: Tensor,
                   n_classes: int,
                   overlap_threshold: float = .5,
                   score_threshold: float = .5,
                   boxes_fmt: BoxesFormat = BoxesFormat.xyxy) -> Tensor:
    """
    Performs Non Maxima Supperssion for each unique class with the given boxes 
    Selects the boxes with higher score and discards the ones pointing to the
    same object.

    Parameters
    ----------
    boxes: Tensor of shape [N, 4]
        Boxes formated according to the input parameter `boxes_fmt`
    scores: Tensor of shape [N]
        Boxes scores ranging from 0 to 1 being 1 a higher value
    classes: Tensor of shape [N]
        Boxes classes
    n_classes: int
        Number of total classes
    boxes_fmt: BoxesFormat, default xyxy
        Format of the boxes, by default it is set to 
        [x_min, y_min, x_max, y_max]
    overlap_threshold: float, default .5
        Overlapping boxes pointing to the same object with an iou larger than
        this threshold are going to be discarded. NMS will only keep the one
        with the highest score
    score_threshold: float, default 0.5
        Boxes with a lower score than score_threshold will be discarded
    """

    if boxes_fmt != BoxesFormat.xyxy:
        convert_fn = getattr(boxes_utils, f'{boxes_fmt.value}_to_xyxy')
        boxes = convert_fn(boxes)

    masks = np.zeros(boxes.shape[0], dtype='bool')
    classes = classes.reshape(-1).astype('int32')

    # Per class NMS
    # TODO: Should labels always start at 1?
    for c in np.arange(1, n_classes + 1):
        if c == -1:
            continue

        mask = (classes == c).reshape(-1)
        if np.sum(mask) == 0:
            continue

        current_scores = np.where(mask, scores, 0.)
        current_boxes = np.where(np.expand_dims(mask, -1), boxes, 0.)

        boxes_mask = nms(boxes=current_boxes,
                         scores=current_scores,
                         overlap_threshold=overlap_threshold,
                         score_threshold=score_threshold)

        masks = masks | boxes_mask

    return masks
Пример #30
0
 def rnn_reference(W, xs, target):
     h = np.zeros(n)
     for x in xs:
         h = np.tanh(np.dot(W, h) + np.dot(W, x))
     predicted = h
     return np.sum((predicted - target)**2)