def _StripFromConcatenateWithPadding(vec_ed, tok_e, tok_d): """Strip concatenate with padding: see the layer below for details.""" # pylint: disable=invalid-name B, L, H = vec_ed.shape L1 = tok_e.shape[1] L2 = tok_d.shape[1] # pylint: enable=invalid-name if L != L1 + L2: raise ValueError( f'Length from encoder-decoder vectors ({L}) does not' f' equal sum of lengths from encoder ({L1}) and decoder' f' ({L2}).') if tok_e.shape != (B, L1): raise ValueError(f'Shape of encoder tokens, {tok_e.shape}, does not' f' equal {(B, L1)}.') if tok_d.shape != (B, L2): raise ValueError(f'Shape of decoder tokens, {tok_d.shape}, does not' f' equal {(B, L2)}.') def _UpdateRow(x): # (L, H), (L1, H) & (L2, H) row_ed, row_e, _ = x mask_e = row_e != 0 len_e = jnp.sum(mask_e, dtype=jnp.int32) # In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e` # and pick up (L2, H) tensor slice from there. zero = jnp.array(0, dtype=len_e.dtype) # avoid int32/int64 mismatch return fastmath.dynamic_slice(row_ed, (len_e, zero), (L2, H)) return fastmath.map(_UpdateRow, [vec_ed, tok_e, tok_d])
def _ConcatWithPadding(vec_e, vec_d, mask_e): """Concatenate with padding: see the ConcatWithPadding layer for details.""" # pylint: disable=invalid-name B, L1, H = vec_e.shape L2 = vec_d.shape[1] # pylint: enable=invalid-name if vec_d.shape != (B, L2, H): raise ValueError(f'Shape of decoder vector, {vec_d.shape}, does not' f' equal {(B, L2, H)}.') if mask_e.shape != (B, L1): raise ValueError(f'Shape of encoder mask, {mask_e.shape}, does not' f' equal {(B, L1)}.') def _UpdateRow(x): # row_e - (L1, H), row_d - (L2, H), row_mask_e - (L1,) row_e, row_d, row_mask_e = x # final_row - (L1+L2, H) final_row = jnp.concatenate([row_e, jnp.zeros_like(row_d)], axis=0) # Find the last real token/vector of the encoder. e_idx = jnp.sum(row_mask_e, dtype=jnp.int32) # Starting after that index, update with the decoder row. zero = jnp.array(0, dtype=e_idx.dtype) # avoid int32/int64 mismatch return fastmath.dynamic_update_slice(final_row, row_d, (e_idx, zero)) return fastmath.map(_UpdateRow, [vec_e, vec_d, mask_e])