def forward(
        self,
        dec_input,
        dec_attn_mask,
        enc_output,
        enc_attn_mask,
        layer_past=None,
        get_key_value=False,
    ):
        # convert to Megatron mask
        dec_attn_mask_3d = build_attention_mask_3d(
            source_mask=dec_attn_mask,
            target_mask=dec_attn_mask,
            attn_mask_type=self.model_attn_mask_type,
        )
        enc_dec_attn_mask_3d = build_attention_mask_3d(
            source_mask=dec_attn_mask,
            target_mask=enc_attn_mask,
            attn_mask_type=AttnMaskType.padding,
        )

        # transformer decoder
        dec_output = self.model(
            dec_input,
            attn_mask_postprocess(dec_attn_mask_3d),
            layer_past=layer_past,
            get_key_value=get_key_value,
            encoder_output=enc_output,
            enc_dec_attn_mask=attn_mask_postprocess(enc_dec_attn_mask_3d),
        )

        return dec_output
Exemple #2
0
    def forward(
        self,
        enc_input,
        enc_attn_mask,
        context_attn_mask=None,
        encoder_output=None,
        layer_past=None,
        get_key_value=False,
    ):
        # expected enc_input shape [batch, num_chunks, num_neighbors, retrieval_seq_len, dim]
        # expected enc_attn_mask shape [batch, num_chunks, num_neighbors, retrieval_seq_len]
        # expected encoder_output shape [batch, seq_len, dim]
        b, k, r, rn, dim = enc_input.shape

        # batch, seq_len, dim
        _, n, _ = encoder_output.shape

        num_seq_chunks = n // self.chunk_size
        assert k == num_seq_chunks, f'sequence requires {num_seq_chunks} retrieved chunks, but only {k} passed in'

        seq_index = num_seq_chunks * self.chunk_size

        retrieved = rearrange(enc_input, 'b k r n d -> (b k r) n d')
        enc_attn_mask = rearrange(enc_attn_mask, 'b k r n -> (b k r) n')
        embed_as_context = repeat(encoder_output[:, :seq_index], 'b (k n) d -> (b k r) n d', n=self.chunk_size, r=r)
        context_attn_mask = repeat(context_attn_mask[:, :seq_index], 'b (k n) -> (b k r) n', n=self.chunk_size, r=r)

        # need to add extra chunk size, since it will be shifted
        cross_attn_q_pos_emb = self.rotary_pos_emb(rn, offset=0)
        cross_attn_k_pos_emb = self.rotary_pos_emb(self.chunk_size)
        attn_pos_emb = (cross_attn_q_pos_emb, cross_attn_q_pos_emb, cross_attn_k_pos_emb)

        # # convert to Megatron mask
        enc_attn_mask_3d = build_attention_mask_3d(
            source_mask=enc_attn_mask, target_mask=enc_attn_mask, attn_mask_type=self.model_attn_mask_type,
        )
        enc_attn_mask_3d = enc_attn_mask_3d[:, None, :, :]

        enc_dec_attn_mask_3d = build_attention_mask_3d(
            source_mask=enc_attn_mask, target_mask=context_attn_mask, attn_mask_type=AttnMaskType.padding,
        )
        enc_dec_attn_mask_3d = enc_dec_attn_mask_3d[:, None, :, :]

        # transformer encoder
        enc_output = self.model(
            retrieved,
            enc_attn_mask_3d,
            layer_past=layer_past,
            get_key_value=get_key_value,
            encoder_output=embed_as_context,
            enc_dec_attn_mask=enc_dec_attn_mask_3d,
            rotary_pos_emb=attn_pos_emb,
        )
        # revert back to original retrieved shape
        enc_output = rearrange(enc_output, '(b k r) n d -> b k r n d', b=b, k=k)
        return enc_output
