Beispiel #1
0
    def svd_fun(tensor):
        # Permute the indices of tensor into something closer to the output
        tensor = jnp.einsum(f"{init_str}->{left_free+right_free}", tensor)

        # Flatten both sides of our tensor to give a single matrix
        left_shape = tensor.shape[:len(left_free)]
        right_shape = tensor.shape[len(left_free):]
        left_size = jnp.prod(left_shape)
        right_size = jnp.prod(right_shape)
        matrix = tensor.reshape((left_size, right_size))

        # Get SVD and format so that left_mat@diag(svs)@right_mat = matrix
        left_mat, sv_vec, right_mat = stable_svd(matrix)

        # Fold singular values into left/right matrices
        left_mat, right_mat = apply_sv(left_mat, right_mat, sv_vec)

        # Reshape the matrices to make them proper tensors
        left_tensor = left_mat.reshape(left_shape + sv_vec.shape)
        right_tensor = right_mat.reshape(sv_vec.shape + right_shape)

        # Move the new bond indices into the correct order
        left_tensor = jnp.einsum(f"{left_free+bond_char}->{left_str}",
                                 left_tensor)
        right_tensor = jnp.einsum(f"{bond_char+right_free}->{right_str}",
                                  right_tensor)

        return out_fun(left_tensor, right_tensor, sv_vec)
Beispiel #2
0
def mnist(flatten: bool = False,
          one_hot_encoding: bool = False,
          data_dir: str = os.path.join("..", "datasets", "mnist")):
    path: Path = Path(data_dir)
    downloaded_data, info = tfds.load(name="mnist",
                                      batch_size=-1,
                                      data_dir=path,
                                      with_info=True)
    mnist_data: Dict[str, Dict[str, np.array]] = tfds.as_numpy(downloaded_data)
    train_data, valid_data = mnist_data.get("train"), mnist_data.get("test")
    input_shape: Tuple[int, ...] = info.features["image"].shape
    train_images, train_labels = tensor.asarray(
        train_data.get("image"), dtype=tensor.float32), tensor.asarray(
            train_data.get("label"), dtype=tensor.float32).reshape(-1, 1)
    valid_images, valid_labels = tensor.asarray(
        valid_data.get("image"), dtype=tensor.float32), tensor.asarray(
            valid_data.get("label"), dtype=tensor.float32).reshape(-1, 1)
    if flatten:
        train_images = train_images.reshape(-1, tensor.prod(list(input_shape)))
        valid_images = valid_images.reshape(-1, tensor.prod(list(input_shape)))
    if one_hot_encoding:
        train_labels = tensor.asarray(pd.get_dummies(train_labels),
                                      dtype=tensor.float32)
        valid_labels = tensor.asarray(pd.get_dummies(valid_labels),
                                      dtype=tensor.float32)
    return (train_images, train_labels), (valid_images, valid_labels)
