Пример #1
0
def render_image(render_fn, rays, rng, normalize_disp, chunk=8192):
  """Render all the pixels of an image (in test mode).

  Args:
    render_fn: function, jit-ed render function.
    rays: a `Rays` namedtuple, the rays to be rendered.
    rng: jnp.ndarray, random number generator (used in training mode only).
    normalize_disp: bool, if true then normalize `disp` to [0, 1].
    chunk: int, the size of chunks to render sequentially.

  Returns:
    rgb: jnp.ndarray, rendered color image.
    disp: jnp.ndarray, rendered disparity image.
    acc: jnp.ndarray, rendered accumulated weights per pixel.
    features: jnp.ndarray, rendered feature image.
    specular: jnp.ndarray, rendered specular residual.
  """
  height, width = rays[0].shape[:2]
  num_rays = height * width
  rays = namedtuple_map(lambda r: r.reshape((num_rays, -1)), rays)

  unused_rng, key_0, key_1 = jax.random.split(rng, 3)
  host_id = jax.host_id()
  results = []
  for i in range(0, num_rays, chunk):
    # pylint: disable=cell-var-from-loop
    chunk_rays = namedtuple_map(lambda r: r[i:i + chunk], rays)
    chunk_size = chunk_rays[0].shape[0]
    rays_remaining = chunk_size % jax.device_count()
    if rays_remaining != 0:
      padding = jax.device_count() - rays_remaining
      chunk_rays = namedtuple_map(
          lambda r: jnp.pad(r, ((0, padding), (0, 0)), mode="edge"), chunk_rays)
    else:
      padding = 0
    # After padding the number of chunk_rays is always divisible by
    # host_count.
    rays_per_host = chunk_rays[0].shape[0] // jax.host_count()
    start, stop = host_id * rays_per_host, (host_id + 1) * rays_per_host
    chunk_rays = namedtuple_map(lambda r: shard(r[start:stop]), chunk_rays)
    chunk_results = render_fn(key_0, key_1, chunk_rays)[-1]
    results.append([unshard(x[0], padding) for x in chunk_results])
    # pylint: enable=cell-var-from-loop
  rgb, disp, acc, _, features, specular = [
      jnp.concatenate(r, axis=0) for r in zip(*results)
  ]
  # Normalize disp for visualization for ndc_rays in llff front-facing scenes.
  if normalize_disp:
    disp = (disp - disp.min()) / (disp.max() - disp.min())
  return (rgb.reshape((height, width, -1)), disp.reshape(
      (height, width, -1)), acc.reshape(
          (height, width, -1)), features.reshape(
              (height, width, -1)), specular.reshape((height, width, -1)))
Пример #2
0
 def __call__(self, y: ndarray) -> ndarray:
   hop_length = self.n_fft // 4
   window_length = self.n_fft
   assert len(y.shape) == 2
   y = rearrange(y, 'n s -> s n')
   p = (self.n_fft - hop_length) // 2
   y = jnp.pad(y, ((p, p), (0, 0)), mode='reflect')
   spec = batched_stft(y, self.n_fft, hop_length, window_length, 'hann', False, 'reflect')
   mag = jnp.sqrt(jnp.square(spec.real) + jnp.square(spec.imag) + 1e-9)
   mel = jnp.einsum('ms,sfn->nfm', self.melfb, mag)
   cond = jnp.log(jnp.clip(mel, a_min=1e-5, a_max=None))
   return cond
Пример #3
0
    def _inverse(self, y):
        # inverse stick-breaking
        remainder = 1 - jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1)
        pad_width = [(0, 0)] * (y.ndim - 1) + [(1, 0)]
        remainder = jnp.pad(remainder, pad_width, mode="constant", constant_values=1.0)
        finfo = jnp.finfo(y.dtype)
        remainder = jnp.clip(remainder, a_min=finfo.tiny)
        t = y / remainder

        # inverse of tanh
        t = jnp.clip(t, a_min=-1 + finfo.eps, a_max=1 - finfo.eps)
        return jnp.arctanh(t)
