Esempio n. 1
0
def _StripFromConcatenateWithPadding(vec_ed, tok_e, tok_d):
    """Strip concatenate with padding: see the layer below for details."""
    # pylint: disable=invalid-name
    B, L, H = vec_ed.shape
    L1 = tok_e.shape[1]
    L2 = tok_d.shape[1]
    # pylint: enable=invalid-name
    if L != L1 + L2:
        raise ValueError(
            f'Length from encoder-decoder vectors ({L}) does not'
            f' equal sum of lengths from encoder ({L1}) and decoder'
            f' ({L2}).')
    if tok_e.shape != (B, L1):
        raise ValueError(f'Shape of encoder tokens, {tok_e.shape}, does not'
                         f' equal {(B, L1)}.')
    if tok_d.shape != (B, L2):
        raise ValueError(f'Shape of decoder tokens, {tok_d.shape}, does not'
                         f' equal {(B, L2)}.')

    def _UpdateRow(x):
        # (L, H), (L1, H) & (L2, H)
        row_ed, row_e, _ = x
        mask_e = row_e != 0
        len_e = jnp.sum(mask_e, dtype=jnp.int32)
        # In `row_ed` start where encoder tokens/vecs end, i.e. are index `len_e`
        # and pick up (L2, H) tensor slice from there.
        zero = jnp.array(0, dtype=len_e.dtype)  # avoid int32/int64 mismatch
        return fastmath.dynamic_slice(row_ed, (len_e, zero), (L2, H))

    return fastmath.map(_UpdateRow, [vec_ed, tok_e, tok_d])
Esempio n. 2
0
def _ConcatWithPadding(vec_e, vec_d, mask_e):
    """Concatenate with padding: see the ConcatWithPadding layer for details."""
    # pylint: disable=invalid-name
    B, L1, H = vec_e.shape
    L2 = vec_d.shape[1]
    # pylint: enable=invalid-name

    if vec_d.shape != (B, L2, H):
        raise ValueError(f'Shape of decoder vector, {vec_d.shape}, does not'
                         f' equal {(B, L2, H)}.')
    if mask_e.shape != (B, L1):
        raise ValueError(f'Shape of encoder mask, {mask_e.shape}, does not'
                         f' equal {(B, L1)}.')

    def _UpdateRow(x):
        # row_e - (L1, H), row_d - (L2, H), row_mask_e - (L1,)
        row_e, row_d, row_mask_e = x
        # final_row - (L1+L2, H)
        final_row = jnp.concatenate([row_e, jnp.zeros_like(row_d)], axis=0)
        # Find the last real token/vector of the encoder.
        e_idx = jnp.sum(row_mask_e, dtype=jnp.int32)
        # Starting after that index, update with the decoder row.
        zero = jnp.array(0, dtype=e_idx.dtype)  # avoid int32/int64 mismatch
        return fastmath.dynamic_update_slice(final_row, row_d, (e_idx, zero))

    return fastmath.map(_UpdateRow, [vec_e, vec_d, mask_e])