Beispiel #3
0
    def test_discrete_barycenter_grid(self, lse_mode, debiased, epsilon):
        """Tests the discrete barycenters on a 5x5x5 grid.

    Puts two masses on opposing ends of the hypercube with small noise in
    between. Check that their W barycenter sits (mostly) at the middle of the
    hypercube (e.g. index (5x5x5-1)/2)

    Args:
      lse_mode: bool, lse or scaling computations.
      debiased: bool, use (or not) debiasing as proposed in
      https://arxiv.org/abs/2006.02575
      epsilon: float, regularization parameter
    """
        size = jnp.array([5, 5, 5])
        grid_3d = grid.Grid(grid_size=size, epsilon=epsilon)
        a = jnp.ones(size)
        b = jnp.ones(size)
        a = a.ravel()
        b = b.ravel()
        a = jax.ops.index_update(a, 0, 10000)
        b = jax.ops.index_update(b, -1, 10000)
        a = a / jnp.sum(a)
        b = b / jnp.sum(b)
        threshold = 1e-2
        _, _, bar, errors = db.discrete_barycenter(grid_3d,
                                                   a=jnp.stack((a, b)),
                                                   threshold=threshold,
                                                   lse_mode=lse_mode,
                                                   debiased=debiased)
        self.assertGreater(bar[(jnp.prod(size) - 1) // 2], 0.7)
        self.assertGreater(1, bar[(jnp.prod(size) - 1) // 2])
        err = errors[jnp.isfinite(errors)][-1]
        self.assertGreater(threshold, err)
Beispiel #4
0
def test_nFS():
    from pOP import nFS as pnFS
    dim = 2
    nC = -1 * np.ones((dim, 1), dtype=np.int32)
    d = np.zeros(dim, dtype=np.int32)
    c = np.ones(dim)
    d2 = np.array([2, 3], dtype=np.int32)
    nC2Py = np.array([4, 7], dtype=np.int32)
    nC2 = np.block([[np.arange(4), -1. * np.ones(3)],
                    [np.arange(7)]]).astype(np.int32)
    n = np.array([10] * dim)
    N = np.prod(n)
    z = np.linspace(0, 2. * np.pi, num=n[0])
    x = onp.zeros((N, dim))
    for k in range(dim):
        nProd = np.prod(n[k + 1:])
        nStack = np.prod(n[0:k])
        dark = np.hstack([z] * nProd)
        x[:, k] = onp.array([dark] * nStack).flatten()
    c = (2. * np.pi) / (x[-1, :] - x[0, :])
    z = (x - x[0, :]) * c - np.pi

    nfs1 = nFS(x[0, :], x[-1, :], nC, 5)
    nfs2 = nFS(x[0, :], x[-1, :], nC2, 10)
    Fc1 = nfs1.H(x.T, d, False)
    Fc2 = nfs2.H(x.T, d2, False)

    Fp1 = pnFS(z, 4, d, nC.flatten() * 0.)
    Fp2 = pnFS(z, 9, d2, nC2Py)

    assert (np.linalg.norm(Fc1 - Fp1, ord='fro') < 1e-14)
    assert (np.linalg.norm(Fc2 - Fp2, ord='fro') < 5e-13)
def partial_trace(A, A_label):
    """ Partial trace on tensor A over repeated labels in A_label """

    num_cont = len(A_label) - len(np.unique(A_label))
    if num_cont > 0:
        dup_list = []
        for ele in np.unique(A_label):
            if sum(A_label == ele) > 1:
                dup_list.append([np.where(A_label == ele)[0]])

        cont_ind = np.array(dup_list).reshape(2*num_cont,order='F')
        free_ind = onp.delete(np.arange(len(A_label)),cont_ind)

        cont_dim = np.prod(np.array(A.shape)[cont_ind[:num_cont]])
        free_dim = np.array(A.shape)[free_ind]

        B_label = onp.delete(A_label, cont_ind)
        cont_label = np.unique(A_label[cont_ind])
        B = np.zeros(np.prod(free_dim))
        A = A.transpose(np.append(free_ind, cont_ind)).reshape(np.prod(free_dim),cont_dim,cont_dim)
        for ip in range(cont_dim):
            B = B + A[:,ip,ip]

        return B.reshape(free_dim), B_label, cont_label

    else:
        return A, A_label, []
def eval_polynomial(
    x: jnp.ndarray,
    coeff_a: float,
    coeff_b: float,
    mul_coeffs: jnp.ndarray,
    sub_coeffs: jnp.ndarray,
) -> jnp.ndarray:
    """Evaluate polynomial.

  Evaluate the polynomial corresponding to the rational equation
  (coeff_b * x - coeff_a) + sum_i mul_coeffs[i]/(x-sub_coeffs[i])
  at x.

  Args:
    x: (n,)
    coeff_a: Scalar
    coeff_b: Scalar
    mul_coeffs: (n,) numpy array of multiplicative coefficients
    sub_coeffs: (n,) numpy array of subtractive coefficients

  Returns:
    Values of polynomial at x (same shape as x).
  """
    result = 0.
    x = jnp.reshape(x, [-1, 1])
    for i in range(mul_coeffs.size):
        coeffs_not_i = (np.arange(mul_coeffs.size) != i)
        result += (mul_coeffs[i] * jnp.prod(
            x - jnp.reshape(sub_coeffs[coeffs_not_i], [1, -1]), axis=-1))
    result = jnp.reshape(result, [-1])
    result -= (jnp.reshape(coeff_b * x - coeff_a, [-1]) * jnp.reshape(
        jnp.prod(x - jnp.reshape(sub_coeffs, [1, -1]), axis=-1), [-1]))
    return jnp.reshape(result, [-1])
def upsample_posterior(x, b, log_diag_cov, repeats):
    """ Posterior of N(x|Az + b, Sigma) where A is an upsample matrix"""
    assert x.shape == b.shape
    assert x.shape == log_diag_cov.shape
    assert x.ndim == 3
    xmb = x - b
    one_over_diag_cov = jnp.exp(-log_diag_cov)

    # Compute the diagonal of the riemannian metric.  This is the diagonal of A^T Sigma^{-1} A
    hr, wr, cr = repeats; assert cr == 1 # Haven't tested cr != 1
    Hx, Wx, C = x.shape
    H, W = Hx//hr, Wx//wr
    rm_diag = one_over_diag_cov.reshape((H, hr, W, wr, C)).transpose((0, 2, 4, 1, 3)).reshape((H, W, C, hr*wr)).sum(axis=-1)

    # Compute the mean of z
    z_mean = upsample_pseudo_inverse(xmb*one_over_diag_cov, (2, 2, 1))/rm_diag*(hr*wr)
    x_proj = upsample(repeats, z_mean)*one_over_diag_cov
    dim_x = jnp.prod(x.shape)
    dim_z = jnp.prod(z_mean.shape)

    # Compute the manifold error term
    log_hx = -0.5*jnp.sum(xmb*(xmb*one_over_diag_cov - x_proj))
    log_hx -= 0.5*jnp.sum(jnp.log(rm_diag))
    log_hx -= 0.5*log_diag_cov.sum()
    log_hx -= 0.5*(dim_x - dim_z)*jnp.log(2*jnp.pi)

    # return z_mean, log_hx, rm_diag, x_proj
    return z_mean, log_hx, rm_diag
Beispiel #8
0
def gaussian_potential(x: jnp.ndarray,
                       mean: Union[float, jnp.ndarray] = 0.,
                       prec: Union[float, jnp.ndarray] = None,
                       sqrt_prec: Union[float, jnp.ndarray] = None,
                       det_prec: float = None) -> Union[float, jnp.ndarray]:
    # sqrt_prec such that prec = sqrt_prec @ sqrt_prec.T
    d = x.shape[-1]

    if prec is None and sqrt_prec is None:
        prec = 1.

    if isinstance(prec, float):
        prec = jnp.ones(d) * prec

    if isinstance(sqrt_prec, float):
        sqrt_prec = jnp.ones(d) * sqrt_prec

    if det_prec is None:
        if prec is not None and prec.ndim < 2:
            det_prec = jnp.prod(prec)
        elif sqrt_prec is not None and sqrt_prec.ndim < 2:
            det_prec = jnp.prod(sqrt_prec)**2

    if det_prec is None:
        # full precision matrix given but no det - computing without norm constant
        neg_log_z = 0
        warn(
            'gaussian_potential queried with non-diagonal prec (or sqrt-prec) but no det_prec given'
            ' -> executing without normalising constant term')
    else:
        neg_log_z = (d * jnp.log(2 * jnp.pi) - jnp.log(det_prec)) / 2

    if x.ndim == 1 and sqrt_prec is None:
        # Single vals value (not vectorised)
        if prec is None:
            out_val = _mv_gaussian_potential_diag(x, mean, 1.)
        elif prec.ndim < 2:
            out_val = _mv_gaussian_potential_diag(x, mean, prec)
        else:
            out_val = _mv_gaussian_potential(x, mean, prec)
    else:
        # Multiple vals values (vectorised)
        if prec is not None and sqrt_prec is None:
            if prec.ndim < 2:
                sqrt_prec = jnp.sqrt(prec)
            else:
                sqrt_prec = jnp.linalg.cholesky(prec)
                warn(
                    'vectorised gaussian_potential queried with prec rather than sqrt_prec'
                    '-> executing using Cholesky decomp')

        if sqrt_prec is None:
            out_val = _mv_gaussian_potential_diag(x, mean, 1.)
        elif sqrt_prec.ndim < 2:
            out_val = _mv_gaussian_potential_diag(x, mean, sqrt_prec**2)
        else:
            out_val = _vectorised_gaussian_potential(x, mean, sqrt_prec)
    return out_val + neg_log_z
Beispiel #9
0
def reduce(
    total: jnp.ndarray,
    count: tp.Optional[jnp.ndarray],
    values: jnp.ndarray,
    reduction: Reduction,
    sample_weight: tp.Optional[np.ndarray],
    dtype: jnp.dtype,
) -> tp.Tuple[jnp.ndarray, jnp.ndarray, tp.Optional[jnp.ndarray]]:

    if sample_weight is not None:
        sample_weight = sample_weight.astype(dtype)

        # Update dimensions of weights to match with values if possible.
        # values, _, sample_weight = tf_losses_utils.squeeze_or_expand_dimensions(
        #     values, sample_weight=sample_weight
        # )

        try:
            # Broadcast weights if possible.
            sample_weight = jnp.broadcast_to(sample_weight, values.shape)
        except ValueError:
            # Reduce values to same ndim as weight array
            ndim = values.ndim
            weight_ndim = sample_weight.ndim
            if reduction == Reduction.SUM:
                values = jnp.sum(values, axis=list(range(weight_ndim, ndim)))
            else:
                values = jnp.mean(values, axis=list(range(weight_ndim, ndim)))

        values = values * sample_weight

    value_sum = jnp.sum(values)

    total += value_sum

    # Exit early if the reduction doesn't have a denominator.
    if reduction == Reduction.SUM:
        num_values = None

    # Update `count` for reductions that require a denominator.
    elif reduction == Reduction.SUM_OVER_BATCH_SIZE:
        num_values = jnp.prod(values.shape).astype(dtype)

    else:
        if sample_weight is None:
            num_values = jnp.prod(jnp.array(values.shape)).astype(dtype)
        else:
            num_values = jnp.sum(sample_weight)

    if count is not None and num_values is not None:
        count += num_values

    if reduction == Reduction.SUM:
        value = total
    else:
        value = total / count

    return value, total, count
Beispiel #10
0
 def _hessianopt(x, f):
     _, hvp = jax.linearize(jax.grad(f), x)
     hvp = jax.jit(hvp)
     n = np.prod(x.shape)
     idxs = np.arange(vsize, n, vsize)
     basis = np.eye(np.prod(x.shape)).reshape(-1, *x.shape)
     splitbasis = np.split(basis, idxs)
     vhvp = jax.vmap(hvp)
     vhvp = jax.jit(vhvp)
     return np.concatenate([vhvp(b)
                            for b in splitbasis]).reshape(x.shape + x.shape)
Beispiel #11
0
 def prob_fn(sample: Array, mu: Array, sigma: Array, action_spec):
     # Support scalar and vector `sigma`. If vector, mu.shape==sigma.shape.
     mu = mu_activation(mu)
     sigma = sigma_activation(sigma)
     # Compute pdf for multivariate gaussian.
     d = mu.shape[-1]
     det = jnp.prod(sigma**2, axis=-1)
     z = ((2 * jnp.pi)**(0.5 * d)) * (det**0.5)
     exp = jnp.exp(-0.5 * jnp.sum(
         ((mu - inv_transform(sample, action_spec)) / sigma)**2, axis=-1))
     det_jacobian = jnp.prod(jnp.clip(1 - sample**2, 0., 1.) + 1e-6)
     return exp / (z * det_jacobian)
Beispiel #12
0
def mvn_kl(mu_0, sigma_0, mu_1, sigma_1):

    logdet_sigma_1 = jnp.prod(jnp.array(jnp.linalg.slogdet(sigma_1)))
    logdet_sigma_0 = jnp.prod(jnp.array(jnp.linalg.slogdet(sigma_0)))
    term_1 = 0.5 * (logdet_sigma_1 - logdet_sigma_0)

    # I wonder if there's a more efficient way?
    mu_outer = jnp.outer(mu_0 - mu_1, mu_0 - mu_1)
    inside_term = mu_outer + sigma_0 - sigma_1
    solved = jnp.linalg.solve(sigma_1, inside_term)
    term_2 = 0.5 * jnp.trace(solved)

    return term_1 + term_2
Beispiel #13
0
  def testPermutationArray(self, dtype, shape):
    key = random.PRNGKey(0)
    x = np.arange(np.prod(shape)).reshape(shape).astype(dtype)
    rand = lambda key: random.permutation(key, x)
    crand = api.jit(rand)

    perm1 = rand(key)
    perm2 = crand(key)

    self.assertAllClose(perm1, perm2, check_dtypes=True)
    self.assertFalse(onp.all(perm1 == x))  # seems unlikely!
    self.assertAllClose(onp.sort(perm1.ravel()), x.ravel(), check_dtypes=False)
    self.assertArraysAllClose(
      x, np.arange(np.prod(shape)).reshape(shape).astype(dtype),
      check_dtypes=True)
Beispiel #14
0
 def _hessianopt(x, f):
     _, hvp = jax.linearize(jax.grad(f), x)
     hvp = jax.jit(hvp)
     vhvp = jax.vmap(hvp)
     vhvp = jax.jit(vhvp)
     basis = np.eye(np.prod(x.shape)).reshape(-1, *x.shape)
     return vhvp(basis).reshape(x.shape + x.shape)
Beispiel #15
0
    def __call__(self, inputs, context=None, reverse=False):
        axes = tuple(i for i in range(len(inputs.shape) - 1))

        def dd_mean_initializer(key, shape):
            """Data-dependent init for mu"""
            nonlocal inputs
            x_mean = np.mean(inputs, axis=axes, keepdims=True)
            return -x_mean

        def dd_stddev_initializer(key, shape):
            """Data-dependent init for sigma"""
            nonlocal inputs
            x_var = np.mean(inputs**2, axis=axes, keepdims=True)
            var = self.scale / (np.sqrt(x_var) + self.eps)
            return var

        shape = (1, ) * len(axes) + (inputs.shape[-1], )
        mu = self.param("actnorm_mean", dd_mean_initializer, shape)
        sigma = self.param("actnorm_stddev", dd_stddev_initializer, shape)

        logsigma = np.log(np.abs(sigma))
        log_det_jacobian = np.prod(np.array(
            inputs.shape[1:-1])) * np.sum(logsigma)

        if reverse:
            outputs = inputs / (sigma + self.eps) - mu
            log_det_jacobian = -log_det_jacobian
        else:
            outputs = sigma * (inputs + mu)
            log_det_jacobian = log_det_jacobian

        return outputs, log_det_jacobian
Beispiel #16
0
  def testThreadsafeIndexing(self):
    # NOTE(skye): I picked these values to be big enough to cause interesting
    # execution overlap, but small enough to not use too much memory. YMMV.
    shape = (8, 8000, 1000)

    if jax.device_count() < shape[0]:
      raise SkipTest(f"requires {shape[0]} devices")

    x = np.arange(np.prod(shape)).reshape(shape)
    sharded_x = pmap(lambda x: x)(x)

    num_threads = 10
    futures = []
    expected = []
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
      for i in range(num_threads):
        idx = i % shape[0]
        # Mix together different kinds of indices
        if i % 2 == 0:
          idx = slice(idx, idx + 1)
        # Use the "kwarg trick" to work around late-binding closures. See
        # https://docs.python-guide.org/writing/gotchas/#late-binding-closures.
        futures.append(executor.submit(
            lambda idx=idx: [sharded_x[idx] for _ in range(10)][0]))
        expected.append(x[idx])
      actual = [f.result() for f in futures]
    self.assertAllClose(actual, expected, check_dtypes=False)
Beispiel #17
0
    def maxandargmax(x, axis=axis):
        if axis is None:
            axes = tuple(range(x.ndim))
        else:
            axes = tuple(int(ax) for ax in axis)

        max_res = jnp.max(x, axis)

        # NumPy does not support multiple axes for argmax; this is a
        # work-around
        keep_axes = jnp.array(
            [i for i in range(x.ndim) if i not in axes], dtype="int64"
        )
        # Not-reduced axes in front
        transposed_x = jnp.transpose(
            x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
        )
        kept_shape = transposed_x.shape[: len(keep_axes)]
        reduced_shape = transposed_x.shape[len(keep_axes) :]

        # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
        # Otherwise reshape would complain citing float arg
        new_shape = kept_shape + (
            jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
        )
        reshaped_x = transposed_x.reshape(new_shape)

        max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")

        return max_res, max_idx_res
def generate_sample_grid(theta_mean, theta_std, n):
    """
    Create a meshgrid of n ** n_dim samples,
    tiling [theta_mean[i] - 5 * theta_std[i], theta_mean[i] + 5 * theta_std]
    into n portions.
    Also returns the volume element.

    Parameters
    ----------
    theta_mean, theta_std : ndarray (n_dim)

    Returns
    -------
    theta_samples : ndarray (nobj, n_dim)

    vol_element: scalar
        Volume element

    """
    n_components = theta_mean.size
    xs = [
        np.linspace(
            theta_mean[i] - 5 * theta_std[i],
            theta_mean[i] + 5 * theta_std[i],
            n,
        )
        for i in range(n_components)
    ]
    mxs = np.meshgrid(*xs)
    orshape = mxs[0].shape
    mxsf = np.vstack([i.ravel() for i in mxs]).T
    dxs = np.vstack([np.diff(xs[i])[i] for i in range(n_components)])
    vol_element = np.prod(dxs)
    theta_samples = np.vstack(mxsf)
    return theta_samples, vol_element
 def init_state(self, a_shape, rng):
     # uses random as a hack to support vmap
     # we should find a non-hack approach to initializing the state
     dim_a = jnp.prod(a_shape)  # np.int32
     a_opt = 0.0 * jax.random.uniform(
         rng, shape=(self.n_steps, dim_a))  # [n_steps, dim_a]
     return a_opt
Beispiel #20
0
def tmrca_sf(t: np.ndarray, y: np.ndarray, n: int) -> np.ndarray:
    """The survival function of the TMRCA at each time point

    Args:
        t: time grid (including zero and infinity)
        y: effective population size in each epoch
        n: number of sampled haplotypes

    """
    # epoch durations
    s = np.diff(t)
    logu = -s / y
    logu = np.concatenate((np.array([0]), logu))
    # the A_2j are the product of this matrix
    # NOTE: using letter  "l" as a variable name to match text
    l = onp.arange(2, n + 1)[:, onp.newaxis]  # noqa: E741
    with onp.errstate(divide='ignore'):
        A2_terms = l * (l - 1) / (l * (l - 1) - l.T * (l.T - 1))
    onp.fill_diagonal(A2_terms, 1)
    A2 = np.prod(A2_terms, axis=0)

    binom_vec = l * (l - 1) / 2

    result = np.zeros(len(t))
    result = index_update(result, index[:-1],
                          np.squeeze(A2[np.newaxis, :]
                                     @ np.exp(np.cumsum(logu[np.newaxis, :-1],
                                                        axis=1)) ** binom_vec))

    assert np.all(np.isfinite(result))

    return result
Beispiel #21
0
def _triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None):  # pylint: disable=redefined-outer-name
    """Scipy solve does not broadcast, so we must do so explicitly."""
    del name
    if JAX_MODE:  # But JAX uses XLA, which can do a batched solve.
        matrix = matrix + np.zeros(rhs.shape[:-2] + (1, 1), dtype=matrix.dtype)
        rhs = rhs + np.zeros(matrix.shape[:-2] + (1, 1), dtype=rhs.dtype)
        return scipy_linalg.solve_triangular(matrix,
                                             rhs,
                                             lower=lower,
                                             trans='C' if adjoint else 'N')
    try:
        bcast = onp.broadcast(matrix[..., :1], rhs)
    except ValueError as e:
        raise ValueError(
            'Error with inputs shaped `matrix`={}, rhs={}:\n{}'.format(
                matrix.shape, rhs.shape, str(e)))
    dim = matrix.shape[-1]
    matrix = onp.broadcast_to(matrix, bcast.shape[:-1] + (dim, ))
    rhs = onp.broadcast_to(rhs, bcast.shape)
    nbatch = int(np.prod(matrix.shape[:-2]))
    flat_mat = matrix.reshape(nbatch, dim, dim)
    flat_rhs = rhs.reshape(nbatch, dim, rhs.shape[-1])
    result = np.empty(flat_rhs.shape)
    if np.size(result):
        # ValueError: On entry to STRTRS parameter number 7 had an illegal value.
        for i, (mat, rh) in enumerate(zip(flat_mat, flat_rhs)):
            result[i] = scipy_linalg.solve_triangular(
                mat, rh, lower=lower, trans='C' if adjoint else 'N')
    return result.reshape(*rhs.shape)
Beispiel #22
0
def get_bins_and_bincounts(samples, normalized=False):
    """take in samples, create a common set of bins, and compute the counts count(x in bin)
    for each bin and each sample x.
    Parameters
    ------------
    samples : np.array of shape (n,) or shape (k, n).
    - If shape (n,): interpreted as a set of n scalar-valued samples.
    - If shape (k, n): interpreted as k sets of n scalar-valued samples.

    Returns
    --------
    probabilities :
    bins :
    """
    nr_samples = np.prod(samples.shape)
    nr_bins = np.log2(nr_samples)
    nr_bins = int(max(nr_bins, 5))

    lims = [np.min(samples), np.max(samples)]
    bins = np.linspace(*lims, num=nr_bins)

    if samples.ndim == 2:
        out = np.asarray([
            np.histogram(x, bins=bins, density=normalized)[0] for x in samples
        ])
        return out, bins
    elif samples.ndim == 1:
        return np.histogram(samples, bins=bins, density=normalized)[0], bins
    else:
        raise ValueError(
            f"Input must have shape (n,) or shape (k,n). Instead received shape {samples.shape}"
        )
Beispiel #23
0
def _get_inputs(key, is_conv, same_inputs, input_shape, fn=np.cos):
  key, split = random.split(key)
  shape = input_shape if is_conv else (input_shape[0], np.prod(input_shape[1:]))
  x1 = fn(random.normal(key, shape))
  x2 = None if same_inputs else 2 * fn(random.normal(split, shape))

  return x1, x2
Beispiel #24
0
 def __init__(self, input_shape):
     super(MLPDynamics, self).__init__()
     self.input_shape = input_shape
     self.dim = jnp.prod(input_shape[1:])
     self.hidden_dim = 100
     self.lin1 = hk.Linear(self.hidden_dim)
     self.lin2 = hk.Linear(self.dim)
Beispiel #25
0
 def update(self, mpc_state, env, env_state, rng, reward_fn=None,
            reward_params=None, reward_rng=None):
   # mpc_state: ([n_steps, dim_a], [n_steps, dim_a, dim_a])
   # env: {.step(s, a), .reward(s)}
   # env_state: [env_shape] np.float32
   # rng: rng key for mpc sampling
   # reward_fn: reward_fn(env, s, params, rng)
   # reward_params: params for reward function
   # reward_rng: rng key for reward function stochasticity, e.g. dropout
   dim_a = jnp.prod(env.a_shape)  # np.int32
   a_opt, a_cov = mpc_state
   a_opt = jnp.concatenate([a_opt[1:, :],
                            jnp.expand_dims(jnp.zeros((dim_a,)),
                                            axis=0)])  # [n_steps, dim_a]
   if self.adaptive_covariance:
     a_cov = jnp.concatenate([a_cov[1:, :],
                              jnp.expand_dims((self.a_std**2)*jnp.eye(dim_a),
                                              axis=0)])
   def iteration_step(input_, _):
     a_opt, a_cov, rng = input_
     rng_da, rng = jax.random.split(rng)
     if self.adaptive_covariance:
       da = jax.vmap(jax.random.multivariate_normal, (0, 0, 0, None), 1)(
           jax.random.split(rng_da, self.n_steps),  # [n_steps], rngs
           jnp.zeros((self.n_steps, dim_a)),  # [n_steps, dim_a] mean
           a_cov,  # [n_steps, dim_a, dim_a] covariance
           (self.n_samples,),
       )  # [n_samples, n_steps, dim_a]
     else:
       da = self.a_std*jax.random.normal(
           rng_da,
           shape=(self.n_samples, self.n_steps, dim_a)
       )  # [n_samples, n_steps, dim_a]
     # a: [n_samples, n_steps, dim_a]
     a = jnp.clip(jnp.expand_dims(a_opt, axis=0) + da, -1.0, 1.0)
     r = jax.vmap(self.rollout, in_axes=(0, None, None, None, None, None))(
         a, env, env_state, reward_fn, reward_params, reward_rng
     )  # [n_samples, n_steps]
     R = jax.vmap(self.returns)(r)  # [n_samples, n_steps], pylint: disable=invalid-name
     w = jax.vmap(self.weights, 1, 1)(R)  # [n_samples, n_steps]
     da_opt = jax.vmap(jnp.average, (1, None, 1))(da, 0, w)  # [n_steps, dim_a]
     a_opt = jnp.clip(a_opt + da_opt, -1.0, 1.0)  # [n_steps, dim_a]
     if self.adaptive_covariance:
       a_cov = jax.vmap(jax.vmap(jnp.outer))(
           da, da
       )  # [n_samples, n_steps, dim_a, dim_a]
       a_cov = jax.vmap(jnp.average, (1, None, 1))(
           a_cov, 0, w
       )  # a_cov: [n_steps, dim_a, dim_a]
       # prevent loss of rank when one sample is heavily weighted
       a_cov = a_cov + jnp.eye(dim_a)*0.00001
     return (a_opt, a_cov, rng), None
   if not self.scan:
     for _ in range(self.n_iterations):
       (a_opt, a_cov, rng), _ = iteration_step((a_opt, a_cov, rng), None)
   else:
     (a_opt, a_cov, rng), _ = jax.lax.scan(
         iteration_step, (a_opt, a_cov, rng), None, length=self.n_iterations
     )
   return (a_opt, a_cov)
Beispiel #26
0
    def test_shuffled_neuron_no_input_ablation_mask_sparsity_full(self):
        """Tests shuffled mask generation, for 100% sparsity."""
        mask = masked.shuffled_neuron_no_input_ablation_mask(
            self._masked_model, self._rng, 1.0)

        with self.subTest(name='shuffled_full_mask'):
            self.assertIn('MaskedModule_0', mask)

        with self.subTest(name='shuffled_full_mask_values'):
            self.assertEqual(
                jnp.count_nonzero(mask['MaskedModule_0']['kernel']),
                jnp.prod(jnp.array(self._input_dimensions)))

        with self.subTest(name='shuffled_full_no_input_ablation'):
            # Check no row (neurons are columns) is completely ablated.
            self.assertTrue((jnp.count_nonzero(
                mask['MaskedModule_0']['kernel'], axis=0) != 0).all())

        with self.subTest(name='shuffled_full_mask_not_masked_values'):
            self.assertIsNone(mask['MaskedModule_0']['bias'])

        masked_output = self._masked_model(self._input, mask=mask)

        with self.subTest(name='shuffled_full_mask_dense_shape'):
            self.assertSequenceEqual(masked_output.shape,
                                     self._unmasked_output.shape)
def compute_weighted_cross_entropy(logits, targets, weights=None):
    """Compute weighted cross entropy and entropy for log probs and targets.

  Args:
   logits: `[batch, length, num_classes]` float array.
   targets: categorical targets `[batch, length]` int array.
   weights: None or array of shape [batch, length, 1]

  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
    if logits.ndim != targets.ndim + 1:
        raise ValueError(
            'Incorrect shapes. Got shape %s logits and %s targets' %
            (str(logits.shape), str(targets.shape)))
    if logits.shape[1] != targets.shape[1]:  # Truncate logits.
        logits = logits[:, :targets.shape[1]]

    onehot_targets = common_utils.onehot(targets, logits.shape[-1])
    loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)
    normalizing_factor = jnp.prod(jnp.asarray(targets.shape))
    if weights is not None:
        loss = loss * weights
        normalizing_factor = weights.sum()

    return loss.sum(), normalizing_factor
def compute_weighted_accuracy(logits, targets, weights=None):
    """Compute weighted accuracy for log probs and targets.

  Args:
   logits: `[batch, length, num_classes]` float array.
   targets: categorical targets `[batch, length]` int array.
   weights: None or array of shape [batch, length, 1]

  Returns:
    Tuple of scalar accuracy and batch normalizing factor.
  """
    if logits.ndim != targets.ndim + 1:
        raise ValueError(
            'Incorrect shapes. Got shape %s logits and %s targets' %
            (str(logits.shape), str(targets.shape)))
    if logits.shape[1] != targets.shape[1]:  # Truncate logits.
        logits = logits[:, :targets.shape[1]]

    acc = jnp.equal(jnp.argmax(logits, axis=-1), targets)
    normalizing_factor = jnp.prod(jnp.asarray(targets.shape))
    if weights is not None:
        acc = acc * weights
        normalizing_factor = weights.sum()

    return acc.sum(), normalizing_factor
Beispiel #29
0
    def loss_fn(variables):
        rays = batch["rays"]
        ret = model.apply(variables, key_0, key_1, rays, FLAGS.randomized)
        if len(ret) not in (1, 2):
            raise ValueError(
                "ret should contain either 1 set of output (coarse only), or 2 sets"
                "of output (coarse as ret[0] and fine as ret[1]).")
        # The main prediction is always at the end of the ret list.
        rgb, unused_disp, unused_acc = ret[-1]
        loss = ((rgb - batch["pixels"][Ellipsis, :3])**2).mean()
        psnr = utils.compute_psnr(loss)
        if len(ret) > 1:
            # If there are both coarse and fine predictions, we compute the loss for
            # the coarse prediction (ret[0]) as well.
            rgb_c, unused_disp_c, unused_acc_c = ret[0]
            loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3])**2).mean()
            psnr_c = utils.compute_psnr(loss_c)
        else:
            loss_c = 0.
            psnr_c = 0.

        def tree_sum_fn(fn):
            return jax.tree_util.tree_reduce(lambda x, y: x + fn(y),
                                             variables,
                                             initializer=0)

        weight_l2 = (tree_sum_fn(lambda z: jnp.sum(z**2)) /
                     tree_sum_fn(lambda z: jnp.prod(jnp.array(z.shape))))

        stats = utils.Stats(loss=loss,
                            psnr=psnr,
                            loss_c=loss_c,
                            psnr_c=psnr_c,
                            weight_l2=weight_l2)
        return loss + loss_c + FLAGS.weight_decay_mult * weight_l2, stats
Beispiel #30
0
def reduce_prod(x, axis=None, keepdims=False):
    if axis is None:
        num_dims = len(x.shape)
        axis = tuple(range(num_dims))
    elif isinstance(axis, list):
        axis = tuple(axis)
    return _jnp.prod(x, axis=axis, keepdims=keepdims)