def __init__(self, attention_mask_func, init_method, output_layer_init_method, num_tokentypes=0, add_pooler=False): super(TransformerLanguageModel, self).__init__() args = get_args() self.hidden_size = args.hidden_size self.num_tokentypes = num_tokentypes self.init_method = init_method self.add_pooler = add_pooler # Embeddings self.embedding = Embedding(self.hidden_size, args.padded_vocab_size, args.max_position_embeddings, args.hidden_dropout, self.init_method, self.num_tokentypes, args.sinusoidal_pos_emb) self._embedding_key = 'embedding' # Transformer self.transformer = ParallelTransformer(attention_mask_func, self.init_method, output_layer_init_method) self._transformer_key = 'transformer' # Pooler if self.add_pooler: self.pooler = Pooler(self.hidden_size, self.init_method) self._pooler_key = 'pooler'
def __init__(self, attention_mask_func, init_method, output_layer_init_method, num_tokentypes=0, add_pooler=False): super(TransformerLanguageModelBase, self).__init__() args = get_args() self.hidden_size = args.hidden_size self.num_tokentypes = num_tokentypes self.init_method = init_method self.add_pooler = add_pooler # Embeddings. if mpu.is_pipeline_first_stage(): self.embedding = Embedding(self.hidden_size, args.padded_vocab_size, args.max_position_embeddings, args.hidden_dropout, self.init_method, self.num_tokentypes) self._embedding_key = 'embedding' # Transformer. self.transformer = ParallelTransformer(attention_mask_func, self.init_method, output_layer_init_method) self._transformer_key = 'transformer' # Pooler. if mpu.is_pipeline_last_stage() and self.add_pooler: self.pooler = Pooler(self.hidden_size, self.init_method) self._pooler_key = 'pooler'
def __init__(self, init_method, output_layer_init_method, encoder_attn_mask_type, num_tokentypes=0, add_decoder=False, decoder_attn_mask_type=AttnMaskType.causal, add_pooler=False, pre_process=True, post_process=True): super(TransformerLanguageModel, self).__init__() args = get_args() self.pre_process = pre_process self.post_process = post_process self.hidden_size = args.hidden_size self.num_tokentypes = num_tokentypes self.init_method = init_method self.encoder_attn_mask_type = encoder_attn_mask_type self.add_decoder = add_decoder self.decoder_attn_mask_type = decoder_attn_mask_type self.add_pooler = add_pooler # Embeddings. if self.pre_process: self.embedding = Embedding(self.hidden_size, args.padded_vocab_size, args.max_position_embeddings, args.hidden_dropout, self.init_method, self.num_tokentypes) self._embedding_key = 'embedding' # Transformer. self.encoder = ParallelTransformer( self.init_method, output_layer_init_method, self_attn_mask_type=self.encoder_attn_mask_type, pre_process=self.pre_process, post_process=self.post_process) self._encoder_key = 'encoder' # Decoder if self.add_decoder: assert args.pipeline_model_parallel_size == 1, \ 'pipeline parallelism is not supported in the presence of decoder' self.decoder = ParallelTransformer( self.init_method, output_layer_init_method, layer_type=LayerType.decoder, self_attn_mask_type=self.decoder_attn_mask_type) self._decoder_key = 'decoder' if self.post_process: # Pooler. if self.add_pooler: self.pooler = Pooler(self.hidden_size, self.init_method) self._pooler_key = 'pooler'
def __init__(self, num_classes, finetune=False): super(VitModel, self).__init__() args = get_args() self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy if args.init_method_xavier_uniform: self.init_method = torch.nn.init.xavier_uniform_ self.scaled_init_method = torch.nn.init.xavier_uniform_ else: self.init_method = init_method_normal(args.init_method_std) self.scaled_init_method = scaled_init_method_normal( args.init_method_std, args.num_layers) self.hidden_size = args.hidden_size self.num_classes = num_classes self.patch_dim = args.patch_dim self.img_dim = args.img_dim self.finetune = finetune assert self.img_dim % self.patch_dim == 0 self.num_patches_per_dim = self.img_dim // self.patch_dim self.num_patches = self.num_patches_per_dim**2 self.seq_length = self.num_patches + 1 self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels # cls_token self.cls_token = torch.nn.Parameter(torch.randn( 1, 1, self.hidden_size)) torch.nn.init.zeros_(self.cls_token) # Linear encoder self.linear_encoder = torch.nn.Linear(self.flatten_dim, self.hidden_size) # embedding self.position_embeddings = torch.nn.Embedding(self.seq_length, self.hidden_size) init_method_normal(args.init_method_std)( self.position_embeddings.weight) self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda() self.position_embeddings._register_load_state_dict_pre_hook( twod_interpolate_position_embeddings_hook) self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout) # Transformer self.transformer = ParallelTransformer(self.init_method, self.scaled_init_method) # MLP head if not self.finetune: self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes) else: self.class_head = get_linear_layer(self.hidden_size, num_classes, torch.nn.init.zeros_)
class TransformerLanguageModelBase(MegatronModule): """Transformer language model. Arguments: transformer_hparams: transformer hyperparameters attention_mask_func: a function that takes `unmaksed-attention-scores` with size [b, np, s, s] and an `attention-mask` and will apply the masking. The function should return a masked score of the same size [b, np, s, s]. masked-attention-scores = attention_mask_func( unmaksed-attention-scores, attention-mask) vocab_size: vocabulary size max_sequence_length: maximum size of sequence. This is used for positional embedding embedding_dropout_prob: dropout probability for embeddings num_tokentypes: size of the token-type embeddings. 0 value will ignore this embedding """ def __init__(self, attention_mask_func, init_method, output_layer_init_method, num_tokentypes=0, add_pooler=False): super(TransformerLanguageModelBase, self).__init__() args = get_args() self.hidden_size = args.hidden_size self.num_tokentypes = num_tokentypes self.init_method = init_method self.add_pooler = add_pooler # Embeddings. if mpu.is_pipeline_first_stage(): self.embedding = Embedding(self.hidden_size, args.padded_vocab_size, args.max_position_embeddings, args.hidden_dropout, self.init_method, self.num_tokentypes) self._embedding_key = 'embedding' # Transformer. self.transformer = ParallelTransformer(attention_mask_func, self.init_method, output_layer_init_method) self._transformer_key = 'transformer' # Pooler. if mpu.is_pipeline_last_stage() and self.add_pooler: self.pooler = Pooler(self.hidden_size, self.init_method) self._pooler_key = 'pooler' def forward(self, language_model_input, attention_mask, tokentype_ids=None, layer_past=None, get_key_value=False, pooling_sequence_index=0): # Embeddings. if mpu.is_pipeline_first_stage(): (input_ids, position_ids) = language_model_input embedding_output = self.embedding(input_ids, position_ids, tokentype_ids=tokentype_ids) transformer_input = embedding_output else: transformer_input = language_model_input # Transformer. transformer_output = self.transformer(transformer_input, attention_mask, layer_past=layer_past, get_key_value=get_key_value) if mpu.is_pipeline_last_stage() and self.add_pooler: pooled_output = self.pooler(transformer_output, pooling_sequence_index) return transformer_output, pooled_output return transformer_output def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): """For easy load.""" state_dict_ = {} if mpu.is_pipeline_first_stage(): state_dict_[self._embedding_key] \ = self.embedding.state_dict_for_save_checkpoint( destination, prefix, keep_vars) state_dict_[self._transformer_key] \ = self.transformer.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if mpu.is_pipeline_last_stage() and self.add_pooler: state_dict_[self._pooler_key] \ = self.pooler.state_dict_for_save_checkpoint( destination, prefix, keep_vars) return state_dict_ def load_state_dict(self, state_dict, strict=True): """Customized load.""" # Embedding. if mpu.is_pipeline_first_stage(): if self._embedding_key in state_dict: state_dict_ = state_dict[self._embedding_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if '_embeddings' in key: state_dict_[key] = state_dict[key] self.embedding.load_state_dict(state_dict_, strict=strict) # Transformer. if self._transformer_key in state_dict: state_dict_ = state_dict[self._transformer_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if 'transformer.' in key: state_dict_[key.split('transformer.')[1]] = state_dict[key] self.transformer.load_state_dict(state_dict_, strict=strict) # Pooler. if mpu.is_pipeline_last_stage() and self.add_pooler: assert 'pooler' in state_dict, \ 'could not find data for pooler in the checkpoint' self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict)
class VitModel(MegatronModule): """Vision Transformer Model.""" def __init__(self, num_classes, finetune=False, pre_process=True, post_process=True): super(VitModel, self).__init__(share_word_embeddings=False) args = get_args() self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy if args.init_method_xavier_uniform: self.init_method = torch.nn.init.xavier_uniform_ self.scaled_init_method = torch.nn.init.xavier_uniform_ else: self.init_method = init_method_normal(args.init_method_std) self.scaled_init_method = scaled_init_method_normal( args.init_method_std, args.num_layers) self.pre_process = pre_process self.post_process = post_process self.hidden_size = args.hidden_size self.num_classes = num_classes self.patch_dim = args.patch_dim self.img_dim = args.img_dim self.finetune = finetune assert self.img_dim % self.patch_dim == 0 self.num_patches_per_dim = self.img_dim // self.patch_dim self.num_patches = self.num_patches_per_dim**2 self.seq_length = self.num_patches + 1 self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels if self.pre_process: # cls_token self.cls_token = torch.nn.Parameter( torch.randn(1, 1, self.hidden_size)) torch.nn.init.zeros_(self.cls_token) # Linear encoder self.linear_encoder = torch.nn.Linear(self.flatten_dim, self.hidden_size) # embedding self.position_embeddings = torch.nn.Embedding( self.seq_length, self.hidden_size) init_method_normal(args.init_method_std)( self.position_embeddings.weight) self.position_ids = torch.arange(self.seq_length).expand( 1, -1).cuda() self.position_embeddings._register_load_state_dict_pre_hook( twod_interpolate_position_embeddings_hook) self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout) # Transformer self.transformer = ParallelTransformer(self.init_method, self.scaled_init_method, pre_process=self.pre_process, post_process=self.post_process) if self.post_process: # MLP head if not self.finetune: self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes) else: self.class_head = get_linear_layer(self.hidden_size, num_classes, torch.nn.init.zeros_) def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" self.transformer.set_input_tensor(input_tensor) def forward(self, input): if self.pre_process: rearranged_input = einops.rearrange( input, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=self.patch_dim, p2=self.patch_dim, ) assert rearranged_input.dtype == torch.half encoder_output = self.linear_encoder(rearranged_input) cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1) concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1) token_embeddings = concatenated_tokens + \ self.position_embeddings(self.position_ids) hidden_states = self.embedding_dropout(token_embeddings) else: hidden_states = input hidden_states = self.transformer(hidden_states, None) if self.post_process: if not self.finetune: hidden_states = self.mlp_head(hidden_states) else: hidden_states = self.class_head(hidden_states[:, 0, :]) return hidden_states
class TransformerLanguageModel(MegatronModule): """Transformer language model. Arguments: transformer_hparams: transformer hyperparameters vocab_size: vocabulary size max_sequence_length: maximum size of sequence. This is used for positional embedding embedding_dropout_prob: dropout probability for embeddings num_tokentypes: size of the token-type embeddings. 0 value will ignore this embedding """ def __init__(self, init_method, output_layer_init_method, encoder_attn_mask_type, num_tokentypes=0, add_decoder=False, decoder_attn_mask_type=AttnMaskType.causal, add_pooler=False, pre_process=True, post_process=True): super(TransformerLanguageModel, self).__init__() args = get_args() self.pre_process = pre_process self.post_process = post_process self.hidden_size = args.hidden_size self.num_tokentypes = num_tokentypes self.init_method = init_method self.encoder_attn_mask_type = encoder_attn_mask_type self.add_decoder = add_decoder self.decoder_attn_mask_type = decoder_attn_mask_type self.add_pooler = add_pooler # Embeddings. if self.pre_process: self.embedding = Embedding(self.hidden_size, args.padded_vocab_size, args.max_position_embeddings, args.hidden_dropout, self.init_method, self.num_tokentypes) self._embedding_key = 'embedding' # Transformer. self.encoder = ParallelTransformer( self.init_method, output_layer_init_method, self_attn_mask_type=self.encoder_attn_mask_type, pre_process=self.pre_process, post_process=self.post_process) self._encoder_key = 'encoder' # Decoder if self.add_decoder: assert args.pipeline_model_parallel_size == 1, \ 'pipeline parallelism is not supported in the presence of decoder' self.decoder = ParallelTransformer( self.init_method, output_layer_init_method, layer_type=LayerType.decoder, self_attn_mask_type=self.decoder_attn_mask_type) self._decoder_key = 'decoder' if self.post_process: # Pooler. if self.add_pooler: self.pooler = Pooler(self.hidden_size, self.init_method) self._pooler_key = 'pooler' def set_input_tensor(self, input_tensor): """ See megatron.model.transformer.set_input_tensor()""" self.encoder.set_input_tensor(input_tensor) def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None, get_key_value=False, pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False): # Embeddings. if self.pre_process: embedding_output = self.embedding(enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids) encoder_input = embedding_output else: encoder_input = None # encoder. if enc_hidden_states is None: encoder_output = self.encoder(encoder_input, enc_attn_mask, layer_past=layer_past, get_key_value=get_key_value) else: encoder_output = enc_hidden_states.to(encoder_input.dtype) if self.post_process: if self.add_pooler: pooled_output = self.pooler(encoder_output, pooling_sequence_index) # output_enc_hidden refers to when we just need the encoder's # output. For example, it is helpful to compute # similarity between two sequences by average pooling if not self.add_decoder or output_enc_hidden: if self.add_pooler and self.post_process: return encoder_output, pooled_output else: return encoder_output # Decoder Embedding dec_embedding_output = self.embedding(dec_input_ids, dec_position_ids) # decoder decoder_output = self.decoder(dec_embedding_output, dec_attn_mask, layer_past=layer_past, get_key_value=get_key_value, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask) if self.add_pooler and self.post_process: return decoder_output, encoder_output, pooled_output else: return decoder_output, encoder_output def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): """For easy load.""" state_dict_ = {} if self.pre_process: state_dict_[self._embedding_key] \ = self.embedding.state_dict_for_save_checkpoint( destination, prefix, keep_vars) state_dict_[self._encoder_key] \ = self.encoder.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if self.post_process: if self.add_pooler: state_dict_[self._pooler_key] \ = self.pooler.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if self.add_decoder: state_dict_[self._decoder_key] \ = self.decoder.state_dict_for_save_checkpoint( destination, prefix, keep_vars) return state_dict_ def load_state_dict(self, state_dict, strict=True): """Customized load.""" # Embedding. if self.pre_process: if self._embedding_key in state_dict: state_dict_ = state_dict[self._embedding_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if '_embeddings' in key: state_dict_[key] = state_dict[key] self.embedding.load_state_dict(state_dict_, strict=strict) # Encoder. if self._encoder_key in state_dict: state_dict_ = state_dict[self._encoder_key] # for backward compatibility. elif 'transformer' in state_dict: state_dict_ = state_dict['transformer'] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if 'transformer.' in key: state_dict_[key.split('transformer.')[1]] = state_dict[key] # for backward compatibility. state_dict_self_attention = {} for key in state_dict_.keys(): if '.attention.' in key: state_dict_self_attention[key.replace( ".attention.", ".self_attention.")] = state_dict_[key] else: state_dict_self_attention[key] = state_dict_[key] state_dict_ = state_dict_self_attention self.encoder.load_state_dict(state_dict_, strict=strict) if self.post_process: # pooler if self.add_pooler: assert 'pooler' in state_dict, \ 'could not find data for pooler in the checkpoint' self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict) # decoder if self.add_decoder: assert 'decoder' in state_dict, \ 'could not find data for pooler in the checkpoint' self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict)
def __init__(self, init_method, output_layer_init_method, encoder_attn_mask_type, num_tokentypes=0, add_encoder=True, add_decoder=False, decoder_attn_mask_type=AttnMaskType.causal, add_pooler=False, pre_process=True, post_process=True): super(TransformerLanguageModel, self).__init__() args = get_args() self.pre_process = pre_process self.post_process = post_process self.hidden_size = args.hidden_size self.num_tokentypes = num_tokentypes self.init_method = init_method self.add_encoder = add_encoder self.encoder_attn_mask_type = encoder_attn_mask_type self.add_decoder = add_decoder self.decoder_attn_mask_type = decoder_attn_mask_type self.add_pooler = add_pooler self.encoder_hidden_state = None # Embeddings. if self.pre_process: self.embedding = Embedding(self.hidden_size, args.padded_vocab_size, args.max_position_embeddings, args.hidden_dropout, self.init_method, self.num_tokentypes) self._embedding_key = 'embedding' # Transformer. # Encoder (usually set to True, False if part of an encoder-decoder # architecture and in encoder-only stage). if self.add_encoder: self.encoder = ParallelTransformer( self.init_method, output_layer_init_method, self_attn_mask_type=self.encoder_attn_mask_type, pre_process=self.pre_process, post_process=self.post_process) self._encoder_key = 'encoder' else: self.encoder = None # Decoder (usually set to False, True if part of an encoder-decoder # architecture and in decoder-only stage). if self.add_decoder: self.decoder = ParallelTransformer( self.init_method, output_layer_init_method, layer_type=LayerType.decoder, self_attn_mask_type=self.decoder_attn_mask_type, pre_process=self.pre_process, post_process=self.post_process) self._decoder_key = 'decoder' else: self.decoder = None if self.post_process: # Pooler. if self.add_pooler: self.pooler = Pooler(self.hidden_size, self.init_method) self._pooler_key = 'pooler'
class TransformerLanguageModel(MegatronModule): """Transformer language model. Arguments: transformer_hparams: transformer hyperparameters vocab_size: vocabulary size max_sequence_length: maximum size of sequence. This is used for positional embedding embedding_dropout_prob: dropout probability for embeddings num_tokentypes: size of the token-type embeddings. 0 value will ignore this embedding """ def __init__(self, init_method, output_layer_init_method, encoder_attn_mask_type, num_tokentypes=0, add_encoder=True, add_decoder=False, decoder_attn_mask_type=AttnMaskType.causal, add_pooler=False, pre_process=True, post_process=True): super(TransformerLanguageModel, self).__init__() args = get_args() self.pre_process = pre_process self.post_process = post_process self.hidden_size = args.hidden_size self.num_tokentypes = num_tokentypes self.init_method = init_method self.add_encoder = add_encoder self.encoder_attn_mask_type = encoder_attn_mask_type self.add_decoder = add_decoder self.decoder_attn_mask_type = decoder_attn_mask_type self.add_pooler = add_pooler self.encoder_hidden_state = None # Embeddings. if self.pre_process: self.embedding = Embedding(self.hidden_size, args.padded_vocab_size, args.max_position_embeddings, args.hidden_dropout, self.init_method, self.num_tokentypes) self._embedding_key = 'embedding' # Transformer. # Encoder (usually set to True, False if part of an encoder-decoder # architecture and in encoder-only stage). if self.add_encoder: self.encoder = ParallelTransformer( self.init_method, output_layer_init_method, self_attn_mask_type=self.encoder_attn_mask_type, pre_process=self.pre_process, post_process=self.post_process) self._encoder_key = 'encoder' else: self.encoder = None # Decoder (usually set to False, True if part of an encoder-decoder # architecture and in decoder-only stage). if self.add_decoder: self.decoder = ParallelTransformer( self.init_method, output_layer_init_method, layer_type=LayerType.decoder, self_attn_mask_type=self.decoder_attn_mask_type, pre_process=self.pre_process, post_process=self.post_process) self._decoder_key = 'decoder' else: self.decoder = None if self.post_process: # Pooler. if self.add_pooler: self.pooler = Pooler(self.hidden_size, self.init_method) self._pooler_key = 'pooler' def set_input_tensor(self, input_tensor): """ See megatron.model.transformer.set_input_tensor()""" # This is usually handled in schedules.py but some inference code still # gives us non-lists or None if not isinstance(input_tensor, list): input_tensor = [input_tensor] if self.add_encoder and self.add_decoder: assert len(input_tensor) == 1, \ 'input_tensor should only be length 1 for stage with both encoder and decoder' self.encoder.set_input_tensor(input_tensor[0]) elif self.add_encoder: assert len(input_tensor) == 1, \ 'input_tensor should only be length 1 for stage with only encoder' self.encoder.set_input_tensor(input_tensor[0]) elif self.add_decoder: if len(input_tensor) == 2: self.decoder.set_input_tensor(input_tensor[0]) self.encoder_hidden_state = input_tensor[1] elif len(input_tensor) == 1: self.decoder.set_input_tensor(None) self.encoder_hidden_state = input_tensor[0] else: raise Exception('input_tensor must have either length 1 or 2') else: raise Exception( 'Stage must have at least either encoder or decoder') def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, enc_dec_attn_mask=None, tokentype_ids=None, inference_params=None, pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False): # Encoder embedding. if self.pre_process: encoder_input = self.embedding(enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids) else: encoder_input = None # Run encoder. if enc_hidden_states is None: if self.encoder is not None: encoder_output = self.encoder( encoder_input, enc_attn_mask, inference_params=inference_params) else: encoder_output = self.encoder_hidden_state else: encoder_output = enc_hidden_states.to(encoder_input.dtype) if self.post_process: if self.add_pooler: pooled_output = self.pooler(encoder_output, pooling_sequence_index) # output_enc_hidden refers to when we just need the encoder's # output. For example, it is helpful to compute # similarity between two sequences by average pooling if not self.add_decoder or output_enc_hidden: if self.add_pooler and self.post_process: return encoder_output, pooled_output else: return encoder_output # Decoder embedding. if self.pre_process: decoder_input = self.embedding(dec_input_ids, dec_position_ids) else: decoder_input = None # Run decoder. decoder_output = self.decoder(decoder_input, dec_attn_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, inference_params=inference_params) if self.add_pooler and self.post_process: return decoder_output, encoder_output, pooled_output else: return decoder_output, encoder_output def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): """For easy load.""" state_dict_ = {} if self.pre_process: state_dict_[self._embedding_key] \ = self.embedding.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if self.add_encoder: state_dict_[self._encoder_key] \ = self.encoder.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if self.post_process: if self.add_pooler: state_dict_[self._pooler_key] \ = self.pooler.state_dict_for_save_checkpoint( destination, prefix, keep_vars) if self.add_decoder: state_dict_[self._decoder_key] \ = self.decoder.state_dict_for_save_checkpoint( destination, prefix, keep_vars) return state_dict_ def load_state_dict(self, state_dict, strict=True): """Customized load.""" # Embedding. if self.pre_process: if self._embedding_key in state_dict: state_dict_ = state_dict[self._embedding_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if '_embeddings' in key: state_dict_[key] = state_dict[key] self.embedding.load_state_dict(state_dict_, strict=strict) # Encoder. if self.add_encoder: if self._encoder_key in state_dict: state_dict_ = state_dict[self._encoder_key] # For backward compatibility. elif 'transformer' in state_dict: state_dict_ = state_dict['transformer'] else: # For backward compatibility. state_dict_ = {} for key in state_dict.keys(): if 'transformer.' in key: state_dict_[key.split('transformer.') [1]] = state_dict[key] # For backward compatibility. state_dict_self_attention = {} for key in state_dict_.keys(): if '.attention.' in key: state_dict_self_attention[key.replace( ".attention.", ".self_attention.")] = state_dict_[key] else: state_dict_self_attention[key] = state_dict_[key] state_dict_ = state_dict_self_attention self.encoder.load_state_dict(state_dict_, strict=strict) # Pooler. if self.post_process: if self.add_pooler: assert 'pooler' in state_dict, \ 'could not find data for pooler in the checkpoint' self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict) # Decoder. if self.add_decoder: assert 'decoder' in state_dict, \ 'could not find data for pooler in the checkpoint' self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict)
def __init__(self, pre_process=True, post_process=True, class_token=True, single_token_output=False): super(VitBackbone, self).__init__(share_word_embeddings=False) args = get_args() self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy if args.init_method_xavier_uniform: self.init_method = torch.nn.init.xavier_uniform_ self.scaled_init_method = torch.nn.init.xavier_uniform_ else: self.init_method = init_method_normal(args.init_method_std) self.scaled_init_method = scaled_init_method_normal( args.init_method_std, args.num_layers) self.pre_process = pre_process self.post_process = post_process self.class_token = class_token self.hidden_size = args.hidden_size self.patch_dim = args.patch_dim self.img_h = args.img_h self.img_w = args.img_w self.micro_batch_size = args.micro_batch_size self.single_token_output = single_token_output assert self.img_h % self.patch_dim == 0 assert self.img_w % self.patch_dim == 0 self.num_patches_per_dim_h = self.img_h // self.patch_dim self.num_patches_per_dim_w = self.img_w // self.patch_dim self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w self.seq_length = self.num_patches + (CLASS_TOKEN_LENGTH if self.class_token else 0) self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels self.input_tensor = None self.position_ids = None if self.pre_process: # cls_token if self.class_token: self.cls_token = torch.nn.Parameter( torch.randn(1, CLASS_TOKEN_LENGTH, self.hidden_size)) torch.nn.init.zeros_(self.cls_token) self.position_ids = torch.arange(self.seq_length).expand( 1, -1).cuda() # Linear encoder self.linear_encoder = torch.nn.Linear(self.flatten_dim, self.hidden_size) # embedding self.position_embeddings = torch.nn.Embedding( self.seq_length, self.hidden_size) init_method_normal(args.init_method_std)( self.position_embeddings.weight) args.class_token_present = self.class_token self.position_embeddings._register_load_state_dict_pre_hook( twod_interpolate_position_embeddings_hook) self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout) # Transformer self.transformer = ParallelTransformer( self.init_method, self.scaled_init_method, pre_process=self.pre_process, post_process=self.post_process, )
class VitBackbone(MegatronModule): """Vision Transformer Model.""" def __init__(self, pre_process=True, post_process=True, class_token=True, single_token_output=False): super(VitBackbone, self).__init__(share_word_embeddings=False) args = get_args() self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy if args.init_method_xavier_uniform: self.init_method = torch.nn.init.xavier_uniform_ self.scaled_init_method = torch.nn.init.xavier_uniform_ else: self.init_method = init_method_normal(args.init_method_std) self.scaled_init_method = scaled_init_method_normal( args.init_method_std, args.num_layers) self.pre_process = pre_process self.post_process = post_process self.class_token = class_token self.hidden_size = args.hidden_size self.patch_dim = args.patch_dim self.img_h = args.img_h self.img_w = args.img_w self.micro_batch_size = args.micro_batch_size self.single_token_output = single_token_output assert self.img_h % self.patch_dim == 0 assert self.img_w % self.patch_dim == 0 self.num_patches_per_dim_h = self.img_h // self.patch_dim self.num_patches_per_dim_w = self.img_w // self.patch_dim self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w self.seq_length = self.num_patches + (CLASS_TOKEN_LENGTH if self.class_token else 0) self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels self.input_tensor = None self.position_ids = None if self.pre_process: # cls_token if self.class_token: self.cls_token = torch.nn.Parameter( torch.randn(1, CLASS_TOKEN_LENGTH, self.hidden_size)) torch.nn.init.zeros_(self.cls_token) self.position_ids = torch.arange(self.seq_length).expand( 1, -1).cuda() # Linear encoder self.linear_encoder = torch.nn.Linear(self.flatten_dim, self.hidden_size) # embedding self.position_embeddings = torch.nn.Embedding( self.seq_length, self.hidden_size) init_method_normal(args.init_method_std)( self.position_embeddings.weight) args.class_token_present = self.class_token self.position_embeddings._register_load_state_dict_pre_hook( twod_interpolate_position_embeddings_hook) self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout) # Transformer self.transformer = ParallelTransformer( self.init_method, self.scaled_init_method, pre_process=self.pre_process, post_process=self.post_process, ) def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" self.transformer.set_input_tensor(input_tensor) def forward(self, input): if self.pre_process: rearranged_input = einops.rearrange( input, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=self.patch_dim, p2=self.patch_dim, ) assert rearranged_input.dtype == torch.half encoder_output = self.linear_encoder(rearranged_input) concatenated_tokens = encoder_output if self.class_token: cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1) concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1) token_embeddings = concatenated_tokens + \ self.position_embeddings(self.position_ids[:, :concatenated_tokens.shape[1]]) hidden_states = self.embedding_dropout(token_embeddings) else: hidden_states = input hidden_states = self.transformer(hidden_states, None) if self.single_token_output: hidden_states = hidden_states[:, 0, :] return hidden_states