Exemple #3
0
    def forward(
        self,
        dec_input,
        dec_attn_mask,
        retrieved_attn_mask=None,
        retrieved_emb=None,
        layer_past=None,
        get_key_value=False,
        eod_positions=None,  # this is a tuple of eod positions returned from tensor.where(tensor == eod_id)
    ):
        # expected dec_input shape [batch, seq_len, dim]
        # expected dec_attn_mask shape [batch, seq_len]
        # expected retrieved_input shape [batch, num_chunks, num_neighbors, retrival_seq_len, dim]
        # expected retrieved_attn_mask shape [batch, num_chunks, num_neighbors, retrival_seq_len]

        # batch, seq_len, dim
        _, n, _ = dec_input.shape

        num_seq_chunks = n // self.chunk_size

        if retrieved_emb is not None:
            b, k, r, rn, dim = retrieved_emb.shape
            assert (
                k == num_seq_chunks
            ), f'sequence requires {num_seq_chunks} retrieved chunks, but only {k} passed in'  # need to add extra chunk size, since it will be shifted
        self_attn_emb = self.rotary_pos_emb(n)

        if retrieved_emb is not None:
            cross_attn_q_pos_emb = self.rotary_pos_emb(self.chunk_size * 2 - 1)
            cross_attn_k_pos_emb = self.rotary_pos_emb(rn, offset=0)
            attn_pos_emb = (self_attn_emb, cross_attn_q_pos_emb, cross_attn_k_pos_emb)
        else:
            attn_pos_emb = (self_attn_emb, None, None)

        dec_attn_mask_3d = self._calculate_dec_att_mask(dec_attn_mask, eod_positions)

        if retrieved_emb is not None:
            dec_attn_mask = rearrange(dec_attn_mask, 'b (k n) -> (b k) n', k=k)
            retrieved_attn_mask = rearrange(retrieved_attn_mask, 'b k r n -> (b k) (r n)')

            enc_dec_attn_mask_3d = build_attention_mask_3d(
                source_mask=dec_attn_mask, target_mask=retrieved_attn_mask, attn_mask_type=AttnMaskType.padding,
            )
            enc_dec_attn_mask_3d = enc_dec_attn_mask_3d[:, None, :, :]
        else:
            enc_dec_attn_mask_3d = None

        # transformer encoder
        enc_output = self.model(
            dec_input,
            dec_attn_mask_3d,
            layer_past=layer_past,
            get_key_value=get_key_value,
            encoder_output=None,
            retrieved_emb=retrieved_emb,
            enc_dec_attn_mask=enc_dec_attn_mask_3d,
            rotary_pos_emb=attn_pos_emb,
        )

        return enc_output
    def forward(
        self,
        enc_input,
        enc_attn_mask,
        layer_past=None,
        get_key_value=False,
    ):
        # convert to Megatron mask
        enc_attn_mask_3d = build_attention_mask_3d(
            source_mask=enc_attn_mask,
            target_mask=enc_attn_mask,
            attn_mask_type=self.model_attn_mask_type,
        )

        # transformer encoder
        enc_output = self.model(
            enc_input,
            attn_mask_postprocess(enc_attn_mask_3d),
            layer_past=layer_past,
            get_key_value=get_key_value,
        )
        # we copy input mask for transformer
        enc_output_mask = enc_attn_mask

        return enc_output, enc_output_mask
Exemple #5
0
        def get_attn_mask_3d(hidden_mask, context_mask, chunks):
            causal_padding = text_chunk_size - 1
            reminder = (text_chunk_size -
                        (hidden_mask.shape[0] + 1)) % text_chunk_size
            hidden_mask = F.pad(hidden_mask, (0, 0, -causal_padding, reminder),
                                value=False)

            dec_attn_mask = rearrange(hidden_mask,
                                      '(k n) b -> (b k) n',
                                      k=chunks)
            context_attn_mask = rearrange(context_mask,
                                          'k r n b -> (b k) (r n)')
            enc_dec_attn_mask_3d = build_attention_mask_3d(
                source_mask=dec_attn_mask,
                target_mask=context_attn_mask,
                attn_mask_type=AttnMaskType.padding,
            )
            enc_dec_attn_mask_3d = enc_dec_attn_mask_3d[:, None, :, :]
            return enc_dec_attn_mask_3d
