Beispiel #1
0
def LocallyConvDense(n_modules, n_units, kernel_size=1, length_kernel_size=1):
    """Layer using local convolutions for approximation of Dense layer.

  The layer splits the last axis of a tensor into `n_modules`, then runs
  a convolution on all those modules, and concatenates their results.
  It is similar to LocallyConnectedDense above, but shares weights.

  Args:
    n_modules: Indicates how many modules (pixels) should be input and output
        split into for processing.
    n_units: how many outputs (filters) should each module generate.
    kernel_size: The size of the kernel to be used.
    length_kernel_size: If > 1, also do causal convolution on the previous axis,
      which is often the sentence length in sequence models.

  Returns:
      LocallyConvDense base.Layer.
  """
    if n_modules == 1:
        return tl.Dense(n_units)
    if kernel_size % 2 != 1:
        raise ValueError('Currently we only handle odd kernel sizes.')
    half = (kernel_size - 1) // 2
    pad_widths = [[0, 0], [length_kernel_size - 1, 0], [half, half], [0, 0]]
    return tl.Serial(
        tl.SplitLastAxis(n_modules),
        tl.Fn('Pad', lambda x: jnp.pad(x, pad_width=pad_widths)),
        tl.Conv(n_units, kernel_size=(length_kernel_size, kernel_size)),
        tl.MergeLastTwoAxes())
Beispiel #2
0
 def pad(z):
     pad_widths = [(0, 0)] * len(z.shape)
     pad_widths[0] = (0, self._n_devices - remainder)
     return jnp.pad(z,
                    pad_widths,
                    mode='constant',
                    constant_values=z.dtype.type(0))
Beispiel #3
0
def beam_init(batch_size, beam_size, max_decode_len, cache, start_tokens=None):
  """Initializes the beam search state data structure."""
  cur_index0 = jnp.array(0)
  live_logprobs0 = jnp.tile(
      jnp.array([0.0] + [NEG_INF] * (beam_size - 1)),
      [batch_size, 1])
  finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF
  if start_tokens is None:
    live_seqs0 = jnp.zeros(
        (batch_size, beam_size, max_decode_len), jnp.int32)
  else:
    live_seqs0 = add_beam_dim(
        np.pad(start_tokens[:, None],
               ((0, 0), (0, max_decode_len - 1)), mode='constant'),
        beam_size)
  finished_seqs0 = jnp.zeros(
      (batch_size, beam_size, max_decode_len), jnp.int32)
  finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_)
  # add beam dimension to attention cache pytree elements
  beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache)
  return BeamState(cur_index=cur_index0,
                   live_logprobs=live_logprobs0,
                   finished_scores=finished_scores0,
                   live_seqs=live_seqs0,
                   finished_seqs=finished_seqs0,
                   finished_flags=finished_flags0,
                   cache=beam_cache0)
Beispiel #4
0
 def shift_right(x):
     pad_widths = [(0, 0)] * len(x.shape)
     pad_widths[1] = (1, 0)
     padded = jnp.pad(x,
                      pad_widths,
                      mode='constant',
                      constant_values=x.dtype.type(cls_id))
     return padded[:, :-1]
Beispiel #5
0
def _zero_pad(x, pad, axis):
    """Helper for jnp.pad with 0s for single-axis case."""
    pad_widths = [(0, 0)] * len(x.shape)
    pad_widths[axis] = pad  # Padding on axis.
    return jnp.pad(x,
                   pad_widths,
                   mode='constant',
                   constant_values=x.dtype.type(0))
Beispiel #6
0
 def forward(self, x):
   assert self._padding == 'VALID'
   # Left pad with 0s. Applying an unmasked valid convolution on top of this
   # yields a causal convolution.
   # TODO(ddohan): Support strided and dilated convolutions.
   rate = 1
   effective_kernel_size = int((self._kernel_size[0] - 1) * rate + 1)
   pad = effective_kernel_size - 1
   x_leftpad = (
       jnp.pad(x, pad_width=[[0, 0], [pad, 0], [0, 0]], mode='constant'))
   return super().forward(x_leftpad)
Beispiel #7
0
 def f(x):  # pylint: disable=invalid-name
   # x : [batch, 1, length, depth]
   x = jnp.pad(x, [(0, 0), (0, 0), (1, 1), (0, 0)],
               mode='constant', constant_values=0.0)
   depth = x.shape[-1] // 3
   assert 3 * depth == x.shape[-1], ('Depth must be divisible by 3', depth,
                                     x.shape)
   xs = [
       x[:, :, :-2, :depth], x[:, :, 1:-1, depth:2 * depth],
       x[:, :, 2:, 2 * depth:3 * depth]
   ]
   return jnp.concatenate(xs, axis=3)
Beispiel #8
0
 def pure_fn(self, x, weights, state, rng, use_cache=False):
     """Calls self.sublayer.pure_fn in an accelerated way."""
     # Check if we can divide x evenly across devices.
     remainder = x.shape[0] % self._n_devices
     if remainder == 0:  # If yes, run the accelerated sublayer.pure_fn.
         return self._jit_pure_fn(x, weights, state, rng)
     # If not, pad first.
     pad_widths = [(0, 0)] * len(x.shape)
     pad_widths[0] = (0, self._n_devices - remainder)
     padded_x = jnp.pad(x,
                        pad_widths,
                        mode='constant',
                        constant_values=x.dtype.type(0))
     # Run and un-pad.
     padded_y, state = self._jit_pure_fn(padded_x, weights, state, rng)
     return padded_y[:x.shape[0]], state
