Exemplo n.º 1
0
    def test_retrival_decoder(self):

        init_method_std = 0.02

        # rotary pos emb dim
        batch = 2
        neighbors = 2
        dim = 128
        pad_id = 19999
        num_attention_heads = 8
        chunks = 32
        text_chunk_size = 64
        input_length = chunks * text_chunk_size
        vocab_size = 20000
        # rot_dim = dim // num_attention_heads
        # rotary_pos_emb = RotaryEmbedding(rot_dim).cuda().half()
        hidden = torch.randint(
            0, vocab_size, (batch, input_length)).cuda()  # (seq, batch, dim)
        hidden_mask = (hidden != pad_id).cuda()

        hidden_emb = torch.rand(batch, input_length,
                                dim).cuda().half()  # (batch, seq, dim)

        # context_chunk_size = 128
        retrieved = torch.randint(
            0, vocab_size, (batch, chunks, neighbors, 2 * chunks)).cuda()
        # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation)

        # context attention mask [b, np, sq, sk]
        pad_id = vocab_size - 1
        context_mask = (retrieved != pad_id).cuda()
        retrieved_emb = torch.rand(batch, chunks, neighbors, 2 * chunks,
                                   dim).cuda().half()
        # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation, hidden)

        layer_type = [
            LayerType.encoder, LayerType.retrieval_decoder, LayerType.encoder,
            LayerType.retrieval_decoder
        ]
        num_layers = len(layer_type)

        init_method = init_method_normal(init_method_std)
        scaled_init_method = scaled_init_method_normal(init_method_std,
                                                       num_layers)
        decoder = (MegatronRetrievalTransformerDecoderModule(
            init_method=init_method,
            output_layer_init_method=scaled_init_method,
            hidden_size=dim,
            ffn_hidden_size=dim * 4,
            num_layers=num_layers,
            num_attention_heads=num_attention_heads,
            precision=16,
            chunk_size=text_chunk_size,
            layer_type=layer_type,
        ).cuda().half())
        out = decoder(hidden_emb,
                      hidden_mask,
                      retrieved_attn_mask=context_mask,
                      retrieved_emb=retrieved_emb)
Exemplo n.º 2
0
    def test_retrieval_encoder(self):

        init_method_std = 0.02

        batch = 2
        neighbors = 2
        # rotary pos emb dim
        dim = 128
        pad_id = 19999
        num_attention_heads = 8
        chunks = 32
        text_chunk_size = 64
        input_length = chunks * text_chunk_size
        vocab_size = 20000

        hidden = torch.randint(
            0, vocab_size, (batch, input_length)).cuda()  # (seq, batch, dim)
        hidden_mask = (hidden != pad_id).cuda()

        hidden_emb = torch.rand(batch, input_length,
                                dim).cuda().half()  # (batch, seq, dim)
        retrieved = torch.randint(
            0, vocab_size,
            (batch, chunks, neighbors, 2 * text_chunk_size)).cuda()
        pad_id = vocab_size - 1
        context_mask = (retrieved != pad_id).cuda()
        retrieved_emb = torch.rand(batch, chunks, neighbors,
                                   2 * text_chunk_size, dim).cuda().half()

        layer_type = [
            LayerType.encoder, LayerType.retrieval_encoder, LayerType.encoder,
            LayerType.retrieval_encoder
        ]
        num_layers = len(layer_type)

        init_method = init_method_normal(init_method_std)
        scaled_init_method = scaled_init_method_normal(init_method_std,
                                                       num_layers)
        encoder = (MegatronRetrievalTransformerEncoderModule(
            init_method=init_method,
            output_layer_init_method=scaled_init_method,
            hidden_size=dim,
            ffn_hidden_size=dim * 4,
            num_layers=num_layers,
            num_attention_heads=num_attention_heads,
            precision=16,
            chunk_size=text_chunk_size,
            layer_type=layer_type,
        ).cuda().half())
        out = encoder(retrieved_emb,
                      context_mask,
                      context_attn_mask=hidden_mask,
                      encoder_output=hidden_emb)
        assert out.shape == torch.Size(
            [batch, chunks, neighbors, 2 * text_chunk_size, dim])
Exemplo n.º 3
0
    def __init__(self,
                 num_tokentypes=2,
                 add_binary_head=True,
                 parallel_output=True,
                 pre_process=True,
                 post_process=True):
        super(BertModel, self).__init__()
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        self.add_binary_head = add_binary_head
        self.parallel_output = parallel_output
        self.pre_process = pre_process
        self.post_process = post_process

        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=self.add_binary_head,
            encoder_attn_mask_type=AttnMaskType.padding,
            init_method=init_method,
            scaled_init_method=scaled_init_method,
            pre_process=self.pre_process,
            post_process=self.post_process,
        )

        self.initialize_word_embeddings(init_method_normal)
        if self.post_process:
            self.lm_head = BertLMHead(
                self.word_embeddings_weight().size(0),
                args.hidden_size,
                init_method,
                args.layernorm_epsilon,
                parallel_output,
            )
            self._lm_head_key = 'lm_head'
            self.binary_head = None
            if self.add_binary_head:
                self.binary_head = get_linear_layer(args.hidden_size, 2,
                                                    init_method)
                self._binary_head_key = 'binary_head'
Exemplo n.º 4
0
    def __init__(self, num_tokentypes=0, parallel_output=True):
        super(T5Model, self).__init__()
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        self.parallel_output = parallel_output
        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)

        self.language_model, self._language_model_key = get_language_model(
            num_tokentypes=num_tokentypes,
            add_pooler=False,
            add_decoder=True,
            encoder_attn_mask_type=AttnMaskType.padding,
            init_method=init_method,
            scaled_init_method=scaled_init_method,
        )

        self.lm_head = T5LMHead(
            self.language_model.embedding.word_embeddings.weight.size(0),
            parallel_output)
        self._lm_head_key = 'lm_head'
