示例#1
0
 def init(key, shape, dtype=np.float32):
   if len(shape) < 2:
     raise ValueError("orthogonal initializer requires at least a 2D shape")
   n_rows, n_cols = onp.prod(shape) // shape[column_axis], shape[column_axis]
   matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)
   A = random.normal(key, matrix_shape, dtype)
   Q, R = np.linalg.qr(A)
   Q *= np.sign(np.diag(R)) # needed for a uniform distribution
   if n_rows < n_cols: Q = Q.T
   Q = np.reshape(Q, tuple(onp.delete(shape, column_axis)) + (shape[column_axis],))
   Q = np.moveaxis(Q, -1, column_axis)
   return scale * Q
示例#2
0
def reverse_zipped(x: _ArrayOrShape, start_axis: int = 0) -> _ArrayOrShape:
    if x is not None:
        ndim = _get_ndim(x)
        source_axes = tuple(j for i in range(ndim - 2, start_axis - 1, -2)
                            for j in (i, i + 1))

        if isinstance(x, (onp.ndarray, np.ndarray)):
            target_axes = range(start_axis, ndim)
            x = np.moveaxis(x, source_axes, target_axes)
        else:
            x = x[:start_axis] + type(x)(x[i] for i in source_axes)
    return x
示例#3
0
def extract_images_patches(images,
                           window_size,
                           stride = (1, 1)):
  """Extracts patches from an image using a convolution operator.

  Args:
    images: A tensor of images of shapes (B, H, W, C).
    window_size: The size of the patches to extract (h, w).
    stride: The shift between extracted patches (s1, s2)

  Returns:
    All the patches in a tensor of dimension
      (B, (H - h + 1) // s1, (W - w + 1) // s2, h, w, C).
  """
  # batch, channels, height, width
  images = jnp.moveaxis(images, -1, 1)
  d = images.shape[1]
  h, w = window_size

  # construct the lookup conv weights
  dim_out = jnp.arange(d * h * w).reshape((-1, 1, 1, 1))
  dim_in = jnp.arange(d).reshape((1, -1, 1, 1))
  i = jnp.arange(h).reshape((1, 1, -1, 1))
  j = jnp.arange(w).reshape((1, 1, 1, -1))
  weights = ((w * i + j) * d + dim_in == dim_out).astype(jnp.float32)

  # batch, h * w * d, (H - h + 1) // s1, (W - w + 1) // s2
  concatenated_patches = jax.lax.conv(images,
                                      weights,
                                      window_strides=stride,
                                      padding="VALID")

  # batch, (H - h + 1) // s1, (W - w + 1) // s2, h * w * d
  concatenated_patches = jnp.moveaxis(concatenated_patches, 1, -1)

  # batch, (H - h + 1) // s1, (W - w + 1) // s2, h, w, d
  shape = concatenated_patches.shape[:3] + (h, w, d)
  patches = concatenated_patches.reshape(shape)
  return patches
示例#4
0
文件: ddpg.py 项目: minalspatil/coax
def shared(S, is_training):
    seq = hk.Sequential([
        coax.utils.diff_transform,
        hk.Conv2D(16, kernel_shape=8, stride=4),
        jax.nn.relu,
        hk.Conv2D(32, kernel_shape=4, stride=2),
        jax.nn.relu,
        hk.Flatten(),
    ])
    X = jnp.moveaxis(
        S / 255., 1,
        -1)  # shape: (batch, frames, h, w) --> (batch, h, w, frames)
    return seq(X)
示例#5
0
def reverse_zipped(x: Union[np.ndarray, Sequence[int]],
                   start_axis: int = 0) -> Union[np.ndarray, Sequence[int]]:
    if x is not None:
        ndim = _get_ndim(x)
        source_axes = tuple(j for i in range(ndim - 2, start_axis - 1, -2)
                            for j in (i, i + 1))

        if isinstance(x, np.ndarray):
            target_axes = range(start_axis, ndim)
            x = np.moveaxis(x, source_axes, target_axes)
        else:
            x = x[:start_axis] + type(x)(x[i] for i in source_axes)
    return x