Пример #4
0
def nlm(img, search_window_radius, filter_radius, h, sigma):
    _h, _w = img.shape
    pad = search_window_radius
    img_pad = jnp.pad(img, pad)

    filter_length = 2*filter_radius + 1
    search_window_length = 2*search_window_radius + 1
    win_y_ixs = win_x_ixs = jnp.arange(
        search_window_length - filter_length + 1)

    filter_size = (filter_length, filter_length)

    def compute(y, x):
        # (y + pad, x + pad) are the center of the current neighborhood
        win_center_y = y + pad
        win_center_x = x + pad

        center_patch = jax.lax.dynamic_slice(
            img_pad, (win_center_y-filter_radius, win_center_x-filter_radius), filter_size)

        # Iterate over all patches in this neighborhood
        def _compare(center):
            center_y, center_x = center
            patch = lax.dynamic_slice(
                img_pad, (center_y - filter_radius, center_x - filter_radius), filter_size)
            d2 = jnp.sum((patch - center_patch) ** 2) / (filter_length ** 2)
            weight = jnp.exp(-(jnp.maximum(d2 - 2 * (sigma**2), 0) / (h**2)))
            intensity = img_pad[center_y, center_x]
            return (weight, intensity)

        def compare(patch_y, patch_x):
            patch_center_y = patch_y + filter_radius
            patch_center_x = patch_x + filter_radius
            # Skip if patch is out of image boundaries or this is the center patch
            skip = lax.lt(patch_center_y, pad) | lax.ge(patch_center_y, _h +
                                                        pad) | lax.lt(patch_center_x, pad) | lax.ge(patch_center_x, _w+pad) | (lax.eq(patch_center_y, win_center_y) & lax.eq(patch_center_x, win_center_x))
            return lax.cond(skip, lambda _: (0., 0.), _compare, (patch_center_y, patch_center_x))

        weights, intensities = _vmap_2d(compare, y + win_y_ixs, x + win_x_ixs)

        # Use max weight for the center patch
        max_weight = jnp.max(weights)
        total_weight = jnp.sum(weights) + max_weight
        pixel = (jnp.sum((weights * intensities)) +
                 max_weight * img_pad[win_center_y, win_center_x]) / total_weight
        # embed()
        return pixel

    # embed()
    h_ixs = jnp.arange(_h)
    w_ixs = jnp.arange(_w)
    out = _vmap_2d(compute, h_ixs, w_ixs)
    return out
Пример #5
0
    def __call__(self, x):
        # transform to (-1, 1) interval
        t = jnp.tanh(x)

        # apply stick-breaking transform
        remainder = jnp.cumprod(1 - jnp.abs(t[..., :-1]), axis=-1)
        pad_width = [(0, 0)] * (t.ndim - 1) + [(1, 0)]
        remainder = jnp.pad(remainder,
                            pad_width,
                            mode="constant",
                            constant_values=1.0)
        return t * remainder
Пример #6
0
def crop(key, image_and_label):
    """Random flips and crops."""
    image, label = image_and_label
    pixels = 4
    pixpad = (pixels, pixels)
    zero = (0, 0)
    padded_image = np.pad(image, (pixpad, pixpad, zero), 'constant', 0.0)
    corner = random.randint(key, (2, ), 0, 2 * pixels)
    corner = np.concatenate((corner, np.zeros((1, ), np.int32)))
    img_size = (32, 32, 3)
    cropped_image = lax.dynamic_slice(padded_image, corner, img_size)
    return cropped_image, label
Пример #7
0
 def inv(self, y):
     # inverse stick-breaking
     z1m_cumprod = 1 - cumsum(y * y)
     pad_width = [(0, 0)] * y.ndim
     pad_width[-1] = (1, 0)
     z1m_cumprod_shifted = np.pad(z1m_cumprod[..., :-1], pad_width,
                                  mode="constant", constant_values=1.)
     t = matrix_to_tril_vec(y, diagonal=-1) / np.sqrt(
         matrix_to_tril_vec(z1m_cumprod_shifted, diagonal=-1))
     # inverse of tanh
     x = np.log((1 + t) / (1 - t)) / 2
     return x