Exemplo n.º 5
0
def get_decoder_model(
    arch,
    hidden_size,
    ffn_hidden_size,
    num_layers,
    num_attention_heads,
    apply_query_key_layer_scaling=True,
    kv_channels=None,
    init_method=None,
    scaled_init_method=None,
    add_decoder=False,
    decoder_attn_mask_type=AttnMaskType.causal,
    pre_process=True,
    post_process=True,
    init_method_std=0.02,
    use_cpu_initialization=False,
    hidden_dropout=0.1,
    attention_dropout=0.1,
    position_embedding_type='learned_absolute',
    relative_attention_num_buckets=32,
    relative_attention_max_distance=128,
    precision=16,
    fp32_residual_connection=False,
    activations_checkpoint_method=None,
    activations_checkpoint_num_layers=1,
    layernorm_epsilon=1e-5,
    bias_gelu_fusion=True,
    bias_dropout_add_fusion=True,
    masked_softmax_fusion=True,
    persist_layer_norm=False,
    openai_gelu=False,
    activation="gelu",
    onnx_safe=False,
    bias=True,
    normalization="layernorm",
    headscale=False,
    transformer_block_type="pre_ln",
    hidden_steps=-1,
    hidden_blocks=1,
    parent_model_type=ModelType.encoder_or_decoder,
    layer_type=None,
    chunk_size=64,
    layer_number_offset=0,  # this is use only for attention norm_factor scaling
):
    """Build language model and return along with the key to save."""

    if kv_channels is None:
        assert (
            hidden_size % num_attention_heads == 0
        ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
        kv_channels = hidden_size // num_attention_heads

    if init_method is None:
        init_method = init_method_normal(init_method_std)

    if scaled_init_method is None:
        scaled_init_method = scaled_init_method_normal(init_method_std, num_layers)

    if arch == "transformer":
        # Language model.
        decoder = MegatronTransformerDecoderModule(
            init_method=init_method,
            output_layer_init_method=scaled_init_method,
            hidden_size=hidden_size,
            num_layers=num_layers,
            num_attention_heads=num_attention_heads,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            kv_channels=kv_channels,
            ffn_hidden_size=ffn_hidden_size,
            decoder_attn_mask_type=decoder_attn_mask_type,
            pre_process=pre_process,
            post_process=post_process,
            use_cpu_initialization=use_cpu_initialization,
            hidden_dropout=hidden_dropout,
            attention_dropout=attention_dropout,
            position_embedding_type=position_embedding_type,
            relative_attention_num_buckets=relative_attention_num_buckets,
            relative_attention_max_distance=relative_attention_max_distance,
            precision=precision,
            fp32_residual_connection=fp32_residual_connection,
            activations_checkpoint_method=activations_checkpoint_method,
            activations_checkpoint_num_layers=activations_checkpoint_num_layers,
            layernorm_epsilon=layernorm_epsilon,
            bias_gelu_fusion=bias_gelu_fusion,
            bias_dropout_add_fusion=bias_dropout_add_fusion,
            masked_softmax_fusion=masked_softmax_fusion,
            persist_layer_norm=persist_layer_norm,
            openai_gelu=openai_gelu,
            onnx_safe=onnx_safe,
            activation=activation,
            bias=bias,
            normalization=normalization,
            transformer_block_type=transformer_block_type,
            headscale=headscale,
            parent_model_type=parent_model_type,
        )
    elif arch == "retro":
        decoder = MegatronRetrievalTransformerDecoderModule(
            init_method=init_method,
            output_layer_init_method=scaled_init_method,
            hidden_size=hidden_size,
            num_layers=num_layers,
            num_attention_heads=num_attention_heads,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            kv_channels=kv_channels,
            layer_type=layer_type,
            ffn_hidden_size=ffn_hidden_size,
            pre_process=pre_process,
            post_process=post_process,
            use_cpu_initialization=use_cpu_initialization,
            hidden_dropout=hidden_dropout,
            attention_dropout=attention_dropout,
            precision=precision,
            fp32_residual_connection=fp32_residual_connection,
            activations_checkpoint_method=activations_checkpoint_method,
            activations_checkpoint_num_layers=activations_checkpoint_num_layers,
            layernorm_epsilon=layernorm_epsilon,
            bias_gelu_fusion=bias_gelu_fusion,
            bias_dropout_add_fusion=bias_dropout_add_fusion,
            masked_softmax_fusion=masked_softmax_fusion,
            persist_layer_norm=persist_layer_norm,
            openai_gelu=openai_gelu,
            onnx_safe=onnx_safe,
            activation=activation,
            bias=bias,
            normalization=normalization,
            transformer_block_type=transformer_block_type,
            parent_model_type=parent_model_type,
            chunk_size=chunk_size,
            layer_number_offset=layer_number_offset,
        )
    else:
        raise ValueError(f"Unknown decoder arch = {arch}. Available decoder arch = {AVAILABLE_DECODERS}")

    return decoder
Exemplo n.º 6
0
    def __init__(
        self,
        vocab_size,
        hidden_size,
        max_position_embeddings,
        num_layers,
        num_attention_heads,
        ffn_hidden_size,
        apply_query_key_layer_scaling=True,
        kv_channels=None,
        num_tokentypes=0,
        parallel_output=True,
        pre_process=True,
        post_process=True,
        init_method_std=0.02,
        fp16_lm_cross_entropy=False,
        use_cpu_initialization=False,
        hidden_dropout=0.1,
        precision=16,
        fp32_residual_connection=False,
        activations_checkpoint_method=None,
        activations_checkpoint_num_layers=1,
        layernorm_epsilon=1e-5,
        bias_gelu_fusion=True,
        persist_layer_norm=False,
        openai_gelu=False,
        onnx_safe=False,
    ):

        super(GPTModel, self).__init__()

        self.parallel_output = parallel_output
        self.pre_process = pre_process
        self.post_process = post_process
        self.fp16_lm_cross_entropy = fp16_lm_cross_entropy

        if kv_channels is None:
            assert (
                hidden_size % num_attention_heads == 0
            ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
            kv_channels = hidden_size // num_attention_heads

        self.language_model, self._language_model_key = get_language_model(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            hidden_dropout=hidden_dropout,
            num_tokentypes=num_tokentypes,
            max_position_embeddings=max_position_embeddings,
            num_layers=num_layers,
            num_attention_heads=num_attention_heads,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            kv_channels=kv_channels,
            ffn_hidden_size=ffn_hidden_size,
            add_pooler=False,
            encoder_attn_mask_type=AttnMaskType.causal,
            init_method=init_method_normal(init_method_std),
            scaled_init_method=scaled_init_method_normal(
                init_method_std, num_layers),
            pre_process=self.pre_process,
            post_process=self.post_process,
            init_method_std=init_method_std,
            use_cpu_initialization=use_cpu_initialization,
            precision=precision,
            fp32_residual_connection=fp32_residual_connection,
            activations_checkpoint_method=activations_checkpoint_method,
            activations_checkpoint_num_layers=activations_checkpoint_num_layers,
            layernorm_epsilon=layernorm_epsilon,
            bias_gelu_fusion=bias_gelu_fusion,
            persist_layer_norm=persist_layer_norm,
            openai_gelu=openai_gelu,
            onnx_safe=onnx_safe,
        )

        self.initialize_word_embeddings(
            init_method=init_method_normal(init_method_std),
            vocab_size=vocab_size,
            hidden_size=hidden_size)
Exemplo n.º 7
0
    def __init__(
        self,
        vocab_size,
        hidden_size,
        max_position_embeddings,
        num_layers,
        num_attention_heads,
        ffn_hidden_size,
        apply_query_key_layer_scaling=True,
        kv_channels=None,
        num_tokentypes=0,
        parallel_output=True,
        pre_process=True,
        post_process=True,
        init_method_std=0.02,
        fp16_lm_cross_entropy=False,
        use_cpu_initialization=False,
        hidden_dropout=0.1,
        precision=16,
        fp32_residual_connection=False,
        activations_checkpoint_method=None,
        activations_checkpoint_num_layers=1,
        layernorm_epsilon=1e-5,
        bias_gelu_fusion=True,
        openai_gelu=False,
        onnx_safe=False,
        add_binary_head=True,
    ):
        super(BertModel, self).__init__()
        # args = get_args()
        self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
        self.add_binary_head = add_binary_head
        self.parallel_output = parallel_output
        self.pre_process = pre_process
        self.post_process = post_process

        init_method = init_method_normal(init_method_std)
        scaled_init_method = scaled_init_method_normal(init_method_std,
                                                       num_layers)

        self.language_model, self._language_model_key = get_language_model(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            hidden_dropout=hidden_dropout,
            num_tokentypes=num_tokentypes,
            max_position_embeddings=max_position_embeddings,
            num_layers=num_layers,
            num_attention_heads=num_attention_heads,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            kv_channels=kv_channels,
            ffn_hidden_size=ffn_hidden_size,
            add_pooler=self.add_binary_head,
            encoder_attn_mask_type=AttnMaskType.padding,
            init_method=init_method,
            scaled_init_method=scaled_init_method,
            pre_process=self.pre_process,
            post_process=self.post_process,
            init_method_std=init_method_std,
            use_cpu_initialization=use_cpu_initialization,
            precision=precision,
            fp32_residual_connection=fp32_residual_connection,
            activations_checkpoint_method=activations_checkpoint_method,
            activations_checkpoint_num_layers=activations_checkpoint_num_layers,
            layernorm_epsilon=layernorm_epsilon,
            bias_gelu_fusion=bias_gelu_fusion,
            openai_gelu=openai_gelu,
            onnx_safe=onnx_safe,
        )

        self.initialize_word_embeddings(
            init_method=init_method_normal(init_method_std),
            vocab_size=vocab_size,
            hidden_size=hidden_size)

        if self.post_process:
            self.lm_head = BertLMHead(
                self.word_embeddings_weight().size(0),
                hidden_size,
                init_method,
                layernorm_epsilon,
                parallel_output,
                openai_gelu,
                onnx_safe,
            )
            self._lm_head_key = 'lm_head'
            self.binary_head = None
            if self.add_binary_head:
                self.binary_head = get_linear_layer(hidden_size, 2,
                                                    init_method)
                self._binary_head_key = 'binary_head'
Exemplo n.º 8
0
    def test_cross_attn(self):
        num_layers = 1
        init_method_std = 0.02
        batch = 2
        neighbors = 2
        # rotary pos emb dim
        dim = 128
        pad_id = 19999
        num_attention_heads = 8
        chunks = 32
        text_chunk_size = 64
        context_chunk_size = 2 * text_chunk_size
        input_length = chunks * text_chunk_size
        vocab_size = 20000

        rot_dim = dim // num_attention_heads
        rotary_pos_emb = RotaryEmbedding(rot_dim).cuda().half()

        hidden = torch.randint(0, vocab_size, (input_length, batch)).cuda()  # (seq, batch, dim)
        hidden_mask = (hidden != pad_id).cuda()
        hidden_emb = torch.rand(input_length, batch, dim).cuda().half()  # (seq, batch, dim)

        retrieved = torch.randint(0, vocab_size, (chunks, neighbors, context_chunk_size, batch)).cuda()
        # retrieved tokens - (num chunks, num retrieved neighbors, retrieved chunk with continuation, batch)

        # context attention mask [b, np, sq, sk]
        context_mask = (retrieved != pad_id).cuda()
        retrieved_emb = torch.rand(chunks, neighbors, context_chunk_size, batch, dim).cuda().half()
        # retrieved tokens - (num chunks, num retrieved neighbors, retrieved chunk with continuation, batch, hidden)

        # need to add extra chunk size, since it will be shifted
        cross_attn_q_pos_emb = rotary_pos_emb(text_chunk_size + text_chunk_size - 1, offset=0)
        cross_attn_k_pos_emb = rotary_pos_emb(context_chunk_size)
        cross_attn_pos_emb = (cross_attn_q_pos_emb, cross_attn_k_pos_emb)

        dec_attn_mask = rearrange(hidden_mask, '(k n) b -> (b k) n', k=chunks)
        context_attn_mask = rearrange(context_mask, 'k r n b -> (b k) (r n)')
        enc_dec_attn_mask_3d = build_attention_mask_3d(
            source_mask=dec_attn_mask, target_mask=context_attn_mask, attn_mask_type=AttnMaskType.padding,
        )
        enc_dec_attn_mask_3d = enc_dec_attn_mask_3d[:, None, :, :]

        init_method = init_method_normal(init_method_std)

        scaled_init_method = scaled_init_method_normal(init_method_std, num_layers)
        cross_attn = (
            ParallelChunkedCrossAttention(
                init_method=init_method,
                output_layer_init_method=scaled_init_method,
                layer_number=0,
                num_attention_heads=num_attention_heads,
                hidden_size=dim,
                precision=16,
                chunk_size=text_chunk_size,
            )
            .cuda()
            .half()
        )

        out, bias = cross_attn(
            hidden_emb, enc_dec_attn_mask_3d, encoder_output=retrieved_emb, rotary_pos_emb=cross_attn_pos_emb
        )
        assert out.shape == torch.Size([input_length, batch, dim])
        assert bias.shape == torch.Size([dim])
Exemplo n.º 9
0
def get_language_model(
    hidden_size,
    ffn_hidden_size,
    num_layers,
    max_position_embeddings,
    num_tokentypes,
    add_pooler,
    vocab_size,
    num_attention_heads,
    encoder_attn_mask_type,
    apply_query_key_layer_scaling=True,
    kv_channels=None,
    init_method=None,
    scaled_init_method=None,
    add_decoder=False,
    decoder_attn_mask_type=AttnMaskType.causal,
    pre_process=True,
    post_process=True,
    init_method_std=0.02,
    use_cpu_initialization=False,
    hidden_dropout=0.1,
    precision=16,
    fp32_residual_connection=False,
    activations_checkpoint_method=None,
    activations_checkpoint_num_layers=1,
    layernorm_epsilon=1e-5,
    bias_gelu_fusion=True,
    masked_softmax_fusion=True,
    persist_layer_norm=False,
    openai_gelu=False,
    onnx_safe=False,
    megatron_legacy=False,
):
    """Build language model and return along with the key to save."""

    if kv_channels is None:
        assert (
            hidden_size % num_attention_heads == 0
        ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
        kv_channels = hidden_size // num_attention_heads

    if init_method is None:
        init_method = init_method_normal(init_method_std)

    if scaled_init_method is None:
        scaled_init_method = scaled_init_method_normal(init_method_std, num_layers)

    # Language model.
    language_model = TransformerLanguageModel(
        init_method=init_method,
        output_layer_init_method=scaled_init_method,
        encoder_attn_mask_type=encoder_attn_mask_type,
        num_tokentypes=num_tokentypes,
        vocab_size=vocab_size,
        max_position_embeddings=max_position_embeddings,
        hidden_size=hidden_size,
        num_layers=num_layers,
        num_attention_heads=num_attention_heads,
        apply_query_key_layer_scaling=apply_query_key_layer_scaling,
        kv_channels=kv_channels,
        ffn_hidden_size=ffn_hidden_size,
        add_decoder=add_decoder,
        decoder_attn_mask_type=decoder_attn_mask_type,
        add_pooler=add_pooler,
        pre_process=pre_process,
        post_process=post_process,
        use_cpu_initialization=use_cpu_initialization,
        hidden_dropout=hidden_dropout,
        precision=precision,
        fp32_residual_connection=fp32_residual_connection,
        activations_checkpoint_method=activations_checkpoint_method,
        activations_checkpoint_num_layers=activations_checkpoint_num_layers,
        layernorm_epsilon=layernorm_epsilon,
        bias_gelu_fusion=bias_gelu_fusion,
        masked_softmax_fusion=masked_softmax_fusion,
        persist_layer_norm=persist_layer_norm,
        openai_gelu=openai_gelu,
        onnx_safe=onnx_safe,
        megatron_legacy=megatron_legacy,
    )
    # key used for checkpoints.
    language_model_key = 'language_model'

    return language_model, language_model_key
Exemplo n.º 10
0
    def __init__(
        self,
        vocab_size,
        hidden_size,
        max_position_embeddings,
        num_attention_heads,
        ffn_hidden_size,
        apply_query_key_layer_scaling=True,
        kv_channels=None,
        num_tokentypes=0,
        parallel_output=True,
        pre_process=True,
        post_process=True,
        init_method_std=0.02,
        fp16_cross_entropy=False,
        use_cpu_initialization=False,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        precision=16,
        fp32_residual_connection=False,
        activations_checkpoint_method=None,
        activations_checkpoint_num_layers=1,
        layernorm_epsilon=1e-5,
        persist_layer_norm=False,
        bias_gelu_fusion=True,
        bias_dropout_add_fusion=True,
        masked_softmax_fusion=True,
        openai_gelu=False,
        activation='gelu',
        onnx_safe=False,
        bias=True,
        normalization='layernorm',
        headscale=False,
        transformer_block_type='pre_ln',
        hidden_steps=-1,
        hidden_blocks=1,
        add_encoder=True,
        add_decoder=True,
        chunk_size=64,
        enc_num_layers=4,  # total number of encoder layers
        dec_num_layers=6,  # total number of decoder layers
        enc_cross_attention=[3],  # layer numbers for cross attention
        dec_cross_attention=[3, 5],  # layer numbers for chunked cross attention
        add_position_embedding=False,
        tokenizer=None,  # tokenizer
    ):
        super(MegatronRetrievalTokenLevelEncoderDecoderModule, self).__init__()

        self.parallel_output = parallel_output
        self.pre_process = pre_process
        self.post_process = post_process
        self.fp16_cross_entropy = fp16_cross_entropy
        self.precision = precision
        self.add_encoder = add_encoder
        self.add_decoder = add_decoder
        self.add_abs_position_embedding = add_position_embedding  # whether use absolute position embedding
        self.tokenizer = tokenizer
        self.eod_id = tokenizer.eos_id

        if kv_channels is None:
            assert (
                hidden_size % num_attention_heads == 0
            ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
            kv_channels = hidden_size // num_attention_heads

        if pre_process:
            self.encoder_embedding = Embedding(
                hidden_size=hidden_size,
                vocab_size=vocab_size,
                max_sequence_length=max_position_embeddings,
                init_method=init_method_normal(init_method_std),
                num_tokentypes=num_tokentypes,
                use_cpu_initialization=use_cpu_initialization,
                embedding_dropout_prob=hidden_dropout,
                add_position_embedding=add_position_embedding,
            )
            self._embedding_key = "embedding"

        if add_encoder:
            enc_layer_types = []
            for i in range(enc_num_layers):
                if i in enc_cross_attention:
                    enc_layer_types.append(LayerType.retrieval_encoder)
                else:
                    enc_layer_types.append(LayerType.encoder)
            self.encoder = get_encoder_model(
                arch="retro",
                hidden_size=hidden_size,
                ffn_hidden_size=ffn_hidden_size,
                num_layers=enc_num_layers,
                num_attention_heads=num_attention_heads,
                apply_query_key_layer_scaling=apply_query_key_layer_scaling,
                kv_channels=kv_channels,
                init_method=init_method_normal(init_method_std),
                scaled_init_method=scaled_init_method_normal(
                    init_method_std, dec_num_layers
                ),  # since the encoder is not independent of decoder, use decoder num of layers
                pre_process=pre_process,
                post_process=post_process,
                init_method_std=init_method_std,
                use_cpu_initialization=use_cpu_initialization,
                hidden_dropout=hidden_dropout,
                attention_dropout=attention_dropout,
                precision=precision,
                fp32_residual_connection=fp32_residual_connection,
                activations_checkpoint_method=activations_checkpoint_method,
                activations_checkpoint_num_layers=activations_checkpoint_num_layers,
                layernorm_epsilon=layernorm_epsilon,
                bias_gelu_fusion=bias_gelu_fusion,
                bias_dropout_add_fusion=bias_dropout_add_fusion,
                masked_softmax_fusion=masked_softmax_fusion,
                persist_layer_norm=persist_layer_norm,
                openai_gelu=openai_gelu,
                onnx_safe=onnx_safe,
                hidden_steps=hidden_steps,
                hidden_blocks=hidden_blocks,
                activation=activation,
                bias=bias,
                normalization=normalization,
                transformer_block_type=transformer_block_type,
                headscale=headscale,
                parent_model_type=ModelType.encoder_and_decoder,
                layer_type=enc_layer_types,
                chunk_size=chunk_size,
                layer_number_offset=0,
            )
            self._encoder_key = "encoder"

        if add_decoder:
            pre_decoder_num_layers = min(dec_cross_attention)
            pre_decoder_layer_types = []
            for i in range(pre_decoder_num_layers):
                pre_decoder_layer_types.append(LayerType.encoder)
            pre_decoder_layer_types.append(LayerType.decoder_pre_mlp)

            post_decoder_num_layers = dec_num_layers - pre_decoder_num_layers
            post_decoder_layer_types = []
            # the first layer in post decoder has to be chunked cross attention without self attention
            assert pre_decoder_num_layers in dec_cross_attention
            for i in range(post_decoder_num_layers):
                if i == 0:
                    post_decoder_layer_types.append(LayerType.retrieval_decoder_after_self_attn)
                elif i + pre_decoder_num_layers in dec_cross_attention:
                    post_decoder_layer_types.append(LayerType.retrieval_decoder)
                else:
                    post_decoder_layer_types.append(LayerType.encoder)

            # it is used to process the inputs for encoder to use as context (H in the paper)
            self.pre_decoder = get_decoder_model(
                arch="retro",
                hidden_size=hidden_size,
                ffn_hidden_size=ffn_hidden_size,
                num_layers=pre_decoder_num_layers + 1,
                num_attention_heads=num_attention_heads,
                apply_query_key_layer_scaling=apply_query_key_layer_scaling,
                kv_channels=kv_channels,
                init_method=init_method_normal(init_method_std),
                scaled_init_method=scaled_init_method_normal(init_method_std, dec_num_layers),
                pre_process=pre_process,
                post_process=False,  # no need for post process
                init_method_std=init_method_std,
                use_cpu_initialization=use_cpu_initialization,
                hidden_dropout=hidden_dropout,
                attention_dropout=attention_dropout,
                precision=precision,
                fp32_residual_connection=fp32_residual_connection,
                activations_checkpoint_method=activations_checkpoint_method,
                activations_checkpoint_num_layers=activations_checkpoint_num_layers,
                layernorm_epsilon=layernorm_epsilon,
                bias_gelu_fusion=bias_gelu_fusion,
                bias_dropout_add_fusion=bias_dropout_add_fusion,
                masked_softmax_fusion=masked_softmax_fusion,
                persist_layer_norm=persist_layer_norm,
                openai_gelu=openai_gelu,
                onnx_safe=onnx_safe,
                hidden_steps=hidden_steps,
                hidden_blocks=hidden_blocks,
                activation=activation,
                bias=bias,
                normalization=normalization,
                transformer_block_type=transformer_block_type,
                headscale=headscale,
                parent_model_type=ModelType.encoder_and_decoder,
                layer_type=pre_decoder_layer_types,
                chunk_size=chunk_size,
                layer_number_offset=0,
            )

            # it is where the chunked cross attention happens
            self.post_decoder = get_decoder_model(
                arch="retro",
                hidden_size=hidden_size,
                ffn_hidden_size=ffn_hidden_size,
                num_layers=post_decoder_num_layers,
                num_attention_heads=num_attention_heads,
                apply_query_key_layer_scaling=apply_query_key_layer_scaling,
                kv_channels=kv_channels,
                init_method=init_method_normal(init_method_std),
                scaled_init_method=scaled_init_method_normal(init_method_std, dec_num_layers),
                pre_process=False,  # directly take the pre_decoder output, skip preprocess
                post_process=post_process,
                init_method_std=init_method_std,
                use_cpu_initialization=use_cpu_initialization,
                hidden_dropout=hidden_dropout,
                attention_dropout=attention_dropout,
                precision=precision,
                fp32_residual_connection=fp32_residual_connection,
                activations_checkpoint_method=activations_checkpoint_method,
                activations_checkpoint_num_layers=activations_checkpoint_num_layers,
                layernorm_epsilon=layernorm_epsilon,
                bias_gelu_fusion=bias_gelu_fusion,
                bias_dropout_add_fusion=bias_dropout_add_fusion,
                masked_softmax_fusion=masked_softmax_fusion,
                persist_layer_norm=persist_layer_norm,
                openai_gelu=openai_gelu,
                onnx_safe=onnx_safe,
                hidden_steps=hidden_steps,
                hidden_blocks=hidden_blocks,
                activation=activation,
                bias=bias,
                normalization=normalization,
                headscale=headscale,
                transformer_block_type=transformer_block_type,
                parent_model_type=ModelType.encoder_and_decoder,
                layer_type=post_decoder_layer_types,
                chunk_size=chunk_size,
                layer_number_offset=pre_decoder_num_layers + 1,
            )
            self._pre_decoder_key = "pre_decoder"
            self._post_decoder_key = "post_decoder"

        self.initialize_word_embeddings(
            init_method=init_method_normal(init_method_std), vocab_size=vocab_size, hidden_size=hidden_size
        )

        if add_decoder and post_process:
            self.tokens_head = MegatronTokenLevelHead(self.word_embeddings_weight().size(0), parallel_output)
            self._tokens_head_key = 'tokens_head'
Exemplo n.º 11
0
def get_encoder_model(
    arch,
    hidden_size,
    ffn_hidden_size,
    num_layers,
    num_attention_heads,
    apply_query_key_layer_scaling=True,
    kv_channels=None,
    init_method=None,
    scaled_init_method=None,
    encoder_attn_mask_type=AttnMaskType.padding,
    pre_process=True,
    post_process=True,
    init_method_std=0.02,
    use_cpu_initialization=False,
    hidden_dropout=0.1,
    attention_dropout=0.1,
    precision=16,
    fp32_residual_connection=False,
    activations_checkpoint_method=None,
    activations_checkpoint_num_layers=1,
    layernorm_epsilon=1e-5,
    bias_gelu_fusion=True,
    masked_softmax_fusion=True,
    persist_layer_norm=False,
    openai_gelu=False,
    activation="gelu",
    onnx_safe=False,
    hidden_steps=-1,
    hidden_blocks=1,
):
    """Build language model and return along with the key to save."""

    if kv_channels is None:
        assert (
            hidden_size % num_attention_heads == 0
        ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
        kv_channels = hidden_size // num_attention_heads

    if init_method is None:
        init_method = init_method_normal(init_method_std)

    if scaled_init_method is None:
        scaled_init_method = scaled_init_method_normal(init_method_std,
                                                       num_layers)

    if arch == "transformer":
        # Language encoder.
        encoder = MegatronTransformerEncoderModule(
            init_method=init_method,
            output_layer_init_method=scaled_init_method,
            hidden_size=hidden_size,
            num_layers=num_layers,
            num_attention_heads=num_attention_heads,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            kv_channels=kv_channels,
            ffn_hidden_size=ffn_hidden_size,
            encoder_attn_mask_type=encoder_attn_mask_type,
            pre_process=pre_process,
            post_process=post_process,
            use_cpu_initialization=use_cpu_initialization,
            hidden_dropout=hidden_dropout,
            attention_dropout=attention_dropout,
            precision=precision,
            fp32_residual_connection=fp32_residual_connection,
            activations_checkpoint_method=activations_checkpoint_method,
            activations_checkpoint_num_layers=activations_checkpoint_num_layers,
            layernorm_epsilon=layernorm_epsilon,
            bias_gelu_fusion=bias_gelu_fusion,
            masked_softmax_fusion=masked_softmax_fusion,
            persist_layer_norm=persist_layer_norm,
            openai_gelu=openai_gelu,
            onnx_safe=onnx_safe,
            activation=activation,
        )
    else:
        raise ValueError(
            f"Unknown encoder arch = {arch}. Available encoder arch = {AVAILABLE_ENCODERS}"
        )

    return encoder
Exemplo n.º 12
0
    def __init__(
        self,
        encoder_arch,
        decoder_arch,
        vocab_size,
        hidden_size,
        max_position_embeddings,
        num_layers,
        num_attention_heads,
        ffn_hidden_size,
        apply_query_key_layer_scaling=True,
        kv_channels=None,
        num_tokentypes=0,
        parallel_output=True,
        pre_process=True,
        post_process=True,
        init_method_std=0.02,
        fp16_cross_entropy=False,
        use_cpu_initialization=False,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        position_embedding_type='learned_absolute',
        relative_attention_num_buckets=32,
        relative_attention_max_distance=128,
        precision=16,
        fp32_residual_connection=False,
        activations_checkpoint_method=None,
        activations_checkpoint_num_layers=1,
        layernorm_epsilon=1e-5,
        persist_layer_norm=False,
        bias_gelu_fusion=True,
        bias_dropout_add_fusion=True,
        masked_softmax_fusion=True,
        openai_gelu=False,
        activation='gelu',
        onnx_safe=False,
        bias=True,
        normalization='layernorm',
        transformer_block_type='pre_ln',
        hidden_steps=-1,
        hidden_blocks=1,
        headscale=False,
        add_encoder=True,
        add_decoder=True,
    ):
        super(MegatronTokenLevelEncoderDecoderModule, self).__init__()

        self.parallel_output = parallel_output
        self.pre_process = pre_process
        self.post_process = post_process
        self.fp16_cross_entropy = fp16_cross_entropy
        self.precision = precision
        self.add_encoder = add_encoder
        self.add_decoder = add_decoder
        self.normalization = normalization
        self.position_embedding_type = position_embedding_type
        self.relative_attention_num_buckets = relative_attention_num_buckets
        self.relative_attention_max_distance = relative_attention_max_distance

        if self.position_embedding_type == 'learned_absolute':
            add_position_embedding = True
        elif self.position_embedding_type == 'relative':
            add_position_embedding = False
        else:
            raise ValueError('Unknown position embeeding type. Options: '
                             '[learned_absolute | relative]')

        if kv_channels is None:
            assert (
                hidden_size % num_attention_heads == 0
            ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
            kv_channels = hidden_size // num_attention_heads

        encoder, decoder = None, None
        if add_encoder:
            if pre_process:
                self.encoder_embedding = Embedding(
                    hidden_size=hidden_size,
                    vocab_size=vocab_size,
                    max_sequence_length=max_position_embeddings,
                    init_method=init_method_normal(init_method_std),
                    num_tokentypes=num_tokentypes,
                    use_cpu_initialization=use_cpu_initialization,
                    embedding_dropout_prob=hidden_dropout,
                    add_position_embedding=add_position_embedding,
                )
                self._encoder_embedding_key = "encoder_embedding"

            encoder = get_encoder_model(
                arch=encoder_arch,
                hidden_size=hidden_size,
                ffn_hidden_size=ffn_hidden_size,
                num_layers=num_layers,
                num_attention_heads=num_attention_heads,
                apply_query_key_layer_scaling=apply_query_key_layer_scaling,
                kv_channels=kv_channels,
                init_method=init_method_normal(init_method_std),
                scaled_init_method=scaled_init_method_normal(
                    init_method_std, num_layers),
                encoder_attn_mask_type=AttnMaskType.padding,
                pre_process=pre_process,
                post_process=post_process,
                init_method_std=init_method_std,
                use_cpu_initialization=use_cpu_initialization,
                hidden_dropout=hidden_dropout,
                attention_dropout=attention_dropout,
                position_embedding_type=position_embedding_type,
                relative_attention_num_buckets=relative_attention_num_buckets,
                relative_attention_max_distance=relative_attention_max_distance,
                precision=precision,
                fp32_residual_connection=fp32_residual_connection,
                activations_checkpoint_method=activations_checkpoint_method,
                activations_checkpoint_num_layers=
                activations_checkpoint_num_layers,
                layernorm_epsilon=layernorm_epsilon,
                bias_gelu_fusion=bias_gelu_fusion,
                bias_dropout_add_fusion=bias_dropout_add_fusion,
                masked_softmax_fusion=masked_softmax_fusion,
                persist_layer_norm=persist_layer_norm,
                openai_gelu=openai_gelu,
                onnx_safe=onnx_safe,
                hidden_steps=hidden_steps,
                hidden_blocks=hidden_blocks,
                activation=activation,
                bias=bias,
                normalization=normalization,
                transformer_block_type=transformer_block_type,
                headscale=headscale,
                parent_model_type=ModelType.encoder_and_decoder,
            )

        if add_decoder:
            # If this is the decoder first stage
            if pre_process:
                # If the encoder also lies on this rank (PP = 1), then just assign embeddings directly.
                if hasattr(self, 'encoder_embedding'):
                    self.decoder_embedding = self.encoder_embedding
                else:
                    # This is the case where PP > 1 and first decoder first stage.
                    # We initialize decoder embeddings, but set them to zero since we they're tied with the encoder embeddings.
                    # A later initialize_embedding call will synchronize the embeddings.
                    self.decoder_embedding = Embedding(
                        hidden_size=hidden_size,
                        vocab_size=vocab_size,
                        max_sequence_length=max_position_embeddings,
                        init_method=init_method_normal(init_method_std),
                        num_tokentypes=num_tokentypes,
                        use_cpu_initialization=use_cpu_initialization,
                        embedding_dropout_prob=hidden_dropout,
                        add_position_embedding=add_position_embedding,
                    )
                    self.decoder_embedding.zero_parameters()

                self._decoder_embedding_key = "decoder_embedding"

            decoder = get_decoder_model(
                arch=decoder_arch,
                hidden_size=hidden_size,
                ffn_hidden_size=ffn_hidden_size,
                num_layers=num_layers,
                num_attention_heads=num_attention_heads,
                apply_query_key_layer_scaling=apply_query_key_layer_scaling,
                kv_channels=kv_channels,
                init_method=init_method_normal(init_method_std),
                scaled_init_method=scaled_init_method_normal(
                    init_method_std, num_layers),
                decoder_attn_mask_type=AttnMaskType.causal,
                pre_process=pre_process,
                post_process=post_process,
                init_method_std=init_method_std,
                use_cpu_initialization=use_cpu_initialization,
                hidden_dropout=hidden_dropout,
                attention_dropout=attention_dropout,
                position_embedding_type=position_embedding_type,
                relative_attention_num_buckets=relative_attention_num_buckets,
                relative_attention_max_distance=relative_attention_max_distance,
                precision=precision,
                fp32_residual_connection=fp32_residual_connection,
                activations_checkpoint_method=activations_checkpoint_method,
                activations_checkpoint_num_layers=
                activations_checkpoint_num_layers,
                layernorm_epsilon=layernorm_epsilon,
                bias_gelu_fusion=bias_gelu_fusion,
                bias_dropout_add_fusion=bias_dropout_add_fusion,
                masked_softmax_fusion=masked_softmax_fusion,
                persist_layer_norm=persist_layer_norm,
                openai_gelu=openai_gelu,
                onnx_safe=onnx_safe,
                hidden_steps=hidden_steps,
                hidden_blocks=hidden_blocks,
                activation=activation,
                bias=bias,
                normalization=normalization,
                transformer_block_type=transformer_block_type,
                headscale=headscale,
                parent_model_type=ModelType.encoder_and_decoder,
            )

        self.enc_dec_model = MegatronTransformerEncoderDecoderModule(
            encoder=encoder, decoder=decoder)
        self._enc_dec_model_key = "enc_dec_model"

        self.initialize_word_embeddings(
            init_method=init_method_normal(init_method_std),
            vocab_size=vocab_size,
            hidden_size=hidden_size)

        if add_decoder and post_process:
            self.tokens_head = MegatronTokenLevelHead(
                self.word_embeddings_weight().size(0), parallel_output)
            self._tokens_head_key = 'tokens_head'
Exemplo n.º 13
0
    def test_retrieval_encoder_inference(self):

        init_method_std = 0.02

        batch = 2
        neighbors = 2
        # rotary pos emb dim
        dim = 128
        pad_id = 19999
        num_attention_heads = 8
        chunks = 32
        text_chunk_size = 64
        input_length = chunks * text_chunk_size
        vocab_size = 20000

        hidden = torch.randint(
            0, vocab_size, (batch, input_length)).cuda()  # (seq, batch, dim)
        hidden_mask = (hidden != pad_id).cuda()

        hidden_emb = torch.rand(batch, input_length,
                                dim).cuda().half()  # (batch, seq, dim)
        retrieved = torch.randint(
            0, vocab_size,
            (batch, chunks, neighbors, 2 * text_chunk_size)).cuda()
        pad_id = vocab_size - 1
        context_mask = (retrieved != pad_id).cuda()
        retrieved_emb = torch.rand(batch, chunks, neighbors,
                                   2 * text_chunk_size, dim).cuda().half()

        layer_type = [
            LayerType.encoder, LayerType.retrieval_encoder, LayerType.encoder,
            LayerType.retrieval_encoder
        ]
        num_layers = len(layer_type)

        init_method = init_method_normal(init_method_std)
        scaled_init_method = scaled_init_method_normal(init_method_std,
                                                       num_layers)
        encoder = (MegatronRetrievalTransformerEncoderModule(
            init_method=init_method,
            output_layer_init_method=scaled_init_method,
            hidden_size=dim,
            ffn_hidden_size=dim * 4,
            num_layers=num_layers,
            num_attention_heads=num_attention_heads,
            precision=16,
            chunk_size=text_chunk_size,
            layer_type=layer_type,
            hidden_dropout=0.0,
            attention_dropout=0.0,
        ).cuda().half())
        out_gt = encoder(retrieved_emb,
                         context_mask,
                         context_attn_mask=hidden_mask,
                         encoder_output=hidden_emb)
        assert out_gt.shape == torch.Size(
            [batch, chunks, neighbors, 2 * text_chunk_size, dim])

        out_1 = encoder(
            None,
            None,
            context_attn_mask=hidden_mask[:, :62],
            encoder_output=hidden_emb[:, :62, :],
            set_inference_key_value_memory=True,
            inference_max_sequence_len=input_length,
            neighbors=neighbors,
        )
        assert out_1 is None
        out_1 = encoder(
            None,
            None,
            context_attn_mask=hidden_mask[:, :63],
            encoder_output=hidden_emb[:, 62:63],
            set_inference_key_value_memory=False,
            inference_max_sequence_len=input_length,
            neighbors=neighbors,
        )
        assert out_1 is None
        out_2 = encoder(
            retrieved_emb[:, :1],
            context_mask[:, :1],
            context_attn_mask=hidden_mask[:, :64],
            encoder_output=hidden_emb[:, 63:64],
            set_inference_key_value_memory=False,
            inference_max_sequence_len=input_length,
            neighbors=neighbors,
        )
        assert (encoder.encoder_output -
                hidden_emb[:, :64]).abs().max().item() < 1e-5
        assert (out_gt[:, 0, ] - out_2[:, 0]).abs().max().item() < 1e-2
        out_test = encoder(
            retrieved_emb[:, :1],
            context_mask[:, :1],
            context_attn_mask=hidden_mask[:, :64],
            encoder_output=hidden_emb[:, :64],
        )
        assert (out_gt[:, 0, ] - out_test[:, 0]).abs().max().item() < 1e-2
        assert (out_gt[:, 0, ] - out_2[:, 0]).abs().max().item() < 1e-2

        for i in range(64, 127):
            out_3 = encoder(
                retrieved_emb[:, :1],
                context_mask[:, :1],
                context_attn_mask=hidden_mask[:, :i + 1],
                encoder_output=hidden_emb[:, i:i + 1],
                set_inference_key_value_memory=False,
                inference_max_sequence_len=input_length,
                neighbors=neighbors,
            )
        i = 127
        out_3 = encoder(
            retrieved_emb[:, :2],
            context_mask[:, :2],
            context_attn_mask=hidden_mask[:, :i + 1],
            encoder_output=hidden_emb[:, i:i + 1],
            set_inference_key_value_memory=False,
            inference_max_sequence_len=input_length,
            neighbors=neighbors,
        )
        assert (encoder.encoder_output -
                hidden_emb[:, 64:128]).abs().max().item() < 1e-5
        assert (out_gt[:, :2, ] - out_3).abs().max().item() < 1e-2
        # test inference
        for i in range(128, 191):
            out_4 = encoder(
                retrieved_emb[:, :2],
                context_mask[:, :2],
                context_attn_mask=hidden_mask[:, :i + 1],
                encoder_output=hidden_emb[:, i:i + 1],
                set_inference_key_value_memory=False,
                inference_max_sequence_len=input_length,
                neighbors=neighbors,
            )
        i = 191
        out_4 = encoder(
            retrieved_emb[:, :3],
            context_mask[:, :3],
            context_attn_mask=hidden_mask[:, :i + 1],
            encoder_output=hidden_emb[:, i:i + 1],
            set_inference_key_value_memory=False,
            inference_max_sequence_len=input_length,
            neighbors=neighbors,
        )

        assert (encoder.encoder_output -
                hidden_emb[:, 128:192]).abs().max().item() < 1e-5
        assert (out_gt[:, :3, ] - out_4).abs().max().item() < 1e-2

        out_2 = encoder(
            retrieved_emb[:, :2],
            context_mask[:, :2],
            context_attn_mask=hidden_mask[:, :130],
            encoder_output=hidden_emb[:, :130, :],
            set_inference_key_value_memory=True,
            inference_max_sequence_len=input_length,
            neighbors=neighbors,
        )
        for i in range(130, 191):
            out_2 = encoder(
                retrieved_emb[:, :2],
                context_mask[:, :2],
                context_attn_mask=hidden_mask[:, :i + 1],
                encoder_output=hidden_emb[:, i:i + 1],
                set_inference_key_value_memory=False,
                inference_max_sequence_len=input_length,
                neighbors=neighbors,
            )
        i = 191
        out_4 = encoder(
            retrieved_emb[:, :3],
            context_mask[:, :3],
            context_attn_mask=hidden_mask[:, :i + 1],
            encoder_output=hidden_emb[:, i:i + 1],
            set_inference_key_value_memory=False,
            inference_max_sequence_len=input_length,
            neighbors=neighbors,
        )
        assert (encoder.encoder_output -
                hidden_emb[:, 128:192]).abs().max().item() < 1e-5
        assert (out_gt[:, :3, ] - out_4).abs().max().item() < 1e-2
Exemplo n.º 14
0
    def test_retrieval_decoder_inference(self):

        init_method_std = 0.02

        # rotary pos emb dim
        batch = 2
        neighbors = 2
        dim = 128
        pad_id = 19999
        num_attention_heads = 8
        chunks = 32
        text_chunk_size = 64
        input_length = chunks * text_chunk_size
        vocab_size = 20000
        # rot_dim = dim // num_attention_heads
        # rotary_pos_emb = RotaryEmbedding(rot_dim).cuda().half()
        hidden = torch.randint(
            0, vocab_size, (batch, input_length)).cuda()  # (seq, batch, dim)
        hidden_mask = (hidden != pad_id).cuda()

        hidden_emb = torch.rand(batch, input_length,
                                dim).cuda().half()  # (batch, seq, dim)

        # context_chunk_size = 128
        retrieved = torch.randint(
            0, vocab_size,
            (batch, chunks, neighbors, 2 * text_chunk_size)).cuda()
        # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation)

        # context attention mask [b, np, sq, sk]
        pad_id = vocab_size - 1
        context_mask = (retrieved != pad_id).cuda()
        retrieved_emb = torch.rand(batch, chunks, neighbors,
                                   2 * text_chunk_size, dim).cuda().half()
        # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation, hidden)

        layer_type = [
            LayerType.encoder, LayerType.retrieval_decoder, LayerType.encoder,
            LayerType.retrieval_decoder
        ]
        num_layers = len(layer_type)

        init_method = init_method_normal(init_method_std)
        scaled_init_method = scaled_init_method_normal(init_method_std,
                                                       num_layers)
        decoder = (MegatronRetrievalTransformerDecoderModule(
            init_method=init_method,
            output_layer_init_method=scaled_init_method,
            hidden_size=dim,
            ffn_hidden_size=dim * 4,
            num_layers=num_layers,
            num_attention_heads=num_attention_heads,
            precision=16,
            chunk_size=text_chunk_size,
            layer_type=layer_type,
            hidden_dropout=0.0,
            attention_dropout=0.0,
        ).cuda().half())
        out = decoder(hidden_emb,
                      hidden_mask,
                      retrieved_attn_mask=context_mask,
                      retrieved_emb=retrieved_emb)
        assert out.shape == torch.Size([batch, input_length, dim])

        out_1 = decoder(
            hidden_emb[:, :62],
            hidden_mask[:, :62],
            retrieved_attn_mask=None,
            retrieved_emb=None,
            set_inference_key_value_memory=True,
            inference_max_sequence_len=input_length,
        )
        assert (out[:, :62] - out_1[:, :62]).abs().max().item() < 1e-2
        out_1 = decoder(
            hidden_emb[:, 62:63],
            hidden_mask[:, :63],
            retrieved_attn_mask=None,
            retrieved_emb=None,
            set_inference_key_value_memory=False,
            inference_max_sequence_len=input_length,
        )
        assert (out[:, 62] - out_1[:, 0]).abs().max().item() < 1e-2
        out_2 = decoder(
            hidden_emb[:, 63:64],
            hidden_mask[:, :64],
            retrieved_attn_mask=context_mask[:, :1],
            retrieved_emb=retrieved_emb[:, :1],
            set_inference_key_value_memory=False,
            inference_max_sequence_len=input_length,
        )
        assert (out[:, 63] - out_2[:, 0]).abs().max().item() < 1e-2
        for i in range(64, 127):
            out_2 = decoder(
                hidden_emb[:, i:i + 1],
                hidden_mask[:, :i + 1],
                retrieved_attn_mask=context_mask[:, :1],
                retrieved_emb=retrieved_emb[:, :1],
                set_inference_key_value_memory=False,
                inference_max_sequence_len=input_length,
            )
            assert (out[:, i] - out_2[:, 0]).abs().max().item() < 1e-2
        for i in range(127, 191):
            out_3 = decoder(
                hidden_emb[:, i:i + 1],
                hidden_mask[:, :i + 1],
                retrieved_attn_mask=context_mask[:, :2],
                retrieved_emb=retrieved_emb[:, :2],
                set_inference_key_value_memory=False,
                inference_max_sequence_len=input_length,
            )
            assert (out[:, i] - out_3[:, 0]).abs().max().item() < 1e-2

        out_1 = decoder(
            hidden_emb[:, :130],
            hidden_mask[:, :130],
            retrieved_attn_mask=context_mask[:, :2],
            retrieved_emb=retrieved_emb[:, :2],
            set_inference_key_value_memory=True,
            inference_max_sequence_len=input_length,
        )
        assert (out[:, :130] - out_1[:, :130]).abs().max().item() < 1e-2
        for i in range(130, 191):
            out_3 = decoder(
                hidden_emb[:, i:i + 1],
                hidden_mask[:, :i + 1],
                retrieved_attn_mask=context_mask[:, :2],
                retrieved_emb=retrieved_emb[:, :2],
                set_inference_key_value_memory=False,
                inference_max_sequence_len=input_length,
            )
            assert (out[:, i] - out_3[:, 0]).abs().max().item() < 1e-2
Exemplo n.º 15
0
    def test_cross_attn_inference(self):
        num_layers = 1
        init_method_std = 0.02
        batch = 2
        neighbors = 2
        # rotary pos emb dim
        dim = 128
        pad_id = 19999
        num_attention_heads = 8
        chunks = 32
        text_chunk_size = 64
        context_chunk_size = 2 * text_chunk_size
        input_length = chunks * text_chunk_size
        vocab_size = 20000

        rot_dim = dim // num_attention_heads
        rotary_pos_emb = RotaryEmbedding(rot_dim).cuda().half()

        hidden = torch.randint(
            0, vocab_size, (input_length, batch)).cuda()  # (seq, batch, dim)
        hidden_mask = (hidden != pad_id).cuda()
        hidden_emb = torch.rand(input_length, batch,
                                dim).cuda().half()  # (seq, batch, dim)

        retrieved = torch.randint(
            0, vocab_size,
            (chunks, neighbors, context_chunk_size, batch)).cuda()
        # retrieved tokens - (num chunks, num retrieved neighbors, retrieved chunk with continuation, batch)

        # context attention mask [b, np, sq, sk]
        context_mask = (retrieved != pad_id).cuda()
        retrieved_emb = torch.rand(chunks, neighbors, context_chunk_size,
                                   batch, dim).cuda().half()
        # retrieved tokens - (num chunks, num retrieved neighbors, retrieved chunk with continuation, batch, hidden)

        # need to add extra chunk size, since it will be shifted
        cross_attn_q_pos_emb = rotary_pos_emb(text_chunk_size +
                                              text_chunk_size - 1,
                                              offset=-text_chunk_size + 1)
        cross_attn_k_pos_emb = rotary_pos_emb(context_chunk_size)
        cross_attn_pos_emb = (cross_attn_q_pos_emb, cross_attn_k_pos_emb)

        def get_attn_mask_3d(hidden_mask, context_mask, chunks):
            causal_padding = text_chunk_size - 1
            reminder = (text_chunk_size -
                        (hidden_mask.shape[0] + 1)) % text_chunk_size
            hidden_mask = F.pad(hidden_mask, (0, 0, -causal_padding, reminder),
                                value=False)

            dec_attn_mask = rearrange(hidden_mask,
                                      '(k n) b -> (b k) n',
                                      k=chunks)
            context_attn_mask = rearrange(context_mask,
                                          'k r n b -> (b k) (r n)')
            enc_dec_attn_mask_3d = build_attention_mask_3d(
                source_mask=dec_attn_mask,
                target_mask=context_attn_mask,
                attn_mask_type=AttnMaskType.padding,
            )
            enc_dec_attn_mask_3d = enc_dec_attn_mask_3d[:, None, :, :]
            return enc_dec_attn_mask_3d

        enc_dec_attn_mask_3d = get_attn_mask_3d(hidden_mask, context_mask,
                                                chunks)

        init_method = init_method_normal(init_method_std)

        scaled_init_method = scaled_init_method_normal(init_method_std,
                                                       num_layers)
        cross_attn = (ParallelChunkedCrossAttention(
            init_method=init_method,
            output_layer_init_method=scaled_init_method,
            layer_number=1,
            num_attention_heads=num_attention_heads,
            hidden_size=dim,
            precision=16,
            chunk_size=text_chunk_size,
        ).cuda().half())

        out, bias = cross_attn(hidden_emb,
                               enc_dec_attn_mask_3d,
                               encoder_output=retrieved_emb,
                               rotary_pos_emb=cross_attn_pos_emb)
        assert out.shape == torch.Size([input_length, batch, dim])
        assert bias.shape == torch.Size([dim])

        attn_mask_3d = None

        out_1, b = cross_attn(
            hidden_emb[:62],
            attn_mask_3d,
            encoder_output=None,
            rotary_pos_emb=cross_attn_pos_emb,
            set_inference_key_value_memory=True,
            inference_max_sequence_len=input_length,
        )
        assert (out_1 - torch.zeros_like(hidden_emb[:62])).abs().max() == 0
        out_1, b = cross_attn(
            hidden_emb[62:63],
            attn_mask_3d,
            encoder_output=None,
            rotary_pos_emb=cross_attn_pos_emb,
            set_inference_key_value_memory=False,
            inference_max_sequence_len=input_length,
        )
        assert (out_1 - torch.zeros_like(hidden_emb[62:63])).abs().max() == 0

        attn_mask_3d = get_attn_mask_3d(hidden_mask[:64], context_mask[:1], 1)
        out_2, b = cross_attn(
            hidden_emb[63:64],
            attn_mask_3d,
            encoder_output=retrieved_emb[:1],
            rotary_pos_emb=cross_attn_pos_emb,
            set_inference_key_value_memory=False,
            inference_max_sequence_len=input_length,
        )
        assert (out[63] - out_2[0]).abs().max().item() < 1e-2

        for i in range(64, 127):
            attn_mask_3d = get_attn_mask_3d(hidden_mask[:i + 1],
                                            context_mask[:1], 1)
            out_2, b = cross_attn(
                hidden_emb[i:i + 1],
                attn_mask_3d,
                encoder_output=retrieved_emb[:1],
                rotary_pos_emb=cross_attn_pos_emb,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=input_length,
            )
        i = 127
        attn_mask_3d = get_attn_mask_3d(hidden_mask[:i + 1], context_mask[:2],
                                        2)
        out_3, b = cross_attn(
            hidden_emb[i:i + 1],
            attn_mask_3d,
            encoder_output=retrieved_emb[:2],
            rotary_pos_emb=cross_attn_pos_emb,
            set_inference_key_value_memory=False,
            inference_max_sequence_len=input_length,
        )
        assert (out[i] - out_3[0]).abs().max().item() < 1e-2

        attn_mask_3d = get_attn_mask_3d(hidden_mask[:130], context_mask[:2], 2)

        out_1, b = cross_attn(
            hidden_emb[:130],
            attn_mask_3d,
            encoder_output=retrieved_emb[:2],
            rotary_pos_emb=cross_attn_pos_emb,
            set_inference_key_value_memory=True,
            inference_max_sequence_len=input_length,
        )

        assert (out[:130] - out_1[:130]).abs().max().item() < 1e-2

        for i in range(130, 191):
            attn_mask_3d = get_attn_mask_3d(hidden_mask[:i + 1],
                                            context_mask[:2], 2)
            out_2, b = cross_attn(
                hidden_emb[i:i + 1],
                attn_mask_3d,
                encoder_output=retrieved_emb[:2],
                rotary_pos_emb=cross_attn_pos_emb,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=input_length,
            )
        i = 191
        attn_mask_3d = get_attn_mask_3d(hidden_mask[:i + 1], context_mask[:3],
                                        3)
        out_4, b = cross_attn(
            hidden_emb[i:i + 1],
            attn_mask_3d,
            encoder_output=retrieved_emb[:3],
            rotary_pos_emb=cross_attn_pos_emb,
            set_inference_key_value_memory=False,
            inference_max_sequence_len=input_length,
        )
        assert (out[i] - out_4[0]).abs().max().item() < 1e-2
Exemplo n.º 16
0
    def __init__(
        self,
        encoder_arch,
        decoder_arch,
        vocab_size,
        hidden_size,
        max_position_embeddings,
        num_layers,
        num_attention_heads,
        ffn_hidden_size,
        apply_query_key_layer_scaling=True,
        kv_channels=None,
        num_tokentypes=0,
        parallel_output=True,
        pre_process=True,
        post_process=True,
        init_method_std=0.02,
        fp16_cross_entropy=False,
        use_cpu_initialization=False,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        precision=16,
        fp32_residual_connection=False,
        activations_checkpoint_method=None,
        activations_checkpoint_num_layers=1,
        layernorm_epsilon=1e-5,
        persist_layer_norm=False,
        bias_gelu_fusion=True,
        masked_softmax_fusion=True,
        openai_gelu=False,
        activation='gelu',
        onnx_safe=False,
        hidden_steps=-1,
        hidden_blocks=1,
    ):
        super(MegatronTokenLevelEncoderDecoderModule, self).__init__()

        self.parallel_output = parallel_output
        self.pre_process = pre_process
        self.post_process = post_process
        self.fp16_cross_entropy = fp16_cross_entropy
        self.precision = precision

        if kv_channels is None:
            assert (
                hidden_size % num_attention_heads == 0
            ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
            kv_channels = hidden_size // num_attention_heads

        # TODO: add get_embedding function to support various embedders (like prompt tuning)
        self.encoder_embedding = Embedding(
            hidden_size=hidden_size,
            vocab_size=vocab_size,
            max_sequence_length=max_position_embeddings,
            init_method=init_method_normal(init_method_std),
            num_tokentypes=num_tokentypes,
            use_cpu_initialization=use_cpu_initialization,
            embedding_dropout_prob=hidden_dropout,
        )
        self.decoder_embedding = self.encoder_embedding
        self._encoder_embedding_key = "encoder_embedding"
        self._decoder_embedding_key = "decoder_embedding"

        encoder = get_encoder_model(
            arch=encoder_arch,
            hidden_size=hidden_size,
            ffn_hidden_size=ffn_hidden_size,
            num_layers=num_layers,
            num_attention_heads=num_attention_heads,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            kv_channels=kv_channels,
            init_method=init_method_normal(init_method_std),
            scaled_init_method=scaled_init_method_normal(
                init_method_std, num_layers),
            encoder_attn_mask_type=AttnMaskType.padding,
            pre_process=pre_process,
            post_process=post_process,
            init_method_std=init_method_std,
            use_cpu_initialization=use_cpu_initialization,
            hidden_dropout=hidden_dropout,
            attention_dropout=attention_dropout,
            precision=precision,
            fp32_residual_connection=fp32_residual_connection,
            activations_checkpoint_method=activations_checkpoint_method,
            activations_checkpoint_num_layers=activations_checkpoint_num_layers,
            layernorm_epsilon=layernorm_epsilon,
            bias_gelu_fusion=bias_gelu_fusion,
            masked_softmax_fusion=masked_softmax_fusion,
            persist_layer_norm=persist_layer_norm,
            openai_gelu=openai_gelu,
            onnx_safe=onnx_safe,
            hidden_steps=hidden_steps,
            hidden_blocks=hidden_blocks,
            activation=activation,
        )

        decoder = get_decoder_model(
            arch=decoder_arch,
            hidden_size=hidden_size,
            ffn_hidden_size=ffn_hidden_size,
            num_layers=num_layers,
            num_attention_heads=num_attention_heads,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            kv_channels=kv_channels,
            init_method=init_method_normal(init_method_std),
            scaled_init_method=scaled_init_method_normal(
                init_method_std, num_layers),
            decoder_attn_mask_type=AttnMaskType.causal,
            pre_process=pre_process,
            post_process=post_process,
            init_method_std=init_method_std,
            use_cpu_initialization=use_cpu_initialization,
            hidden_dropout=hidden_dropout,
            attention_dropout=attention_dropout,
            precision=precision,
            fp32_residual_connection=fp32_residual_connection,
            activations_checkpoint_method=activations_checkpoint_method,
            activations_checkpoint_num_layers=activations_checkpoint_num_layers,
            layernorm_epsilon=layernorm_epsilon,
            bias_gelu_fusion=bias_gelu_fusion,
            masked_softmax_fusion=masked_softmax_fusion,
            persist_layer_norm=persist_layer_norm,
            openai_gelu=openai_gelu,
            onnx_safe=onnx_safe,
            hidden_steps=hidden_steps,
            hidden_blocks=hidden_blocks,
            activation=activation,
        )

        self.enc_dec_model = MegatronTransformerEncoderDecoderModule(
            encoder=encoder,
            decoder=decoder,
        )
        self._enc_dec_model_key = "enc_dec_model"

        self.tokens_head = MegatronTokenLevelHead(
            self.decoder_embedding.word_embeddings.weight.size(0),
            parallel_output)
        self._tokens_head_key = 'tokens_head'