Exemple #6
0
 def _calculate_dec_att_mask(self, dec_attn_mask, eod_positions):
     # # convert to Megatron mask
     dec_attn_mask_3d = build_attention_mask_3d(
         source_mask=dec_attn_mask, target_mask=dec_attn_mask, attn_mask_type=self.model_attn_mask_type,
     )
     if eod_positions is not None:
         # to mask out the token ids [id, id,  eod, id, pad, eod, id, id]
         # so attention is not across eod, mask should be:
         # [false, true,  true, true,  true, true,  true,  true]
         # [false, false, true, true,  true, true,  true,  true]
         # [false, false, false,true,  true, true,  true,  true]
         # [true,  true,  true, false, true, true,  true,  true]
         # [true,  true,  true, true,  true, true,  true,  true]
         # [true,  true,  true, false, true, false, true,  true]
         # [true,  true,  true, true,  true, true,  false, true]
         # [true,  true,  true, true,  true, true,  false, false]
         for batch, eod_pos in zip(*eod_positions):
             eod_plus_one = eod_pos.item() + 1
             dec_attn_mask_3d[batch][eod_plus_one:, :eod_plus_one] = True
     dec_attn_mask_3d = dec_attn_mask_3d[:, None, :, :]
     return dec_attn_mask_3d
Exemple #7
0
    def test_cross_attn(self):
        num_layers = 1
        init_method_std = 0.02
        batch = 2
        neighbors = 2
        # rotary pos emb dim
        dim = 128
        pad_id = 19999
        num_attention_heads = 8
        chunks = 32
        text_chunk_size = 64
        context_chunk_size = 2 * text_chunk_size
        input_length = chunks * text_chunk_size
        vocab_size = 20000

        rot_dim = dim // num_attention_heads
        rotary_pos_emb = RotaryEmbedding(rot_dim).cuda().half()

        hidden = torch.randint(0, vocab_size, (input_length, batch)).cuda()  # (seq, batch, dim)
        hidden_mask = (hidden != pad_id).cuda()
        hidden_emb = torch.rand(input_length, batch, dim).cuda().half()  # (seq, batch, dim)

        retrieved = torch.randint(0, vocab_size, (chunks, neighbors, context_chunk_size, batch)).cuda()
        # retrieved tokens - (num chunks, num retrieved neighbors, retrieved chunk with continuation, batch)

        # context attention mask [b, np, sq, sk]
        context_mask = (retrieved != pad_id).cuda()
        retrieved_emb = torch.rand(chunks, neighbors, context_chunk_size, batch, dim).cuda().half()
        # retrieved tokens - (num chunks, num retrieved neighbors, retrieved chunk with continuation, batch, hidden)

        # need to add extra chunk size, since it will be shifted
        cross_attn_q_pos_emb = rotary_pos_emb(text_chunk_size + text_chunk_size - 1, offset=0)
        cross_attn_k_pos_emb = rotary_pos_emb(context_chunk_size)
        cross_attn_pos_emb = (cross_attn_q_pos_emb, cross_attn_k_pos_emb)

        dec_attn_mask = rearrange(hidden_mask, '(k n) b -> (b k) n', k=chunks)
        context_attn_mask = rearrange(context_mask, 'k r n b -> (b k) (r n)')
        enc_dec_attn_mask_3d = build_attention_mask_3d(
            source_mask=dec_attn_mask, target_mask=context_attn_mask, attn_mask_type=AttnMaskType.padding,
        )
        enc_dec_attn_mask_3d = enc_dec_attn_mask_3d[:, None, :, :]

        init_method = init_method_normal(init_method_std)

        scaled_init_method = scaled_init_method_normal(init_method_std, num_layers)
        cross_attn = (
            ParallelChunkedCrossAttention(
                init_method=init_method,
                output_layer_init_method=scaled_init_method,
                layer_number=0,
                num_attention_heads=num_attention_heads,
                hidden_size=dim,
                precision=16,
                chunk_size=text_chunk_size,
            )
            .cuda()
            .half()
        )

        out, bias = cross_attn(
            hidden_emb, enc_dec_attn_mask_3d, encoder_output=retrieved_emb, rotary_pos_emb=cross_attn_pos_emb
        )
        assert out.shape == torch.Size([input_length, batch, dim])
        assert bias.shape == torch.Size([dim])
