class ResNetVLBERTv3(Module): def __init__(self, config): super(ResNetVLBERTv3, self).__init__(config) self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS if not config.NETWORK.BLIND: self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=self.enable_cnn_reg_loss) if config.NETWORK.VLBERT.object_word_embed_mode == 1: self.object_linguistic_embeddings = nn.Embedding( 601, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 2: self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 3: self.object_linguistic_embeddings = None else: raise NotImplementedError self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path self.language_pretrained_model_path = language_pretrained_model_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path) # self.hm_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) # self.hi_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) dim = config.NETWORK.VLBERT.hidden_size if config.NETWORK.CLASSIFIER_TYPE == "2fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, config.NETWORK.CLASSIFIER_CLASS), ) elif config.NETWORK.CLASSIFIER_TYPE == "1fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_CLASS)) elif config.NETWORK.CLASSIFIER_TYPE == 'mlm': transform = BertPredictionHeadTransform(config.NETWORK.VLBERT) linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.CLASSIFIER_CLASS) self.final_mlp = nn.Sequential( transform, nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), linear) else: raise ValueError("Not support classifier type: {}!".format( config.NETWORK.CLASSIFIER_TYPE)) # init weights self.init_weight() self.fix_params() def setup_adapter(self, freeze_rcnn=True): if hasattr(self, 'image_feature_extractor') and freeze_rcnn: for param in self.image_feature_extractor.parameters(): param.requires_grad = False self.vlbert.setup_adapter() def init_weight(self): # self.hm_out.weight.data.normal_(mean=0.0, std=0.02) # self.hm_out.bias.data.zero_() # self.hi_out.weight.data.normal_(mean=0.0, std=0.02) # self.hi_out.bias.data.zero_() self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_(mean=0.0, std=0.02) for m in self.final_mlp.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) if self.config.NETWORK.CLASSIFIER_TYPE == 'mlm': language_pretrained = torch.load( self.language_pretrained_model_path) mlm_transform_state_dict = {} pretrain_keys = [] for k, v in language_pretrained.items(): if k.startswith('cls.predictions.transform.'): pretrain_keys.append(k) k_ = k[len('cls.predictions.transform.'):] if 'gamma' in k_: k_ = k_.replace('gamma', 'weight') if 'beta' in k_: k_ = k_.replace('beta', 'bias') mlm_transform_state_dict[k_] = v print("loading pretrained classifier transform keys: {}.".format( pretrain_keys)) self.final_mlp[0].load_state_dict(mlm_transform_state_dict) def train(self, mode=True): super(ResNetVLBERTv3, self).train(mode) # turn some frozen layers to eval mode if self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): pass def _collect_obj_reps(self, span_tags, object_reps): """ Collect span-level object representations :param span_tags: [batch_size, ..leading_dims.., L] :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] :return: """ span_tags_fixed = torch.clamp( span_tags, min=0) # In case there were masked values here row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape) row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None] # Add extra diminsions to the row broadcaster so it matches row_id leading_dims = len(span_tags.shape) - 2 for i in range(leading_dims): row_id_broadcaster = row_id_broadcaster[..., None] row_id += row_id_broadcaster object_select = object_reps[row_id.view(-1), span_tags_fixed.view(-1)] return object_select.view(*span_tags_fixed.shape, -1) def prepare_text(self, question, question_tags, question_mask): batch_size, max_q_len = question.shape max_len = question_mask.sum(1).max() + 2 cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) q_end = 1 + question_mask.sum(1, keepdim=True) input_ids = torch.zeros((batch_size, max_len), dtype=question.dtype, device=question.device) input_mask = torch.ones((batch_size, max_len), dtype=torch.bool, device=question.device) input_type_ids = torch.zeros((batch_size, max_len), dtype=question.dtype, device=question.device) text_tags = input_type_ids.new_zeros((batch_size, max_len)) grid_i, grid_j = torch.meshgrid( torch.arange(batch_size, device=question.device), torch.arange(max_len, device=question.device)) input_mask[grid_j > q_end] = 0 # input_type_ids[(grid_j > q_end) & (grid_j <= a_end)] = 1 q_input_mask = (grid_j > 0) & (grid_j < q_end) input_ids[:, 0] = cls_id input_ids[grid_j == q_end] = sep_id input_ids[q_input_mask] = question[question_mask] text_tags[q_input_mask] = question_tags[question_mask] return input_ids, input_type_ids, text_tags, input_mask def train_forward( self, image, boxes, im_info, text, label: torch.Tensor, *args, ): ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] objects = boxes[:, :max_len, 4] boxes = boxes[:, :max_len, :4] obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) if self.config.NETWORK.IMAGE_FROZEN_BACKBONE_ALL: obj_reps = {k: v.detach() for k, v in obj_reps.items()} text_ids = text text_tags = text.new_zeros(text_ids.shape) text_mask = (text > 0.5) ############################################ # prepare text text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text( text_ids, text_tags, text_mask) if self.config.NETWORK.NO_GROUNDING: obj_rep_zeroed = obj_reps['obj_reps'].new_zeros( obj_reps['obj_reps'].shape) text_tags.zero_() text_visual_embeddings = self._collect_obj_reps( text_tags, obj_rep_zeroed) else: text_visual_embeddings = self._collect_obj_reps( text_tags, obj_reps['obj_reps']) assert self.config.NETWORK.VLBERT.object_word_embed_mode in [1, 2] if self.config.NETWORK.VLBERT.object_word_embed_mode == 1: object_linguistic_embeddings = self.object_linguistic_embeddings( objects.long().clamp( min=0, max=self.object_linguistic_embeddings.weight.data.shape[0] - 1)) else: object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT hidden_states, hc = self.vlbert(text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_all_encoded_layers=False) _batch_inds = torch.arange(text.shape[0], device=text.device) hm = hidden_states[_batch_inds, 0] # hm = F.tanh(self.hm_out(hidden_states[_batch_inds, ans_pos])) # hi = F.tanh(self.hi_out(hidden_states[_batch_inds, ans_pos + 2])) ########################################### outputs = {} # classifier # logits = self.final_mlp(hc * hm * hi) # logits = self.final_mlp(hc) logits = self.final_mlp(hm) if label.ndim == 1: label = label.unsqueeze(1) # loss ans_loss = F.binary_cross_entropy_with_logits(logits, label) * label.size(1) outputs.update({ 'label_logits': logits, 'label': label, 'ans_loss': ans_loss }) loss = ans_loss.mean() return outputs, loss def inference_forward(self, image, boxes, im_info, text, *args): ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] objects = boxes[:, :max_len, 4] boxes = boxes[:, :max_len, :4] obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) text_ids = text text_tags = text.new_zeros(text_ids.shape) text_mask = (text > 0.5) ############################################ # prepare text text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text( text_ids, text_tags, text_mask) if self.config.NETWORK.NO_GROUNDING: obj_rep_zeroed = obj_reps['obj_reps'].new_zeros( obj_reps['obj_reps'].shape) text_tags.zero_() text_visual_embeddings = self._collect_obj_reps( text_tags, obj_rep_zeroed) else: text_visual_embeddings = self._collect_obj_reps( text_tags, obj_reps['obj_reps']) assert self.config.NETWORK.VLBERT.object_word_embed_mode in [1, 2] if self.config.NETWORK.VLBERT.object_word_embed_mode == 1: object_linguistic_embeddings = self.object_linguistic_embeddings( objects.long().clamp( min=0, max=self.object_linguistic_embeddings.weight.data.shape[0] - 1)) else: object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT hidden_states, hc = self.vlbert(text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_all_encoded_layers=False) _batch_inds = torch.arange(text.shape[0], device=text.device) hm = hidden_states[_batch_inds, 0] # hm = F.tanh(self.hm_out(hidden_states[_batch_inds, ans_pos])) # hi = F.tanh(self.hi_out(hidden_states[_batch_inds, ans_pos + 2])) ########################################### outputs = {} # classifier # logits = self.final_mlp(hc * hm * hi) # logits = self.final_mlp(hc) logits = self.final_mlp(hm) outputs.update({'label_logits': logits}) return outputs
class ResNetVLBERT(Module): def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=False) self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path self.language_pretrained_model_path = language_pretrained_model_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path) self.task1_head = Task1Head(config.NETWORK.VLBERT) self.task2_head = Task2Head(config.NETWORK.VLBERT) self.task3_head = Task3Head(config.NETWORK.VLBERT) # init weights self.init_weight() self.fix_params() def init_weight(self): self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) for m in self.task1_head.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) for m in self.task2_head.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) for m in self.task3_head.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) def train(self, mode=True): super(ResNetVLBERT, self).train(mode) # turn some frozen layers to eval mode if self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): for param in self.image_feature_extractor.parameters(): param.requires_grad = False for param in self.vlbert.parameters(): param.requires_grad = False for param in self.object_linguistic_embeddings.parameters(): param.requires_grad def train_forward(self, image, boxes, im_info, expression, label, pos, target, mask): ########################################### if self.vlbert.training: self.vlbert.eval() if self.image_feature_extractor.training: self.image_feature_extractor.eval() # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) max_len = int(box_mask.sum(1).max().item()) origin_len = boxes.shape[1] box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] #our labels for foil are binary and 1 dimension #label = label[:, :max_len] obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) ############################################ # prepare text cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) text_input_ids = expression.new_zeros( (expression.shape[0], expression.shape[1] + 2)) text_input_ids[:, 0] = cls_id text_input_ids[:, 1:-1] = expression _sep_pos = (text_input_ids > 0).sum(1) _batch_inds = torch.arange(expression.shape[0], device=expression.device) text_input_ids[_batch_inds, _sep_pos] = sep_id text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape) text_mask = text_input_ids > 0 text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze( 1).repeat((1, text_input_ids.shape[1], 1)) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT sequence_reps, object_reps, _ = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_text_and_object_separately=True, output_all_encoded_layers=False) cls_rep = sequence_reps[:, 0, :].squeeze(1) sequence_reps = sequence_reps[:, 1:, :] cls_log_probs = self.task1_head(cls_rep) pos_log_probs = self.task2_head(sequence_reps, text_mask[:, 1:]) zeros = torch.zeros_like(text_input_ids) mask_len = mask.shape[1] error_cor_mask = zeros error_cor_mask[:, 1:mask_len + 1] = mask error_cor_mask = error_cor_mask.bool() text_input_ids[error_cor_mask] = self.tokenizer.convert_tokens_to_ids( ['[MASK]'])[0] sequence_reps, _, _ = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_text_and_object_separately=True, output_all_encoded_layers=False) sequence_reps = sequence_reps[:, 1:, :] select_index = pos.view(-1, 1).unsqueeze(2).repeat(1, 1, 768) select_index[select_index < 0] = 0 masked_reps = torch.gather(sequence_reps, 1, select_index).squeeze(1) cor_log_probs = self.task3_head(masked_reps) loss_mask = label.view(-1).float() cls_loss = F.binary_cross_entropy(cls_log_probs.view(-1), label.view(-1).float(), reduction="none") pos_loss = F.nll_loss( pos_log_probs, pos, ignore_index=-1, reduction="none").view(-1) * loss_mask cor_loss = F.nll_loss(cor_log_probs.view(-1, cor_log_probs.shape[-1]), target, ignore_index=0, reduction="none").view(-1) * loss_mask loss = cls_loss.mean() + pos_loss.mean() + cor_loss.mean() outputs = { "cls_logits": cls_log_probs, "pos_logits": pos_log_probs, "cor_logits": cor_log_probs, "cls_label": label, "pos_label": pos, "cor_label": target } return outputs, loss def inference_forward(self, image, boxes, im_info, expression, label, pos, target): # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) max_len = int(box_mask.sum(1).max().item()) origin_len = boxes.shape[1] box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] #our labels for foil are binary and 1 dimension #label = label[:, :max_len] with torch.no_grad(): obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) ############################################ # prepare text cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) text_input_ids = expression.new_zeros( (expression.shape[0], expression.shape[1] + 2)) text_input_ids[:, 0] = cls_id text_input_ids[:, 1:-1] = expression _sep_pos = (text_input_ids > 0).sum(1) _batch_inds = torch.arange(expression.shape[0], device=expression.device) text_input_ids[_batch_inds, _sep_pos] = sep_id text_token_type_ids = text_input_ids.new_zeros( text_input_ids.shape) text_mask = text_input_ids > 0 text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze( 1).repeat((1, text_input_ids.shape[1], 1)) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT sequence_reps, object_reps, _ = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_text_and_object_separately=True, output_all_encoded_layers=False) cls_rep = sequence_reps[:, 0, :].squeeze(1) sequence_reps = sequence_reps[:, 1:, :] cls_log_probs = self.task1_head(cls_rep) pos_log_probs = self.task2_head(sequence_reps, text_mask[:, 1:]) zeros = torch.zeros_like(text_input_ids) mask_len = mask.shape[1] error_cor_mask = zeros error_cor_mask[:, 1:mask_len + 1] = mask error_cor_mask = error_cor_mask.bool() text_input_ids[error_cor_mask] = self.tokenizer.convert_tokens_to_ids( ['[MASK]'])[0] with torch.no_grad(): sequence_reps, _, _ = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_text_and_object_separately=True, output_all_encoded_layers=False) sequence_reps = sequence_reps[:, 1:, :] select_index = pos.view(-1, 1).unsqueeze(2).repeat(1, 1, 768) select_index[select_index < 0] = 0 masked_reps = torch.gather(sequence_reps, 1, select_index).squeeze(1) cor_log_probs = self.task3_head(masked_reps) loss_mask = label.view(-1).float() cls_loss = F.binary_cross_entropy(cls_log_probs.view(-1), label.view(-1).float(), reduction="none") pos_loss = F.nll_loss( pos_log_probs, pos, ignore_index=-1, reduction="none").view(-1) * loss_mask cor_loss = F.nll_loss(cor_log_probs.view(-1, cor_log_probs.shape[-1]), target, ignore_index=0, reduction="none").view(-1) * loss_mask loss = cls_loss.mean() + pos_loss.mean() + cor_loss.mean() outputs = { "cls_logits": cls_log_probs, "pos_logits": pos_log_probs, "cor_logits": cor_log_probs, "cls_label": label, "pos_label": pos, "cor_label": target } return outputs, loss