示例#6
0
    def test_conv_local_general_dilated(self, n, padding, lhs_spec, rhs_spec,
                                        out_spec):
        """Make sure LCN with tiled CNN kernel matches CNN."""
        if xla_bridge.get_backend().platform == 'cpu' and n > 1:
            raise absltest.SkipTest('Skipping large tests on CPU.')

        lhs_spec_default = 'NCHWDX'[:n + 2]
        rhs_spec_default = 'OIHWDX'[:n + 2]

        lhs_default = random.normal(random.PRNGKey(1),
                                    (2, 4, 7, 6, 5, 8)[:n + 2])
        rhs_default = random.normal(random.PRNGKey(2),
                                    (3, 4, 2, 3, 1, 2)[:n + 2])

        window_strides = (1, 2, 3, 4)[:n]
        rhs_dilation = (2, 1, 3, 2)[:n]

        lhs_perm = [lhs_spec_default.index(c) for c in lhs_spec]
        lhs = np.transpose(lhs_default, lhs_perm)

        rhs_perm = [rhs_spec_default.index(c) for c in rhs_spec]
        rhs = np.transpose(rhs_default, rhs_perm)

        kwargs = dict(lhs=lhs,
                      window_strides=window_strides,
                      padding=padding,
                      rhs_dilation=rhs_dilation,
                      dimension_numbers=(lhs_spec, rhs_spec, out_spec))

        out_conv = lax.conv_general_dilated(rhs=rhs, **kwargs)

        rhs_local = np.moveaxis(rhs,
                                (rhs_spec.index('O'), rhs_spec.index('I')),
                                (0, 1))
        rhs_local = rhs_local.reshape((rhs_local.shape[0], -1) + (1, ) * n)

        rhs_shape = (rhs_local.shape[:2] +
                     tuple(out_conv.shape[out_spec.index(c)]
                           for c in rhs_spec_default[2:]))

        rhs_local = np.broadcast_to(rhs_local, rhs_shape)
        rhs_local = np.transpose(rhs_local, rhs_perm)

        filter_shape = [
            rhs.shape[i] for i in range(n + 2) if rhs_spec[i] not in ('O', 'I')
        ]
        out_local = utils.conv_local_general_dilated(rhs=rhs_local,
                                                     filter_shape=filter_shape,
                                                     **kwargs)

        self.assertAllClose(out_conv, out_local, atol=1e-5, rtol=1e-5)