Пример #8
0
def _maybe_pad_uneven_sharding(x: JTensor, partition_spec: pjit.PartitionSpec,
                               shape: Sequence[int], mesh_shape: Sequence[int],
                               mesh_axis_names: Sequence[str]) -> JTensor:
  """Pads x to make it evenly shardable, if needed."""
  paddings = _get_uneven_sharding_paddings(partition_spec, shape, mesh_shape,
                                           mesh_axis_names)
  if all([p == 0 for p in paddings]):
    return x
  # Annotate before pad to make sure they have the same sharding. (Pad does not
  # have the highest sharding propgation priority.)
  x = base_layer.with_sharding_constraint(x, partition_spec)
  return jnp.pad(x, [[0, p] for p in paddings])
Пример #9
0
def combined_loss_given_predictions(log_probab_actions_new,
                                    log_probab_actions_old,
                                    value_prediction_new,
                                    value_prediction_old,
                                    padded_actions,
                                    rewards_to_actions,
                                    padded_rewards,
                                    reward_mask,
                                    gamma=0.99,
                                    lambda_=0.95,
                                    epsilon=0.2,
                                    c1=1.0,
                                    c2=0.01):
    """Computes the combined (clipped loss + value loss) given predictions."""
    # Sum values over symbols in an action's representation, because it's a simple
    # way of going from AT to RT+1 and does not decrease the expressive power.
    value_prediction_old = np.dot(value_prediction_old,
                                  rewards_to_actions.transpose())
    value_prediction_new = np.dot(value_prediction_new,
                                  rewards_to_actions.transpose())
    (value_loss, value_summaries) = value_loss_given_predictions(
        value_prediction_new,
        padded_rewards,
        reward_mask,
        gamma=gamma,
        value_prediction_old=value_prediction_old,
        epsilon=epsilon)
    (ppo_loss,
     ppo_summaries) = ppo_loss_given_predictions(log_probab_actions_new,
                                                 log_probab_actions_old,
                                                 value_prediction_old,
                                                 padded_actions,
                                                 rewards_to_actions,
                                                 padded_rewards,
                                                 reward_mask,
                                                 gamma=gamma,
                                                 lambda_=lambda_,
                                                 epsilon=epsilon)
    # Pad the reward mask to be compatible with rewards_to_actions.
    padded_reward_mask = np.pad(reward_mask, ((0, 0), (0, 1)))
    action_mask = np.dot(padded_reward_mask, rewards_to_actions)
    entropy_bonus = masked_entropy(log_probab_actions_new, action_mask)
    combined_loss_ = ppo_loss + (c1 * value_loss) - (c2 * entropy_bonus)

    summaries = {
        "combined_loss": combined_loss_,
        "entropy_bonus": entropy_bonus,
    }
    for loss_summaries in (value_summaries, ppo_summaries):
        summaries.update(loss_summaries)

    return (combined_loss_, (ppo_loss, value_loss, entropy_bonus), summaries)
Пример #10
0
def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None):
    """
    Vectorizing map that maps a function `fn` over `batch_ndims` leading axes
    of `xs`. This uses jax.vmap over smaller chunks of the batch dimensions
    to keep memory usage constant.

    :param callable fn: The function to map over.
    :param xs: JAX pytree (e.g. an array, a list/tuple/dict of arrays,...)
    :param int batch_ndims: The number of leading dimensions of `xs`
        to apply `fn` element-wise over them.
    :param int chunk_size: Size of each chunk of `xs`.
        Defaults to the size of batch dimensions.
    :returns: output of `fn(xs)`.
    """
    flatten_xs = tree_flatten(xs)[0]
    batch_shape = np.shape(flatten_xs[0])[:batch_ndims]
    for x in flatten_xs[1:]:
        assert np.shape(x)[:batch_ndims] == batch_shape

    # we'll do map(vmap(fn), xs) and make xs.shape = (num_chunks, chunk_size, ...)
    num_chunks = batch_size = int(np.prod(batch_shape))
    prepend_shape = (batch_size, ) if batch_size > 1 else ()
    xs = tree_map(
        lambda x: jnp.reshape(x, prepend_shape + jnp.shape(x)[batch_ndims:]),
        xs)
    # XXX: probably for the default behavior with chunk_size=None,
    # it is better to catch OOM error and reduce chunk_size by half until OOM disappears.
    chunk_size = batch_size if chunk_size is None else min(
        batch_size, chunk_size)
    if chunk_size > 1:
        pad = chunk_size - (batch_size % chunk_size)
        xs = tree_map(
            lambda x: jnp.pad(x, ((0, pad), ) + ((0, 0), ) * (np.ndim(x) - 1)),
            xs)
        num_chunks = batch_size // chunk_size + int(pad > 0)
        prepend_shape = (-1, ) if num_chunks > 1 else ()
        xs = tree_map(
            lambda x: jnp.reshape(
                x, prepend_shape + (chunk_size, ) + jnp.shape(x)[1:]),
            xs,
        )
        fn = vmap(fn)

    ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
    map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
    ys = tree_map(
        lambda y: jnp.reshape(y, (int(np.prod(jnp.shape(y)[:map_ndims])), ) +
                              jnp.shape(y)[map_ndims:])[:batch_size],
        ys,
    )
    return tree_map(lambda y: jnp.reshape(y, batch_shape + jnp.shape(y)[1:]),
                    ys)
