def __init__( self, encoder_arch, decoder_arch, vocab_size, hidden_size, max_position_embeddings, num_layers, num_attention_heads, ffn_hidden_size, apply_query_key_layer_scaling=True, kv_channels=None, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True, init_method_std=0.02, fp16_cross_entropy=False, use_cpu_initialization=False, hidden_dropout=0.1, attention_dropout=0.1, position_embedding_type='learned_absolute', relative_attention_num_buckets=32, relative_attention_max_distance=128, precision=16, fp32_residual_connection=False, activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, persist_layer_norm=False, bias_gelu_fusion=True, bias_dropout_add_fusion=True, masked_softmax_fusion=True, openai_gelu=False, activation='gelu', onnx_safe=False, bias=True, normalization='layernorm', transformer_block_type='pre_ln', hidden_steps=-1, hidden_blocks=1, headscale=False, add_encoder=True, add_decoder=True, ): super(MegatronTokenLevelEncoderDecoderModule, self).__init__() self.parallel_output = parallel_output self.pre_process = pre_process self.post_process = post_process self.fp16_cross_entropy = fp16_cross_entropy self.precision = precision self.add_encoder = add_encoder self.add_decoder = add_decoder self.normalization = normalization self.position_embedding_type = position_embedding_type self.relative_attention_num_buckets = relative_attention_num_buckets self.relative_attention_max_distance = relative_attention_max_distance if self.position_embedding_type == 'learned_absolute': add_position_embedding = True elif self.position_embedding_type == 'relative': add_position_embedding = False else: raise ValueError('Unknown position embeeding type. Options: ' '[learned_absolute | relative]') if kv_channels is None: assert ( hidden_size % num_attention_heads == 0 ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' kv_channels = hidden_size // num_attention_heads encoder, decoder = None, None if add_encoder: if pre_process: self.encoder_embedding = Embedding( hidden_size=hidden_size, vocab_size=vocab_size, max_sequence_length=max_position_embeddings, init_method=init_method_normal(init_method_std), num_tokentypes=num_tokentypes, use_cpu_initialization=use_cpu_initialization, embedding_dropout_prob=hidden_dropout, add_position_embedding=add_position_embedding, ) self._encoder_embedding_key = "encoder_embedding" encoder = get_encoder_model( arch=encoder_arch, hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_layers=num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, init_method=init_method_normal(init_method_std), scaled_init_method=scaled_init_method_normal( init_method_std, num_layers), encoder_attn_mask_type=AttnMaskType.padding, pre_process=pre_process, post_process=post_process, init_method_std=init_method_std, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, position_embedding_type=position_embedding_type, relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers= activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, hidden_steps=hidden_steps, hidden_blocks=hidden_blocks, activation=activation, bias=bias, normalization=normalization, transformer_block_type=transformer_block_type, headscale=headscale, parent_model_type=ModelType.encoder_and_decoder, ) if add_decoder: # If this is the decoder first stage if pre_process: # If the encoder also lies on this rank (PP = 1), then just assign embeddings directly. if hasattr(self, 'encoder_embedding'): self.decoder_embedding = self.encoder_embedding else: # This is the case where PP > 1 and first decoder first stage. # We initialize decoder embeddings, but set them to zero since we they're tied with the encoder embeddings. # A later initialize_embedding call will synchronize the embeddings. self.decoder_embedding = Embedding( hidden_size=hidden_size, vocab_size=vocab_size, max_sequence_length=max_position_embeddings, init_method=init_method_normal(init_method_std), num_tokentypes=num_tokentypes, use_cpu_initialization=use_cpu_initialization, embedding_dropout_prob=hidden_dropout, add_position_embedding=add_position_embedding, ) self.decoder_embedding.zero_parameters() self._decoder_embedding_key = "decoder_embedding" decoder = get_decoder_model( arch=decoder_arch, hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_layers=num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, init_method=init_method_normal(init_method_std), scaled_init_method=scaled_init_method_normal( init_method_std, num_layers), decoder_attn_mask_type=AttnMaskType.causal, pre_process=pre_process, post_process=post_process, init_method_std=init_method_std, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, position_embedding_type=position_embedding_type, relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers= activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, hidden_steps=hidden_steps, hidden_blocks=hidden_blocks, activation=activation, bias=bias, normalization=normalization, transformer_block_type=transformer_block_type, headscale=headscale, parent_model_type=ModelType.encoder_and_decoder, ) self.enc_dec_model = MegatronTransformerEncoderDecoderModule( encoder=encoder, decoder=decoder) self._enc_dec_model_key = "enc_dec_model" self.initialize_word_embeddings( init_method=init_method_normal(init_method_std), vocab_size=vocab_size, hidden_size=hidden_size) if add_decoder and post_process: self.tokens_head = MegatronTokenLevelHead( self.word_embeddings_weight().size(0), parallel_output) self._tokens_head_key = 'tokens_head'
def __init__( self, vocab_size, hidden_size, max_position_embeddings, num_attention_heads, ffn_hidden_size, apply_query_key_layer_scaling=True, kv_channels=None, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True, init_method_std=0.02, fp16_cross_entropy=False, use_cpu_initialization=False, hidden_dropout=0.1, attention_dropout=0.1, precision=16, fp32_residual_connection=False, activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, persist_layer_norm=False, bias_gelu_fusion=True, bias_dropout_add_fusion=True, masked_softmax_fusion=True, openai_gelu=False, activation='gelu', onnx_safe=False, bias=True, normalization='layernorm', headscale=False, transformer_block_type='pre_ln', hidden_steps=-1, hidden_blocks=1, add_encoder=True, add_decoder=True, chunk_size=64, enc_num_layers=4, # total number of encoder layers dec_num_layers=6, # total number of decoder layers enc_cross_attention=[3], # layer numbers for cross attention dec_cross_attention=[3, 5], # layer numbers for chunked cross attention add_position_embedding=False, tokenizer=None, # tokenizer ): super(MegatronRetrievalTokenLevelEncoderDecoderModule, self).__init__() self.parallel_output = parallel_output self.pre_process = pre_process self.post_process = post_process self.fp16_cross_entropy = fp16_cross_entropy self.precision = precision self.add_encoder = add_encoder self.add_decoder = add_decoder self.add_abs_position_embedding = add_position_embedding # whether use absolute position embedding self.tokenizer = tokenizer self.eod_id = tokenizer.eos_id if kv_channels is None: assert ( hidden_size % num_attention_heads == 0 ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' kv_channels = hidden_size // num_attention_heads if pre_process: self.encoder_embedding = Embedding( hidden_size=hidden_size, vocab_size=vocab_size, max_sequence_length=max_position_embeddings, init_method=init_method_normal(init_method_std), num_tokentypes=num_tokentypes, use_cpu_initialization=use_cpu_initialization, embedding_dropout_prob=hidden_dropout, add_position_embedding=add_position_embedding, ) self._embedding_key = "embedding" if add_encoder: enc_layer_types = [] for i in range(enc_num_layers): if i in enc_cross_attention: enc_layer_types.append(LayerType.retrieval_encoder) else: enc_layer_types.append(LayerType.encoder) self.encoder = get_encoder_model( arch="retro", hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_layers=enc_num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, init_method=init_method_normal(init_method_std), scaled_init_method=scaled_init_method_normal( init_method_std, dec_num_layers ), # since the encoder is not independent of decoder, use decoder num of layers pre_process=pre_process, post_process=post_process, init_method_std=init_method_std, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, hidden_steps=hidden_steps, hidden_blocks=hidden_blocks, activation=activation, bias=bias, normalization=normalization, transformer_block_type=transformer_block_type, headscale=headscale, parent_model_type=ModelType.encoder_and_decoder, layer_type=enc_layer_types, chunk_size=chunk_size, layer_number_offset=0, ) self._encoder_key = "encoder" if add_decoder: pre_decoder_num_layers = min(dec_cross_attention) pre_decoder_layer_types = [] for i in range(pre_decoder_num_layers): pre_decoder_layer_types.append(LayerType.encoder) pre_decoder_layer_types.append(LayerType.decoder_pre_mlp) post_decoder_num_layers = dec_num_layers - pre_decoder_num_layers post_decoder_layer_types = [] # the first layer in post decoder has to be chunked cross attention without self attention assert pre_decoder_num_layers in dec_cross_attention for i in range(post_decoder_num_layers): if i == 0: post_decoder_layer_types.append(LayerType.retrieval_decoder_after_self_attn) elif i + pre_decoder_num_layers in dec_cross_attention: post_decoder_layer_types.append(LayerType.retrieval_decoder) else: post_decoder_layer_types.append(LayerType.encoder) # it is used to process the inputs for encoder to use as context (H in the paper) self.pre_decoder = get_decoder_model( arch="retro", hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_layers=pre_decoder_num_layers + 1, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, init_method=init_method_normal(init_method_std), scaled_init_method=scaled_init_method_normal(init_method_std, dec_num_layers), pre_process=pre_process, post_process=False, # no need for post process init_method_std=init_method_std, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, hidden_steps=hidden_steps, hidden_blocks=hidden_blocks, activation=activation, bias=bias, normalization=normalization, transformer_block_type=transformer_block_type, headscale=headscale, parent_model_type=ModelType.encoder_and_decoder, layer_type=pre_decoder_layer_types, chunk_size=chunk_size, layer_number_offset=0, ) # it is where the chunked cross attention happens self.post_decoder = get_decoder_model( arch="retro", hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_layers=post_decoder_num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, init_method=init_method_normal(init_method_std), scaled_init_method=scaled_init_method_normal(init_method_std, dec_num_layers), pre_process=False, # directly take the pre_decoder output, skip preprocess post_process=post_process, init_method_std=init_method_std, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, hidden_steps=hidden_steps, hidden_blocks=hidden_blocks, activation=activation, bias=bias, normalization=normalization, headscale=headscale, transformer_block_type=transformer_block_type, parent_model_type=ModelType.encoder_and_decoder, layer_type=post_decoder_layer_types, chunk_size=chunk_size, layer_number_offset=pre_decoder_num_layers + 1, ) self._pre_decoder_key = "pre_decoder" self._post_decoder_key = "post_decoder" self.initialize_word_embeddings( init_method=init_method_normal(init_method_std), vocab_size=vocab_size, hidden_size=hidden_size ) if add_decoder and post_process: self.tokens_head = MegatronTokenLevelHead(self.word_embeddings_weight().size(0), parallel_output) self._tokens_head_key = 'tokens_head'
def __init__( self, encoder_arch, decoder_arch, vocab_size, hidden_size, max_position_embeddings, num_layers, num_attention_heads, ffn_hidden_size, apply_query_key_layer_scaling=True, kv_channels=None, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True, init_method_std=0.02, fp16_cross_entropy=False, use_cpu_initialization=False, hidden_dropout=0.1, attention_dropout=0.1, precision=16, fp32_residual_connection=False, activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, persist_layer_norm=False, bias_gelu_fusion=True, masked_softmax_fusion=True, openai_gelu=False, activation='gelu', onnx_safe=False, hidden_steps=-1, hidden_blocks=1, ): super(MegatronTokenLevelEncoderDecoderModule, self).__init__() self.parallel_output = parallel_output self.pre_process = pre_process self.post_process = post_process self.fp16_cross_entropy = fp16_cross_entropy self.precision = precision if kv_channels is None: assert ( hidden_size % num_attention_heads == 0 ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' kv_channels = hidden_size // num_attention_heads # TODO: add get_embedding function to support various embedders (like prompt tuning) self.encoder_embedding = Embedding( hidden_size=hidden_size, vocab_size=vocab_size, max_sequence_length=max_position_embeddings, init_method=init_method_normal(init_method_std), num_tokentypes=num_tokentypes, use_cpu_initialization=use_cpu_initialization, embedding_dropout_prob=hidden_dropout, ) self.decoder_embedding = self.encoder_embedding self._encoder_embedding_key = "encoder_embedding" self._decoder_embedding_key = "decoder_embedding" encoder = get_encoder_model( arch=encoder_arch, hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_layers=num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, init_method=init_method_normal(init_method_std), scaled_init_method=scaled_init_method_normal( init_method_std, num_layers), encoder_attn_mask_type=AttnMaskType.padding, pre_process=pre_process, post_process=post_process, init_method_std=init_method_std, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, hidden_steps=hidden_steps, hidden_blocks=hidden_blocks, activation=activation, ) decoder = get_decoder_model( arch=decoder_arch, hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_layers=num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, init_method=init_method_normal(init_method_std), scaled_init_method=scaled_init_method_normal( init_method_std, num_layers), decoder_attn_mask_type=AttnMaskType.causal, pre_process=pre_process, post_process=post_process, init_method_std=init_method_std, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, hidden_steps=hidden_steps, hidden_blocks=hidden_blocks, activation=activation, ) self.enc_dec_model = MegatronTransformerEncoderDecoderModule( encoder=encoder, decoder=decoder, ) self._enc_dec_model_key = "enc_dec_model" self.tokens_head = MegatronTokenLevelHead( self.decoder_embedding.word_embeddings.weight.size(0), parallel_output) self._tokens_head_key = 'tokens_head'