示例#7
0
    def _batch_mahalanobis(bL, bx):
        # NB: The following procedure handles the case: bL.shape = (i, 1, n, n), bx.shape = (i, j, n)
        # because we don't want to broadcast bL to the shape (i, j, n, n).

        # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
        # we are going to make bx have shape (..., 1, j,  i, 1, n) to apply batched tril_solve
        sample_ndim = bx.ndim - bL.ndim + 1  # size of sample_shape
        out_shape = np.shape(bx)[:-1]  # shape of output
        # Reshape bx with the shape (..., 1, i, j, 1, n)
        bx_new_shape = out_shape[:sample_ndim]
        for (sL, sx) in zip(bL.shape[:-2], out_shape[sample_ndim:]):
            bx_new_shape += (sx // sL, sL)
        bx_new_shape += (-1, )
        bx = np.reshape(bx, bx_new_shape)
        # Permute bx to make it have shape (..., 1, j, i, 1, n)
        permute_dims = (tuple(range(sample_ndim)) +
                        tuple(range(sample_ndim, bx.ndim - 1, 2)) +
                        tuple(range(sample_ndim + 1, bx.ndim - 1, 2)) +
                        (bx.ndim - 1, ))
        bx = np.transpose(bx, permute_dims)

        # reshape to (-1, i, 1, n)
        xt = np.reshape(bx, (-1, ) + bL.shape[:-1])
        # permute to (i, 1, n, -1)
        xt = np.moveaxis(xt, 0, -1)
        solve_bL_bx = solve_triangular(bL, xt,
                                       lower=True)  # shape: (i, 1, n, -1)
        M = np.sum(solve_bL_bx**2, axis=-2)  # shape: (i, 1, -1)
        # permute back to (-1, i, 1)
        M = np.moveaxis(M, -1, 0)
        # reshape back to (..., 1, j, i, 1)
        M = np.reshape(M, bx.shape[:-1])
        # permute back to (..., 1, i, j, 1)
        permute_inv_dims = tuple(range(sample_ndim))
        for i in range(bL.ndim - 2):
            permute_inv_dims += (sample_ndim + i, len(out_shape) + i)
        M = np.transpose(M, permute_inv_dims)
        return np.reshape(M, out_shape)
示例#8
0
 def tensor(self, other):
     if not isinstance(other, Tensor):
         raise TypeError(messages.type_err(Tensor, other))
     dom, cod = self.dom @ other.dom, self.cod @ other.cod
     array = np.tensordot(self.array, other.array, 0)\
         if self.array.shape and other.array.shape\
         else self.array * other.array
     source = range(len(dom @ cod))
     target = [
         i if i < len(self.dom) or i >= len(self.dom @ self.cod @ other.dom)
         else i - len(self.cod) if i >= len(self.dom @ self.cod) else i +
         len(other.dom) for i in source
     ]
     return Tensor(dom, cod, np.moveaxis(array, source, target))
示例#9
0
def encode_coordinate(x, basis, legacy_posenc_order=False):
  """Concatenate `x` with Fourier features of `x` projected onto `basis`."""
  xb = x @ basis
  # Instead of computing [sin(x), cos(x)], we use the trig identity
  # cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
  four_feat = jnp.sin(jnp.concatenate([xb, xb + 0.5 * jnp.pi], axis=-1))
  # TODO(bydeng): remove this re-ordering when a new batch of pre-trained
  # models is ready.
  if legacy_posenc_order:
    four_feat = jnp.moveaxis(
        jnp.reshape(four_feat,
                    list(four_feat.shape[:-1]) + [2, -1, x.shape[-1]]), -3, -2)
    four_feat = jnp.reshape(four_feat, list(four_feat.shape[:-3]) + [-1])
  return jnp.concatenate([x] + [four_feat], axis=-1)
示例#10
0
 def setUp(self):
     super().setUp()
     self.batch_size = 8
     self.params = {
         'key_a': (jnp.zeros((2, 3, 4)), jnp.zeros([])),
         'key_b': jnp.zeros((6, 7))
     }
     # Example `i`'s grads are full of `i`s. Important to include 0 to ensure
     # there are no divisons by 0 (e.g. in norm clipping)
     a = jnp.arange(self.batch_size)
     self.per_eg_grads = jax.tree_map(
         lambda p: jnp.moveaxis(a * jnp.ones(p.shape +
                                             (self.batch_size, )), -1, 0),
         self.params)
示例#11
0
def func(S, A, is_training):
    """ type-1 q-function: (s,a) -> q(s,a) """
    body = hk.Sequential((
        coax.utils.diff_transform,
        hk.Conv2D(16, kernel_shape=8, stride=4), jax.nn.relu,
        hk.Conv2D(32, kernel_shape=4, stride=2), jax.nn.relu,
        hk.Flatten(),
    ))
    head = hk.Sequential((
        hk.Linear(256), jax.nn.relu,
        hk.Linear(1, w_init=jnp.zeros), jnp.ravel
    ))
    X = jnp.moveaxis(S / 255., 1, -1)  # shape: (batch, frames, h, w) --> (batch, h, w, frames)
    return head(jax.vmap(jnp.kron)(body(X), A))
示例#12
0
 def gate(self, params):
     # start with identity scalar and tensor-product all the gates
     mat = 1
     for g in self.gates:
         g = g.gate(params)
         mat = jnp.kron(mat, g)
     # if we have interleaved gates (i.e. a gate acting on qudit #1 and #3 but not #2)
     # then we need to un-permute the indices
     if self.permuted != self.regInfo.unpermuted:
         # reshape 2D unitary into qudit indices
         mat = mat.reshape(self.regInfo.shape)
         mat = jnp.moveaxis(mat, self.permuted, self.regInfo.unpermuted)
         mat = mat.reshape((self.regInfo.dim, self.regInfo.dim))
     return mat
示例#13
0
        def predict_fn_inf(fx_train_0, fx_test_0, k_test_train):
            fx_train_t = y_train.astype(k_train_train.dtype)
            if fx_test_0 is None:
                return fx_train_t

            rhs = y_train if fx_train_0 is None else y_train - fx_train_0
            dfx_test = np.tensordot(k_test_train, solve(rhs, trace_axes),
                                    (odd, first))
            dfx_test = np.moveaxis(dfx_test, last_t_axes, trace_axes)
            fx_test_t = fx_test_0 + dfx_test

            if fx_train_0 is None:
                return fx_test_t
            return fx_train_t, fx_test_t
示例#14
0
def _zip_axes(x: np.ndarray,
              start_axis: int = 0,
              end_axis: int = None,
              unzip: bool = False) -> np.ndarray:
  """Zip/unzip (interleave/de-interleave) axes starting from `start_axis`.

  Changes the shape as follows:
    If `unzip == True`:
    `[..., X, X, ..., Y, Y, ..., Z, Z, ...] -> [..., X, Y, Z, ..., X, Y, Z, ..]`
    If `unzip == False`:
    `[..., X, Y, Z, ..., X, Y, Z, ...] -> [..., X, X, ..., Y, Y, ..., Z, Z, ..]`

  Args:
    x: `np.ndarray` with an even number of dimensions following `start_axis`.
    start_axis: `int`, number of axis from which to zip/unzip.
    end_axis: `int`, number of axis until which to zip/unzip.
    unzip: `bool`, set to `True` to unzip instead of zip.

  Returns:
    A `np.ndarray` with a new shape.
  """
  if end_axis is None:
    end_axis = x.ndim

  half_ndim, ragged = divmod(end_axis - start_axis, 2)
  if ragged:
    raise ValueError(
        f'Need even number of axes to zip, got {end_axis - start_axis}.')

  odd_axes = range(start_axis + 1, end_axis, 2)
  last_axes = range(end_axis - half_ndim, end_axis)

  if unzip:
    x = np.moveaxis(x, odd_axes, last_axes)
  else:
    x = np.moveaxis(x, last_axes, odd_axes)
  return x
示例#15
0
def _trace_and_diagonal(ntk: np.ndarray, trace_axes: Axes,
                        diagonal_axes: Axes) -> np.ndarray:
    """Extract traces and diagonals along respective pairs of axes from the `ntk`.

  Args:
    ntk:
      input empirical NTK of shape `(N1, X, Y, Z, ..., N2, X, Y, Z, ...)`.
    trace_axes:
      axes (among `X, Y, Z, ...`) to trace over, i.e. compute the trace along
      and remove the  respective pairs of axes from the `ntk`.
    diagonal_axes:
      axes (among `X, Y, Z, ...`) to take the diagonal along, i.e. extract the
      diagonal along the respective pairs of axes from the `ntk` (and hence
      reduce the resulting `ntk` axes count by 2).
  Returns:
    An array of shape, for example, `(N1, N2, Y, Z, Z, ...)` if
    `trace_axes=(1,)` (`X` axes removed), and `diagonal_axes=(2,)` (`Y` axes
    replaced with a single `Y` axis).
  """

    if ntk.ndim % 2 == 1:
        raise ValueError(
            'Expected an even-dimensional kernel. Please file a bug at'
            'https://github.com/google/neural-tangents/issues/new')

    output_ndim = ntk.ndim // 2

    trace_axes = utils.canonicalize_axis(trace_axes, output_ndim)
    diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim)

    n_diag, n_trace = len(diagonal_axes), len(trace_axes)
    contract_size = utils.size_at(ntk.shape[:output_ndim], trace_axes)

    for i, c in enumerate(reversed(trace_axes)):
        ntk = np.trace(ntk, axis1=c, axis2=output_ndim + c - i)

    for i, d in enumerate(diagonal_axes):
        axis1 = d - i
        axis2 = output_ndim + d - 2 * i - n_trace
        for c in trace_axes:
            if c < d:
                axis1 -= 1
                axis2 -= 1
        ntk = np.diagonal(ntk, axis1=axis1, axis2=axis2)

    ntk = utils.zip_axes(ntk, 0, ntk.ndim - n_diag)
    res_diagonal_axes = utils.get_res_batch_dims(trace_axes, diagonal_axes)
    ntk = np.moveaxis(ntk, range(-n_diag, 0), res_diagonal_axes)
    return ntk / contract_size