Exemple #8
0
    def forward(
        self,
        dec_input,
        dec_attn_mask,
        retrieved_attn_mask=None,
        retrieved_emb=None,
        layer_past=None,
        get_key_value=False,
        eod_positions=None,  # this is a tuple of eod positions returned from tensor.where(tensor == eod_id)
        set_inference_key_value_memory=False,
        inference_max_sequence_len=None,
    ):
        # expected dec_input shape [batch, seq_len, dim]
        # expected dec_attn_mask shape [batch, seq_len]
        # expected retrieved_input shape [batch, num_chunks, num_neighbors, retrival_seq_len, dim]
        # expected retrieved_attn_mask shape [batch, num_chunks, num_neighbors, retrival_seq_len]

        # batch, seq_len, dim
        if isinstance(dec_input, tuple):
            n, _, _ = dec_input[1].shape
        else:
            _, n, _ = dec_input.shape

        if set_inference_key_value_memory == True:
            # seq_index = (n // chunk_size) * chunk_size
            self.current_len = n
            num_seq_chunks = self.current_len // self.chunk_size
            self_attn_emb = self.rotary_pos_emb(self.current_len)
        elif inference_max_sequence_len is not None:
            # only handles single token increment
            assert n == 1
            self.current_len += n
            self_attn_emb = self.rotary_pos_emb(self.current_len)
            num_seq_chunks = self.current_len // self.chunk_size
        else:
            # this is normal forward without inference
            num_seq_chunks = n // self.chunk_size
            self_attn_emb = self.rotary_pos_emb(n)

        if retrieved_emb is not None:
            b, k, r, rn, dim = retrieved_emb.shape
            assert (
                k == num_seq_chunks
            ), f'sequence requires {num_seq_chunks} retrieved chunks, but only {k} passed in'  # need to add extra chunk size, since it will be shifted

        if retrieved_emb is not None:
            cross_attn_q_pos_emb = self.rotary_pos_emb(
                self.chunk_size * 2 - 1, offset=-self.chunk_size + 1)
            cross_attn_k_pos_emb = self.rotary_pos_emb(rn, offset=0)
            attn_pos_emb = (self_attn_emb, cross_attn_q_pos_emb,
                            cross_attn_k_pos_emb)
        else:
            attn_pos_emb = (self_attn_emb, None, None)

        dec_attn_mask_3d = self._calculate_dec_att_mask(
            dec_attn_mask, eod_positions)

        if retrieved_emb is not None:
            # need to shift the dec_attn_mask as first causal_padding elements are ignored
            # also pad it to be the multiple of self.chunk_size
            causal_padding = self.chunk_size - 1
            reminder = (self.chunk_size -
                        (dec_attn_mask.shape[1] + 1)) % self.chunk_size
            dec_attn_mask = F.pad(dec_attn_mask, (-causal_padding, reminder),
                                  value=False)

            dec_attn_mask = rearrange(dec_attn_mask, 'b (k n) -> (b k) n', k=k)
            retrieved_attn_mask = rearrange(retrieved_attn_mask,
                                            'b k r n -> (b k) (r n)')

            enc_dec_attn_mask_3d = build_attention_mask_3d(
                source_mask=dec_attn_mask,
                target_mask=retrieved_attn_mask,
                attn_mask_type=AttnMaskType.padding,
            )
            enc_dec_attn_mask_3d = enc_dec_attn_mask_3d[:, None, :, :]
        else:
            enc_dec_attn_mask_3d = None

        # transformer encoder
        enc_output = self.model(
            dec_input,
            dec_attn_mask_3d,
            layer_past=layer_past,
            get_key_value=get_key_value,
            encoder_output=None,
            retrieved_emb=retrieved_emb,
            enc_dec_attn_mask=enc_dec_attn_mask_3d,
            rotary_pos_emb=attn_pos_emb,
            set_inference_key_value_memory=set_inference_key_value_memory,
            inference_max_sequence_len=inference_max_sequence_len,
        )

        return enc_output