Пример #11
0
def render_image(state, data, render_fn, rng, chunk=8192):
  """Render all the pixels of an image (in test mode).

  Args:
    state: model_utils.TrainState.
    data: dict, test example.
    render_fn: function, jit-ed render function.
    rng: jnp.ndarray, random number generator (used in training mode only).
    chunk: int, the size of chunks to render sequentially.

  Returns:
    rgb: jnp.ndarray, rendered color image.
    disp: jnp.ndarray, rendered disparity image.
    acc: jnp.ndarray, rendered accumulated weights per pixel.
  """
  rays = data["rays"]
  h, w = rays.shape[:2]
  rays = rays.reshape((h * w, -1))
  unused_rng, key_0, key_1 = jax.random.split(rng, 3)
  model = state.optimizer.target
  model_state = state.model_state
  host_id = jax.host_id()
  rgb = []
  disp = []
  acc = []
  with nn.stateful(model_state, mutable=False):
    for i in range(0, rays.shape[0], chunk):
      chunk_rays = rays[i:i + chunk]
      remainder = chunk_rays.shape[0] % jax.device_count()
      if remainder != 0:
        padding = jax.device_count() - remainder
        chunk_rays = jnp.pad(chunk_rays, ((0, padding), (0, 0)), mode="edge")
      else:
        padding = 0
      # After padding the number of chunk_rays is always divisible by
      # host_count.
      per_host_rays = chunk_rays.shape[0] // jax.host_count()
      chunk_rays = chunk_rays[(host_id * per_host_rays):
                              ((host_id + 1) * per_host_rays)]
      chunk_rays = shard(chunk_rays)
      ret = render_fn(key_0, key_1, model, chunk_rays)
      rgb.append(unshard(ret[-1][0][0], padding))
      disp.append(unshard(ret[-1][1][0], padding))
      acc.append(unshard(ret[-1][2][0], padding))
  rgb = jnp.concatenate(rgb, axis=0)
  disp = jnp.concatenate(disp, axis=0)
  # Normalize disp for visualization for ndc_rays in llff front-facing scenes.
  if rays.shape[-1] > 6:
    disp = (disp - disp.min()) / (disp.max() - disp.min())
  acc = jnp.concatenate(acc, axis=0)
  return (rgb.reshape((h, w, -1)), disp.reshape(
      (h, w, -1)), acc.reshape((h, w, -1)))
Пример #12
0
def _update_block(rng_key, num_blocks, subsample_idx, plate_size):
    size, subsample_size = plate_size
    rng_key, subkey, block_key = random.split(rng_key, 3)
    block_size = (subsample_size - 1) // num_blocks + 1
    pad = block_size - (subsample_size - 1) % block_size - 1

    chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks)
    new_idx = random.randint(subkey, minval=0, maxval=size, shape=(block_size,))
    subsample_idx_padded = jnp.pad(subsample_idx, (0, pad))
    start = chosen_block * block_size
    subsample_idx_padded = lax.dynamic_update_slice_in_dim(
        subsample_idx_padded, new_idx, start, 0)
    return rng_key, subsample_idx_padded[:subsample_size], pad, new_idx, start
