Exemple #1
0
def PerformPositionOperations(pos, positions=None):
    """Gets pos and returns (q1, ..., q5)."""
    succ_keys = positions[:-1, :]
    succ_values = positions[1:, :]
    subtract_1_keys = positions[1:, :]
    subtract_1_values = positions[:-1, :]
    l = int(positions.shape[0]) // 2
    add_keys = np.array([
        np.concatenate([positions[i, :], positions[j, :]]) for i in range(l)
        for j in range(l)
    ])
    add_values = np.array(
        [positions[i + j, :] for i in range(l) for j in range(l)])
    # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)"
    sub_keys = np.array([
        np.concatenate([positions[i, :], positions[j, :]]) for j in range(l)
        for i in range(l)
    ])
    sub_values = np.array(
        [positions[max(i - j, 0), :] for j in range(l) for i in range(l)])
    query_types = [
        QueryPositionKV(),
        QueryPositionKV(keys=succ_keys, values=succ_values),
        QueryPositionKV(keys=subtract_1_keys, values=subtract_1_values),
        QueryPositionKV(keys=add_keys, values=add_values, binary=True),
        QueryPositionKV(keys=sub_keys, values=sub_values, binary=True)
    ]
    return [qt @ pos for qt in query_types]  # pylint: disable=syntax-error
Exemple #2
0
 def forward_with_state(self, x, weights, state, rng):
     batch_size, length = x.shape[0], x.shape[1]
     max_pos = min(self._bases)**self._n_digits
     rng1, rng2, rng3 = math.random.split(rng, 3)
     assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length,
                                                               max_pos)
     positions = jnp.arange(0, length)[None, :]
     if self._mode == 'train':
         # In 1% of training cases still start from 0 to be exactly as in eval.
         start_from_nonzero = jax.random.randint(
             rng1, (batch_size, ), 0, self._start_from_zero_one_in)
         start_from_nonzero = jnp.minimum(1, start_from_nonzero)
         random_start = jax.random.randint(rng2, (batch_size, ), 0,
                                           max_pos - length)
         random_start *= start_from_nonzero
         positions += random_start[:, None]
     res = []
     for bn, base in enumerate(self._bases):
         pos_embeddings = []
         cur_positions = positions
         for i in range(self._n_digits):
             cur_indices = jnp.mod(cur_positions, base)
             cur_positions = cur_positions // base
             s = weights[bn][i]
             pos_embeddings.append(
                 cur_indices.astype(jnp.float32)[:, :, None] * s)
         embeddings = jnp.concatenate(pos_embeddings, axis=-1)
         if self._mode == 'train':
             base_dropout = jax.random.randint(rng3, (batch_size, ), 0,
                                               self._base_dropout_one_in)
             base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32)
             embeddings *= base_dropout[:, None, None]
         res.append(embeddings)
     res = sum(res) + jnp.zeros_like(x)
     return jnp.concatenate([x, res], axis=-1), state
Exemple #3
0
def Interleave(inputs, **unused_kwargs):
    """Interleaves and flattens two serialized sequences.

  The first sequence can be longer by 1 than the second one. This is so we can
  interleave sequences of observations and actions, when there's 1 extra
  observation at the end.

  For serialized sequences [[x_1_1, ..., x_1_R1], ..., [x_L1_1, ..., x_L1_R1]]
  and [[y_1_1, ..., y_1_R2], ..., [y_L2_1, ..., y_L2_R2]], where L1 = L2 + 1,
  the result is [x_1_1, ..., x_1_R1, y_1_1, ..., y_1_R2, ..., x_L2_1, ...,
  x_L2_R1, y_L2_1, ..., y_L2_R2, x_L1_1, ..., x_L1_R1] (batch dimension omitted
  for clarity).

  Args:
    inputs: Pair of sequences of shapes (B, L1, R1) and (B, L2, R2), where B
      is batch size, L* is the length of the sequence and R* is the
      representation length of each element in the sequence.

  Returns:
    Interleaved sequence of shape (B, L1 * R1 + L2 * R2).
  """
    (x, y) = inputs
    (batch_size, _, _) = x.shape
    (_, length, _) = y.shape
    assert x.shape[1] in (length, length + 1)

    reprs = np.concatenate((x[:, :length], y), axis=2)
    reprs = np.reshape(reprs, (batch_size, -1))
    remainder = np.reshape(x[:, length:], (batch_size, -1))
    return np.concatenate((reprs, remainder), axis=1)
    def interleave(x, y):
        (batch_size, _, _) = x.shape
        (_, length, _) = y.shape
        assert x.shape[1] in (length, length + 1)

        reprs = jnp.concatenate((x[:, :length], y), axis=2)
        reprs = jnp.reshape(reprs, (batch_size, -1))
        remainder = jnp.reshape(x[:, length:], (batch_size, -1))
        return jnp.concatenate((reprs, remainder), axis=1)
