Beispiel #1
0
def beam_init(batch_size, beam_size, max_decode_len, cache, start_tokens=None):
  """Initializes the beam search state data structure."""
  cur_index0 = jnp.array(0)
  live_logprobs0 = jnp.tile(
      jnp.array([0.0] + [NEG_INF] * (beam_size - 1)),
      [batch_size, 1])
  finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF
  if start_tokens is None:
    live_seqs0 = jnp.zeros(
        (batch_size, beam_size, max_decode_len), jnp.int32)
  else:
    live_seqs0 = add_beam_dim(
        np.pad(start_tokens[:, None],
               ((0, 0), (0, max_decode_len - 1)), mode='constant'),
        beam_size)
  finished_seqs0 = jnp.zeros(
      (batch_size, beam_size, max_decode_len), jnp.int32)
  finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_)
  # add beam dimension to attention cache pytree elements
  beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache)
  return BeamState(cur_index=cur_index0,
                   live_logprobs=live_logprobs0,
                   finished_scores=finished_scores0,
                   live_seqs=live_seqs0,
                   finished_seqs=finished_seqs0,
                   finished_flags=finished_flags0,
                   cache=beam_cache0)
Beispiel #2
0
def zero_pad(x, pad, axis):
    """Helper for np.pad with 0s for single-axis case."""
    pad_widths = [(0, 0)] * len(x.shape)
    pad_widths[axis] = pad  # Padding on axis.
    return np.pad(x,
                  pad_widths,
                  mode='constant',
                  constant_values=x.dtype.type(0))
Beispiel #3
0
 def forward(self, x, weights):
   assert self._padding == 'VALID'
   # Left pad with 0s. Applying an unmasked valid convolution on top of this
   # yields a causal convolution.
   # TODO(ddohan): Support strided and dilated convolutions.
   rate = 1
   effective_kernel_size = int((self._kernel_size[0] - 1) * rate + 1)
   pad = effective_kernel_size - 1
   x_leftpad = np.pad(x, pad_width=[[0, 0], [pad, 0], [0, 0]], mode='constant')
   return super(CausalConv, self).forward(x_leftpad, weights)
Beispiel #4
0
def ShiftRight(x, n_shifts=1, mode='train', **unused_kwargs):
    """Layer to shift the tensor to the right by padding on axis 1."""
    if mode == 'predict':
        # Do nothing in predict mode, as then the sequence length is 1.
        return x

    pad_widths = [(0, 0)] * len(x.shape)
    pad_widths[1] = (n_shifts, 0)  # Padding on axis=1
    padded = np.pad(x,
                    pad_widths,
                    mode='constant',
                    constant_values=x.dtype.type(0))
    return padded[:, :-n_shifts]
Beispiel #5
0
def DiagonalGate(x):
  """Split channels in 3 parts. Shifts 1st and 3rd sections to left/right."""
  # x : [batch, 1, length, depth]
  x = np.pad(
      x, [(0, 0), (0, 0), (1, 1), (0, 0)], mode='constant', constant_values=0.0)
  depth = x.shape[-1] // 3
  assert 3 * depth == x.shape[-1], ('Depth must be divisible by 3', depth,
                                    x.shape)
  xs = [
      x[:, :, :-2, :depth], x[:, :, 1:-1, depth:2 * depth],
      x[:, :, 2:, 2 * depth:3 * depth]
  ]
  return np.concatenate(xs, axis=3)
Beispiel #6
0
 def f(x):  # pylint: disable=invalid-name
     # x : [batch, 1, length, depth]
     x = np.pad(x, [(0, 0), (0, 0), (1, 1), (0, 0)],
                mode='constant',
                constant_values=0.0)
     depth = x.shape[-1] // 3
     assert 3 * depth == x.shape[-1], ('Depth must be divisible by 3',
                                       depth, x.shape)
     xs = [
         x[:, :, :-2, :depth], x[:, :, 1:-1, depth:2 * depth],
         x[:, :, 2:, 2 * depth:3 * depth]
     ]
     return np.concatenate(xs, axis=3)
Beispiel #7
0
def serialize_observations_and_actions(  # pylint: disable=invalid-name
    observations,
    actions,
    observation_serializer,
    action_serializer,
    representation_length,
):
    """Serializes observations and actions into a discrete sequence.

  Args:
    observations: Array (B, T + 1, ...), of observations, where B is the batch
      size and T is the number of timesteps excluding the last observation.
    actions: Array (B, T, ...) of actions.
    observation_serializer: SpaceSerializer for observations.
    action_serializer: SpaceSerializer for actions.
    representation_length: Number of symbols in the serialized sequence. The
      sequence is padded up to this number.
  Returns:
    Serialized sequence of shape (B, R) where R = representation_length.
  """
    (batch_size, n_timesteps) = actions.shape[:2]
    assert observations.shape[:2] == (batch_size, n_timesteps + 1)

    serialization = tl.Serial([
        tl.Parallel(
            Serialize(serializer=observation_serializer),  # pylint: disable=no-value-for-parameter
            Serialize(serializer=action_serializer),  # pylint: disable=no-value-for-parameter
        ),
        Interleave(),  # pylint: disable=no-value-for-parameter
    ])
    serialization.init(shapes.signature((observations, actions)))
    reprs = serialization((observations, actions))

    assert reprs.shape[1] <= representation_length
    return np.pad(
        reprs,
        pad_width=((0, 0), (0, representation_length - reprs.shape[1])),
        mode='constant',
    )
Beispiel #8
0
  def _get_initial_state(self, inputs, targets_prefix, batch_size):
    """Get initial state for beam search."""
    if targets_prefix is None:
      prompt = np.zeros((batch_size, 1), dtype=np.int32)
    else:
      prompt = np.pad(
          targets_prefix[:, :-1], ((0, 0), (1, 0)), mode='constant')

    # Get state prior to running the encoder or incorporating targets_prefix
    if inputs is None:
      signature = ShapeDtype((batch_size, 1), prompt.dtype)
    else:
      signature = (ShapeDtype(inputs.shape, inputs.dtype),
                   ShapeDtype((batch_size, 1), prompt.dtype))
    # Trax's model.init is stateful as opposed to functional. Calling it on an
    # already-existing model instance doesn't work.
    # TODO(lukaszkaiser): add purely functional init to Trax.
    _, initial_state = self.model(mode='predict').init(signature)

    # Incorporate encoder and prompt into state
    _, prompted_state = self.model_infer.pure_fn(
        prompt if inputs is None else (inputs, prompt),
        self.model_weights,
        initial_state,
        jax.random.PRNGKey(0))
    state_structure = jax.tree_structure(prompted_state)

    if targets_prefix is not None:
      initial_state = prompted_state
    elif self.encoder_idx is not None:
      initial_state = (tuple(prompted_state[:self.encoder_idx])
                       + tuple(initial_state[self.encoder_idx:]))

    # Fix tree structure of the state (there's a tuple vs. list mismatch)
    initial_state = jax.tree_unflatten(
        state_structure, jax.tree_leaves(initial_state))

    return initial_state
Beispiel #9
0
 def PadRight(x, n_to_pad, **unused_kwargs):
     pad_widths = [(0, 0), (0, n_to_pad)] + [(0, 0)] * (x.ndim - 2)
     return np.pad(x,
                   pad_widths,
                   mode='constant',
                   constant_values=x.dtype.type(0))
Beispiel #10
0
 def pad_right(x):
     pad_widths = [(0, 0), (0, n_to_pad)] + [(0, 0)] * (x.ndim - 2)
     return jnp.pad(x,
                    pad_widths,
                    mode='constant',
                    constant_values=x.dtype.type(0))