Пример #13
0
def _fftconvolve(x, h):
    fft = jnp.fft.fft
    ifft = jnp.fft.ifft

    N = x.shape[0]
    M = h.shape[0]

    out_length = N + M -1

    fft_size = _fft_size_factor(out_length, 5)

    x = jnp.pad(x, [0, fft_size - N])
    h = jnp.pad(h, [0, fft_size - M])

    X = fft(x)
    H = fft(h)

    y = ifft(X * H)

    y = y[:out_length]

    return y
Пример #14
0
def dilated_conv3x3(x, features, strides=(1, 1), groups=1, dilation=1, name='dilated_conv3x3'):
  """Use PyTorch's padding style and dilation."""
  _d = max(1, dilation)
  x = jnp.pad(x, [(0, 0), (_d, _d), (_d, _d), (0, 0)], 'constant', (0, 0))
  return nn.Conv(
      features, (3, 3),
      strides,
      padding='VALID',
      kernel_dilation=(_d, _d),
      feature_group_count=groups,
      use_bias=False,
      name=name)(
          x)
Пример #15
0
def shift_right(x, train=True):
    """Shift the input to the right by padding on axis 1."""
    if train:
        pad_widths = [(0, 0)] * len(x.shape)
        pad_widths[1] = (1, 0)  # Padding on axis=1
        padded = jnp.pad(x,
                         pad_widths,
                         mode='constant',
                         constant_values=x.dtype.type(0))
        return padded[:, :-1]
    else:
        # Do nothing in predict mode, as then the sequence length is 1.
        return x
Пример #16
0
def add_translation(dataset, pixel):
    """ Modify an array of images (JAX arrays) by adding random translations up to 'pixel' pixels."""
    dataset_padded = np.pad(dataset,
                            ((0, 0), (pixel, pixel), (pixel, pixel), (0, 0)),
                            constant_values=dataset[0, 0, 0])
    for n in range(0, dataset.shape[0]):
        key = random.PRNGKey(n)
        dataset = jax.ops.index_update(
            dataset, n,
            np.roll(dataset_padded[n],
                    random.randint(key, [2], -pixel, pixel),
                    axis=(0, 1))[pixel:(28 + pixel), pixel:(28 + pixel)])
    return dataset
Пример #17
0
 def _inverse(self, y):
     # inverse stick-breaking
     z1m_cumprod = 1 - jnp.cumsum(y * y, axis=-1)
     pad_width = [(0, 0)] * y.ndim
     pad_width[-1] = (1, 0)
     z1m_cumprod_shifted = jnp.pad(z1m_cumprod[..., :-1],
                                   pad_width,
                                   mode="constant",
                                   constant_values=1.0)
     t = matrix_to_tril_vec(y, diagonal=-1) / jnp.sqrt(
         matrix_to_tril_vec(z1m_cumprod_shifted, diagonal=-1))
     # inverse of tanh
     return jnp.arctanh(t)
Пример #18
0
 def __init__(self, predictor, cutpoints, validate_args=None):
     if jnp.ndim(predictor) == 0:
         predictor, = promote_shapes(predictor, shape=(1,))
     else:
         predictor = predictor[..., None]
     predictor, self.cutpoints = promote_shapes(predictor, cutpoints)
     self.predictor = predictor[..., 0]
     cumulative_probs = expit(cutpoints - predictor)
     # add two boundary points 0 and 1
     pad_width = [(0, 0)] * (jnp.ndim(cumulative_probs) - 1) + [(1, 1)]
     cumulative_probs = jnp.pad(cumulative_probs, pad_width, constant_values=(0, 1))
     probs = cumulative_probs[..., 1:] - cumulative_probs[..., :-1]
     super(OrderedLogistic, self).__init__(probs, validate_args=validate_args)