Exemple #5
0
    def reverse(self, output, weights=(), state=(), new_state=(), **kwargs):
        del weights, kwargs

        x1_split = []
        x2_split = []
        for y in output:
            y1, y2 = np.split(y, 2, -1)
            x1_split.append(y1)
            x2_split.append(y2)

        x1 = np.concatenate(x1_split, self._axis)
        x2 = np.concatenate(x2_split, self._axis)

        return (x1, x2)
Exemple #6
0
    def forward(self, inputs, weights):
        x, gru_state = inputs

        # Dense layer on the concatenation of x and h.
        w1, b1, w2, b2 = weights
        y = jnp.dot(jnp.concatenate([x, gru_state], axis=-1), w1) + b1

        # Update and reset gates.
        u, r = jnp.split(math.sigmoid(y), 2, axis=-1)

        # Candidate.
        c = jnp.dot(jnp.concatenate([x, r * gru_state], axis=-1), w2) + b2

        new_gru_state = u * gru_state + (1 - u) * jnp.tanh(c)
        return new_gru_state, new_gru_state
Exemple #7
0
    def forward_with_state(self,
                           inputs,
                           weights=layer_base.EMPTY_WEIGHTS,
                           state=layer_base.EMPTY_STATE,
                           rng=None,
                           **kwargs):
        depth = inputs.shape[-1]

        if self._mode == 'predict':
            emb = self._get_embeddings(t=state)
            emb = emb[:, np.newaxis, :]
            state = state + 1
        else:
            input_len = inputs.shape[-2]
            emb = self._get_embeddings(t=np.arange(input_len, dtype=np.int32))
            # Leave batch axis as 1 for broadcasting:
            emb = emb[np.newaxis, :, :]
            emb = np.broadcast_to(emb, inputs.shape[:-1] + (3, ))

        # Replace the last num_features channels of input.
        inputs = np.concatenate([inputs[..., :-self.num_features], emb], -1)
        if inputs.shape[-1] > depth:
            logging.warning('dropping feature(s): %d down to %d',
                            inputs.shape[-1], depth)
            inputs = inputs[..., -depth:]

        assert inputs.shape[-1] == depth, inputs.shape
        return inputs, state
Exemple #8
0
 def test_weights_state(self):
   layer = base.Fn(
       '2in2out',
       lambda x, y: (x + y, jnp.concatenate([x, y], axis=0)), n_out=2)
   weights, state = layer.new_weights_and_state(None)
   self.assertEmpty(weights)
   self.assertEmpty(state)
Exemple #9
0
    def forward_with_state(self,
                           inputs,
                           weights=base.EMPTY_WEIGHTS,
                           state=base.EMPTY_STATE,
                           rng=None,
                           **kwargs):
        embs = []
        for ax_emb in weights:
            ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) +
                                     self._shape + (ax_emb.shape[-1], ))
            embs.append(ax_emb)
        emb = np.concatenate(embs, -1)

        if self._mode == 'predict':
            assert self._dropout == 0.0
            emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1]))
            return inputs + emb[:, state, :][:, None, :], state + 1
        elif self._dropout == 0:
            return inputs + np.reshape(emb, inputs.shape), state
        else:
            noise_shape = list(emb.shape)
            for dim in self._dropout_broadcast_dims:
                noise_shape[dim] = 1
            keep_prob = 1.0 - self._dropout
            if math.backend_name() == 'jax':
                keep_prob = jax.lax.tie_in(
                    inputs, np.full((), keep_prob, dtype=inputs.dtype))
            keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape))
            multiplier = keep.astype(inputs.dtype) / keep_prob

            return inputs + np.reshape(emb * multiplier, inputs.shape), state