示例#16
0
def cholesky_update(L, x, coef=1):
    """
    Finds cholesky of L @ L.T + coef * x @ x.T.

    **References;**

        1. A more efficient rank-one covariance matrix update for evolution strategies,
           Oswin Krause and Christian Igel
    """
    batch_shape = lax.broadcast_shapes(L.shape[:-2], x.shape[:-1])
    L = np.broadcast_to(L, batch_shape + L.shape[-2:])
    x = np.broadcast_to(x, batch_shape + x.shape[-1:])
    diag = np.diagonal(L, axis1=-2, axis2=-1)
    # convert to unit diagonal triangular matrix: L @ D @ T.t
    L = L / diag[..., None, :]
    D = np.square(diag)

    def scan_fn(carry, val):
        b, w = carry
        j, Dj, L_j = val
        wj = w[..., j]
        gamma = b * Dj + coef * np.square(wj)
        Dj_new = gamma / b
        b = gamma / Dj_new

        # update vectors w and L_j
        w = w - wj[..., None] * L_j
        L_j = L_j + (coef * wj / gamma)[..., None] * w
        return (b, w), (Dj_new, L_j)

    D, L = np.moveaxis(D, -1, 0), np.moveaxis(L, -1,
                                              0)  # move scan dim to front
    _, (D, L) = lax.scan(scan_fn, (np.ones(batch_shape), x),
                         (np.arange(D.shape[0]), D, L))
    D, L = np.moveaxis(D, 0, -1), np.moveaxis(L, 0, -1)  # move scan dim back
    return L * np.sqrt(D)[..., None, :]