Пример #19
0
  def apply(self,
            x,
            channels,
            strides = (1, 1),
            train = True):
    """Implements the forward pass in the module.

    Args:
      x: Input to the module. Should have shape [batch_size, dim, dim, features]
        where dim is the resolution (width and height if the input is an image).
      channels: How many channels to use in the convolutional layers.
      strides: Strides for the pooling.
      train: If False, will use the moving average for batch norm statistics.

    Returns:
      The output of the resnet block. Will have shape
        [batch_size, dim, dim, channels] if strides = (1, 1) or
        [batch_size, dim/2, dim/2, channels] if strides = (2, 2).
    """

    if x.shape[-1] == channels:
      return x

    # Skip path 1
    h1 = nn.avg_pool(x, (1, 1), strides=strides, padding='VALID')
    h1 = nn.Conv(
        h1,
        channels // 2, (1, 1),
        strides=(1, 1),
        padding='SAME',
        bias=False,
        kernel_init=utils.conv_kernel_init_fn,
        name='conv_h1')

    # Skip path 2
    # The next two lines offset the "image" by one pixel on the right and one
    # down (see Shake-Shake regularization, Xavier Gastaldi for details)
    pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]]
    h2 = jnp.pad(x, pad_arr)[:, 1:, 1:, :]
    h2 = nn.avg_pool(h2, (1, 1), strides=strides, padding='VALID')
    h2 = nn.Conv(
        h2,
        channels // 2, (1, 1),
        strides=(1, 1),
        padding='SAME',
        bias=False,
        kernel_init=utils.conv_kernel_init_fn,
        name='conv_h2')
    merged_branches = jnp.concatenate([h1, h2], axis=3)
    return utils.activation(
        merged_branches, apply_relu=False, train=train, name='bn_residual')
Пример #20
0
def _gmres_incremental(A, b, x0, unit_residual, residual_norm, ptol, restart,
                       M):
    """
  Implements a single restart of GMRES. The restart-dimensional Krylov subspace
  K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the
  projection of the true solution into this subspace is returned.

  This implementation builds the QR factorization during the Arnoldi process.
  """
    # https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf

    V = tree_map(
        lambda x: jnp.pad(x[..., None], ((0, 0), ) * x.ndim +
                          ((0, restart), )),
        unit_residual,
    )
    dtype = jnp.result_type(*tree_leaves(b))
    # use eye() to avoid constructing a singular matrix in case of early
    # termination
    R = jnp.eye(restart, restart + 1, dtype=dtype)

    givens = jnp.zeros((restart, 2), dtype=dtype)
    beta_vec = jnp.zeros((restart + 1), dtype=dtype)
    beta_vec = beta_vec.at[0].set(residual_norm)

    def loop_cond(carry):
        k, err, _, _, _, _ = carry
        return jnp.logical_and(k < restart, err > ptol)

    def arnoldi_qr_step(carry):
        k, _, V, R, beta_vec, givens = carry
        V, H, _ = _kth_arnoldi_iteration(k, A, M, V, R)
        R_row, givens = _apply_givens_rotations(H[k, :], givens, k)
        R = R.at[k, :].set(R_row)
        beta_vec = _rotate_vectors(beta_vec, k, *givens[k, :])
        err = abs(beta_vec[k + 1])
        return k + 1, err, V, R, beta_vec, givens

    carry = (0, residual_norm, V, R, beta_vec, givens)
    carry = lax.while_loop(loop_cond, arnoldi_qr_step, carry)
    k, residual_norm, V, R, beta_vec, _ = carry
    del k  # Until we figure out how to pass this to the user.

    y = jsp.linalg.solve_triangular(R[:, :-1].T, beta_vec[:-1])
    dx = tree_map(lambda X: _dot(X[..., :-1], y), V)

    x = _add(x0, dx)
    residual = M(_sub(b, A(x)))
    unit_residual, residual_norm = _safe_normalize(residual)
    # TODO(shoyer): "Inner loop tolerance control" on ptol, like SciPy
    return x, unit_residual, residual_norm
    def __call__(self, patch_inputs, pixel_inputs):
        b = patch_inputs.shape[0]
        out_ch = self.out_ch or patch_inputs.shape[-1]

        x = rearrange(pixel_inputs, '... n d -> ... (n d)')
        x = nn.Dense(features=out_ch,
                     dtype=self.dtype,
                     precision=self.precision,
                     kernel_init=self.kernel_init,
                     bias_init=self.bias_init)(x)
        x = rearrange(x, '(b h w) d -> b (h w) d', b=b)
        x = jnp.pad(x, ((0, 0), (1, 0), (0, 0)))
        output = x + patch_inputs
        return output