Exemple #10
0
def CombineHeadsPos(x, n_heads=1, **unused_kwargs):
    """Mix x = (x0, p0, ..., xH, pH) into (x0, ...., xH), p_combined.

  The positions are averaged as vectors.

  Args:
    x: input vector, concatenated (x0, p0, ..., xH, pH).
    n_heads: number of heads.

  Returns:
    the vector with combined xs and one with combined positions.
  """
    seqlen = x.shape[1]
    d_head = x.shape[2]
    x = np.reshape(x, (-1, n_heads, seqlen, d_head))
    x = np.transpose(x, (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
    x = np.reshape(x, (-1, seqlen, n_heads * d_head))
    head_size = int(d_head) - POS_VECTOR_SIZE
    res, positions, idx = [], [], 0
    for _ in range(n_heads):
        res.append(x[:, :, idx:idx + head_size])
        idx += head_size
        positions.append(x[:, :, idx:idx + POS_VECTOR_SIZE])
        idx += POS_VECTOR_SIZE
    combined_position = sum(positions) / float(len(positions))
    return np.concatenate(res, axis=-1), combined_position
Exemple #11
0
    def forward(self, inputs, weights):
        state = self.state
        depth = inputs.shape[-1]

        if self._mode == 'predict':
            emb = self._get_embeddings(t=state)
            emb = emb[:, jnp.newaxis, :]
            state = state + 1
        else:
            input_len = inputs.shape[-2]
            emb = self._get_embeddings(
                t=jnp.arange(input_len, dtype=jnp.int32))
            # Leave batch axis as 1 for broadcasting:
            emb = emb[jnp.newaxis, :, :]
            emb = jnp.broadcast_to(emb, inputs.shape[:-1] + (3, ))

        # Replace the last num_features channels of input.
        inputs = jnp.concatenate([inputs[..., :-self.num_features], emb], -1)
        if inputs.shape[-1] > depth:
            logging.warning('dropping feature(s): %d down to %d',
                            inputs.shape[-1], depth)
            inputs = inputs[..., -depth:]

        assert inputs.shape[-1] == depth, inputs.shape
        self.state = state
        return inputs
Exemple #12
0
  def reverse(self, output, weights=(), state=(), new_state=(), **kwargs):
    del weights, kwargs
    if not isinstance(output, (list, tuple)):
      output = [output]

    x1_split = []
    x2_split = []
    for y in output:
      y1, y2 = np.split(y, 2, -1)
      x1_split.append(y1)
      x2_split.append(y2)

    x1 = np.concatenate(x1_split, self._axis)
    x2 = np.concatenate(x2_split, self._axis)

    return (x1, x2)
Exemple #13
0
    def forward(self, inputs, weights):
        x, lstm_state = inputs

        # LSTM state consists of c and h.
        c, h = jnp.split(lstm_state, 2, axis=-1)

        # Dense layer on the concatenation of x and h.
        w, b = weights
        y = jnp.dot(jnp.concatenate([x, h], axis=-1), w) + b

        # i = input_gate, j = new_input, f = forget_gate, o = output_gate
        i, j, f, o = jnp.split(y, 4, axis=-1)

        new_c = c * math.sigmoid(f) + math.sigmoid(i) * jnp.tanh(j)
        new_h = jnp.tanh(new_c) * math.sigmoid(o)
        return new_h, jnp.concatenate([new_c, new_h], axis=-1)
Exemple #14
0
    def hash_vectors(self, vecs, rng):
        # See https://arxiv.org/pdf/1509.02897.pdf
        # We sample a different random rotation for each round of hashing to
        # decrease the probability of hash misses.
        if isinstance(self.n_buckets, int):
            assert self.n_buckets % 2 == 0
            rot_size = self.n_buckets
            n_buckets = self.n_buckets
        else:
            # Factorize the hash if self.n_buckets is a list or tuple
            rot_size, n_buckets = 0, 1
            for factor in self.n_buckets:
                assert factor % 2 == 0
                rot_size += factor
                n_buckets *= factor

        rotations_shape = (vecs.shape[-1], self.n_hashes, rot_size // 2)

        rng = jax.lax.stop_gradient(jax.lax.tie_in(vecs, rng))
        random_rotations = jax.random.normal(rng,
                                             rotations_shape).astype('float32')
        rotated_vecs = np.einsum('tf,fhb->htb', vecs, random_rotations)

        if isinstance(self.n_buckets, int) or len(self.n_buckets) == 1:
            rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs],
                                          axis=-1)
            buckets = np.argmax(rotated_vecs, axis=-1)
        else:
            # Get the buckets for them and combine.
            buckets, cur_sum, cur_product = None, 0, 1
            for factor in self.n_buckets:
                rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
                cur_sum += factor // 2
                rv = np.concatenate([rv, -rv], axis=-1)
                if buckets is None:
                    buckets = np.argmax(rv, axis=-1)
                else:
                    buckets += cur_product * np.argmax(rv, axis=-1)
                cur_product *= factor

        # buckets is now (self.n_hashes, seqlen). Next we add offsets so that
        # bucket numbers from different hashing rounds don't overlap.
        offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes))
        offsets = np.reshape(offsets * n_buckets, (-1, 1))
        buckets = np.reshape(buckets + offsets, (-1, ))

        return buckets
 def forward(self, inp, weights):
   """Reshape input to have heads dimension and concatenate positions there."""
   x = inp[0]
   n_batches, seqlen = x.shape[0], x.shape[1]
   d_head = x.shape[-1] // self._n_heads
   res = np.reshape(x, (n_batches, seqlen, self._n_heads, d_head))
   res = np.transpose(res, (0, 2, 1, 3))  # (batch, heads, len, depth)
   if self._n_pos == 1:  # Just one position given, tile into each head.
     pos_shape = list(res.shape)[:-1] + [inp[1].shape[-1]]
     pos = inp[1][:, None, :, :] + np.zeros(pos_shape)  # Add 0 to broadcast.
   else:  # As many positions as heads, concatenate them in.
     pos = [p[:, None, :, :] for p in inp[1:]]
     pos = np.concatenate(pos, axis=1)
   res = np.concatenate([res, pos], axis=-1)
   # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
   res = np.reshape(res, (-1, seqlen, d_head + POS_VECTOR_SIZE))
   return res