示例#17
0
def make_image_grid(images, nrow=10):
  """Given a list of images, tile into a single grid image."""
  ncol = int(math.ceil(len(images) / nrow))
  to_pad = nrow - len(images) % nrow
  images = np.stack(images)
  if images.ndim == 3:
    # Add channel dimension
    images = images[Ellipsis, None]
  H, W, C = images.shape[1:]  # pylint: disable=invalid-name
  if to_pad and to_pad != nrow:
    padding_frames = np.zeros((to_pad, H, W, C), dtype=images.dtype)
    images = np.concatenate([images, padding_frames], axis=0)
  images = np.reshape(images, (ncol, nrow, H, W, C))
  images = np.moveaxis(images, 1, 2)  # nc, nr, h, w, c --> nc, h, nr, w, c
  return images.reshape(1, ncol * H, nrow * W, C)
示例#18
0
        def dstate_dt(state_t: ODEState, unused_t) -> ODEState:
            fx_train_t, fx_test_t, qx_train_t, qx_test_t = (state_t.fx_train,
                                                            state_t.fx_test,
                                                            state_t.qx_train,
                                                            state_t.qx_test)

            dy_df_t = grad_loss(fx_train_t)

            fx_train_t = -np.moveaxis(
                np.tensordot(k_train_train, dy_df_t,
                             (odd, non_t_axes)), last_t_axes, trace_axes)
            if fx_test_t is not None:
                fx_test_t = -np.moveaxis(
                    np.tensordot(k_test_train, dy_df_t,
                                 (odd, non_t_axes)), last_t_axes, trace_axes)

            if momentum is None:
                return ODEState(fx_train_t, fx_test_t)  # pytype: disable=wrong-arg-count

            fx_train_t += momentum * qx_train_t
            if qx_test_t is not None:
                fx_test_t += momentum * qx_test_t

            return ODEState(qx_train_t, qx_test_t, fx_train_t, fx_test_t)  # pytype: disable=wrong-arg-count