Пример #22
0
 def make_batch(examples):
     """Stack a structure of np arrays nested in lists/tuples."""
     assert examples
     if isinstance(examples[0], (list, tuple)):
         return type(examples[0])(
             make_batch([example[i] for example in examples])
             for i in range(len(examples[0])))
     else:
         batch = np.stack(examples, axis=0)
         pad_width = ([(0, batch_size - len(examples))] + [(0, 0)] *
                      (len(batch.shape) - 1))
         # Pad with zeros. This doesn't change anything, because we have weights
         # in the examples.
         return np.pad(batch, pad_width, mode="constant")
Пример #23
0
def add_padded_translation(dataset, pixel):
    """ Modify an array of images (JAX arrays) by adding padding and random translations up to 'pixel' pixels. The
    images are rescaled to their initial size. """
    dataset_padded = np.pad(dataset,
                            ((0, 0), (pixel, pixel), (pixel, pixel), (0, 0)),
                            constant_values=dataset[0, 0, 0])
    for n in range(0, dataset.shape[0]):
        key = random.PRNGKey(n)
        temp = np.roll(dataset_padded[n],
                       random.randint(key, [2], -pixel, pixel),
                       axis=(0, 1))
        dataset = jax.ops.index_update(
            dataset, n, jax.image.resize(temp, [28, 28, 1], "cubic"))
    return dataset
Пример #24
0
    def call(self, x, params=(), **kwargs):
        assert self._padding == 'VALID'
        # Left pad with 0s. Applying an unmasked valid convolution on top of this
        # yields a causal convolution.
        # TODO(ddohan): Support strided and dilated convolutions.
        rate = 1
        effective_kernel_size = int((self._kernel_size[0] - 1) * rate + 1)
        pad = effective_kernel_size - 1
        x_leftpad = np.pad(x,
                           pad_width=[[0, 0], [pad, 0], [0, 0]],
                           mode='constant')

        res = super(CausalConv, self).call(x_leftpad, params)
        return res
Пример #25
0
    def test_cnn_2d(self):
        cnn = nets.CNN.partial(F=(3, 3), channels=[3, 2, 5], strides=[1, 1])
        _, params = cnn.init_by_shape(random.PRNGKey(0), [(4, 4)])
        cnnModel = nn.Model(cnn, params)

        S0 = jnp.array([[1, 0, 1, 1], [0, 1, 1, 1], [0, 0, 1, 0], [1, 0, 0,
                                                                   1]])
        S0 = jnp.pad(S0, [(0, 3), (0, 3)], 'wrap')
        S = jnp.array(
            [S0[i:i + 4, j:j + 4] for i in range(4) for j in range(4)])
        psiS = jax.vmap(cnnModel)(S)
        psiS = psiS - psiS[0]

        self.assertTrue(jnp.max(jnp.abs(psiS)) < 1e-12)
Пример #26
0
    def __call__(self, x, pos_emb, mask):
        dim_in, h = x.shape[-1], self.heads
        scale = dim_in**-0.5

        norm = nn.LayerNorm()
        to_qkv = nn.Dense(features=self.dim_head * h * 3, use_bias=False)
        to_out = nn.Dense(features=dim_in)

        x = norm(x)
        qkv = np.split(to_qkv(x), 3, axis=-1)
        q, k, v = map(lambda t: rearrange(t, "i (h d) -> i h d", h=h), qkv)

        q = index_update(q, index[1:], apply_rotary_pos_emb(q[1:], pos_emb))
        k = index_update(k, index[1:], apply_rotary_pos_emb(k[1:], pos_emb))

        sim = einsum("i h d, j h d -> i j h", q, k) * scale

        mask = np.pad(mask, (1, 0), constant_values=True)
        mask = rearrange(mask, "j -> () j ()")

        if self.causal:
            i, j = sim.shape[:2]
            tri_mask = np.ones((i - 1, j - 1), dtype=bool)
            tri_mask = np.pad(tri_mask, ((1, 0), (1, 0)),
                              constant_values=False)
            causal_mask = np.triu(tri_mask, j - i + 1)
            causal_mask = rearrange(causal_mask, "i j -> i j ()")
            mask = ~causal_mask * mask

        sim = np.where(mask, sim, LARGE_NEG_VALUE)

        attn = nn.softmax(sim, axis=-2)

        out = einsum("i j h, j h d -> i h d", attn, v)

        out = rearrange(out, "i h d -> i (h d)")
        return to_out(out)