Exemple #16
0
    def forward(self, inputs, weights):
        del weights
        x1, x2 = inputs

        x1_split = np.split(x1, self._n_sections, self._axis)
        x2_split = np.split(x2, self._n_sections, self._axis)

        res = [np.concatenate(ys, -1) for ys in zip(x1_split, x2_split)]
        return tuple(res)
Exemple #17
0
 def _UpdateRow(x):
     # row_e - (L1, H), row_d - (L2, H), row_mask_e - (L1,)
     row_e, row_d, row_mask_e = x
     # final_row - (L1+L2, H)
     final_row = jnp.concatenate([row_e, jnp.zeros_like(row_d)], axis=0)
     # Find the last real token/vector of the encoder.
     e_idx = jnp.sum(row_mask_e, dtype=jnp.int32)
     # Starting after that index, update with the decoder row.
     return jax.lax.dynamic_update_slice(final_row, row_d, (e_idx, 0))
Exemple #18
0
 def test_fn_layer_weights_state(self):
     layer = Fn('2in2out',
                lambda x, y: (x + y, np.concatenate([x, y], axis=0)),
                n_out=2)
     input_signature = None
     weights, state = layer.new_weights_and_state(input_signature)
     self.assertIsNotNone(weights)
     self.assertIsNotNone(state)
     self.assertEmpty(weights)
     self.assertEmpty(state)
Exemple #19
0
    def forward_with_state(self,
                           inputs,
                           weights=base.EMPTY_WEIGHTS,
                           state=base.EMPTY_STATE,
                           rng=None,
                           **kwargs):
        embs = []
        for ax_emb in weights:
            ax_emb = np.broadcast_to(ax_emb, (inputs.shape[0], ) +
                                     self._shape + (ax_emb.shape[-1], ))
            embs.append(ax_emb)

        if self._mode == 'predict':
            assert self._dropout == 0.0
            emb = np.concatenate(embs, -1)
            emb = np.reshape(emb, (inputs.shape[0], -1, emb.shape[-1]))
            emb = jax.lax.dynamic_slice_in_dim(emb,
                                               state,
                                               inputs.shape[1],
                                               axis=1)
            return inputs + emb, state + inputs.shape[1]
        elif self._dropout == 0:
            # TODO(kitaev): concat-then-reshape (as is the case with dropout enabled)
            # leads to memory blow-up on TPU.
            # emb = np.concatenate(embs, -1)
            # return inputs + np.reshape(emb, inputs.shape), state
            return inputs + np.concatenate([
                np.reshape(emb, inputs.shape[:-1] + (emb.shape[-1], ))
                for emb in embs
            ], -1), state
        else:
            emb = np.concatenate(embs, -1)
            noise_shape = list(emb.shape)
            for dim in self._dropout_broadcast_dims:
                noise_shape[dim] = 1
            keep_prob = 1.0 - self._dropout
            if math.backend_name() == 'jax':
                keep_prob = jax.lax.tie_in(
                    inputs, np.full((), keep_prob, dtype=inputs.dtype))
            keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape))
            multiplier = keep.astype(inputs.dtype) / keep_prob

            return inputs + np.reshape(emb * multiplier, inputs.shape), state