Exemple #9
0
    def forward(
        self,
        enc_input,
        enc_attn_mask,
        context_attn_mask=None,
        encoder_output=None,
        layer_past=None,
        get_key_value=False,
        set_inference_key_value_memory=False,  # when doing inference, set this to true to allocate all the cached matrix. later set false to do incremental inference
        inference_max_sequence_len=None,
        neighbors=2,
    ):
        # expected enc_input shape [batch, num_chunks, num_neighbors, retrieval_seq_len, dim]
        # expected enc_attn_mask shape [batch, num_chunks, num_neighbors, retrieval_seq_len]
        # expected encoder_output shape [batch, seq_len, dim]

        # batch, seq_len, dim
        b, n, dim = encoder_output.shape

        if set_inference_key_value_memory:
            # run once to setup the cache
            chunk_start = 0
            num_seq_chunks = n // self.chunk_size
            num_chunks = inference_max_sequence_len // self.chunk_size
            self.cache_output = self._allocate_memory(
                b,
                num_chunks,
                neighbors,
                self.chunk_size * 2,
                dim,
                dtype=encoder_output.dtype)
            self.seq_pos_in_chunk = n
            self.current_chunk = n // self.chunk_size
            self.encoder_output = self._allocate_memory(
                b, self.chunk_size, dim, dtype=encoder_output.dtype)
            self.context_attn_mask = self._allocate_memory(
                b, self.chunk_size, dtype=context_attn_mask.dtype)
            self.context_attn_mask
            chunk_beg = self.chunk_size * num_seq_chunks
            chunk_end = self.chunk_size * num_seq_chunks + self.seq_pos_in_chunk % self.chunk_size
            # store the remainders
            self.encoder_output[:, :self.seq_pos_in_chunk % self.
                                chunk_size, :] = encoder_output[:, chunk_beg:
                                                                chunk_end, :]
            self.context_attn_mask[:, :self.seq_pos_in_chunk % self.
                                   chunk_size] = context_attn_mask[:,
                                                                   chunk_beg:
                                                                   chunk_end]
        elif inference_max_sequence_len is not None:
            # second time of running
            # only support one token at a time
            assert n == 1
            self.seq_pos_in_chunk += n
            self.current_chunk = self.seq_pos_in_chunk // self.chunk_size
            # if exceed the chunk size
            pos_beg = (self.seq_pos_in_chunk - 1) % self.chunk_size
            # if self.seq_pos_in_chunk - 1 >= self.chunk_size:
            #     self.current_chunk += 1
            #     self.seq_pos_in_chunk -= self.chunk_size
            chunk_start = self.current_chunk - 1
            self.encoder_output[:, pos_beg:pos_beg + 1, :] = encoder_output
            self.context_attn_mask[:, pos_beg:pos_beg +
                                   1] = context_attn_mask[:, self.
                                                          seq_pos_in_chunk -
                                                          1:self.
                                                          seq_pos_in_chunk]
            encoder_output = self.encoder_output[:, :pos_beg + 1, :]
            context_attn_mask = self.context_attn_mask[:, :pos_beg + 1]
            num_seq_chunks = 1
            if not self.seq_pos_in_chunk % self.chunk_size == 0:
                # still accumulate the encoder_output
                # return the cached results
                if self.current_chunk == 0:
                    return None
                return self.cache_output[:, :self.current_chunk]
            if enc_input is not None:
                # only need one chunk for the later calculation
                enc_input = enc_input[:, self.current_chunk -
                                      1:self.current_chunk]
                enc_attn_mask = enc_attn_mask[:, self.current_chunk -
                                              1:self.current_chunk]

        if enc_input is None:
            return None

        _, k, r, rn, _ = enc_input.shape

        assert r == neighbors
        if inference_max_sequence_len is None:
            num_seq_chunks = n // self.chunk_size
            assert k == num_seq_chunks, f'sequence requires {num_seq_chunks} retrieved chunks, but only {k} passed in'
        else:
            pass

        seq_index = num_seq_chunks * self.chunk_size

        retrieved = rearrange(enc_input, 'b k r n d -> (b k r) n d')
        enc_attn_mask = rearrange(enc_attn_mask, 'b k r n -> (b k r) n')
        # embed_as_context = repeat(encoder_output[:, :seq_index], 'b (k n) d -> (b k r) n d', n=self.chunk_size, r=r)
        # context_attn_mask = repeat(context_attn_mask[:, :seq_index], 'b (k n) -> (b k r) n', n=self.chunk_size, r=r)

        cross_attn_q_pos_emb = self.rotary_pos_emb(rn, offset=0)

        if inference_max_sequence_len is not None and not set_inference_key_value_memory:
            cross_attn_k_pos_emb = self.rotary_pos_emb(n % self.chunk_size,
                                                       offset=pos_beg)
            embed_as_context = repeat(encoder_output[:, :seq_index],
                                      'b (k n) d -> (b k r) n d',
                                      n=pos_beg + 1,
                                      r=r)
            context_attn_mask = repeat(context_attn_mask[:, :seq_index],
                                       'b (k n) -> (b k r) n',
                                       n=pos_beg + 1,
                                       r=r)
        else:
            embed_as_context = repeat(encoder_output[:, :seq_index],
                                      'b (k n) d -> (b k r) n d',
                                      n=self.chunk_size,
                                      r=r)
            context_attn_mask = repeat(context_attn_mask[:, :seq_index],
                                       'b (k n) -> (b k r) n',
                                       n=self.chunk_size,
                                       r=r)
            cross_attn_k_pos_emb = self.rotary_pos_emb(self.chunk_size,
                                                       offset=0)

        attn_pos_emb = (cross_attn_q_pos_emb, cross_attn_q_pos_emb,
                        cross_attn_k_pos_emb)

        # # convert to Megatron mask
        enc_attn_mask_3d = build_attention_mask_3d(
            source_mask=enc_attn_mask,
            target_mask=enc_attn_mask,
            attn_mask_type=self.model_attn_mask_type,
        )
        enc_attn_mask_3d = enc_attn_mask_3d[:, None, :, :]

        enc_dec_attn_mask_3d = build_attention_mask_3d(
            source_mask=enc_attn_mask,
            target_mask=context_attn_mask,
            attn_mask_type=AttnMaskType.padding,
        )
        enc_dec_attn_mask_3d = enc_dec_attn_mask_3d[:, None, :, :]

        # transformer encoder
        enc_output = self.model(
            retrieved,
            enc_attn_mask_3d,
            layer_past=layer_past,
            get_key_value=get_key_value,
            encoder_output=embed_as_context,
            enc_dec_attn_mask=enc_dec_attn_mask_3d,
            rotary_pos_emb=attn_pos_emb,
        )
        # revert back to original retrieved shape
        enc_output = rearrange(enc_output,
                               '(b k r) n d -> b k r n d',
                               b=b,
                               k=k)

        if inference_max_sequence_len is not None:
            # update encoded for current chunk
            self.cache_output[:, chunk_start:self.
                              current_chunk, :, :, :] = enc_output
            # read all encodings
            enc_output = self.cache_output[:, :self.current_chunk]
        return enc_output