Пример #27
0
    def apply(self,
              x,
              F=[
                  8,
              ],
              channels=[10],
              strides=[1],
              actFun=[nn.elu],
              bias=True,
              firstLayerBias=False):

        initFunction = partial(jax.nn.initializers.variance_scaling(
            scale=1.0, mode="fan_avg", distribution="uniform"),
                               dtype=global_defs.tReal)

        # Set up padding for periodic boundary conditions
        # Padding size must be 1 - filter diameter
        pads = [(0, 0)]
        for f in F:
            pads.append((0, f - 1))
        pads.append((0, 0))

        bias = [bias] * len(channels)
        bias[0] = firstLayerBias

        for l in range(len(actFun), len(channels)):
            actFun.append(actFun[-1])

        # List of axes that will be summed for symmetrization
        reduceDims = tuple([-i - 1 for i in range(len(strides) + 2)])

        # Add feature dimension
        #x = jnp.expand_dims(2*x-1, axis=-1)
        x = jnp.expand_dims(jnp.expand_dims(2 * x - 1, axis=0), axis=-1)
        for c, fun, b in zip(channels, actFun, bias):
            x = jnp.pad(x, pads, 'wrap')
            x = fun(
                nn.Conv(x,
                        features=c,
                        kernel_size=tuple(F),
                        strides=strides,
                        padding=[(0, 0)] * len(strides),
                        bias=b,
                        dtype=global_defs.tReal,
                        kernel_init=initFunction))

        nrm = jnp.sqrt(jnp.prod(jnp.array(x.shape[reduceDims[-1]:])))

        return jnp.sum(x, axis=reduceDims) / nrm
Пример #28
0
def shift(array, offset):
    """Shifts array by offset and pads zero on the edge.

  Args:
    array: Float numpy array with shape (num_grids,).
    offset: Integer, the offset moving to the left.

  Returns:
    Float numpy array with shape (num_grids,).
  """
    sliced = array[slice(offset, None) if offset >= 0 else slice(None, offset)]
    return jnp.pad(sliced,
                   pad_width=(-min(offset, 0), max(offset, 0)),
                   mode='constant',
                   constant_values=0)
Пример #29
0
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
    assert len(timesteps.shape) == 1  # and timesteps.dtype == tf.int32
    half_dim = embedding_dim // 2
    # magic number 10000 is from transformers
    emb = math.log(max_positions) / (half_dim - 1)
    # emb = math.log(2.) / (half_dim - 1)
    emb = jnp.exp(jnp.arange(half_dim, dtype=jnp.float32) * -emb)
    # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
    # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
    emb = timesteps[:, None] * emb[None, :]
    emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = jnp.pad(emb, [[0, 0], [0, 1]])
    assert emb.shape == (timesteps.shape[0], embedding_dim)
    return emb
Пример #30
0
 def apply(self,
           x,
           features,
           kernel_size=(3, 3),
           act_fn=nn.relu,
           num_layers=3,
           padding="constant"):
     conv2d = nn.Conv.partial(features=features,
                              kernel_size=kernel_size,
                              padding="VALID")
     padding_size = (kernel_size[0] - 1) // 2
     for i in range(num_layers - 1):
         if padding is not None:
             x = jnp.pad(x, ((0, 0), (padding_size, padding_size),
                             (padding_size, padding_size), (0, 0)),
                         mode=padding)
         x = conv2d(x, name=f"conv_{i+1}")
         x = act_fn(x)
     if padding is not None:
         x = jnp.pad(x, ((0, 0), (padding_size, padding_size),
                         (padding_size, padding_size), (0, 0)),
                     mode=padding)
     x = conv2d(x, name=f"conv_{num_layers}")
     return x