Beispiel #9
0
  def _get_initial_state(self, inputs, targets_prefix, batch_size):
    """Get initial state for beam search."""
    if targets_prefix is None:
      prompt = np.zeros((batch_size, 1), dtype=np.int32)
    else:
      prompt = np.pad(
          targets_prefix[:, :-1], ((0, 0), (1, 0)), mode='constant')

    # Get state prior to running the encoder or incorporating targets_prefix
    if inputs is None:
      signature = ShapeDtype((batch_size, 1), prompt.dtype)
    else:
      signature = (ShapeDtype(inputs.shape, inputs.dtype),
                   ShapeDtype((batch_size, 1), prompt.dtype))
    # Trax's model.init is stateful as opposed to functional. Calling it on an
    # already-existing model instance doesn't work.
    # TODO(lukaszkaiser): add purely functional init to Trax.
    _, initial_state = self.model(mode='predict').init(signature)

    # Incorporate encoder and prompt into state
    _, prompted_state = self.model_infer.pure_fn(
        prompt if inputs is None else (inputs, prompt),
        self.model_weights,
        initial_state,
        jax.random.PRNGKey(0))
    state_structure = jax.tree_structure(prompted_state)

    if targets_prefix is not None:
      initial_state = prompted_state
    elif self.encoder_idx is not None:
      initial_state = (tuple(prompted_state[:self.encoder_idx])
                       + tuple(initial_state[self.encoder_idx:]))

    # Fix tree structure of the state (there's a tuple vs. list mismatch)
    initial_state = jax.tree_unflatten(
        state_structure, trax.fastmath.tree_leaves(initial_state))

    return initial_state
Beispiel #10
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input, except the final dimension
      is the layer's `filters` value, and the second to last dimension is
      shrinked if 'VALID' padding is used with kernel_size bigger than one.
    """
        if self._use_bias:
            if not isinstance(self.weights, (tuple, list)):
                raise ValueError(f'Weights should be a (w, b) tuple or list; '
                                 f'instead got: {self.weights}')
            w, b = self.weights
        else:
            w = self.weights

        linear_results_before_shifting = jnp.einsum('...lp,lkpd->...lkd', x, w)
        # TODO(jaszczur): this could be run after padding for better efficiency

        if self._kernel_size == 1:
            # With kernel size 1 we don't have to split or shift anything.
            linear_result = jnp.squeeze(linear_results_before_shifting,
                                        axis=-2)
        else:
            # We computed a result for every "pixel", but each direction from the
            # receptive field (there are 'self._kernel_size' such directions) must be
            # shifted by a different amount. The easiest way to do it is to split
            # the tensor to 'self._kernel_size' smaller tensors, shift each one
            # appropriately, and then sum them together.
            split_shifting_linear_results = jnp.split(
                linear_results_before_shifting, self._kernel_size, axis=-2)

            for i in range(self._kernel_size):
                # Each tensor has to be shifted a different amount.
                if self._padding == 'WRAP':
                    # We can shift by padding and cutting. With 'wrap' padding we
                    # essentially have a torus.
                    padding = [(0, 0)
                               for i in split_shifting_linear_results[i].shape]
                    padding[-3] = ((self._kernel_size - 1) - i, i)
                    split_shifting_linear_results[i] = jnp.pad(
                        split_shifting_linear_results[i], padding, mode='wrap')
                    split_shifting_linear_results[
                        i] = split_shifting_linear_results[i][
                            ..., (self._kernel_size - 1) //
                            2:-(self._kernel_size - 1) // 2, :, :]
                elif self._padding == 'SAME':
                    # We can shift by padding and cutting.
                    padding = [(0, 0)
                               for i in split_shifting_linear_results[i].shape]
                    padding[-3] = ((self._kernel_size - 1) - i, i)
                    split_shifting_linear_results[i] = jnp.pad(
                        split_shifting_linear_results[i], padding)
                    split_shifting_linear_results[
                        i] = split_shifting_linear_results[i][
                            ..., (self._kernel_size - 1) //
                            2:-(self._kernel_size - 1) // 2, :, :]
                    # TODO(jaszczur): improve efficiency by not padding things to cut
                elif self._padding == 'VALID':
                    # We don't need to shift - just cut the leftmost and rightmost values.
                    cut_left = (self._kernel_size - 1) - i
                    cut_right = split_shifting_linear_results[i].shape[-3] - i
                    split_shifting_linear_results[
                        i] = split_shifting_linear_results[i][
                            ..., cut_left:cut_right, :, :]
                else:
                    raise ValueError(f'Invalid padding {self._padding}')
            # After shifting.
            shifted_linear_results = jnp.concatenate(
                split_shifting_linear_results, axis=-2)
            linear_result = jnp.sum(shifted_linear_results, axis=-2)

        if self._use_bias:
            return linear_result + b
        else:
            return linear_result
Beispiel #11
0
def _zero_pad(x, pad, axis):  # pylint: disable = invalid-name
    """Helper for jnp.pad with 0s for single-axis case."""
    pad_widths = [(0, 0)] * len(x.shape)
    pad_widths[axis] = pad  # Padding on axis.
    return jnp.pad(x, pad_widths, mode='constant')
Beispiel #12
0
 def pad_right(x):
   pad_widths = [(0, 0), (0, n_to_pad)] + [(0, 0)] * (x.ndim - 2)
   return jnp.pad(
       x, pad_widths, mode='constant', constant_values=x.dtype.type(0))
Beispiel #13
0
 def pad_to_chunk_len(v):
   width = [(0, 0)] * v.ndim
   width[2] = (0, chunk_len - v.shape[2])
   return jnp.pad(v, width, mode='constant', constant_values=0.0)