示例#19
0
 def __call__(self, shape: Shape, dtype: DType) -> np.ndarray:
     if len(shape) < 2:
         raise ValueError("Orthogonal initializer requires at least a 2D shape.")
     n_rows = shape[self.axis]
     n_cols = np.prod(shape) // n_rows
     matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
     norm_dst = jax.random.normal(hooks.next_rng_key(), matrix_shape, dtype)
     q_mat, r_mat = jnp.linalg.qr(norm_dst)
     # Enforce Q is uniformly distributed
     q_mat *= jnp.sign(jnp.diag(r_mat))
     if n_rows < n_cols:
         q_mat = q_mat.T
     q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
     q_mat = jnp.moveaxis(q_mat, 0, self.axis)
     return jax.lax.convert_element_type(self.scale, dtype) * q_mat
示例#20
0
def reverse_zipped(
    mat: Union[np.ndarray, Sequence[int]],
    start_axis: int = 0) -> Union[np.ndarray, Sequence[int]]:
  if mat is not None:
    ndim = _get_ndim(mat)
    source_axes = tuple(j
                        for i in range(ndim - 2, start_axis - 1, -2)
                        for j in (i, i + 1))

    if isinstance(mat, np.ndarray):
      target_axes = range(start_axis, ndim)
      mat = np.moveaxis(mat, source_axes, target_axes)
    else:
      mat = mat[:start_axis] + type(mat)(mat[i] for i in source_axes)
  return mat