Exemple #20
0
def QueryPositionKV(x, keys=None, values=None, binary=False, **unused_kwargs):
    """Query a table with a position vector."""
    if keys is None:
        return x
    k = np.array(keys)
    v = np.array(values)
    q = x
    if binary:
        q = np.concatenate([x, x], axis=-1)
    return tl.DotProductAttention(q, k, v, None, 0.0, None, None)
Exemple #21
0
 def test_fn_layer_example(self):
     layer = base.Fn(lambda x, y: (x + y, np.concatenate([x, y], axis=0)))
     input_signature = (ShapeDtype((2, 7)), ShapeDtype((2, 7)))
     expected_shape = ((2, 7), (4, 7))
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
     inp = (np.array([2]), np.array([3]))
     x, xs = layer(inp)
     self.assertEqual(int(x), 5)
     self.assertEqual([int(y) for y in xs], [2, 3])
Exemple #22
0
 def test_fn_layer_difficult_n_out(self):
     with self.assertRaisesRegex(ValueError, 'n_out'):
         # Determining the output of this layer is hard with dummies.
         base.Fn(lambda x: np.concatencate([x, x], axis=4))
     # Check that this layer works when n_out is set.
     layer = base.Fn(lambda x: np.concatenate([x, x], axis=4), n_out=1)
     input_signature = ShapeDtype((2, 1, 2, 2, 3))
     expected_shape = (2, 1, 2, 2, 6)
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
Exemple #23
0
    def _get_embeddings(self, lo: int, hi: int, depth, rng=None):
        """Get embeddings float[length, depth].

    Args:
      lo: where to start sampling
      hi: where to stop sampling
      depth: embedding depth
      rng: rng for random phase

    Returns:
      embeddings: float[length, depth]
    """
        noise = self._get_noise(lo, hi, (depth + 1) // 2)
        # Make the stddev around 1 after 1/drift.
        noise = noise * self._drift**.5

        t, c = onp.mgrid[lo:hi, :depth]
        # Make even channels cos, odd channels sin:
        c_div_2, c_mod_2 = divmod(c, 2)
        # Off-by-one correction for odd depth:
        drift = self._drift
        if depth > 2:
            drift = drift**(((depth + 1) // 2) / (depth // 2))
        # Spend roughly half the frequencies on noise:
        freq = np.geomspace(.5, .5 * drift**2, num=(depth + 1) // 2)[c_div_2]
        cycles = c_mod_2 / 4 + freq * t + noise[:, c_div_2[0, :]] / 4
        assert cycles.shape == (hi - lo, depth), cycles.shape

        # Get random phases:
        if self._affine:
            assert rng is not None
            cycles = cycles + trax.math.random.uniform(
                rng, (
                    1,
                    depth,
                ), minval=0, maxval=1)

        # Convert from cycles to radians:
        embeddings = np.cos(np.pi * 2 * cycles)

        # Set the last channels to the time bin features:
        if self._time_bin_length is not None:
            inter_bin_idx, intra_bin_idx = divmod(t[:, -1:],
                                                  self._time_bin_length)
            bin_parity = inter_bin_idx % 2
            bin_fraction = intra_bin_idx / self._time_bin_length
            embeddings = np.concatenate([
                embeddings[:, :-3],
                1 / (1 + inter_bin_idx),
                bin_fraction,
                bin_parity.astype(np.float32),
            ], -1)

        assert embeddings.shape == (hi - lo, depth), embeddings.shape
        return embeddings
Exemple #24
0
def look_adjacent(x, n_chunks_before, n_chunks_after):
    """Used to implement attention between consecutive chunks.

  Args:
    x: array of shape [n_chunks, chunk_len, ...]
    n_chunks_before: Number of previous chunks to attend to.
    n_chunks_after: Number of subsequent chunks to attend to.
  Returns:
    array of shape [n_chunks, N * chunk_len, ...], where
    N = (1 + n_chunks_before + n_chunks_after).
  """
    if n_chunks_before == 0 and n_chunks_after == 0:
        return x

    slices = []
    for i in range(-n_chunks_before, n_chunks_after + 1):
        if i == 0:
            slices.append(x)
        else:
            slices.append(np.concatenate([x[i:, ...], x[:i, ...]], axis=0))
    return np.concatenate(slices, axis=1)
Exemple #25
0
 def f(x):  # pylint: disable=invalid-name
     # x : [batch, 1, length, depth]
     x = np.pad(x, [(0, 0), (0, 0), (1, 1), (0, 0)],
                mode='constant',
                constant_values=0.0)
     depth = x.shape[-1] // 3
     assert 3 * depth == x.shape[-1], ('Depth must be divisible by 3',
                                       depth, x.shape)
     xs = [
         x[:, :, :-2, :depth], x[:, :, 1:-1, depth:2 * depth],
         x[:, :, 2:, 2 * depth:3 * depth]
     ]
     return np.concatenate(xs, axis=3)
Exemple #26
0
def DiagonalGate(x):
  """Split channels in 3 parts. Shifts 1st and 3rd sections to left/right."""
  # x : [batch, 1, length, depth]
  x = np.pad(
      x, [(0, 0), (0, 0), (1, 1), (0, 0)], mode='constant', constant_values=0.0)
  depth = x.shape[-1] // 3
  assert 3 * depth == x.shape[-1], ('Depth must be divisible by 3', depth,
                                    x.shape)
  xs = [
      x[:, :, :-2, :depth], x[:, :, 1:-1, depth:2 * depth],
      x[:, :, 2:, 2 * depth:3 * depth]
  ]
  return np.concatenate(xs, axis=3)
Exemple #27
0
def Deinterleave(inputs, x_size, y_size, **unused_kwargs):
    """Inverse of Interleave."""
    reprs = inputs
    (batch_size, length) = reprs.shape[:2]
    shape_suffix = reprs.shape[2:]
    remainder_length = length % (x_size + y_size)
    remainder = reprs[:, None, -remainder_length:]
    reprs = reprs[:, :-remainder_length]
    reprs = np.reshape(reprs, (batch_size, -1, x_size + y_size) + shape_suffix)
    x_reprs = reprs[:, :, :x_size]
    y_reprs = reprs[:, :, x_size:]
    x_reprs = np.concatenate((x_reprs, remainder), axis=1)
    return (x_reprs, y_reprs)
Exemple #28
0
  def F(x):
    # TODO(afrozm): What to do in this case?
    if mode == 'predict':
      raise ValueError('MaskOfRightShiftedArray not implemented for predict.')

    mask = x != 0

    if n_shifts == 0:
      return mask

    # Need to set (B, n_shifts, ...) section to True.
    trues_shape = (x.shape[0], n_shifts) + mask.shape[2:]
    trues = jnp.full(trues_shape, True)
    return jnp.concatenate([trues, mask[:, n_shifts:, ...]], axis=1)
Exemple #29
0
  def F(vec_e, vec_d, mask_e, mask_d):
    # pylint: disable=invalid-name
    L1 = mask_e.shape[1]
    L2 = mask_d.shape[1]
    # pylint: enable=invalid-name

    # [-(L1+L2), -L2) but with padding 0-ed out - (B, L1).
    mask_e_key = jnp.arange(-(L1 + L2), -L2) * mask_e
    # [-L2,0) but with padding 0-ed out - (B, L2).
    mask_d_key = jnp.arange(-L2, 0) * mask_d

    # Shape (B, L1+L2, H)
    enc_dec_concat = jnp.concatenate([vec_e, vec_d], axis=1)
    # Shape (B, L1+L2)
    mask_concat = jnp.concatenate([mask_e_key, mask_d_key], axis=1)
    # Make `mask_concat` the same shape as `enc_dec_concat`
    mask_concat = (
        mask_concat[..., jnp.newaxis] +
        jnp.zeros_like(enc_dec_concat, dtype=jnp.int32))
    # Sort on `mask_concat` so padding with key=0 goes to the right end, axis=1.
    _, enc_dec_pad = math.sort_key_val(mask_concat, enc_dec_concat, 1)

    return enc_dec_pad
Exemple #30
0
 def deinterleave(inputs):
     reprs = inputs
     (batch_size, length) = reprs.shape[:2]
     shape_suffix = reprs.shape[2:]
     remainder_length = length % (x_size + y_size)
     if remainder_length > 0:
         remainder = reprs[:, None, -remainder_length:]
         reprs = reprs[:, :-remainder_length]
     reprs = jnp.reshape(reprs,
                         (batch_size, -1, x_size + y_size) + shape_suffix)
     x_reprs = reprs[:, :, :x_size]
     y_reprs = reprs[:, :, x_size:]
     if remainder_length > 0:
         x_reprs = jnp.concatenate((x_reprs, remainder), axis=1)
     return (x_reprs, y_reprs)