def __init__(
        self,
        encoder_arch,
        decoder_arch,
        vocab_size,
        hidden_size,
        max_position_embeddings,
        num_layers,
        num_attention_heads,
        ffn_hidden_size,
        apply_query_key_layer_scaling=True,
        kv_channels=None,
        num_tokentypes=0,
        parallel_output=True,
        pre_process=True,
        post_process=True,
        init_method_std=0.02,
        fp16_cross_entropy=False,
        use_cpu_initialization=False,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        position_embedding_type='learned_absolute',
        relative_attention_num_buckets=32,
        relative_attention_max_distance=128,
        precision=16,
        fp32_residual_connection=False,
        activations_checkpoint_method=None,
        activations_checkpoint_num_layers=1,
        layernorm_epsilon=1e-5,
        persist_layer_norm=False,
        bias_gelu_fusion=True,
        bias_dropout_add_fusion=True,
        masked_softmax_fusion=True,
        openai_gelu=False,
        activation='gelu',
        onnx_safe=False,
        bias=True,
        normalization='layernorm',
        transformer_block_type='pre_ln',
        hidden_steps=-1,
        hidden_blocks=1,
        headscale=False,
        add_encoder=True,
        add_decoder=True,
    ):
        super(MegatronTokenLevelEncoderDecoderModule, self).__init__()

        self.parallel_output = parallel_output
        self.pre_process = pre_process
        self.post_process = post_process
        self.fp16_cross_entropy = fp16_cross_entropy
        self.precision = precision
        self.add_encoder = add_encoder
        self.add_decoder = add_decoder
        self.normalization = normalization
        self.position_embedding_type = position_embedding_type
        self.relative_attention_num_buckets = relative_attention_num_buckets
        self.relative_attention_max_distance = relative_attention_max_distance

        if self.position_embedding_type == 'learned_absolute':
            add_position_embedding = True
        elif self.position_embedding_type == 'relative':
            add_position_embedding = False
        else:
            raise ValueError('Unknown position embeeding type. Options: '
                             '[learned_absolute | relative]')

        if kv_channels is None:
            assert (
                hidden_size % num_attention_heads == 0
            ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
            kv_channels = hidden_size // num_attention_heads

        encoder, decoder = None, None
        if add_encoder:
            if pre_process:
                self.encoder_embedding = Embedding(
                    hidden_size=hidden_size,
                    vocab_size=vocab_size,
                    max_sequence_length=max_position_embeddings,
                    init_method=init_method_normal(init_method_std),
                    num_tokentypes=num_tokentypes,
                    use_cpu_initialization=use_cpu_initialization,
                    embedding_dropout_prob=hidden_dropout,
                    add_position_embedding=add_position_embedding,
                )
                self._encoder_embedding_key = "encoder_embedding"

            encoder = get_encoder_model(
                arch=encoder_arch,
                hidden_size=hidden_size,
                ffn_hidden_size=ffn_hidden_size,
                num_layers=num_layers,
                num_attention_heads=num_attention_heads,
                apply_query_key_layer_scaling=apply_query_key_layer_scaling,
                kv_channels=kv_channels,
                init_method=init_method_normal(init_method_std),
                scaled_init_method=scaled_init_method_normal(
                    init_method_std, num_layers),
                encoder_attn_mask_type=AttnMaskType.padding,
                pre_process=pre_process,
                post_process=post_process,
                init_method_std=init_method_std,
                use_cpu_initialization=use_cpu_initialization,
                hidden_dropout=hidden_dropout,
                attention_dropout=attention_dropout,
                position_embedding_type=position_embedding_type,
                relative_attention_num_buckets=relative_attention_num_buckets,
                relative_attention_max_distance=relative_attention_max_distance,
                precision=precision,
                fp32_residual_connection=fp32_residual_connection,
                activations_checkpoint_method=activations_checkpoint_method,
                activations_checkpoint_num_layers=
                activations_checkpoint_num_layers,
                layernorm_epsilon=layernorm_epsilon,
                bias_gelu_fusion=bias_gelu_fusion,
                bias_dropout_add_fusion=bias_dropout_add_fusion,
                masked_softmax_fusion=masked_softmax_fusion,
                persist_layer_norm=persist_layer_norm,
                openai_gelu=openai_gelu,
                onnx_safe=onnx_safe,
                hidden_steps=hidden_steps,
                hidden_blocks=hidden_blocks,
                activation=activation,
                bias=bias,
                normalization=normalization,
                transformer_block_type=transformer_block_type,
                headscale=headscale,
                parent_model_type=ModelType.encoder_and_decoder,
            )

        if add_decoder:
            # If this is the decoder first stage
            if pre_process:
                # If the encoder also lies on this rank (PP = 1), then just assign embeddings directly.
                if hasattr(self, 'encoder_embedding'):
                    self.decoder_embedding = self.encoder_embedding
                else:
                    # This is the case where PP > 1 and first decoder first stage.
                    # We initialize decoder embeddings, but set them to zero since we they're tied with the encoder embeddings.
                    # A later initialize_embedding call will synchronize the embeddings.
                    self.decoder_embedding = Embedding(
                        hidden_size=hidden_size,
                        vocab_size=vocab_size,
                        max_sequence_length=max_position_embeddings,
                        init_method=init_method_normal(init_method_std),
                        num_tokentypes=num_tokentypes,
                        use_cpu_initialization=use_cpu_initialization,
                        embedding_dropout_prob=hidden_dropout,
                        add_position_embedding=add_position_embedding,
                    )
                    self.decoder_embedding.zero_parameters()

                self._decoder_embedding_key = "decoder_embedding"

            decoder = get_decoder_model(
                arch=decoder_arch,
                hidden_size=hidden_size,
                ffn_hidden_size=ffn_hidden_size,
                num_layers=num_layers,
                num_attention_heads=num_attention_heads,
                apply_query_key_layer_scaling=apply_query_key_layer_scaling,
                kv_channels=kv_channels,
                init_method=init_method_normal(init_method_std),
                scaled_init_method=scaled_init_method_normal(
                    init_method_std, num_layers),
                decoder_attn_mask_type=AttnMaskType.causal,
                pre_process=pre_process,
                post_process=post_process,
                init_method_std=init_method_std,
                use_cpu_initialization=use_cpu_initialization,
                hidden_dropout=hidden_dropout,
                attention_dropout=attention_dropout,
                position_embedding_type=position_embedding_type,
                relative_attention_num_buckets=relative_attention_num_buckets,
                relative_attention_max_distance=relative_attention_max_distance,
                precision=precision,
                fp32_residual_connection=fp32_residual_connection,
                activations_checkpoint_method=activations_checkpoint_method,
                activations_checkpoint_num_layers=
                activations_checkpoint_num_layers,
                layernorm_epsilon=layernorm_epsilon,
                bias_gelu_fusion=bias_gelu_fusion,
                bias_dropout_add_fusion=bias_dropout_add_fusion,
                masked_softmax_fusion=masked_softmax_fusion,
                persist_layer_norm=persist_layer_norm,
                openai_gelu=openai_gelu,
                onnx_safe=onnx_safe,
                hidden_steps=hidden_steps,
                hidden_blocks=hidden_blocks,
                activation=activation,
                bias=bias,
                normalization=normalization,
                transformer_block_type=transformer_block_type,
                headscale=headscale,
                parent_model_type=ModelType.encoder_and_decoder,
            )

        self.enc_dec_model = MegatronTransformerEncoderDecoderModule(
            encoder=encoder, decoder=decoder)
        self._enc_dec_model_key = "enc_dec_model"

        self.initialize_word_embeddings(
            init_method=init_method_normal(init_method_std),
            vocab_size=vocab_size,
            hidden_size=hidden_size)

        if add_decoder and post_process:
            self.tokens_head = MegatronTokenLevelHead(
                self.word_embeddings_weight().size(0), parallel_output)
            self._tokens_head_key = 'tokens_head'
Beispiel #2
0
    def __init__(
        self,
        vocab_size,
        hidden_size,
        max_position_embeddings,
        num_attention_heads,
        ffn_hidden_size,
        apply_query_key_layer_scaling=True,
        kv_channels=None,
        num_tokentypes=0,
        parallel_output=True,
        pre_process=True,
        post_process=True,
        init_method_std=0.02,
        fp16_cross_entropy=False,
        use_cpu_initialization=False,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        precision=16,
        fp32_residual_connection=False,
        activations_checkpoint_method=None,
        activations_checkpoint_num_layers=1,
        layernorm_epsilon=1e-5,
        persist_layer_norm=False,
        bias_gelu_fusion=True,
        bias_dropout_add_fusion=True,
        masked_softmax_fusion=True,
        openai_gelu=False,
        activation='gelu',
        onnx_safe=False,
        bias=True,
        normalization='layernorm',
        headscale=False,
        transformer_block_type='pre_ln',
        hidden_steps=-1,
        hidden_blocks=1,
        add_encoder=True,
        add_decoder=True,
        chunk_size=64,
        enc_num_layers=4,  # total number of encoder layers
        dec_num_layers=6,  # total number of decoder layers
        enc_cross_attention=[3],  # layer numbers for cross attention
        dec_cross_attention=[3, 5],  # layer numbers for chunked cross attention
        add_position_embedding=False,
        tokenizer=None,  # tokenizer
    ):
        super(MegatronRetrievalTokenLevelEncoderDecoderModule, self).__init__()

        self.parallel_output = parallel_output
        self.pre_process = pre_process
        self.post_process = post_process
        self.fp16_cross_entropy = fp16_cross_entropy
        self.precision = precision
        self.add_encoder = add_encoder
        self.add_decoder = add_decoder
        self.add_abs_position_embedding = add_position_embedding  # whether use absolute position embedding
        self.tokenizer = tokenizer
        self.eod_id = tokenizer.eos_id

        if kv_channels is None:
            assert (
                hidden_size % num_attention_heads == 0
            ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
            kv_channels = hidden_size // num_attention_heads

        if pre_process:
            self.encoder_embedding = Embedding(
                hidden_size=hidden_size,
                vocab_size=vocab_size,
                max_sequence_length=max_position_embeddings,
                init_method=init_method_normal(init_method_std),
                num_tokentypes=num_tokentypes,
                use_cpu_initialization=use_cpu_initialization,
                embedding_dropout_prob=hidden_dropout,
                add_position_embedding=add_position_embedding,
            )
            self._embedding_key = "embedding"

        if add_encoder:
            enc_layer_types = []
            for i in range(enc_num_layers):
                if i in enc_cross_attention:
                    enc_layer_types.append(LayerType.retrieval_encoder)
                else:
                    enc_layer_types.append(LayerType.encoder)
            self.encoder = get_encoder_model(
                arch="retro",
                hidden_size=hidden_size,
                ffn_hidden_size=ffn_hidden_size,
                num_layers=enc_num_layers,
                num_attention_heads=num_attention_heads,
                apply_query_key_layer_scaling=apply_query_key_layer_scaling,
                kv_channels=kv_channels,
                init_method=init_method_normal(init_method_std),
                scaled_init_method=scaled_init_method_normal(
                    init_method_std, dec_num_layers
                ),  # since the encoder is not independent of decoder, use decoder num of layers
                pre_process=pre_process,
                post_process=post_process,
                init_method_std=init_method_std,
                use_cpu_initialization=use_cpu_initialization,
                hidden_dropout=hidden_dropout,
                attention_dropout=attention_dropout,
                precision=precision,
                fp32_residual_connection=fp32_residual_connection,
                activations_checkpoint_method=activations_checkpoint_method,
                activations_checkpoint_num_layers=activations_checkpoint_num_layers,
                layernorm_epsilon=layernorm_epsilon,
                bias_gelu_fusion=bias_gelu_fusion,
                bias_dropout_add_fusion=bias_dropout_add_fusion,
                masked_softmax_fusion=masked_softmax_fusion,
                persist_layer_norm=persist_layer_norm,
                openai_gelu=openai_gelu,
                onnx_safe=onnx_safe,
                hidden_steps=hidden_steps,
                hidden_blocks=hidden_blocks,
                activation=activation,
                bias=bias,
                normalization=normalization,
                transformer_block_type=transformer_block_type,
                headscale=headscale,
                parent_model_type=ModelType.encoder_and_decoder,
                layer_type=enc_layer_types,
                chunk_size=chunk_size,
                layer_number_offset=0,
            )
            self._encoder_key = "encoder"

        if add_decoder:
            pre_decoder_num_layers = min(dec_cross_attention)
            pre_decoder_layer_types = []
            for i in range(pre_decoder_num_layers):
                pre_decoder_layer_types.append(LayerType.encoder)
            pre_decoder_layer_types.append(LayerType.decoder_pre_mlp)

            post_decoder_num_layers = dec_num_layers - pre_decoder_num_layers
            post_decoder_layer_types = []
            # the first layer in post decoder has to be chunked cross attention without self attention
            assert pre_decoder_num_layers in dec_cross_attention
            for i in range(post_decoder_num_layers):
                if i == 0:
                    post_decoder_layer_types.append(LayerType.retrieval_decoder_after_self_attn)
                elif i + pre_decoder_num_layers in dec_cross_attention:
                    post_decoder_layer_types.append(LayerType.retrieval_decoder)
                else:
                    post_decoder_layer_types.append(LayerType.encoder)

            # it is used to process the inputs for encoder to use as context (H in the paper)
            self.pre_decoder = get_decoder_model(
                arch="retro",
                hidden_size=hidden_size,
                ffn_hidden_size=ffn_hidden_size,
                num_layers=pre_decoder_num_layers + 1,
                num_attention_heads=num_attention_heads,
                apply_query_key_layer_scaling=apply_query_key_layer_scaling,
                kv_channels=kv_channels,
                init_method=init_method_normal(init_method_std),
                scaled_init_method=scaled_init_method_normal(init_method_std, dec_num_layers),
                pre_process=pre_process,
                post_process=False,  # no need for post process
                init_method_std=init_method_std,
                use_cpu_initialization=use_cpu_initialization,
                hidden_dropout=hidden_dropout,
                attention_dropout=attention_dropout,
                precision=precision,
                fp32_residual_connection=fp32_residual_connection,
                activations_checkpoint_method=activations_checkpoint_method,
                activations_checkpoint_num_layers=activations_checkpoint_num_layers,
                layernorm_epsilon=layernorm_epsilon,
                bias_gelu_fusion=bias_gelu_fusion,
                bias_dropout_add_fusion=bias_dropout_add_fusion,
                masked_softmax_fusion=masked_softmax_fusion,
                persist_layer_norm=persist_layer_norm,
                openai_gelu=openai_gelu,
                onnx_safe=onnx_safe,
                hidden_steps=hidden_steps,
                hidden_blocks=hidden_blocks,
                activation=activation,
                bias=bias,
                normalization=normalization,
                transformer_block_type=transformer_block_type,
                headscale=headscale,
                parent_model_type=ModelType.encoder_and_decoder,
                layer_type=pre_decoder_layer_types,
                chunk_size=chunk_size,
                layer_number_offset=0,
            )

            # it is where the chunked cross attention happens
            self.post_decoder = get_decoder_model(
                arch="retro",
                hidden_size=hidden_size,
                ffn_hidden_size=ffn_hidden_size,
                num_layers=post_decoder_num_layers,
                num_attention_heads=num_attention_heads,
                apply_query_key_layer_scaling=apply_query_key_layer_scaling,
                kv_channels=kv_channels,
                init_method=init_method_normal(init_method_std),
                scaled_init_method=scaled_init_method_normal(init_method_std, dec_num_layers),
                pre_process=False,  # directly take the pre_decoder output, skip preprocess
                post_process=post_process,
                init_method_std=init_method_std,
                use_cpu_initialization=use_cpu_initialization,
                hidden_dropout=hidden_dropout,
                attention_dropout=attention_dropout,
                precision=precision,
                fp32_residual_connection=fp32_residual_connection,
                activations_checkpoint_method=activations_checkpoint_method,
                activations_checkpoint_num_layers=activations_checkpoint_num_layers,
                layernorm_epsilon=layernorm_epsilon,
                bias_gelu_fusion=bias_gelu_fusion,
                bias_dropout_add_fusion=bias_dropout_add_fusion,
                masked_softmax_fusion=masked_softmax_fusion,
                persist_layer_norm=persist_layer_norm,
                openai_gelu=openai_gelu,
                onnx_safe=onnx_safe,
                hidden_steps=hidden_steps,
                hidden_blocks=hidden_blocks,
                activation=activation,
                bias=bias,
                normalization=normalization,
                headscale=headscale,
                transformer_block_type=transformer_block_type,
                parent_model_type=ModelType.encoder_and_decoder,
                layer_type=post_decoder_layer_types,
                chunk_size=chunk_size,
                layer_number_offset=pre_decoder_num_layers + 1,
            )
            self._pre_decoder_key = "pre_decoder"
            self._post_decoder_key = "post_decoder"

        self.initialize_word_embeddings(
            init_method=init_method_normal(init_method_std), vocab_size=vocab_size, hidden_size=hidden_size
        )

        if add_decoder and post_process:
            self.tokens_head = MegatronTokenLevelHead(self.word_embeddings_weight().size(0), parallel_output)
            self._tokens_head_key = 'tokens_head'
Beispiel #3
0
    def __init__(
        self,
        encoder_arch,
        decoder_arch,
        vocab_size,
        hidden_size,
        max_position_embeddings,
        num_layers,
        num_attention_heads,
        ffn_hidden_size,
        apply_query_key_layer_scaling=True,
        kv_channels=None,
        num_tokentypes=0,
        parallel_output=True,
        pre_process=True,
        post_process=True,
        init_method_std=0.02,
        fp16_cross_entropy=False,
        use_cpu_initialization=False,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        precision=16,
        fp32_residual_connection=False,
        activations_checkpoint_method=None,
        activations_checkpoint_num_layers=1,
        layernorm_epsilon=1e-5,
        persist_layer_norm=False,
        bias_gelu_fusion=True,
        masked_softmax_fusion=True,
        openai_gelu=False,
        activation='gelu',
        onnx_safe=False,
        hidden_steps=-1,
        hidden_blocks=1,
    ):
        super(MegatronTokenLevelEncoderDecoderModule, self).__init__()

        self.parallel_output = parallel_output
        self.pre_process = pre_process
        self.post_process = post_process
        self.fp16_cross_entropy = fp16_cross_entropy
        self.precision = precision

        if kv_channels is None:
            assert (
                hidden_size % num_attention_heads == 0
            ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
            kv_channels = hidden_size // num_attention_heads

        # TODO: add get_embedding function to support various embedders (like prompt tuning)
        self.encoder_embedding = Embedding(
            hidden_size=hidden_size,
            vocab_size=vocab_size,
            max_sequence_length=max_position_embeddings,
            init_method=init_method_normal(init_method_std),
            num_tokentypes=num_tokentypes,
            use_cpu_initialization=use_cpu_initialization,
            embedding_dropout_prob=hidden_dropout,
        )
        self.decoder_embedding = self.encoder_embedding
        self._encoder_embedding_key = "encoder_embedding"
        self._decoder_embedding_key = "decoder_embedding"

        encoder = get_encoder_model(
            arch=encoder_arch,
            hidden_size=hidden_size,
            ffn_hidden_size=ffn_hidden_size,
            num_layers=num_layers,
            num_attention_heads=num_attention_heads,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            kv_channels=kv_channels,
            init_method=init_method_normal(init_method_std),
            scaled_init_method=scaled_init_method_normal(
                init_method_std, num_layers),
            encoder_attn_mask_type=AttnMaskType.padding,
            pre_process=pre_process,
            post_process=post_process,
            init_method_std=init_method_std,
            use_cpu_initialization=use_cpu_initialization,
            hidden_dropout=hidden_dropout,
            attention_dropout=attention_dropout,
            precision=precision,
            fp32_residual_connection=fp32_residual_connection,
            activations_checkpoint_method=activations_checkpoint_method,
            activations_checkpoint_num_layers=activations_checkpoint_num_layers,
            layernorm_epsilon=layernorm_epsilon,
            bias_gelu_fusion=bias_gelu_fusion,
            masked_softmax_fusion=masked_softmax_fusion,
            persist_layer_norm=persist_layer_norm,
            openai_gelu=openai_gelu,
            onnx_safe=onnx_safe,
            hidden_steps=hidden_steps,
            hidden_blocks=hidden_blocks,
            activation=activation,
        )

        decoder = get_decoder_model(
            arch=decoder_arch,
            hidden_size=hidden_size,
            ffn_hidden_size=ffn_hidden_size,
            num_layers=num_layers,
            num_attention_heads=num_attention_heads,
            apply_query_key_layer_scaling=apply_query_key_layer_scaling,
            kv_channels=kv_channels,
            init_method=init_method_normal(init_method_std),
            scaled_init_method=scaled_init_method_normal(
                init_method_std, num_layers),
            decoder_attn_mask_type=AttnMaskType.causal,
            pre_process=pre_process,
            post_process=post_process,
            init_method_std=init_method_std,
            use_cpu_initialization=use_cpu_initialization,
            hidden_dropout=hidden_dropout,
            attention_dropout=attention_dropout,
            precision=precision,
            fp32_residual_connection=fp32_residual_connection,
            activations_checkpoint_method=activations_checkpoint_method,
            activations_checkpoint_num_layers=activations_checkpoint_num_layers,
            layernorm_epsilon=layernorm_epsilon,
            bias_gelu_fusion=bias_gelu_fusion,
            masked_softmax_fusion=masked_softmax_fusion,
            persist_layer_norm=persist_layer_norm,
            openai_gelu=openai_gelu,
            onnx_safe=onnx_safe,
            hidden_steps=hidden_steps,
            hidden_blocks=hidden_blocks,
            activation=activation,
        )

        self.enc_dec_model = MegatronTransformerEncoderDecoderModule(
            encoder=encoder,
            decoder=decoder,
        )
        self._enc_dec_model_key = "enc_dec_model"

        self.tokens_head = MegatronTokenLevelHead(
            self.decoder_embedding.word_embeddings.weight.size(0),
            parallel_output)
        self._tokens_head_key = 'tokens_head'