示例#21
0
def twiddle_factor_to_matrix(twiddle_factor, stride):
    """
    twiddle_factor: (n // 2, 2, 2)
    stride: int
    Return:
        (n, n)
    """
    n = twiddle_factor.shape[0] * 2
    assert twiddle_factor.shape == (n // 2, 2, 2)
    assert stride == 1 << int(math.log2(stride)), 'stride must be a power of 2'
    x = jnp.eye(n)
    t = jnp.moveaxis(twiddle_factor.reshape(n // (2 * stride), stride, 2, 2),
                     -3, -1)
    y = x.reshape(n, n // (2 * stride), 1, 2, stride)
    y = (t * y).sum(axis=-2).reshape(n, n)
    return y.T
示例#22
0
  def _compute_routing_instructions(self, router_probs,
                                    expert_capacity):
    """Computes masks for the highest probability token per expert.

    Args:
      router_probs: <float32>[NUM_GROUPS, TOKENS_PER_GROUP, NUM_EXPERTS]
        probabilities used to determine the routing of tokens to the experts.
      expert_capacity: Each group will send this many tokens to each expert.

    Returns:
      Dispatch and combine arrays for routing with masked matmuls.
    """
    tokens_per_group = router_probs.shape[1]

    # vmap over group dimension.
    router_probs_t = jax.vmap(lambda m: m.transpose())(router_probs)

    # Top expert_capacity router probability and corresponding token indices for
    # each expert. Shapes: [NUM_GROUPS, NUM_EXPERTS, EXPERT_CAPACITY].
    expert_gate, expert_index = _top_k(router_probs_t, k=expert_capacity)

    # Convert to one-hot mask of expert indices for each token in each group.
    # Shape: [NUM_GROUPS, NUM_EXPERTS, EXPERT_CAPACITY, TOKENS_PER_GROUP].
    dispatch_mask = jax.nn.one_hot(
        expert_index, tokens_per_group, dtype=jnp.int32)

    # Move axes to conform with shape expected by MoeLayer API.
    # Shape: [NUM_GROUPS, TOKENS_PER_GROUP, NUM_EXPERTS, EXPERT_CAPACITY]
    dispatch_mask = jnp.moveaxis(dispatch_mask, 3, 1)

    # The combine array will be used for combining expert outputs, scaled by the
    # router probabilities. Shape: [NUM_GROUPS, NUM_EXPERTS, TOKENS_PER_GROUP,
    # EXPERT_CAPACITY].
    combine_array = jnp.einsum(
        "...ec,...tec->...tec",
        expert_gate,
        dispatch_mask,
        precision=jax.lax.Precision.DEFAULT)

    # Return to default dtype now that router computation is complete.
    combine_array = jax.lax.convert_element_type(combine_array, self.dtype)

    # Each expert is choosing tokens until it reaches full capacity, so we don't
    # need an auxiliary loading balancing loss for expert choice routing.
    auxiliary_loss = 0.0

    return RouterMask(dispatch_mask, combine_array, auxiliary_loss)
示例#23
0
def svgd_log(log, style="-", full=False):
    """plot metrics logged during SVGD run."""
    # plot mean and var
    titles = log["metric_names"]
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    colors = colors + colors + colors  # avoid index out of bound
    for key, dic in log.items():
        if key == "desc":
            plotobject(dic, colors, style=style, xlabel="step")
            colors = colors[len(dic):]

        elif key == "metrics" and full:
            for k, v in dic.items():
                v = np.moveaxis(v, 0, 1)
                plotobject(v, colors, titles[k], yscale="log",
                           style=style)  # moveaxis swaps axes 0 and 1
                colors = colors[len(v):]
示例#24
0
def logistic_mix_sample(nn_out, rng):
    m, t, inv_scales, logit_weights = logistic_preprocess(nn_out)
    rng_mix, rng_logistic = random.split(rng)
    mix_idx = random.categorical(rng_mix, logit_weights, -3)

    def select_mix(arr):
        return jnp.squeeze(
            jnp.take_along_axis(arr, jnp.expand_dims(mix_idx, (-4, -1)), -4),
            -4)

    m, t, inv_scales = map(lambda x: jnp.moveaxis(select_mix(x), -1, 0),
                           (m, t, inv_scales))
    l = random.logistic(rng_logistic, m.shape) / inv_scales
    img_red = m[0] + l[0]
    img_green = m[1] + t[0] * img_red + l[1]
    img_blue = m[2] + t[1] * img_red + t[2] * img_green + l[2]
    return jnp.stack([img_red, img_green, img_blue], -1)
示例#25
0
 def init(key, shape, dtype=dtype):
     if len(shape) < 2:
         raise ValueError(
             "orthogonal initializer requires at least a 2D shape")
     n_rows, n_cols = prod(shape) // shape[column_axis], shape[column_axis]
     matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows,
                                                              n_cols)
     A = random.normal(key, matrix_shape, dtype)
     Q, R = jnp.linalg.qr(A)
     diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim)
     Q *= diag_sign  # needed for a uniform distribution
     if n_rows < n_cols: Q = Q.T
     Q = jnp.reshape(
         Q,
         tuple(np.delete(shape, column_axis)) + (shape[column_axis], ))
     Q = jnp.moveaxis(Q, -1, column_axis)
     return scale * Q
示例#26
0
文件: mip_test.py 项目: wx-b/mipnerf
    def test_control_points(self):
        rng = random.PRNGKey(0)
        batch_size = 10
        for num_dims in [1, 2, 3]:
            key, rng = random.split(rng)
            mean = jax.random.normal(key, [batch_size, num_dims])
            key, rng = random.split(rng)
            half_cov = jax.random.normal(key, [batch_size] + [num_dims] * 2)
            cov = half_cov @ jnp.moveaxis(half_cov, -1, -2)

            sqrtm_cov = sqrtm(cov)
            self.assertArraysAllClose(sqrtm_cov @ sqrtm_cov, cov, atol=1e-5)

            points = control_points(mean, cov)
            mean_recon, cov_recon = surface_stats(points)
            self.assertArraysAllClose(mean, mean_recon)
            self.assertArraysAllClose(cov, cov_recon, atol=1e-5)
示例#27
0
def func(S, is_training):
    """ type-2 q-function: s -> q(s,.) """
    seq = hk.Sequential((
        coax.utils.diff_transform,
        hk.Conv2D(16, kernel_shape=8, stride=4),
        jax.nn.relu,
        hk.Conv2D(32, kernel_shape=4, stride=2),
        jax.nn.relu,
        hk.Flatten(),
        hk.Linear(256),
        jax.nn.relu,
        hk.Linear(env.action_space.n, w_init=jnp.zeros),
    ))
    X = jnp.moveaxis(
        S / 255., 1,
        -1)  # shape: (batch, frames, h, w) --> (batch, h, w, frames)
    return seq(X)
示例#28
0
def _reorder_reshape_inputs(arr, shape):
    """ Function to reorder axes and reshape dimensions of input data.
    This takes input data, assumed to be of shape: (Nfreq, Npol, Npix)
    and converts to shape (Npix * Npol, Nfreq), which is easier to work
    with in the likelihood.

    Parameters
    ----------
    arr: ndarray
        Numpy array with three dimensions.

    Returns
    -------
    ndarray
        Numpy array with two dimensions.
    """
    return np.moveaxis(arr, (0, 1, 2),
                       (2, 0, 1)).reshape(shape).astype(np.float32)
示例#29
0
def dot_general(lhs: np.ndarray,
                rhs: np.ndarray,
                contracting_dims: Axes,
                batch_dims: Axes,
                precision=None) -> np.ndarray:
    """`jax.lax.dot_general` with preserved dims order and shared lhs / rhs dims.

  Precisely, returns `jax.lax.dot_general(lhs, rhs, dimension_numbers)` where
  `dimension_numbers == ((contracting_dims, contracting_dims),
                         (batch_dims, batch_dims))`,
  but preserves the dimension order in the output. See XLA's
   `DotGeneral<https://www.tensorflow.org/xla/operation_semantics#dotgeneral>`.

  Args:
    lhs: array.
    rhs: array, must have the same dimensionality as `lhs`.
    contracting_dims: contracting dimensions.
    batch_dims: batch dimensions.
    precision: Optional. Either `None`, which means the default precision for
      the backend, or a `Precision` enum value.

  Returns:
    Dot product result with preserved dimension order.
  """
    if lhs.ndim != rhs.ndim:
        raise ValueError(
            f'`lhs` and `rhs` must have the same dimensionality, got'
            f'`lhs.ndim == {lhs.ndim}` and `rhs.ndim == {rhs.ndim}`.')

    contracting_dims = canonicalize_axis(contracting_dims, lhs)
    batch_dims = canonicalize_axis(batch_dims, lhs)

    n_batch_dims = len(batch_dims)
    leading_batch_dims = range(n_batch_dims)

    dimension_numbers = ((contracting_dims, contracting_dims), (batch_dims,
                                                                batch_dims))

    prod = lax.dot_general(lhs, rhs, dimension_numbers, precision)
    prod = zip_axes(prod, n_batch_dims)

    res_batch_dims = get_res_batch_dims(contracting_dims, batch_dims)
    prod = np.moveaxis(prod, leading_batch_dims, res_batch_dims)
    return prod
示例#30
0
  def update_fn(updates, state, params=None):
    del params
    grads_flat, grads_treedef = jax.tree_flatten(updates)
    bsize = grads_flat[0].shape[0]

    if any(g.ndim == 0 or bsize != g.shape[0] for g in grads_flat):
      raise ValueError(
          'Unlike other transforms, `differentially_private_aggregate` expects'
          ' `updates` to have a batch dimension in the 0th axis. That is, this'
          ' function expects per-example gradients as input.')

    new_key, *rngs = jax.random.split(state.rng_key, len(grads_flat)+1)
    global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads_flat)
    divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0)
    clipped = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1) for g in grads_flat]
    noised = [(g + noise_std * jax.random.normal(r, g.shape, g.dtype)) / bsize
              for g, r in zip(clipped, rngs)]
    return (jax.tree_unflatten(grads_treedef, noised),
            DifferentiallyPrivateAggregateState(rng_key=new_key))