def __init__(self, config, language_pretrained_model_path=None): super(VisualLinguisticBert, self).__init__(config) self.config = config # embeddings self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.end_embedding = nn.Embedding(1, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) self.embedding_LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob) # for compatibility of roberta self.position_padding_idx = config.position_padding_idx # visual transform self.visual_1x1_text = None self.visual_1x1_object = None if config.visual_size != config.hidden_size: self.visual_1x1_text = nn.Linear(config.visual_size, config.hidden_size) self.visual_1x1_object = nn.Linear(config.visual_size, config.hidden_size) if config.visual_ln: self.visual_ln_text = BertLayerNorm(config.hidden_size, eps=1e-12) self.visual_ln_object = BertLayerNorm(config.hidden_size, eps=1e-12) else: visual_scale_text = nn.Parameter(torch.as_tensor(self.config.visual_scale_text_init, dtype=torch.float), requires_grad=True) self.register_parameter('visual_scale_text', visual_scale_text) visual_scale_object = nn.Parameter(torch.as_tensor(self.config.visual_scale_object_init, dtype=torch.float), requires_grad=True) self.register_parameter('visual_scale_object', visual_scale_object) self.encoder = BertEncoder(config) if self.config.with_pooler: self.pooler = BertPooler(config) # init weights self.apply(self.init_weights) if config.visual_ln: self.visual_ln_text.weight.data.fill_(self.config.visual_scale_text_init) self.visual_ln_object.weight.data.fill_(self.config.visual_scale_object_init) # load language pretrained model if language_pretrained_model_path is not None: self.load_language_pretrained_model(language_pretrained_model_path) if config.word_embedding_frozen: for p in self.word_embeddings.parameters(): p.requires_grad = False self.special_word_embeddings = nn.Embedding(NUM_SPECIAL_WORDS, config.hidden_size) self.special_word_embeddings.weight.data.copy_(self.word_embeddings.weight.data[:NUM_SPECIAL_WORDS])
def __init__(self, config): super(RobertaLMHead, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size))
def __init__(self, bert_config, input_size, output_all_encoded_layers=False): super(BertEncoderWrapper, self).__init__() self.bert_config = bert_config self.output_all_encoded_layers = output_all_encoded_layers self.input_transform = nn.Linear(input_size, bert_config.hidden_size) self.with_position_embeddings = False if 'with_position_embeddings' not in bert_config \ else bert_config.with_position_embeddings if self.with_position_embeddings: self.position_embedding = nn.Embedding( bert_config.max_position_embeddings, bert_config.hidden_size) self.LayerNorm = BertLayerNorm(bert_config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(bert_config.hidden_dropout_prob) self.bert_encoder = BertEncoder(bert_config) self.apply(self.init_bert_weights)
def __init__(self, dummy_config): super(LXMERT, self).__init__(dummy_config) frcnn_cfg = Config.from_pretrained("unc-nlp/frcnn-vg-finetuned") # self.frcnn = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=frcnn_cfg) self.backbone, self.roi_heads = build_image_encoder() self.lxmert_vqa = LxmertForPreTraining.from_pretrained("unc-nlp/lxmert-base-uncased") # self.lxmert_vqa = LxmertForQuestionAnswering.from_pretrained("unc-nlp/lxmert-vqa-uncased") self.tokenizer = LxmertTokenizer.from_pretrained("unc-nlp/lxmert-base-uncased") self.image_preprocess = Preprocess(frcnn_cfg) hid_dim = self.lxmert_vqa.config.hidden_size # transform = BertPredictionHeadTransform(self.config.NETWORK.VLBERT) self.logit_fc = nn.Sequential( nn.Linear(hid_dim, hid_dim), GELU(), BertLayerNorm(hid_dim), nn.Dropout(self.config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), nn.Linear(hid_dim, self.config.NETWORK.CLASSIFIER_CLASS), )
class VisualLinguisticBert(BaseModel): def __init__(self, config, language_pretrained_model_path=None, finetune_strategy='standard', is_policy_net=False): super(VisualLinguisticBert, self).__init__(config) self.config = config # embeddings self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.end_embedding = nn.Embedding(1, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) self.embedding_LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob) # for compatibility of roberta self.position_padding_idx = config.position_padding_idx # visual transform self.visual_1x1_text = None self.visual_1x1_object = None if config.visual_size != config.hidden_size: self.visual_1x1_text = nn.Linear(config.visual_size, config.hidden_size) self.visual_1x1_object = nn.Linear(config.visual_size, config.hidden_size) if config.visual_ln: self.visual_ln_text = BertLayerNorm(config.hidden_size, eps=1e-12) self.visual_ln_object = BertLayerNorm(config.hidden_size, eps=1e-12) else: visual_scale_text = nn.Parameter(torch.as_tensor( self.config.visual_scale_text_init, dtype=torch.float), requires_grad=True) self.register_parameter('visual_scale_text', visual_scale_text) visual_scale_object = nn.Parameter(torch.as_tensor( self.config.visual_scale_object_init, dtype=torch.float), requires_grad=True) self.register_parameter('visual_scale_object', visual_scale_object) self.encoder = BertEncoder(config, finetune_strategy=finetune_strategy) if self.config.with_pooler: self.pooler = BertPooler(config) # init weights self.apply(self.init_weights) if config.visual_ln: self.visual_ln_text.weight.data.fill_( self.config.visual_scale_text_init) self.visual_ln_object.weight.data.fill_( self.config.visual_scale_object_init) # self.is_policy_net self.is_policy_net = is_policy_net # load language pretrained model if language_pretrained_model_path is not None and not is_policy_net: self.load_language_pretrained_model(language_pretrained_model_path) if config.word_embedding_frozen: for p in self.word_embeddings.parameters(): p.requires_grad = False self.special_word_embeddings = nn.Embedding( NUM_SPECIAL_WORDS, config.hidden_size) self.special_word_embeddings.weight.data.copy_( self.word_embeddings.weight.data[:NUM_SPECIAL_WORDS]) def word_embeddings_wrapper(self, input_ids): if self.config.word_embedding_frozen: word_embeddings = self.word_embeddings(input_ids) word_embeddings[input_ids < NUM_SPECIAL_WORDS] \ = self.special_word_embeddings(input_ids[input_ids < NUM_SPECIAL_WORDS]) return word_embeddings else: return self.word_embeddings(input_ids) def forward(self, text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, object_mask, output_all_encoded_layers=True, output_text_and_object_separately=False, output_attention_probs=False, policy=None): # get seamless concatenate embeddings and mask embedding_output, attention_mask, text_mask_new, object_mask_new = self.embedding( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, object_mask) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to( dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 # extended_attention_mask = 1.0 - extended_attention_mask # extended_attention_mask[extended_attention_mask != 0] = float('-inf') if output_attention_probs: encoded_layers, attention_probs = self.encoder( embedding_output, extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers, output_attention_probs=output_attention_probs, policy=policy) else: encoded_layers = self.encoder( embedding_output, extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers, output_attention_probs=output_attention_probs, policy=policy) sequence_output = encoded_layers[-1] pooled_output = self.pooler( sequence_output) if self.config.with_pooler else None if not output_all_encoded_layers: encoded_layers = encoded_layers[-1] if output_text_and_object_separately: if not output_all_encoded_layers: encoded_layers = [encoded_layers] encoded_layers_text = [] encoded_layers_object = [] for encoded_layer in encoded_layers: max_text_len = text_input_ids.shape[1] max_object_len = object_vl_embeddings.shape[1] encoded_layer_text = encoded_layer[:, :max_text_len] encoded_layer_object = encoded_layer.new_zeros( (encoded_layer.shape[0], max_object_len, encoded_layer.shape[2])) encoded_layer_object[object_mask] = encoded_layer[ object_mask_new] encoded_layers_text.append(encoded_layer_text) encoded_layers_object.append(encoded_layer_object) if not output_all_encoded_layers: encoded_layers_text = encoded_layers_text[0] encoded_layers_object = encoded_layers_object[0] if output_attention_probs: return encoded_layers_text, encoded_layers_object, pooled_output, attention_probs else: return encoded_layers_text, encoded_layers_object, pooled_output else: if output_attention_probs: return encoded_layers, pooled_output, attention_probs else: return encoded_layers, pooled_output def embedding(self, text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, object_mask): text_linguistic_embedding = self.word_embeddings_wrapper( text_input_ids) if self.visual_1x1_text is not None: text_visual_embeddings = self.visual_1x1_text( text_visual_embeddings) if self.config.visual_ln: text_visual_embeddings = self.visual_ln_text( text_visual_embeddings) else: text_visual_embeddings *= self.visual_scale_text text_vl_embeddings = text_linguistic_embedding + text_visual_embeddings object_visual_embeddings = object_vl_embeddings[:, :, :self.config. visual_size] if self.visual_1x1_object is not None: object_visual_embeddings = self.visual_1x1_object( object_visual_embeddings) if self.config.visual_ln: object_visual_embeddings = self.visual_ln_object( object_visual_embeddings) else: object_visual_embeddings *= self.visual_scale_object object_linguistic_embeddings = object_vl_embeddings[:, :, self.config. visual_size:] object_vl_embeddings = object_linguistic_embeddings + object_visual_embeddings bs = text_vl_embeddings.size(0) vl_embed_size = text_vl_embeddings.size(-1) max_length = (text_mask.sum(1) + object_mask.sum(1)).max() + 1 grid_ind, grid_pos = torch.meshgrid( torch.arange(bs, dtype=torch.long, device=text_vl_embeddings.device), torch.arange(max_length, dtype=torch.long, device=text_vl_embeddings.device)) text_end = text_mask.sum(1, keepdim=True) object_end = text_end + object_mask.sum(1, keepdim=True) # seamlessly concatenate visual linguistic embeddings of text and object _zero_id = torch.zeros((bs, ), dtype=torch.long, device=text_vl_embeddings.device) vl_embeddings = text_vl_embeddings.new_zeros( (bs, max_length, vl_embed_size)) vl_embeddings[grid_pos < text_end] = text_vl_embeddings[text_mask] vl_embeddings[(grid_pos >= text_end) & ( grid_pos < object_end)] = object_vl_embeddings[object_mask] vl_embeddings[grid_pos == object_end] = self.end_embedding(_zero_id) # token type embeddings/ segment embeddings token_type_ids = text_token_type_ids.new_zeros((bs, max_length)) token_type_ids[grid_pos < text_end] = text_token_type_ids[text_mask] token_type_ids[(grid_pos >= text_end) & (grid_pos <= object_end)] = 2 token_type_embeddings = self.token_type_embeddings(token_type_ids) # position embeddings position_ids = grid_pos + self.position_padding_idx + 1 if self.config.obj_pos_id_relative: position_ids[(grid_pos >= text_end) & (grid_pos < object_end)] \ = text_end.expand((bs, max_length))[(grid_pos >= text_end) & (grid_pos < object_end)] \ + self.position_padding_idx + 1 position_ids[grid_pos == object_end] = ( text_end + 1).squeeze(1) + self.position_padding_idx + 1 else: assert False, "Don't use position id 510/511 for objects and [END]!!!" position_ids[(grid_pos >= text_end) & (grid_pos < object_end )] = self.config.max_position_embeddings - 2 position_ids[grid_pos == object_end] = self.config.max_position_embeddings - 1 position_embeddings = self.position_embeddings(position_ids) mask = text_mask.new_zeros((bs, max_length)) mask[grid_pos <= object_end] = 1 embeddings = vl_embeddings + position_embeddings + token_type_embeddings embeddings = self.embedding_LayerNorm(embeddings) embeddings = self.embedding_dropout(embeddings) return embeddings, mask, grid_pos < text_end, ( grid_pos >= text_end) & (grid_pos < object_end) def load_language_pretrained_model(self, language_pretrained_model_path): pretrained_state_dict = torch.load( language_pretrained_model_path, map_location=lambda storage, loc: storage) encoder_pretrained_state_dict = {} pooler_pretrained_state_dict = {} embedding_ln_pretrained_state_dict = {} unexpected_keys = [] for k, v in pretrained_state_dict.items(): if k.startswith('bert.'): k = k[len('bert.'):] elif k.startswith('roberta.'): k = k[len('roberta.'):] else: unexpected_keys.append(k) continue if 'gamma' in k: k = k.replace('gamma', 'weight') if 'beta' in k: k = k.replace('beta', 'bias') if k.startswith('encoder.'): k_ = k[len('encoder.'):] if k_ in self.encoder.state_dict(): encoder_pretrained_state_dict[k_] = v else: unexpected_keys.append(k) elif k.startswith('embeddings.'): k_ = k[len('embeddings.'):] if k_ == 'word_embeddings.weight': self.word_embeddings.weight.data = v.to( dtype=self.word_embeddings.weight.data.dtype, device=self.word_embeddings.weight.data.device) elif k_ == 'position_embeddings.weight': self.position_embeddings.weight.data = v.to( dtype=self.position_embeddings.weight.data.dtype, device=self.position_embeddings.weight.data.device) elif k_ == 'token_type_embeddings.weight': self.token_type_embeddings.weight.data[:v.size(0)] = v.to( dtype=self.token_type_embeddings.weight.data.dtype, device=self.token_type_embeddings.weight.data.device) if v.size(0) == 1: # Todo: roberta token type embedding self.token_type_embeddings.weight.data[1] = v[0].clone( ).to( dtype=self.token_type_embeddings.weight.data.dtype, device=self.token_type_embeddings.weight.data. device) self.token_type_embeddings.weight.data[2] = v[0].clone( ).to( dtype=self.token_type_embeddings.weight.data.dtype, device=self.token_type_embeddings.weight.data. device) elif k_.startswith('LayerNorm.'): k__ = k_[len('LayerNorm.'):] if k__ in self.embedding_LayerNorm.state_dict(): embedding_ln_pretrained_state_dict[k__] = v else: unexpected_keys.append(k) else: unexpected_keys.append(k) elif self.config.with_pooler and k.startswith('pooler.'): k_ = k[len('pooler.'):] if k_ in self.pooler.state_dict(): pooler_pretrained_state_dict[k_] = v else: unexpected_keys.append(k) else: unexpected_keys.append(k) if len(unexpected_keys) > 0: print("Warnings: Unexpected keys: {}.".format(unexpected_keys)) self.embedding_LayerNorm.load_state_dict( embedding_ln_pretrained_state_dict) # preprocess encoder state dict for parallel blocks if not self.is_policy_net: for k in self.encoder.state_dict(): if 'parallel_' in str(k): encoder_pretrained_state_dict[ k] = encoder_pretrained_state_dict[k.replace( 'parallel_', '')] self.encoder.load_state_dict(encoder_pretrained_state_dict) if self.config.with_pooler and len(pooler_pretrained_state_dict) > 0: self.pooler.load_state_dict(pooler_pretrained_state_dict)
def __init__(self, config, language_pretrained_model_path=None): super(VisualLinguisticBert, self).__init__(config) self.config = config # embeddings self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.end_embedding = nn.Embedding(1, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) self.embedding_LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob) # for compatibility of roberta self.position_padding_idx = config.position_padding_idx # visual transform self.visual_1x1_text = None self.visual_1x1_object = None if config.visual_size != config.hidden_size: # Always False self.visual_1x1_text = nn.Linear(config.visual_size, config.hidden_size) self.visual_1x1_object = nn.Linear(config.visual_size, config.hidden_size) if config.visual_ln: # Always True self.visual_ln_text = BertLayerNorm(config.hidden_size, eps=1e-12) self.visual_ln_object = BertLayerNorm(config.hidden_size, eps=1e-12) else: visual_scale_text = nn.Parameter(torch.as_tensor( self.config.visual_scale_text_init, dtype=torch.float), requires_grad=True) self.register_parameter('visual_scale_text', visual_scale_text) visual_scale_object = nn.Parameter(torch.as_tensor( self.config.visual_scale_object_init, dtype=torch.float), requires_grad=True) self.register_parameter('visual_scale_object', visual_scale_object) self.encoder = BertEncoder(config) if self.config.with_pooler: self.pooler = BertPooler(config) # init weights self.apply(self.init_weights) if config.visual_ln: self.visual_ln_text.weight.data.fill_( self.config.visual_scale_text_init) self.visual_ln_object.weight.data.fill_( self.config.visual_scale_object_init) # load language pretrained model if language_pretrained_model_path is not None: self.load_language_pretrained_model(language_pretrained_model_path) if config.word_embedding_frozen: # False by default for p in self.word_embeddings.parameters(): p.requires_grad = False self.special_word_embeddings = nn.Embedding( NUM_SPECIAL_WORDS, config.hidden_size) self.special_word_embeddings.weight.data.copy_( self.word_embeddings.weight.data[:NUM_SPECIAL_WORDS]) self.enhanced_img_feature = False self.no_predicate = False if config.ENHANCED_IMG_FEATURE: self.enhanced_img_feature = True if config.NO_PREDICATE: # VRD self.no_predicate = True self.lan_img_conv3 = nn.Conv2d(768, 1, kernel_size=(1, 1)) else: # SpatialSense self.lan_img_conv3 = nn.Conv2d(768, 768, kernel_size=(1, 1)) self.lan_img_conv4 = nn.Conv2d(768, 1, kernel_size=(1, 1)) self.obj_feat_downsample = nn.Conv2d(2048, 768, kernel_size=(1, 1)) self.obj_feat_batchnorm = nn.BatchNorm2d(768) self.lan_img_conv1 = nn.Conv2d(768, 768, kernel_size=(1, 1)) self.lan_img_conv2 = nn.Conv2d(768, 768, kernel_size=(1, 1)) # self.lan_img_conv3 = nn.Conv2d(768, 768, kernel_size=(1, 1)) # self.lan_img_bn1 = nn.BatchNorm2d(768) self.lan_img_avgpool = nn.AvgPool2d(14, stride=1)
class VisualLinguisticBert(BaseModel): def __init__(self, config, language_pretrained_model_path=None): super(VisualLinguisticBert, self).__init__(config) self.config = config # embeddings self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.end_embedding = nn.Embedding(1, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) self.embedding_LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob) # for compatibility of roberta self.position_padding_idx = config.position_padding_idx # visual transform self.visual_1x1_text = None self.visual_1x1_object = None if config.visual_size != config.hidden_size: # Always False self.visual_1x1_text = nn.Linear(config.visual_size, config.hidden_size) self.visual_1x1_object = nn.Linear(config.visual_size, config.hidden_size) if config.visual_ln: # Always True self.visual_ln_text = BertLayerNorm(config.hidden_size, eps=1e-12) self.visual_ln_object = BertLayerNorm(config.hidden_size, eps=1e-12) else: visual_scale_text = nn.Parameter(torch.as_tensor( self.config.visual_scale_text_init, dtype=torch.float), requires_grad=True) self.register_parameter('visual_scale_text', visual_scale_text) visual_scale_object = nn.Parameter(torch.as_tensor( self.config.visual_scale_object_init, dtype=torch.float), requires_grad=True) self.register_parameter('visual_scale_object', visual_scale_object) self.encoder = BertEncoder(config) if self.config.with_pooler: self.pooler = BertPooler(config) # init weights self.apply(self.init_weights) if config.visual_ln: self.visual_ln_text.weight.data.fill_( self.config.visual_scale_text_init) self.visual_ln_object.weight.data.fill_( self.config.visual_scale_object_init) # load language pretrained model if language_pretrained_model_path is not None: self.load_language_pretrained_model(language_pretrained_model_path) if config.word_embedding_frozen: # False by default for p in self.word_embeddings.parameters(): p.requires_grad = False self.special_word_embeddings = nn.Embedding( NUM_SPECIAL_WORDS, config.hidden_size) self.special_word_embeddings.weight.data.copy_( self.word_embeddings.weight.data[:NUM_SPECIAL_WORDS]) self.enhanced_img_feature = False self.no_predicate = False if config.ENHANCED_IMG_FEATURE: self.enhanced_img_feature = True if config.NO_PREDICATE: # VRD self.no_predicate = True self.lan_img_conv3 = nn.Conv2d(768, 1, kernel_size=(1, 1)) else: # SpatialSense self.lan_img_conv3 = nn.Conv2d(768, 768, kernel_size=(1, 1)) self.lan_img_conv4 = nn.Conv2d(768, 1, kernel_size=(1, 1)) self.obj_feat_downsample = nn.Conv2d(2048, 768, kernel_size=(1, 1)) self.obj_feat_batchnorm = nn.BatchNorm2d(768) self.lan_img_conv1 = nn.Conv2d(768, 768, kernel_size=(1, 1)) self.lan_img_conv2 = nn.Conv2d(768, 768, kernel_size=(1, 1)) # self.lan_img_conv3 = nn.Conv2d(768, 768, kernel_size=(1, 1)) # self.lan_img_bn1 = nn.BatchNorm2d(768) self.lan_img_avgpool = nn.AvgPool2d(14, stride=1) # TODO: make these layers trainable during finetuning! def word_embeddings_wrapper(self, input_ids): if self.config.word_embedding_frozen: # False by default word_embeddings = self.word_embeddings(input_ids) word_embeddings[input_ids < NUM_SPECIAL_WORDS] \ = self.special_word_embeddings(input_ids[input_ids < NUM_SPECIAL_WORDS]) return word_embeddings else: return self.word_embeddings(input_ids) def forward(self, text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, object_mask, object_visual_feat=None, spo_len=None, output_all_encoded_layers=True, output_text_and_object_separately=False, output_attention_probs=False): # get seamless concatenate embeddings and mask embedding_output, attention_mask, text_mask_new, object_mask_new, spo_fused_masks = self.embedding( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, object_mask, object_visual_feat, spo_len) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to( dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 # extended_attention_mask = 1.0 - extended_attention_mask # extended_attention_mask[extended_attention_mask != 0] = float('-inf') if output_attention_probs: encoded_layers, attention_probs = self.encoder( embedding_output, extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers, output_attention_probs=output_attention_probs) else: encoded_layers = self.encoder( embedding_output, extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers, output_attention_probs=output_attention_probs) sequence_output = encoded_layers[-1] pooled_output = self.pooler( sequence_output) if self.config.with_pooler else None if not output_all_encoded_layers: encoded_layers = encoded_layers[-1] if output_text_and_object_separately: # False if not output_all_encoded_layers: encoded_layers = [encoded_layers] encoded_layers_text = [] encoded_layers_object = [] for encoded_layer in encoded_layers: max_text_len = text_input_ids.shape[1] max_object_len = object_vl_embeddings.shape[1] encoded_layer_text = encoded_layer[:, :max_text_len] encoded_layer_object = encoded_layer.new_zeros( (encoded_layer.shape[0], max_object_len, encoded_layer.shape[2])) encoded_layer_object[object_mask] = encoded_layer[ object_mask_new] encoded_layers_text.append(encoded_layer_text) encoded_layers_object.append(encoded_layer_object) if not output_all_encoded_layers: encoded_layers_text = encoded_layers_text[0] encoded_layers_object = encoded_layers_object[0] if output_attention_probs: return encoded_layers_text, encoded_layers_object, pooled_output, attention_probs else: return encoded_layers_text, encoded_layers_object, pooled_output else: if output_attention_probs: # False return encoded_layers, pooled_output, attention_probs else: if spo_fused_masks is not None: return encoded_layers, pooled_output, spo_fused_masks else: return encoded_layers, pooled_output, None def embedding(self, text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, object_mask, object_visual_feat=None, spo_len=None): # (Text) Token Embedding + Visual Feature Embedding text_linguistic_embedding = self.word_embeddings_wrapper( text_input_ids) # if object_visual_feat is not None and spo_len is not None: if self.enhanced_img_feature and object_visual_feat is not None and spo_len is not None: # memo the shape object_visual_feat_shape = object_visual_feat.shape # reshape to [batch_size * nb_imgs, 2048, 14, 14] & downsample from 2048 to 768 object_visual_feat = object_visual_feat.view( -1, object_visual_feat.shape[2], object_visual_feat.shape[3], object_visual_feat.shape[4]) object_visual_feat = self.obj_feat_downsample(object_visual_feat) object_visual_feat = self.obj_feat_batchnorm(object_visual_feat) # restore to [batch_size, nb_imgs, 768, 14, 14] object_visual_feat = object_visual_feat.view( (object_visual_feat_shape[0], object_visual_feat_shape[1], 768, object_visual_feat_shape[3], object_visual_feat_shape[4])) spo_fused_masks = [] if self.no_predicate: # VRD for i in range(text_input_ids.shape[0]): # for each img sample subj_end = 1 + spo_len[i, 0] obj_end = subj_end + spo_len[i, 1] sub_text_emb = text_linguistic_embedding[ i, 1:subj_end].view(-1, 1, 1) if spo_len[ i, 0] == 1 else text_linguistic_embedding[ i, 1:subj_end].sum(dim=0).view(-1, 1, 1) sub_text_emb = torch.cat(([sub_text_emb] * 14), dim=1) sub_text_emb = torch.cat(([sub_text_emb] * 14), dim=2) fused_sub = object_visual_feat[i, 0] + sub_text_emb obj_text_emb = text_linguistic_embedding[ i, subj_end:obj_end].view(-1, 1, 1) if spo_len[ i, 1] == 1 else text_linguistic_embedding[ i, subj_end:obj_end].sum(dim=0).view(-1, 1, 1) obj_text_emb = torch.cat(([obj_text_emb] * 14), dim=1) obj_text_emb = torch.cat(([obj_text_emb] * 14), dim=2) fused_obj = object_visual_feat[i, 0] + obj_text_emb # spo_fused = torch.cat((fused_sub.unsqueeze(0), fused_obj.unsqueeze(0))) fused_sub = self.lan_img_conv1(fused_sub.unsqueeze(0)) fused_sub = nn.functional.relu(fused_sub) fused_obj = self.lan_img_conv2(fused_obj.unsqueeze(0)) fused_obj = nn.functional.relu(fused_obj) spo_fused = torch.cat((fused_sub, fused_obj)) spo_fused = self.lan_img_conv3(spo_fused).squeeze() # spo_fused[spo_fused > 1] = 1 # spo_fused[spo_fused < 0] = 0 # Save the mask for computing mask loss # spo_fused_masks.append(spo_fused) spo_fused_norm = torch.zeros_like(spo_fused) for j in range(2): max1, min1 = spo_fused[j].max(), spo_fused[j].min() denominator = max1 - min1 if not self.config.mask_loss_sum and not self.config.mask_loss_mse: # For BCE loss # import pdb; pdb.set_trace() denominator += 1e-10 spo_fused_norm[j] = (spo_fused[j] - min1) / denominator # spo_fused[j] = torch.sigmoid(spo_fused[j]) spo_fused_masks.append(spo_fused_norm) text_visual_embeddings[ i, 1:subj_end] = self.lan_img_avgpool( object_visual_feat[i, 0] * spo_fused_norm[0]).squeeze() # spo_fused[0] text_visual_embeddings[ i, subj_end:obj_end] = self.lan_img_avgpool( object_visual_feat[i, 0] * spo_fused_norm[1]).squeeze() # spo_fused[1] else: # SpatialSense for i in range(text_input_ids.shape[0]): # for each img sample subj_end = 1 + spo_len[i, 0] pred_end = subj_end + spo_len[i, 1] obj_end = pred_end + spo_len[i, 2] sub_text_emb = text_linguistic_embedding[ i, 1:subj_end].view(-1, 1, 1) if spo_len[ i, 0] == 1 else text_linguistic_embedding[ i, 1:subj_end].sum(dim=0).view(-1, 1, 1) sub_text_emb = torch.cat(([sub_text_emb] * 14), dim=1) sub_text_emb = torch.cat(([sub_text_emb] * 14), dim=2) fused_sub = object_visual_feat[i, 0] + sub_text_emb pred_text_emb = text_linguistic_embedding[ i, subj_end:pred_end].view(-1, 1, 1) if spo_len[ i, 1] == 1 else text_linguistic_embedding[ i, subj_end:pred_end].sum(dim=0).view(-1, 1, 1) pred_text_emb = torch.cat(([pred_text_emb] * 14), dim=1) pred_text_emb = torch.cat(([pred_text_emb] * 14), dim=2) fused_pred = object_visual_feat[i, 0] + pred_text_emb obj_text_emb = text_linguistic_embedding[ i, pred_end:obj_end].view(-1, 1, 1) if spo_len[ i, 2] == 1 else text_linguistic_embedding[ i, pred_end:obj_end].sum(dim=0).view(-1, 1, 1) obj_text_emb = torch.cat(([obj_text_emb] * 14), dim=1) obj_text_emb = torch.cat(([obj_text_emb] * 14), dim=2) fused_obj = object_visual_feat[i, 0] + obj_text_emb fused_sub = self.lan_img_conv1(fused_sub.unsqueeze(0)) fused_sub = nn.functional.relu(fused_sub) fused_pred = self.lan_img_conv2(fused_pred.unsqueeze(0)) fused_pred = nn.functional.relu(fused_pred) fused_obj = self.lan_img_conv3(fused_obj.unsqueeze(0)) fused_obj = nn.functional.relu(fused_obj) spo_fused = torch.cat((fused_sub, fused_pred, fused_obj)) spo_fused = self.lan_img_conv4(spo_fused).squeeze() # spo_fused[spo_fused > 1] = 1 # spo_fused[spo_fused < 0] = 0 # Save the mask for computing mask loss # spo_fused_masks.append(spo_fused) # spo_fused += object_visual_feat[i,0].unsqueeze(0) # spo_fused = self.lan_img_avgpool(spo_fused).squeeze() # import pdb; pdb.set_trace() spo_fused_norm = torch.zeros_like(spo_fused) for j in range(3): max1, min1 = spo_fused[j].max(), spo_fused[j].min() denominator = max1 - min1 # if not self.config.mask_loss_sum and not self.config.mask_loss_mse: # For BCE loss # # import pdb; pdb.set_trace() # denominator += 1e-10 spo_fused_norm[j] = (spo_fused[j] - min1) / denominator # spo_fused[j] = torch.sigmoid(spo_fused[j]) spo_fused_masks.append(spo_fused_norm) text_visual_embeddings[i, 1:subj_end] = self.lan_img_avgpool( object_visual_feat[i, 0] * spo_fused_norm[0]).squeeze() text_visual_embeddings[ i, subj_end:pred_end] = self.lan_img_avgpool( object_visual_feat[i, 0] * spo_fused_norm[1]).squeeze() # spo_fused[1] text_visual_embeddings[ i, pred_end:obj_end] = self.lan_img_avgpool( object_visual_feat[i, 0] * spo_fused_norm[2]).squeeze() # spo_fused[1] spo_fused_masks = torch.cat(spo_fused_masks) if self.visual_1x1_text is not None: # always False text_visual_embeddings = self.visual_1x1_text( text_visual_embeddings) if self.config.visual_ln: # always True text_visual_embeddings = self.visual_ln_text( text_visual_embeddings) else: text_visual_embeddings *= self.visual_scale_text text_vl_embeddings = text_linguistic_embedding + text_visual_embeddings # (Object) Token Embedding + Visual Feature Embedding object_visual_embeddings = object_vl_embeddings[:, :, :self.config. visual_size] if self.visual_1x1_object is not None: # always False object_visual_embeddings = self.visual_1x1_object( object_visual_embeddings) if self.config.visual_ln: # always True object_visual_embeddings = self.visual_ln_object( object_visual_embeddings) else: object_visual_embeddings *= self.visual_scale_object object_linguistic_embeddings = object_vl_embeddings[:, :, self.config. visual_size:] # import pdb; pdb.set_trace() object_vl_embeddings = object_linguistic_embeddings + object_visual_embeddings # Some indices setup for following process bs = text_vl_embeddings.size(0) vl_embed_size = text_vl_embeddings.size(-1) max_length = (text_mask.sum(1) + object_mask.sum(1)).max() + 1 grid_ind, grid_pos = torch.meshgrid( torch.arange(bs, dtype=torch.long, device=text_vl_embeddings.device), torch.arange(max_length, dtype=torch.long, device=text_vl_embeddings.device)) text_end = text_mask.sum(1, keepdim=True) object_end = text_end + object_mask.sum(1, keepdim=True) # seamlessly concatenate visual linguistic embeddings of text and object _zero_id = torch.zeros((bs, ), dtype=torch.long, device=text_vl_embeddings.device) vl_embeddings = text_vl_embeddings.new_zeros( (bs, max_length, vl_embed_size)) vl_embeddings[grid_pos < text_end] = text_vl_embeddings[text_mask] vl_embeddings[(grid_pos >= text_end) & ( grid_pos < object_end)] = object_vl_embeddings[object_mask] vl_embeddings[grid_pos == object_end] = self.end_embedding( _zero_id) # '[END]' # segment embeddings / token type embeddings # import pdb; pdb.set_trace() token_type_ids = text_token_type_ids.new_zeros((bs, max_length)) token_type_ids[grid_pos < text_end] = text_token_type_ids[text_mask] token_type_ids[(grid_pos >= text_end) & (grid_pos <= object_end)] = 2 token_type_embeddings = self.token_type_embeddings(token_type_ids) # position embeddings position_ids = grid_pos + self.position_padding_idx + 1 if self.config.use_img_region_order: pass elif self.config.obj_pos_id_relative: # always True! position_ids[(grid_pos >= text_end) & (grid_pos < object_end)] \ = text_end.expand((bs, max_length))[(grid_pos >= text_end) & (grid_pos < object_end)] \ + self.position_padding_idx + 1 position_ids[grid_pos == object_end] = ( text_end + 1).squeeze(1) + self.position_padding_idx + 1 else: assert False, "Don't use position id 510/511 for objects and [END]!!!" position_ids[(grid_pos >= text_end) & (grid_pos < object_end )] = self.config.max_position_embeddings - 2 position_ids[grid_pos == object_end] = self.config.max_position_embeddings - 1 position_embeddings = self.position_embeddings(position_ids) # import pdb; pdb.set_trace() mask = text_mask.new_zeros((bs, max_length)) mask[grid_pos <= object_end] = 1 embeddings = vl_embeddings + position_embeddings + token_type_embeddings embeddings = self.embedding_LayerNorm(embeddings) embeddings = self.embedding_dropout(embeddings) if self.enhanced_img_feature and object_visual_feat is not None and spo_len is not None: return embeddings, mask, grid_pos < text_end, ( grid_pos >= text_end) & (grid_pos < object_end), spo_fused_masks else: return embeddings, mask, grid_pos < text_end, ( grid_pos >= text_end) & (grid_pos < object_end), None def load_language_pretrained_model(self, language_pretrained_model_path): pretrained_state_dict = torch.load( language_pretrained_model_path, map_location=lambda storage, loc: storage) encoder_pretrained_state_dict = {} pooler_pretrained_state_dict = {} embedding_ln_pretrained_state_dict = {} unexpected_keys = [] for k, v in pretrained_state_dict.items(): if k.startswith('bert.'): k = k[len('bert.'):] elif k.startswith('roberta.'): k = k[len('roberta.'):] else: unexpected_keys.append(k) continue if 'gamma' in k: k = k.replace('gamma', 'weight') if 'beta' in k: k = k.replace('beta', 'bias') if k.startswith('encoder.'): k_ = k[len('encoder.'):] if k_ in self.encoder.state_dict(): encoder_pretrained_state_dict[k_] = v else: unexpected_keys.append(k) elif k.startswith('embeddings.'): k_ = k[len('embeddings.'):] if k_ == 'word_embeddings.weight': self.word_embeddings.weight.data = v.to( dtype=self.word_embeddings.weight.data.dtype, device=self.word_embeddings.weight.data.device) elif k_ == 'position_embeddings.weight': self.position_embeddings.weight.data = v.to( dtype=self.position_embeddings.weight.data.dtype, device=self.position_embeddings.weight.data.device) elif k_ == 'token_type_embeddings.weight': self.token_type_embeddings.weight.data[:v.size(0)] = v.to( dtype=self.token_type_embeddings.weight.data.dtype, device=self.token_type_embeddings.weight.data.device) if v.size(0) == 1: # Todo: roberta token type embedding self.token_type_embeddings.weight.data[1] = v[0].clone( ).to( dtype=self.token_type_embeddings.weight.data.dtype, device=self.token_type_embeddings.weight.data. device) self.token_type_embeddings.weight.data[2] = v[0].clone( ).to( dtype=self.token_type_embeddings.weight.data.dtype, device=self.token_type_embeddings.weight.data. device) elif k_.startswith('LayerNorm.'): k__ = k_[len('LayerNorm.'):] if k__ in self.embedding_LayerNorm.state_dict(): embedding_ln_pretrained_state_dict[k__] = v else: unexpected_keys.append(k) else: unexpected_keys.append(k) elif self.config.with_pooler and k.startswith('pooler.'): k_ = k[len('pooler.'):] if k_ in self.pooler.state_dict(): pooler_pretrained_state_dict[k_] = v else: unexpected_keys.append(k) else: unexpected_keys.append(k) if len(unexpected_keys) > 0: print("Warnings: Unexpected keys: {}.".format(unexpected_keys)) self.embedding_LayerNorm.load_state_dict( embedding_ln_pretrained_state_dict) self.encoder.load_state_dict(encoder_pretrained_state_dict) if self.config.with_pooler and len(pooler_pretrained_state_dict) > 0: self.pooler.load_state_dict(pooler_pretrained_state_dict)
def __init__(self, config, language_pretrained_model_path=None): super(VisualLinguisticBertDecoder, self).__init__(config) self.config = config # embeddings self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.end_embedding = nn.Embedding(1, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) self.embedding_LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob) # for compatibility of roberta self.position_padding_idx = config.position_padding_idx # visual transform self.visual_1x1_text = None self.visual_1x1_object = None if config.visual_size != config.hidden_size: self.visual_1x1_text = nn.Linear(config.visual_size, config.hidden_size) self.visual_1x1_object = nn.Linear(config.visual_size, config.hidden_size) if config.visual_ln: self.visual_ln_text = BertLayerNorm(config.hidden_size, eps=1e-12) self.visual_ln_object = BertLayerNorm(config.hidden_size, eps=1e-12) else: visual_scale_text = nn.Parameter(torch.as_tensor( self.config.visual_scale_text_init, dtype=torch.float), requires_grad=True) self.register_parameter('visual_scale_text', visual_scale_text) visual_scale_object = nn.Parameter(torch.as_tensor( self.config.visual_scale_object_init, dtype=torch.float), requires_grad=True) self.register_parameter('visual_scale_object', visual_scale_object) # ********************************************* # FM addition - Set-up decoder layer for MT # Initializing a BERT bert-base-uncased style configuration configuration = BertConfig() configuration.vocab_size = config.vocab_size # FM edit: reduce size - 12 layers doesn't fit in single 12GB GPU configuration.num_hidden_layers = 6 configuration.is_decoder = True # Initializing a model from the bert-base-uncased style configuration self.decoder = BertModel(configuration) # ********************************************* if self.config.with_pooler: self.pooler = BertPooler(config) # init weights self.apply(self.init_weights) if config.visual_ln: self.visual_ln_text.weight.data.fill_( self.config.visual_scale_text_init) self.visual_ln_object.weight.data.fill_( self.config.visual_scale_object_init) # load language pretrained model if language_pretrained_model_path is not None: self.load_language_pretrained_model(language_pretrained_model_path) if config.word_embedding_frozen: for p in self.word_embeddings.parameters(): p.requires_grad = False self.special_word_embeddings = nn.Embedding( NUM_SPECIAL_WORDS, config.hidden_size) self.special_word_embeddings.weight.data.copy_( self.word_embeddings.weight.data[:NUM_SPECIAL_WORDS])