def test_retrival_decoder(self): init_method_std = 0.02 # rotary pos emb dim batch = 2 neighbors = 2 dim = 128 pad_id = 19999 num_attention_heads = 8 chunks = 32 text_chunk_size = 64 input_length = chunks * text_chunk_size vocab_size = 20000 # rot_dim = dim // num_attention_heads # rotary_pos_emb = RotaryEmbedding(rot_dim).cuda().half() hidden = torch.randint( 0, vocab_size, (batch, input_length)).cuda() # (seq, batch, dim) hidden_mask = (hidden != pad_id).cuda() hidden_emb = torch.rand(batch, input_length, dim).cuda().half() # (batch, seq, dim) # context_chunk_size = 128 retrieved = torch.randint( 0, vocab_size, (batch, chunks, neighbors, 2 * chunks)).cuda() # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation) # context attention mask [b, np, sq, sk] pad_id = vocab_size - 1 context_mask = (retrieved != pad_id).cuda() retrieved_emb = torch.rand(batch, chunks, neighbors, 2 * chunks, dim).cuda().half() # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation, hidden) layer_type = [ LayerType.encoder, LayerType.retrieval_decoder, LayerType.encoder, LayerType.retrieval_decoder ] num_layers = len(layer_type) init_method = init_method_normal(init_method_std) scaled_init_method = scaled_init_method_normal(init_method_std, num_layers) decoder = (MegatronRetrievalTransformerDecoderModule( init_method=init_method, output_layer_init_method=scaled_init_method, hidden_size=dim, ffn_hidden_size=dim * 4, num_layers=num_layers, num_attention_heads=num_attention_heads, precision=16, chunk_size=text_chunk_size, layer_type=layer_type, ).cuda().half()) out = decoder(hidden_emb, hidden_mask, retrieved_attn_mask=context_mask, retrieved_emb=retrieved_emb)
def test_retrieval_encoder(self): init_method_std = 0.02 batch = 2 neighbors = 2 # rotary pos emb dim dim = 128 pad_id = 19999 num_attention_heads = 8 chunks = 32 text_chunk_size = 64 input_length = chunks * text_chunk_size vocab_size = 20000 hidden = torch.randint( 0, vocab_size, (batch, input_length)).cuda() # (seq, batch, dim) hidden_mask = (hidden != pad_id).cuda() hidden_emb = torch.rand(batch, input_length, dim).cuda().half() # (batch, seq, dim) retrieved = torch.randint( 0, vocab_size, (batch, chunks, neighbors, 2 * text_chunk_size)).cuda() pad_id = vocab_size - 1 context_mask = (retrieved != pad_id).cuda() retrieved_emb = torch.rand(batch, chunks, neighbors, 2 * text_chunk_size, dim).cuda().half() layer_type = [ LayerType.encoder, LayerType.retrieval_encoder, LayerType.encoder, LayerType.retrieval_encoder ] num_layers = len(layer_type) init_method = init_method_normal(init_method_std) scaled_init_method = scaled_init_method_normal(init_method_std, num_layers) encoder = (MegatronRetrievalTransformerEncoderModule( init_method=init_method, output_layer_init_method=scaled_init_method, hidden_size=dim, ffn_hidden_size=dim * 4, num_layers=num_layers, num_attention_heads=num_attention_heads, precision=16, chunk_size=text_chunk_size, layer_type=layer_type, ).cuda().half()) out = encoder(retrieved_emb, context_mask, context_attn_mask=hidden_mask, encoder_output=hidden_emb) assert out.shape == torch.Size( [batch, chunks, neighbors, 2 * text_chunk_size, dim])
def __init__(self, num_tokentypes=2, add_binary_head=True, parallel_output=True, pre_process=True, post_process=True): super(BertModel, self).__init__() args = get_args() self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.add_binary_head = add_binary_head self.parallel_output = parallel_output self.pre_process = pre_process self.post_process = post_process init_method = init_method_normal(args.init_method_std) scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) self.language_model, self._language_model_key = get_language_model( num_tokentypes=num_tokentypes, add_pooler=self.add_binary_head, encoder_attn_mask_type=AttnMaskType.padding, init_method=init_method, scaled_init_method=scaled_init_method, pre_process=self.pre_process, post_process=self.post_process, ) self.initialize_word_embeddings(init_method_normal) if self.post_process: self.lm_head = BertLMHead( self.word_embeddings_weight().size(0), args.hidden_size, init_method, args.layernorm_epsilon, parallel_output, ) self._lm_head_key = 'lm_head' self.binary_head = None if self.add_binary_head: self.binary_head = get_linear_layer(args.hidden_size, 2, init_method) self._binary_head_key = 'binary_head'
def __init__(self, num_tokentypes=0, parallel_output=True): super(T5Model, self).__init__() args = get_args() self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.parallel_output = parallel_output init_method = init_method_normal(args.init_method_std) scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) self.language_model, self._language_model_key = get_language_model( num_tokentypes=num_tokentypes, add_pooler=False, add_decoder=True, encoder_attn_mask_type=AttnMaskType.padding, init_method=init_method, scaled_init_method=scaled_init_method, ) self.lm_head = T5LMHead( self.language_model.embedding.word_embeddings.weight.size(0), parallel_output) self._lm_head_key = 'lm_head'
def get_decoder_model( arch, hidden_size, ffn_hidden_size, num_layers, num_attention_heads, apply_query_key_layer_scaling=True, kv_channels=None, init_method=None, scaled_init_method=None, add_decoder=False, decoder_attn_mask_type=AttnMaskType.causal, pre_process=True, post_process=True, init_method_std=0.02, use_cpu_initialization=False, hidden_dropout=0.1, attention_dropout=0.1, position_embedding_type='learned_absolute', relative_attention_num_buckets=32, relative_attention_max_distance=128, precision=16, fp32_residual_connection=False, activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, bias_gelu_fusion=True, bias_dropout_add_fusion=True, masked_softmax_fusion=True, persist_layer_norm=False, openai_gelu=False, activation="gelu", onnx_safe=False, bias=True, normalization="layernorm", headscale=False, transformer_block_type="pre_ln", hidden_steps=-1, hidden_blocks=1, parent_model_type=ModelType.encoder_or_decoder, layer_type=None, chunk_size=64, layer_number_offset=0, # this is use only for attention norm_factor scaling ): """Build language model and return along with the key to save.""" if kv_channels is None: assert ( hidden_size % num_attention_heads == 0 ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' kv_channels = hidden_size // num_attention_heads if init_method is None: init_method = init_method_normal(init_method_std) if scaled_init_method is None: scaled_init_method = scaled_init_method_normal(init_method_std, num_layers) if arch == "transformer": # Language model. decoder = MegatronTransformerDecoderModule( init_method=init_method, output_layer_init_method=scaled_init_method, hidden_size=hidden_size, num_layers=num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, ffn_hidden_size=ffn_hidden_size, decoder_attn_mask_type=decoder_attn_mask_type, pre_process=pre_process, post_process=post_process, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, position_embedding_type=position_embedding_type, relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, activation=activation, bias=bias, normalization=normalization, transformer_block_type=transformer_block_type, headscale=headscale, parent_model_type=parent_model_type, ) elif arch == "retro": decoder = MegatronRetrievalTransformerDecoderModule( init_method=init_method, output_layer_init_method=scaled_init_method, hidden_size=hidden_size, num_layers=num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, layer_type=layer_type, ffn_hidden_size=ffn_hidden_size, pre_process=pre_process, post_process=post_process, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, activation=activation, bias=bias, normalization=normalization, transformer_block_type=transformer_block_type, parent_model_type=parent_model_type, chunk_size=chunk_size, layer_number_offset=layer_number_offset, ) else: raise ValueError(f"Unknown decoder arch = {arch}. Available decoder arch = {AVAILABLE_DECODERS}") return decoder
def __init__( self, vocab_size, hidden_size, max_position_embeddings, num_layers, num_attention_heads, ffn_hidden_size, apply_query_key_layer_scaling=True, kv_channels=None, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True, init_method_std=0.02, fp16_lm_cross_entropy=False, use_cpu_initialization=False, hidden_dropout=0.1, precision=16, fp32_residual_connection=False, activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, bias_gelu_fusion=True, persist_layer_norm=False, openai_gelu=False, onnx_safe=False, ): super(GPTModel, self).__init__() self.parallel_output = parallel_output self.pre_process = pre_process self.post_process = post_process self.fp16_lm_cross_entropy = fp16_lm_cross_entropy if kv_channels is None: assert ( hidden_size % num_attention_heads == 0 ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' kv_channels = hidden_size // num_attention_heads self.language_model, self._language_model_key = get_language_model( vocab_size=vocab_size, hidden_size=hidden_size, hidden_dropout=hidden_dropout, num_tokentypes=num_tokentypes, max_position_embeddings=max_position_embeddings, num_layers=num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, ffn_hidden_size=ffn_hidden_size, add_pooler=False, encoder_attn_mask_type=AttnMaskType.causal, init_method=init_method_normal(init_method_std), scaled_init_method=scaled_init_method_normal( init_method_std, num_layers), pre_process=self.pre_process, post_process=self.post_process, init_method_std=init_method_std, use_cpu_initialization=use_cpu_initialization, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, ) self.initialize_word_embeddings( init_method=init_method_normal(init_method_std), vocab_size=vocab_size, hidden_size=hidden_size)
def __init__( self, vocab_size, hidden_size, max_position_embeddings, num_layers, num_attention_heads, ffn_hidden_size, apply_query_key_layer_scaling=True, kv_channels=None, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True, init_method_std=0.02, fp16_lm_cross_entropy=False, use_cpu_initialization=False, hidden_dropout=0.1, precision=16, fp32_residual_connection=False, activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, bias_gelu_fusion=True, openai_gelu=False, onnx_safe=False, add_binary_head=True, ): super(BertModel, self).__init__() # args = get_args() self.fp16_lm_cross_entropy = fp16_lm_cross_entropy self.add_binary_head = add_binary_head self.parallel_output = parallel_output self.pre_process = pre_process self.post_process = post_process init_method = init_method_normal(init_method_std) scaled_init_method = scaled_init_method_normal(init_method_std, num_layers) self.language_model, self._language_model_key = get_language_model( vocab_size=vocab_size, hidden_size=hidden_size, hidden_dropout=hidden_dropout, num_tokentypes=num_tokentypes, max_position_embeddings=max_position_embeddings, num_layers=num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, ffn_hidden_size=ffn_hidden_size, add_pooler=self.add_binary_head, encoder_attn_mask_type=AttnMaskType.padding, init_method=init_method, scaled_init_method=scaled_init_method, pre_process=self.pre_process, post_process=self.post_process, init_method_std=init_method_std, use_cpu_initialization=use_cpu_initialization, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, openai_gelu=openai_gelu, onnx_safe=onnx_safe, ) self.initialize_word_embeddings( init_method=init_method_normal(init_method_std), vocab_size=vocab_size, hidden_size=hidden_size) if self.post_process: self.lm_head = BertLMHead( self.word_embeddings_weight().size(0), hidden_size, init_method, layernorm_epsilon, parallel_output, openai_gelu, onnx_safe, ) self._lm_head_key = 'lm_head' self.binary_head = None if self.add_binary_head: self.binary_head = get_linear_layer(hidden_size, 2, init_method) self._binary_head_key = 'binary_head'
def test_cross_attn(self): num_layers = 1 init_method_std = 0.02 batch = 2 neighbors = 2 # rotary pos emb dim dim = 128 pad_id = 19999 num_attention_heads = 8 chunks = 32 text_chunk_size = 64 context_chunk_size = 2 * text_chunk_size input_length = chunks * text_chunk_size vocab_size = 20000 rot_dim = dim // num_attention_heads rotary_pos_emb = RotaryEmbedding(rot_dim).cuda().half() hidden = torch.randint(0, vocab_size, (input_length, batch)).cuda() # (seq, batch, dim) hidden_mask = (hidden != pad_id).cuda() hidden_emb = torch.rand(input_length, batch, dim).cuda().half() # (seq, batch, dim) retrieved = torch.randint(0, vocab_size, (chunks, neighbors, context_chunk_size, batch)).cuda() # retrieved tokens - (num chunks, num retrieved neighbors, retrieved chunk with continuation, batch) # context attention mask [b, np, sq, sk] context_mask = (retrieved != pad_id).cuda() retrieved_emb = torch.rand(chunks, neighbors, context_chunk_size, batch, dim).cuda().half() # retrieved tokens - (num chunks, num retrieved neighbors, retrieved chunk with continuation, batch, hidden) # need to add extra chunk size, since it will be shifted cross_attn_q_pos_emb = rotary_pos_emb(text_chunk_size + text_chunk_size - 1, offset=0) cross_attn_k_pos_emb = rotary_pos_emb(context_chunk_size) cross_attn_pos_emb = (cross_attn_q_pos_emb, cross_attn_k_pos_emb) dec_attn_mask = rearrange(hidden_mask, '(k n) b -> (b k) n', k=chunks) context_attn_mask = rearrange(context_mask, 'k r n b -> (b k) (r n)') enc_dec_attn_mask_3d = build_attention_mask_3d( source_mask=dec_attn_mask, target_mask=context_attn_mask, attn_mask_type=AttnMaskType.padding, ) enc_dec_attn_mask_3d = enc_dec_attn_mask_3d[:, None, :, :] init_method = init_method_normal(init_method_std) scaled_init_method = scaled_init_method_normal(init_method_std, num_layers) cross_attn = ( ParallelChunkedCrossAttention( init_method=init_method, output_layer_init_method=scaled_init_method, layer_number=0, num_attention_heads=num_attention_heads, hidden_size=dim, precision=16, chunk_size=text_chunk_size, ) .cuda() .half() ) out, bias = cross_attn( hidden_emb, enc_dec_attn_mask_3d, encoder_output=retrieved_emb, rotary_pos_emb=cross_attn_pos_emb ) assert out.shape == torch.Size([input_length, batch, dim]) assert bias.shape == torch.Size([dim])
def get_language_model( hidden_size, ffn_hidden_size, num_layers, max_position_embeddings, num_tokentypes, add_pooler, vocab_size, num_attention_heads, encoder_attn_mask_type, apply_query_key_layer_scaling=True, kv_channels=None, init_method=None, scaled_init_method=None, add_decoder=False, decoder_attn_mask_type=AttnMaskType.causal, pre_process=True, post_process=True, init_method_std=0.02, use_cpu_initialization=False, hidden_dropout=0.1, precision=16, fp32_residual_connection=False, activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, bias_gelu_fusion=True, masked_softmax_fusion=True, persist_layer_norm=False, openai_gelu=False, onnx_safe=False, megatron_legacy=False, ): """Build language model and return along with the key to save.""" if kv_channels is None: assert ( hidden_size % num_attention_heads == 0 ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' kv_channels = hidden_size // num_attention_heads if init_method is None: init_method = init_method_normal(init_method_std) if scaled_init_method is None: scaled_init_method = scaled_init_method_normal(init_method_std, num_layers) # Language model. language_model = TransformerLanguageModel( init_method=init_method, output_layer_init_method=scaled_init_method, encoder_attn_mask_type=encoder_attn_mask_type, num_tokentypes=num_tokentypes, vocab_size=vocab_size, max_position_embeddings=max_position_embeddings, hidden_size=hidden_size, num_layers=num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, ffn_hidden_size=ffn_hidden_size, add_decoder=add_decoder, decoder_attn_mask_type=decoder_attn_mask_type, add_pooler=add_pooler, pre_process=pre_process, post_process=post_process, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, megatron_legacy=megatron_legacy, ) # key used for checkpoints. language_model_key = 'language_model' return language_model, language_model_key
def __init__( self, vocab_size, hidden_size, max_position_embeddings, num_attention_heads, ffn_hidden_size, apply_query_key_layer_scaling=True, kv_channels=None, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True, init_method_std=0.02, fp16_cross_entropy=False, use_cpu_initialization=False, hidden_dropout=0.1, attention_dropout=0.1, precision=16, fp32_residual_connection=False, activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, persist_layer_norm=False, bias_gelu_fusion=True, bias_dropout_add_fusion=True, masked_softmax_fusion=True, openai_gelu=False, activation='gelu', onnx_safe=False, bias=True, normalization='layernorm', headscale=False, transformer_block_type='pre_ln', hidden_steps=-1, hidden_blocks=1, add_encoder=True, add_decoder=True, chunk_size=64, enc_num_layers=4, # total number of encoder layers dec_num_layers=6, # total number of decoder layers enc_cross_attention=[3], # layer numbers for cross attention dec_cross_attention=[3, 5], # layer numbers for chunked cross attention add_position_embedding=False, tokenizer=None, # tokenizer ): super(MegatronRetrievalTokenLevelEncoderDecoderModule, self).__init__() self.parallel_output = parallel_output self.pre_process = pre_process self.post_process = post_process self.fp16_cross_entropy = fp16_cross_entropy self.precision = precision self.add_encoder = add_encoder self.add_decoder = add_decoder self.add_abs_position_embedding = add_position_embedding # whether use absolute position embedding self.tokenizer = tokenizer self.eod_id = tokenizer.eos_id if kv_channels is None: assert ( hidden_size % num_attention_heads == 0 ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' kv_channels = hidden_size // num_attention_heads if pre_process: self.encoder_embedding = Embedding( hidden_size=hidden_size, vocab_size=vocab_size, max_sequence_length=max_position_embeddings, init_method=init_method_normal(init_method_std), num_tokentypes=num_tokentypes, use_cpu_initialization=use_cpu_initialization, embedding_dropout_prob=hidden_dropout, add_position_embedding=add_position_embedding, ) self._embedding_key = "embedding" if add_encoder: enc_layer_types = [] for i in range(enc_num_layers): if i in enc_cross_attention: enc_layer_types.append(LayerType.retrieval_encoder) else: enc_layer_types.append(LayerType.encoder) self.encoder = get_encoder_model( arch="retro", hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_layers=enc_num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, init_method=init_method_normal(init_method_std), scaled_init_method=scaled_init_method_normal( init_method_std, dec_num_layers ), # since the encoder is not independent of decoder, use decoder num of layers pre_process=pre_process, post_process=post_process, init_method_std=init_method_std, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, hidden_steps=hidden_steps, hidden_blocks=hidden_blocks, activation=activation, bias=bias, normalization=normalization, transformer_block_type=transformer_block_type, headscale=headscale, parent_model_type=ModelType.encoder_and_decoder, layer_type=enc_layer_types, chunk_size=chunk_size, layer_number_offset=0, ) self._encoder_key = "encoder" if add_decoder: pre_decoder_num_layers = min(dec_cross_attention) pre_decoder_layer_types = [] for i in range(pre_decoder_num_layers): pre_decoder_layer_types.append(LayerType.encoder) pre_decoder_layer_types.append(LayerType.decoder_pre_mlp) post_decoder_num_layers = dec_num_layers - pre_decoder_num_layers post_decoder_layer_types = [] # the first layer in post decoder has to be chunked cross attention without self attention assert pre_decoder_num_layers in dec_cross_attention for i in range(post_decoder_num_layers): if i == 0: post_decoder_layer_types.append(LayerType.retrieval_decoder_after_self_attn) elif i + pre_decoder_num_layers in dec_cross_attention: post_decoder_layer_types.append(LayerType.retrieval_decoder) else: post_decoder_layer_types.append(LayerType.encoder) # it is used to process the inputs for encoder to use as context (H in the paper) self.pre_decoder = get_decoder_model( arch="retro", hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_layers=pre_decoder_num_layers + 1, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, init_method=init_method_normal(init_method_std), scaled_init_method=scaled_init_method_normal(init_method_std, dec_num_layers), pre_process=pre_process, post_process=False, # no need for post process init_method_std=init_method_std, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, hidden_steps=hidden_steps, hidden_blocks=hidden_blocks, activation=activation, bias=bias, normalization=normalization, transformer_block_type=transformer_block_type, headscale=headscale, parent_model_type=ModelType.encoder_and_decoder, layer_type=pre_decoder_layer_types, chunk_size=chunk_size, layer_number_offset=0, ) # it is where the chunked cross attention happens self.post_decoder = get_decoder_model( arch="retro", hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_layers=post_decoder_num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, init_method=init_method_normal(init_method_std), scaled_init_method=scaled_init_method_normal(init_method_std, dec_num_layers), pre_process=False, # directly take the pre_decoder output, skip preprocess post_process=post_process, init_method_std=init_method_std, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, hidden_steps=hidden_steps, hidden_blocks=hidden_blocks, activation=activation, bias=bias, normalization=normalization, headscale=headscale, transformer_block_type=transformer_block_type, parent_model_type=ModelType.encoder_and_decoder, layer_type=post_decoder_layer_types, chunk_size=chunk_size, layer_number_offset=pre_decoder_num_layers + 1, ) self._pre_decoder_key = "pre_decoder" self._post_decoder_key = "post_decoder" self.initialize_word_embeddings( init_method=init_method_normal(init_method_std), vocab_size=vocab_size, hidden_size=hidden_size ) if add_decoder and post_process: self.tokens_head = MegatronTokenLevelHead(self.word_embeddings_weight().size(0), parallel_output) self._tokens_head_key = 'tokens_head'
def get_encoder_model( arch, hidden_size, ffn_hidden_size, num_layers, num_attention_heads, apply_query_key_layer_scaling=True, kv_channels=None, init_method=None, scaled_init_method=None, encoder_attn_mask_type=AttnMaskType.padding, pre_process=True, post_process=True, init_method_std=0.02, use_cpu_initialization=False, hidden_dropout=0.1, attention_dropout=0.1, precision=16, fp32_residual_connection=False, activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, bias_gelu_fusion=True, masked_softmax_fusion=True, persist_layer_norm=False, openai_gelu=False, activation="gelu", onnx_safe=False, hidden_steps=-1, hidden_blocks=1, ): """Build language model and return along with the key to save.""" if kv_channels is None: assert ( hidden_size % num_attention_heads == 0 ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' kv_channels = hidden_size // num_attention_heads if init_method is None: init_method = init_method_normal(init_method_std) if scaled_init_method is None: scaled_init_method = scaled_init_method_normal(init_method_std, num_layers) if arch == "transformer": # Language encoder. encoder = MegatronTransformerEncoderModule( init_method=init_method, output_layer_init_method=scaled_init_method, hidden_size=hidden_size, num_layers=num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, ffn_hidden_size=ffn_hidden_size, encoder_attn_mask_type=encoder_attn_mask_type, pre_process=pre_process, post_process=post_process, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, activation=activation, ) else: raise ValueError( f"Unknown encoder arch = {arch}. Available encoder arch = {AVAILABLE_ENCODERS}" ) return encoder
def __init__( self, encoder_arch, decoder_arch, vocab_size, hidden_size, max_position_embeddings, num_layers, num_attention_heads, ffn_hidden_size, apply_query_key_layer_scaling=True, kv_channels=None, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True, init_method_std=0.02, fp16_cross_entropy=False, use_cpu_initialization=False, hidden_dropout=0.1, attention_dropout=0.1, position_embedding_type='learned_absolute', relative_attention_num_buckets=32, relative_attention_max_distance=128, precision=16, fp32_residual_connection=False, activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, persist_layer_norm=False, bias_gelu_fusion=True, bias_dropout_add_fusion=True, masked_softmax_fusion=True, openai_gelu=False, activation='gelu', onnx_safe=False, bias=True, normalization='layernorm', transformer_block_type='pre_ln', hidden_steps=-1, hidden_blocks=1, headscale=False, add_encoder=True, add_decoder=True, ): super(MegatronTokenLevelEncoderDecoderModule, self).__init__() self.parallel_output = parallel_output self.pre_process = pre_process self.post_process = post_process self.fp16_cross_entropy = fp16_cross_entropy self.precision = precision self.add_encoder = add_encoder self.add_decoder = add_decoder self.normalization = normalization self.position_embedding_type = position_embedding_type self.relative_attention_num_buckets = relative_attention_num_buckets self.relative_attention_max_distance = relative_attention_max_distance if self.position_embedding_type == 'learned_absolute': add_position_embedding = True elif self.position_embedding_type == 'relative': add_position_embedding = False else: raise ValueError('Unknown position embeeding type. Options: ' '[learned_absolute | relative]') if kv_channels is None: assert ( hidden_size % num_attention_heads == 0 ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' kv_channels = hidden_size // num_attention_heads encoder, decoder = None, None if add_encoder: if pre_process: self.encoder_embedding = Embedding( hidden_size=hidden_size, vocab_size=vocab_size, max_sequence_length=max_position_embeddings, init_method=init_method_normal(init_method_std), num_tokentypes=num_tokentypes, use_cpu_initialization=use_cpu_initialization, embedding_dropout_prob=hidden_dropout, add_position_embedding=add_position_embedding, ) self._encoder_embedding_key = "encoder_embedding" encoder = get_encoder_model( arch=encoder_arch, hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_layers=num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, init_method=init_method_normal(init_method_std), scaled_init_method=scaled_init_method_normal( init_method_std, num_layers), encoder_attn_mask_type=AttnMaskType.padding, pre_process=pre_process, post_process=post_process, init_method_std=init_method_std, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, position_embedding_type=position_embedding_type, relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers= activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, hidden_steps=hidden_steps, hidden_blocks=hidden_blocks, activation=activation, bias=bias, normalization=normalization, transformer_block_type=transformer_block_type, headscale=headscale, parent_model_type=ModelType.encoder_and_decoder, ) if add_decoder: # If this is the decoder first stage if pre_process: # If the encoder also lies on this rank (PP = 1), then just assign embeddings directly. if hasattr(self, 'encoder_embedding'): self.decoder_embedding = self.encoder_embedding else: # This is the case where PP > 1 and first decoder first stage. # We initialize decoder embeddings, but set them to zero since we they're tied with the encoder embeddings. # A later initialize_embedding call will synchronize the embeddings. self.decoder_embedding = Embedding( hidden_size=hidden_size, vocab_size=vocab_size, max_sequence_length=max_position_embeddings, init_method=init_method_normal(init_method_std), num_tokentypes=num_tokentypes, use_cpu_initialization=use_cpu_initialization, embedding_dropout_prob=hidden_dropout, add_position_embedding=add_position_embedding, ) self.decoder_embedding.zero_parameters() self._decoder_embedding_key = "decoder_embedding" decoder = get_decoder_model( arch=decoder_arch, hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_layers=num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, init_method=init_method_normal(init_method_std), scaled_init_method=scaled_init_method_normal( init_method_std, num_layers), decoder_attn_mask_type=AttnMaskType.causal, pre_process=pre_process, post_process=post_process, init_method_std=init_method_std, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, position_embedding_type=position_embedding_type, relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers= activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, bias_dropout_add_fusion=bias_dropout_add_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, hidden_steps=hidden_steps, hidden_blocks=hidden_blocks, activation=activation, bias=bias, normalization=normalization, transformer_block_type=transformer_block_type, headscale=headscale, parent_model_type=ModelType.encoder_and_decoder, ) self.enc_dec_model = MegatronTransformerEncoderDecoderModule( encoder=encoder, decoder=decoder) self._enc_dec_model_key = "enc_dec_model" self.initialize_word_embeddings( init_method=init_method_normal(init_method_std), vocab_size=vocab_size, hidden_size=hidden_size) if add_decoder and post_process: self.tokens_head = MegatronTokenLevelHead( self.word_embeddings_weight().size(0), parallel_output) self._tokens_head_key = 'tokens_head'
def test_retrieval_encoder_inference(self): init_method_std = 0.02 batch = 2 neighbors = 2 # rotary pos emb dim dim = 128 pad_id = 19999 num_attention_heads = 8 chunks = 32 text_chunk_size = 64 input_length = chunks * text_chunk_size vocab_size = 20000 hidden = torch.randint( 0, vocab_size, (batch, input_length)).cuda() # (seq, batch, dim) hidden_mask = (hidden != pad_id).cuda() hidden_emb = torch.rand(batch, input_length, dim).cuda().half() # (batch, seq, dim) retrieved = torch.randint( 0, vocab_size, (batch, chunks, neighbors, 2 * text_chunk_size)).cuda() pad_id = vocab_size - 1 context_mask = (retrieved != pad_id).cuda() retrieved_emb = torch.rand(batch, chunks, neighbors, 2 * text_chunk_size, dim).cuda().half() layer_type = [ LayerType.encoder, LayerType.retrieval_encoder, LayerType.encoder, LayerType.retrieval_encoder ] num_layers = len(layer_type) init_method = init_method_normal(init_method_std) scaled_init_method = scaled_init_method_normal(init_method_std, num_layers) encoder = (MegatronRetrievalTransformerEncoderModule( init_method=init_method, output_layer_init_method=scaled_init_method, hidden_size=dim, ffn_hidden_size=dim * 4, num_layers=num_layers, num_attention_heads=num_attention_heads, precision=16, chunk_size=text_chunk_size, layer_type=layer_type, hidden_dropout=0.0, attention_dropout=0.0, ).cuda().half()) out_gt = encoder(retrieved_emb, context_mask, context_attn_mask=hidden_mask, encoder_output=hidden_emb) assert out_gt.shape == torch.Size( [batch, chunks, neighbors, 2 * text_chunk_size, dim]) out_1 = encoder( None, None, context_attn_mask=hidden_mask[:, :62], encoder_output=hidden_emb[:, :62, :], set_inference_key_value_memory=True, inference_max_sequence_len=input_length, neighbors=neighbors, ) assert out_1 is None out_1 = encoder( None, None, context_attn_mask=hidden_mask[:, :63], encoder_output=hidden_emb[:, 62:63], set_inference_key_value_memory=False, inference_max_sequence_len=input_length, neighbors=neighbors, ) assert out_1 is None out_2 = encoder( retrieved_emb[:, :1], context_mask[:, :1], context_attn_mask=hidden_mask[:, :64], encoder_output=hidden_emb[:, 63:64], set_inference_key_value_memory=False, inference_max_sequence_len=input_length, neighbors=neighbors, ) assert (encoder.encoder_output - hidden_emb[:, :64]).abs().max().item() < 1e-5 assert (out_gt[:, 0, ] - out_2[:, 0]).abs().max().item() < 1e-2 out_test = encoder( retrieved_emb[:, :1], context_mask[:, :1], context_attn_mask=hidden_mask[:, :64], encoder_output=hidden_emb[:, :64], ) assert (out_gt[:, 0, ] - out_test[:, 0]).abs().max().item() < 1e-2 assert (out_gt[:, 0, ] - out_2[:, 0]).abs().max().item() < 1e-2 for i in range(64, 127): out_3 = encoder( retrieved_emb[:, :1], context_mask[:, :1], context_attn_mask=hidden_mask[:, :i + 1], encoder_output=hidden_emb[:, i:i + 1], set_inference_key_value_memory=False, inference_max_sequence_len=input_length, neighbors=neighbors, ) i = 127 out_3 = encoder( retrieved_emb[:, :2], context_mask[:, :2], context_attn_mask=hidden_mask[:, :i + 1], encoder_output=hidden_emb[:, i:i + 1], set_inference_key_value_memory=False, inference_max_sequence_len=input_length, neighbors=neighbors, ) assert (encoder.encoder_output - hidden_emb[:, 64:128]).abs().max().item() < 1e-5 assert (out_gt[:, :2, ] - out_3).abs().max().item() < 1e-2 # test inference for i in range(128, 191): out_4 = encoder( retrieved_emb[:, :2], context_mask[:, :2], context_attn_mask=hidden_mask[:, :i + 1], encoder_output=hidden_emb[:, i:i + 1], set_inference_key_value_memory=False, inference_max_sequence_len=input_length, neighbors=neighbors, ) i = 191 out_4 = encoder( retrieved_emb[:, :3], context_mask[:, :3], context_attn_mask=hidden_mask[:, :i + 1], encoder_output=hidden_emb[:, i:i + 1], set_inference_key_value_memory=False, inference_max_sequence_len=input_length, neighbors=neighbors, ) assert (encoder.encoder_output - hidden_emb[:, 128:192]).abs().max().item() < 1e-5 assert (out_gt[:, :3, ] - out_4).abs().max().item() < 1e-2 out_2 = encoder( retrieved_emb[:, :2], context_mask[:, :2], context_attn_mask=hidden_mask[:, :130], encoder_output=hidden_emb[:, :130, :], set_inference_key_value_memory=True, inference_max_sequence_len=input_length, neighbors=neighbors, ) for i in range(130, 191): out_2 = encoder( retrieved_emb[:, :2], context_mask[:, :2], context_attn_mask=hidden_mask[:, :i + 1], encoder_output=hidden_emb[:, i:i + 1], set_inference_key_value_memory=False, inference_max_sequence_len=input_length, neighbors=neighbors, ) i = 191 out_4 = encoder( retrieved_emb[:, :3], context_mask[:, :3], context_attn_mask=hidden_mask[:, :i + 1], encoder_output=hidden_emb[:, i:i + 1], set_inference_key_value_memory=False, inference_max_sequence_len=input_length, neighbors=neighbors, ) assert (encoder.encoder_output - hidden_emb[:, 128:192]).abs().max().item() < 1e-5 assert (out_gt[:, :3, ] - out_4).abs().max().item() < 1e-2
def test_retrieval_decoder_inference(self): init_method_std = 0.02 # rotary pos emb dim batch = 2 neighbors = 2 dim = 128 pad_id = 19999 num_attention_heads = 8 chunks = 32 text_chunk_size = 64 input_length = chunks * text_chunk_size vocab_size = 20000 # rot_dim = dim // num_attention_heads # rotary_pos_emb = RotaryEmbedding(rot_dim).cuda().half() hidden = torch.randint( 0, vocab_size, (batch, input_length)).cuda() # (seq, batch, dim) hidden_mask = (hidden != pad_id).cuda() hidden_emb = torch.rand(batch, input_length, dim).cuda().half() # (batch, seq, dim) # context_chunk_size = 128 retrieved = torch.randint( 0, vocab_size, (batch, chunks, neighbors, 2 * text_chunk_size)).cuda() # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation) # context attention mask [b, np, sq, sk] pad_id = vocab_size - 1 context_mask = (retrieved != pad_id).cuda() retrieved_emb = torch.rand(batch, chunks, neighbors, 2 * text_chunk_size, dim).cuda().half() # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation, hidden) layer_type = [ LayerType.encoder, LayerType.retrieval_decoder, LayerType.encoder, LayerType.retrieval_decoder ] num_layers = len(layer_type) init_method = init_method_normal(init_method_std) scaled_init_method = scaled_init_method_normal(init_method_std, num_layers) decoder = (MegatronRetrievalTransformerDecoderModule( init_method=init_method, output_layer_init_method=scaled_init_method, hidden_size=dim, ffn_hidden_size=dim * 4, num_layers=num_layers, num_attention_heads=num_attention_heads, precision=16, chunk_size=text_chunk_size, layer_type=layer_type, hidden_dropout=0.0, attention_dropout=0.0, ).cuda().half()) out = decoder(hidden_emb, hidden_mask, retrieved_attn_mask=context_mask, retrieved_emb=retrieved_emb) assert out.shape == torch.Size([batch, input_length, dim]) out_1 = decoder( hidden_emb[:, :62], hidden_mask[:, :62], retrieved_attn_mask=None, retrieved_emb=None, set_inference_key_value_memory=True, inference_max_sequence_len=input_length, ) assert (out[:, :62] - out_1[:, :62]).abs().max().item() < 1e-2 out_1 = decoder( hidden_emb[:, 62:63], hidden_mask[:, :63], retrieved_attn_mask=None, retrieved_emb=None, set_inference_key_value_memory=False, inference_max_sequence_len=input_length, ) assert (out[:, 62] - out_1[:, 0]).abs().max().item() < 1e-2 out_2 = decoder( hidden_emb[:, 63:64], hidden_mask[:, :64], retrieved_attn_mask=context_mask[:, :1], retrieved_emb=retrieved_emb[:, :1], set_inference_key_value_memory=False, inference_max_sequence_len=input_length, ) assert (out[:, 63] - out_2[:, 0]).abs().max().item() < 1e-2 for i in range(64, 127): out_2 = decoder( hidden_emb[:, i:i + 1], hidden_mask[:, :i + 1], retrieved_attn_mask=context_mask[:, :1], retrieved_emb=retrieved_emb[:, :1], set_inference_key_value_memory=False, inference_max_sequence_len=input_length, ) assert (out[:, i] - out_2[:, 0]).abs().max().item() < 1e-2 for i in range(127, 191): out_3 = decoder( hidden_emb[:, i:i + 1], hidden_mask[:, :i + 1], retrieved_attn_mask=context_mask[:, :2], retrieved_emb=retrieved_emb[:, :2], set_inference_key_value_memory=False, inference_max_sequence_len=input_length, ) assert (out[:, i] - out_3[:, 0]).abs().max().item() < 1e-2 out_1 = decoder( hidden_emb[:, :130], hidden_mask[:, :130], retrieved_attn_mask=context_mask[:, :2], retrieved_emb=retrieved_emb[:, :2], set_inference_key_value_memory=True, inference_max_sequence_len=input_length, ) assert (out[:, :130] - out_1[:, :130]).abs().max().item() < 1e-2 for i in range(130, 191): out_3 = decoder( hidden_emb[:, i:i + 1], hidden_mask[:, :i + 1], retrieved_attn_mask=context_mask[:, :2], retrieved_emb=retrieved_emb[:, :2], set_inference_key_value_memory=False, inference_max_sequence_len=input_length, ) assert (out[:, i] - out_3[:, 0]).abs().max().item() < 1e-2
def test_cross_attn_inference(self): num_layers = 1 init_method_std = 0.02 batch = 2 neighbors = 2 # rotary pos emb dim dim = 128 pad_id = 19999 num_attention_heads = 8 chunks = 32 text_chunk_size = 64 context_chunk_size = 2 * text_chunk_size input_length = chunks * text_chunk_size vocab_size = 20000 rot_dim = dim // num_attention_heads rotary_pos_emb = RotaryEmbedding(rot_dim).cuda().half() hidden = torch.randint( 0, vocab_size, (input_length, batch)).cuda() # (seq, batch, dim) hidden_mask = (hidden != pad_id).cuda() hidden_emb = torch.rand(input_length, batch, dim).cuda().half() # (seq, batch, dim) retrieved = torch.randint( 0, vocab_size, (chunks, neighbors, context_chunk_size, batch)).cuda() # retrieved tokens - (num chunks, num retrieved neighbors, retrieved chunk with continuation, batch) # context attention mask [b, np, sq, sk] context_mask = (retrieved != pad_id).cuda() retrieved_emb = torch.rand(chunks, neighbors, context_chunk_size, batch, dim).cuda().half() # retrieved tokens - (num chunks, num retrieved neighbors, retrieved chunk with continuation, batch, hidden) # need to add extra chunk size, since it will be shifted cross_attn_q_pos_emb = rotary_pos_emb(text_chunk_size + text_chunk_size - 1, offset=-text_chunk_size + 1) cross_attn_k_pos_emb = rotary_pos_emb(context_chunk_size) cross_attn_pos_emb = (cross_attn_q_pos_emb, cross_attn_k_pos_emb) def get_attn_mask_3d(hidden_mask, context_mask, chunks): causal_padding = text_chunk_size - 1 reminder = (text_chunk_size - (hidden_mask.shape[0] + 1)) % text_chunk_size hidden_mask = F.pad(hidden_mask, (0, 0, -causal_padding, reminder), value=False) dec_attn_mask = rearrange(hidden_mask, '(k n) b -> (b k) n', k=chunks) context_attn_mask = rearrange(context_mask, 'k r n b -> (b k) (r n)') enc_dec_attn_mask_3d = build_attention_mask_3d( source_mask=dec_attn_mask, target_mask=context_attn_mask, attn_mask_type=AttnMaskType.padding, ) enc_dec_attn_mask_3d = enc_dec_attn_mask_3d[:, None, :, :] return enc_dec_attn_mask_3d enc_dec_attn_mask_3d = get_attn_mask_3d(hidden_mask, context_mask, chunks) init_method = init_method_normal(init_method_std) scaled_init_method = scaled_init_method_normal(init_method_std, num_layers) cross_attn = (ParallelChunkedCrossAttention( init_method=init_method, output_layer_init_method=scaled_init_method, layer_number=1, num_attention_heads=num_attention_heads, hidden_size=dim, precision=16, chunk_size=text_chunk_size, ).cuda().half()) out, bias = cross_attn(hidden_emb, enc_dec_attn_mask_3d, encoder_output=retrieved_emb, rotary_pos_emb=cross_attn_pos_emb) assert out.shape == torch.Size([input_length, batch, dim]) assert bias.shape == torch.Size([dim]) attn_mask_3d = None out_1, b = cross_attn( hidden_emb[:62], attn_mask_3d, encoder_output=None, rotary_pos_emb=cross_attn_pos_emb, set_inference_key_value_memory=True, inference_max_sequence_len=input_length, ) assert (out_1 - torch.zeros_like(hidden_emb[:62])).abs().max() == 0 out_1, b = cross_attn( hidden_emb[62:63], attn_mask_3d, encoder_output=None, rotary_pos_emb=cross_attn_pos_emb, set_inference_key_value_memory=False, inference_max_sequence_len=input_length, ) assert (out_1 - torch.zeros_like(hidden_emb[62:63])).abs().max() == 0 attn_mask_3d = get_attn_mask_3d(hidden_mask[:64], context_mask[:1], 1) out_2, b = cross_attn( hidden_emb[63:64], attn_mask_3d, encoder_output=retrieved_emb[:1], rotary_pos_emb=cross_attn_pos_emb, set_inference_key_value_memory=False, inference_max_sequence_len=input_length, ) assert (out[63] - out_2[0]).abs().max().item() < 1e-2 for i in range(64, 127): attn_mask_3d = get_attn_mask_3d(hidden_mask[:i + 1], context_mask[:1], 1) out_2, b = cross_attn( hidden_emb[i:i + 1], attn_mask_3d, encoder_output=retrieved_emb[:1], rotary_pos_emb=cross_attn_pos_emb, set_inference_key_value_memory=False, inference_max_sequence_len=input_length, ) i = 127 attn_mask_3d = get_attn_mask_3d(hidden_mask[:i + 1], context_mask[:2], 2) out_3, b = cross_attn( hidden_emb[i:i + 1], attn_mask_3d, encoder_output=retrieved_emb[:2], rotary_pos_emb=cross_attn_pos_emb, set_inference_key_value_memory=False, inference_max_sequence_len=input_length, ) assert (out[i] - out_3[0]).abs().max().item() < 1e-2 attn_mask_3d = get_attn_mask_3d(hidden_mask[:130], context_mask[:2], 2) out_1, b = cross_attn( hidden_emb[:130], attn_mask_3d, encoder_output=retrieved_emb[:2], rotary_pos_emb=cross_attn_pos_emb, set_inference_key_value_memory=True, inference_max_sequence_len=input_length, ) assert (out[:130] - out_1[:130]).abs().max().item() < 1e-2 for i in range(130, 191): attn_mask_3d = get_attn_mask_3d(hidden_mask[:i + 1], context_mask[:2], 2) out_2, b = cross_attn( hidden_emb[i:i + 1], attn_mask_3d, encoder_output=retrieved_emb[:2], rotary_pos_emb=cross_attn_pos_emb, set_inference_key_value_memory=False, inference_max_sequence_len=input_length, ) i = 191 attn_mask_3d = get_attn_mask_3d(hidden_mask[:i + 1], context_mask[:3], 3) out_4, b = cross_attn( hidden_emb[i:i + 1], attn_mask_3d, encoder_output=retrieved_emb[:3], rotary_pos_emb=cross_attn_pos_emb, set_inference_key_value_memory=False, inference_max_sequence_len=input_length, ) assert (out[i] - out_4[0]).abs().max().item() < 1e-2
def __init__( self, encoder_arch, decoder_arch, vocab_size, hidden_size, max_position_embeddings, num_layers, num_attention_heads, ffn_hidden_size, apply_query_key_layer_scaling=True, kv_channels=None, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True, init_method_std=0.02, fp16_cross_entropy=False, use_cpu_initialization=False, hidden_dropout=0.1, attention_dropout=0.1, precision=16, fp32_residual_connection=False, activations_checkpoint_method=None, activations_checkpoint_num_layers=1, layernorm_epsilon=1e-5, persist_layer_norm=False, bias_gelu_fusion=True, masked_softmax_fusion=True, openai_gelu=False, activation='gelu', onnx_safe=False, hidden_steps=-1, hidden_blocks=1, ): super(MegatronTokenLevelEncoderDecoderModule, self).__init__() self.parallel_output = parallel_output self.pre_process = pre_process self.post_process = post_process self.fp16_cross_entropy = fp16_cross_entropy self.precision = precision if kv_channels is None: assert ( hidden_size % num_attention_heads == 0 ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' kv_channels = hidden_size // num_attention_heads # TODO: add get_embedding function to support various embedders (like prompt tuning) self.encoder_embedding = Embedding( hidden_size=hidden_size, vocab_size=vocab_size, max_sequence_length=max_position_embeddings, init_method=init_method_normal(init_method_std), num_tokentypes=num_tokentypes, use_cpu_initialization=use_cpu_initialization, embedding_dropout_prob=hidden_dropout, ) self.decoder_embedding = self.encoder_embedding self._encoder_embedding_key = "encoder_embedding" self._decoder_embedding_key = "decoder_embedding" encoder = get_encoder_model( arch=encoder_arch, hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_layers=num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, init_method=init_method_normal(init_method_std), scaled_init_method=scaled_init_method_normal( init_method_std, num_layers), encoder_attn_mask_type=AttnMaskType.padding, pre_process=pre_process, post_process=post_process, init_method_std=init_method_std, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, hidden_steps=hidden_steps, hidden_blocks=hidden_blocks, activation=activation, ) decoder = get_decoder_model( arch=decoder_arch, hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_layers=num_layers, num_attention_heads=num_attention_heads, apply_query_key_layer_scaling=apply_query_key_layer_scaling, kv_channels=kv_channels, init_method=init_method_normal(init_method_std), scaled_init_method=scaled_init_method_normal( init_method_std, num_layers), decoder_attn_mask_type=AttnMaskType.causal, pre_process=pre_process, post_process=post_process, init_method_std=init_method_std, use_cpu_initialization=use_cpu_initialization, hidden_dropout=hidden_dropout, attention_dropout=attention_dropout, precision=precision, fp32_residual_connection=fp32_residual_connection, activations_checkpoint_method=activations_checkpoint_method, activations_checkpoint_num_layers=activations_checkpoint_num_layers, layernorm_epsilon=layernorm_epsilon, bias_gelu_fusion=bias_gelu_fusion, masked_softmax_fusion=masked_softmax_fusion, persist_layer_norm=persist_layer_norm, openai_gelu=openai_gelu, onnx_safe=onnx_safe, hidden_steps=hidden_steps, hidden_blocks=hidden_blocks, activation=activation, ) self.enc_dec_model = MegatronTransformerEncoderDecoderModule( encoder=encoder, decoder=decoder, ) self._enc_dec_model_key = "enc_dec_model" self.tokens_head = MegatronTokenLevelHead( self.decoder_embedding.word_embeddings.weight.size(0), parallel_output) self._tokens_head_key = 'tokens_head'