class ResNetVLBERT(Module): def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS if not config.NETWORK.BLIND: self.image_feature_extractor = FastRCNN(config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=self.enable_cnn_reg_loss) if config.NETWORK.VLBERT.object_word_embed_mode == 1: self.object_linguistic_embeddings = nn.Embedding(81, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 2: self.object_linguistic_embeddings = nn.Embedding(1, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 3: self.object_linguistic_embeddings = None else: raise NotImplementedError self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained(config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format(config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path self.language_pretrained_model_path = language_pretrained_model_path if language_pretrained_model_path is None: print("Warning: no pretrained language model found, training from scratch!!!") # Also pass the finetuning strategy self.vlbert = VisualLinguisticBert(config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path, finetune_strategy=config.FINETUNE_STRATEGY) # self.hm_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) # self.hi_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) dim = config.NETWORK.VLBERT.hidden_size if config.NETWORK.CLASSIFIER_TYPE == "2fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, config.DATASET.ANSWER_VOCAB_SIZE), ) elif config.NETWORK.CLASSIFIER_TYPE == "1fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.DATASET.ANSWER_VOCAB_SIZE) ) elif config.NETWORK.CLASSIFIER_TYPE == 'mlm': transform = BertPredictionHeadTransform(config.NETWORK.VLBERT) linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.DATASET.ANSWER_VOCAB_SIZE) self.final_mlp = nn.Sequential( transform, nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), linear ) else: raise ValueError("Not support classifier type: {}!".format(config.NETWORK.CLASSIFIER_TYPE)) # init weights self.init_weight() self.fix_params() def init_weight(self): # self.hm_out.weight.data.normal_(mean=0.0, std=0.02) # self.hm_out.bias.data.zero_() # self.hi_out.weight.data.normal_(mean=0.0, std=0.02) # self.hi_out.bias.data.zero_() self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_(mean=0.0, std=0.02) for m in self.final_mlp.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) if self.config.NETWORK.CLASSIFIER_TYPE == 'mlm': language_pretrained = torch.load(self.language_pretrained_model_path) mlm_transform_state_dict = {} pretrain_keys = [] for k, v in language_pretrained.items(): if k.startswith('cls.predictions.transform.'): pretrain_keys.append(k) k_ = k[len('cls.predictions.transform.'):] if 'gamma' in k_: k_ = k_.replace('gamma', 'weight') if 'beta' in k_: k_ = k_.replace('beta', 'bias') mlm_transform_state_dict[k_] = v print("loading pretrained classifier transform keys: {}.".format(pretrain_keys)) self.final_mlp[0].load_state_dict(mlm_transform_state_dict) def train(self, mode=True): super(ResNetVLBERT, self).train(mode) # turn some frozen layers to eval mode if self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): pass def _collect_obj_reps(self, span_tags, object_reps): """ Collect span-level object representations :param span_tags: [batch_size, ..leading_dims.., L] :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] :return: """ span_tags_fixed = torch.clamp(span_tags, min=0) # In case there were masked values here row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape) row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None] # Add extra diminsions to the row broadcaster so it matches row_id leading_dims = len(span_tags.shape) - 2 for i in range(leading_dims): row_id_broadcaster = row_id_broadcaster[..., None] row_id += row_id_broadcaster return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view(*span_tags_fixed.shape, -1) def prepare_text_from_qa(self, question, question_tags, question_mask, answer, answer_tags, answer_mask): batch_size, max_q_len = question.shape _, max_a_len = answer.shape max_len = (question_mask.sum(1) + answer_mask.sum(1)).max() + 3 cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(['[CLS]', '[SEP]']) q_end = 1 + question_mask.sum(1, keepdim=True) a_end = q_end + 1 + answer_mask.sum(1, keepdim=True) input_ids = torch.zeros((batch_size, max_len), dtype=question.dtype, device=question.device) input_mask = torch.ones((batch_size, max_len), dtype=torch.bool, device=question.device) input_type_ids = torch.zeros((batch_size, max_len), dtype=question.dtype, device=question.device) text_tags = input_type_ids.new_zeros((batch_size, max_len)) grid_i, grid_j = torch.meshgrid(torch.arange(batch_size, device=question.device), torch.arange(max_len, device=question.device)) input_mask[grid_j > a_end] = 0 input_type_ids[(grid_j > q_end) & (grid_j <= a_end)] = 1 q_input_mask = (grid_j > 0) & (grid_j < q_end) a_input_mask = (grid_j > q_end) & (grid_j < a_end) input_ids[:, 0] = cls_id input_ids[grid_j == q_end] = sep_id input_ids[grid_j == a_end] = sep_id input_ids[q_input_mask] = question[question_mask] input_ids[a_input_mask] = answer[answer_mask] text_tags[q_input_mask] = question_tags[question_mask] text_tags[a_input_mask] = answer_tags[answer_mask] return input_ids, input_type_ids, text_tags, input_mask, (a_end - 1).squeeze(1) def train_forward(self, image, boxes, im_info, question, label, policy=None ): ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > - 1.5) max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) question_ids = question question_tags = question.new_zeros(question_ids.shape) question_mask = (question > 0.5) answer_ids = question_ids.new_zeros((question_ids.shape[0], 1)).fill_( self.tokenizer.convert_tokens_to_ids(['[MASK]'])[0]) answer_mask = question_mask.new_zeros(answer_ids.shape).fill_(1) answer_tags = question_tags.new_zeros(answer_ids.shape) ############################################ # prepare text text_input_ids, text_token_type_ids, text_tags, text_mask, ans_pos = self.prepare_text_from_qa(question_ids, question_tags, question_mask, answer_ids, answer_tags, answer_mask) if self.config.NETWORK.NO_GROUNDING: obj_rep_zeroed = obj_reps['obj_reps'].new_zeros(obj_reps['obj_reps'].shape) text_tags.zero_() text_visual_embeddings = self._collect_obj_reps(text_tags, obj_rep_zeroed) else: text_visual_embeddings = self._collect_obj_reps(text_tags, obj_reps['obj_reps']) assert self.config.NETWORK.VLBERT.object_word_embed_mode == 2 object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long() ) object_vl_embeddings = torch.cat((obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT hidden_states, hc = self.vlbert(text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_all_encoded_layers=False, policy=policy) _batch_inds = torch.arange(question.shape[0], device=question.device) hm = hidden_states[_batch_inds, ans_pos] # hm = F.tanh(self.hm_out(hidden_states[_batch_inds, ans_pos])) # hi = F.tanh(self.hi_out(hidden_states[_batch_inds, ans_pos + 2])) ########################################### outputs = {} # classifier # logits = self.final_mlp(hc * hm * hi) # logits = self.final_mlp(hc) logits = self.final_mlp(hm) # loss ans_loss = F.binary_cross_entropy_with_logits(logits, label) * label.size(1) outputs.update({'label_logits': logits, 'label': label, 'ans_loss': ans_loss}) loss = ans_loss.mean() # check for auxiliary losses if policy is not None: if self.config.USE_CONSTRAIN_K_LOSS: loss_k = constrain_k_loss(policy, self.config.CONSTRAIN_K_NUM_BLOCKS, self.config.CONSTRAIN_K_SCALE) loss += loss_k outputs.update({'loss_k': loss_k}) if self.config.USE_DETERMINISTIC_POLICY_LOSS: loss_d = deterministic_policy_loss(policy, self.config.DETERMINISTIC_POLICY_SCALE) loss += loss_d outputs.update({'loss_d': loss_d}) return outputs, loss def inference_forward(self, image, boxes, im_info, question, policy=None): ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > - 1.5) max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) question_ids = question question_tags = question.new_zeros(question_ids.shape) question_mask = (question > 0.5) answer_ids = question_ids.new_zeros((question_ids.shape[0], 1)).fill_( self.tokenizer.convert_tokens_to_ids(['[MASK]'])[0]) answer_mask = question_mask.new_zeros(answer_ids.shape).fill_(1) answer_tags = question_tags.new_zeros(answer_ids.shape) ############################################ # prepare text text_input_ids, text_token_type_ids, text_tags, text_mask, ans_pos = self.prepare_text_from_qa(question_ids, question_tags, question_mask, answer_ids, answer_tags, answer_mask) if self.config.NETWORK.NO_GROUNDING: obj_rep_zeroed = obj_reps['obj_reps'].new_zeros(obj_reps['obj_reps'].shape) text_tags.zero_() text_visual_embeddings = self._collect_obj_reps(text_tags, obj_rep_zeroed) else: text_visual_embeddings = self._collect_obj_reps(text_tags, obj_reps['obj_reps']) assert self.config.NETWORK.VLBERT.object_word_embed_mode == 2 object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long() ) object_vl_embeddings = torch.cat((obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT hidden_states, hc = self.vlbert(text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_all_encoded_layers=False, policy=policy) _batch_inds = torch.arange(question.shape[0], device=question.device) hm = hidden_states[_batch_inds, ans_pos] # hm = F.tanh(self.hm_out(hidden_states[_batch_inds, ans_pos])) # hi = F.tanh(self.hi_out(hidden_states[_batch_inds, ans_pos + 2])) ########################################### outputs = {} # classifier # logits = self.final_mlp(hc * hm * hi) # logits = self.final_mlp(hc) logits = self.final_mlp(hm) outputs.update({'label_logits': logits}) return outputs
class ResNetVLBERT(Module): def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS self.cnn_loss_top = config.NETWORK.CNN_LOSS_TOP if not config.NETWORK.BLIND: self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=(self.enable_cnn_reg_loss and not self.cnn_loss_top)) if config.NETWORK.VLBERT.object_word_embed_mode == 1: self.object_linguistic_embeddings = nn.Embedding( 81, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 2: self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 3: self.object_linguistic_embeddings = None else: raise NotImplementedError if self.enable_cnn_reg_loss and self.cnn_loss_top: self.cnn_loss_reg = nn.Sequential( VisualLinguisticBertMVRCHeadTransform( config.NETWORK.VLBERT), nn.Dropout(config.NETWORK.CNN_REG_DROPOUT, inplace=False), nn.Linear(config.NETWORK.VLBERT.hidden_size, 81)) self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN if 'roberta' in config.NETWORK.BERT_MODEL_NAME: self.tokenizer = RobertaTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) else: self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = TimeDistributed( VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path)) self.for_pretrain = config.NETWORK.FOR_MASK_VL_MODELING_PRETRAIN assert not self.for_pretrain, "Not implement pretrain mode now!" if not self.for_pretrain: dim = config.NETWORK.VLBERT.hidden_size if config.NETWORK.CLASSIFIER_TYPE == "2fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, 1), ) elif config.NETWORK.CLASSIFIER_TYPE == "1fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, 1)) else: raise ValueError("Not support classifier type: {}!".format( config.NETWORK.CLASSIFIER_TYPE)) # init weights self.init_weight() self.fix_params() def init_weight(self): if not self.config.NETWORK.BLIND: self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_(mean=0.0, std=0.02) if self.enable_cnn_reg_loss and self.cnn_loss_top: self.cnn_loss_reg.apply(self.vlbert._module.init_weights) if not self.for_pretrain: for m in self.final_mlp.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) def train(self, mode=True): super(ResNetVLBERT, self).train(mode) # turn some frozen layers to eval mode if (not self.config.NETWORK.BLIND) and self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): if self.config.NETWORK.BLIND: self.vlbert._module.visual_scale_text.requires_grad = False self.vlbert._module.visual_scale_object.requires_grad = False def _collect_obj_reps(self, span_tags, object_reps): """ Collect span-level object representations :param span_tags: [batch_size, ..leading_dims.., L] :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] :return: """ span_tags_fixed = torch.clamp( span_tags, min=0) # In case there were masked values here row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape) row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None] # Add extra diminsions to the row broadcaster so it matches row_id leading_dims = len(span_tags.shape) - 2 for i in range(leading_dims): row_id_broadcaster = row_id_broadcaster[..., None] row_id += row_id_broadcaster return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view( *span_tags_fixed.shape, -1) def prepare_text_from_qa(self, question, question_tags, question_mask, answers, answers_tags, answers_mask): batch_size, max_q_len = question.shape _, num_choices, max_a_len = answers.shape max_len = (question_mask.sum(1) + answers_mask.sum(2).max(1)[0]).max() + 3 cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) question = question.repeat(1, num_choices).view(-1, num_choices, max_q_len) question_mask = question_mask.repeat(1, num_choices).view( -1, num_choices, max_q_len) q_end = 1 + question_mask.sum(2, keepdim=True) a_end = q_end + 1 + answers_mask.sum(2, keepdim=True) input_ids = torch.zeros((batch_size, num_choices, max_len), dtype=question.dtype, device=question.device) input_mask = torch.ones((batch_size, num_choices, max_len), dtype=torch.uint8, device=question.device) input_type_ids = torch.zeros((batch_size, num_choices, max_len), dtype=question.dtype, device=question.device) text_tags = input_type_ids.new_zeros( (batch_size, num_choices, max_len)) grid_i, grid_j, grid_k = torch.meshgrid( torch.arange(batch_size, device=question.device), torch.arange(num_choices, device=question.device), torch.arange(max_len, device=question.device)) input_mask[grid_k > a_end] = 0 input_type_ids[(grid_k > q_end) & (grid_k <= a_end)] = 1 q_input_mask = (grid_k > 0) & (grid_k < q_end) a_input_mask = (grid_k > q_end) & (grid_k < a_end) input_ids[:, :, 0] = cls_id input_ids[grid_k == q_end] = sep_id input_ids[grid_k == a_end] = sep_id input_ids[q_input_mask] = question[question_mask] input_ids[a_input_mask] = answers[answers_mask] text_tags[q_input_mask] = question_tags[question_mask] text_tags[a_input_mask] = answers_tags[answers_mask] return input_ids, input_type_ids, text_tags, input_mask def prepare_text_from_qa_onesent(self, question, question_tags, question_mask, answers, answers_tags, answers_mask): batch_size, max_q_len = question.shape _, num_choices, max_a_len = answers.shape max_len = (question_mask.sum(1) + answers_mask.sum(2).max(1)[0]).max() + 2 cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) question = question.repeat(1, num_choices).view(-1, num_choices, max_q_len) question_mask = question_mask.repeat(1, num_choices).view( -1, num_choices, max_q_len) q_end = 1 + question_mask.sum(2, keepdim=True) a_end = q_end + answers_mask.sum(2, keepdim=True) input_ids = torch.zeros((batch_size, num_choices, max_len), dtype=question.dtype, device=question.device) input_mask = torch.ones((batch_size, num_choices, max_len), dtype=torch.uint8, device=question.device) input_type_ids = torch.zeros((batch_size, num_choices, max_len), dtype=question.dtype, device=question.device) text_tags = input_type_ids.new_zeros( (batch_size, num_choices, max_len)) grid_i, grid_j, grid_k = torch.meshgrid( torch.arange(batch_size, device=question.device), torch.arange(num_choices, device=question.device), torch.arange(max_len, device=question.device)) input_mask[grid_k > a_end] = 0 q_input_mask = (grid_k > 0) & (grid_k < q_end) a_input_mask = (grid_k >= q_end) & (grid_k < a_end) input_ids[:, :, 0] = cls_id input_ids[grid_k == a_end] = sep_id input_ids[q_input_mask] = question[question_mask] input_ids[a_input_mask] = answers[answers_mask] text_tags[q_input_mask] = question_tags[question_mask] text_tags[a_input_mask] = answers_tags[answers_mask] return input_ids, input_type_ids, text_tags, input_mask def prepare_text_from_aq(self, question, question_tags, question_mask, answers, answers_tags, answers_mask): batch_size, max_q_len = question.shape _, num_choices, max_a_len = answers.shape max_len = (question_mask.sum(1) + answers_mask.sum(2).max(1)[0]).max() + 3 cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) question = question.repeat(1, num_choices).view(-1, num_choices, max_q_len) question_mask = question_mask.repeat(1, num_choices).view( -1, num_choices, max_q_len) a_end = 1 + answers_mask.sum(2, keepdim=True) q_end = a_end + 1 + question_mask.sum(2, keepdim=True) input_ids = torch.zeros((batch_size, num_choices, max_len), dtype=question.dtype, device=question.device) input_mask = torch.ones((batch_size, num_choices, max_len), dtype=torch.uint8, device=question.device) input_type_ids = torch.zeros((batch_size, num_choices, max_len), dtype=question.dtype, device=question.device) text_tags = input_type_ids.new_zeros( (batch_size, num_choices, max_len)) grid_i, grid_j, grid_k = torch.meshgrid( torch.arange(batch_size, device=question.device), torch.arange(num_choices, device=question.device), torch.arange(max_len, device=question.device)) input_mask[grid_k > q_end] = 0 input_type_ids[(grid_k > a_end) & (grid_k <= q_end)] = 1 q_input_mask = (grid_k > a_end) & (grid_k < q_end) a_input_mask = (grid_k > 0) & (grid_k < a_end) input_ids[:, :, 0] = cls_id input_ids[grid_k == a_end] = sep_id input_ids[grid_k == q_end] = sep_id input_ids[q_input_mask] = question[question_mask] input_ids[a_input_mask] = answers[answers_mask] text_tags[q_input_mask] = question_tags[question_mask] text_tags[a_input_mask] = answers_tags[answers_mask] return input_ids, input_type_ids, text_tags, input_mask def train_forward(self, image, boxes, masks, question, question_align_matrix, answer_choices, answer_align_matrix, answer_label, im_info, mask_position=None, mask_type=None, mask_label=None): ########################################### # visual feature extraction images = image objects = boxes[:, :, -1] segms = masks boxes = boxes[:, :, :4] box_mask = (boxes[:, :, -1] > -0.5) max_len = int(box_mask.sum(1).max().item()) objects = objects[:, :max_len] box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] segms = segms[:, :max_len] if self.config.NETWORK.BLIND: obj_reps = { 'obj_reps': boxes.new_zeros( (*boxes.shape[:-1], self.config.NETWORK.IMAGE_FINAL_DIM)) } else: obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=objects, segms=segms) num_choices = answer_choices.shape[1] question_ids = question[:, :, 0] question_tags = question[:, :, 1] question_tags = question_tags.repeat(1, num_choices).view( question_tags.shape[0], num_choices, -1) question_mask = (question[:, :, 0] > 0.5) answer_ids = answer_choices[:, :, :, 0] answer_tags = answer_choices[:, :, :, 1] answer_mask = (answer_choices[:, :, :, 0] > 0.5) ############################################ # prepare text if self.config.NETWORK.ANSWER_FIRST: if self.config.NETWORK.QA_ONE_SENT: raise NotImplemented else: text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text_from_aq( question_ids, question_tags, question_mask, answer_ids, answer_tags, answer_mask) else: if self.config.NETWORK.QA_ONE_SENT: text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text_from_qa_onesent( question_ids, question_tags, question_mask, answer_ids, answer_tags, answer_mask) else: text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text_from_qa( question_ids, question_tags, question_mask, answer_ids, answer_tags, answer_mask) if self.config.NETWORK.NO_GROUNDING: text_tags.zero_() text_visual_embeddings = self._collect_obj_reps( text_tags, obj_reps['obj_reps']) if self.config.NETWORK.BLIND: object_linguistic_embeddings = boxes.new_zeros( (*boxes.shape[:-1], self.config.NETWORK.VLBERT.hidden_size)) object_linguistic_embeddings = object_linguistic_embeddings.unsqueeze( 1).repeat(1, num_choices, 1, 1) else: if self.config.NETWORK.VLBERT.object_word_embed_mode in [1, 2]: object_linguistic_embeddings = self.object_linguistic_embeddings( objects.long().clamp(min=0, max=self.object_linguistic_embeddings. weight.data.shape[0] - 1)) object_linguistic_embeddings = object_linguistic_embeddings.unsqueeze( 1).repeat(1, num_choices, 1, 1) elif self.config.NETWORK.VLBERT.object_word_embed_mode == 3: cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) global_context_mask = text_mask & ( text_input_ids != cls_id) & (text_input_ids != sep_id) word_embedding = self.vlbert._module.word_embeddings( text_input_ids) word_embedding[global_context_mask == 0] = 0 object_linguistic_embeddings = word_embedding.sum( dim=2) / global_context_mask.sum( dim=2, keepdim=True).to(dtype=word_embedding.dtype) object_linguistic_embeddings = object_linguistic_embeddings.unsqueeze( 2).repeat((1, 1, max_len, 1)) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'].unsqueeze(1).repeat( 1, num_choices, 1, 1), object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT if self.config.NETWORK.NO_OBJ_ATTENTION or self.config.NETWORK.BLIND: box_mask.zero_() hidden_states_text, hidden_states_objects, pooled_rep = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask.unsqueeze(1).repeat(1, num_choices, 1), output_all_encoded_layers=False, output_text_and_object_separately=True) ########################################### outputs = {} # classifier logits = self.final_mlp(pooled_rep).squeeze(2) # loss if self.config.NETWORK.CLASSIFIER_SIGMOID: _, choice_ind = torch.meshgrid( torch.arange(logits.shape[0], device=logits.device), torch.arange(num_choices, device=logits.device)) label_binary = (choice_ind == answer_label.unsqueeze(1)) if mask_type is not None and self.config.NETWORK.REPLACE_OBJECT_CHANGE_LABEL: label_binary = label_binary * (mask_type != 1).unsqueeze(1) weight = logits.new_zeros(logits.shape).fill_(1.0) weight[ label_binary == 1] = self.config.NETWORK.CLASSIFIER_SIGMOID_LOSS_POSITIVE_WEIGHT rescale = (self.config.NETWORK.CLASSIFIER_SIGMOID_LOSS_POSITIVE_WEIGHT + 1.0) \ / (2.0 * self.config.NETWORK.CLASSIFIER_SIGMOID_LOSS_POSITIVE_WEIGHT) ans_loss = rescale * F.binary_cross_entropy_with_logits( logits, label_binary.to(dtype=logits.dtype), weight=weight) outputs['positive_fraction'] = label_binary.to( dtype=logits.dtype).sum() / label_binary.numel() else: ans_loss = F.cross_entropy(logits, answer_label.long().view(-1)) outputs.update({ 'label_logits': logits, 'label': answer_label.long().view(-1), 'ans_loss': ans_loss }) loss = ans_loss.mean() * self.config.NETWORK.ANS_LOSS_WEIGHT if mask_position is not None: assert False, "Todo: align to original position." _batch_ind = torch.arange(images.shape[0], dtype=torch.long, device=images.device) mask_pos_rep = hidden_states[_batch_ind, answer_label, mask_position] mask_pred_logits = ( obj_reps['obj_reps'] @ mask_pos_rep.unsqueeze(-1)).squeeze(-1) mask_pred_logits[1 - box_mask] -= 10000.0 mask_object_loss = F.cross_entropy(mask_pred_logits, mask_label, ignore_index=-1) logits_padded = mask_pred_logits.new_zeros( (mask_pred_logits.shape[0], origin_len)).fill_(-10000.0) logits_padded[:, :mask_pred_logits.shape[1]] = mask_pred_logits mask_pred_logits = logits_padded outputs.update({ 'mask_object_loss': mask_object_loss, 'mask_object_logits': mask_pred_logits, 'mask_object_label': mask_label }) loss = loss + mask_object_loss.mean( ) * self.config.NETWORK.MASK_OBJECT_LOSS_WEIGHT if self.enable_cnn_reg_loss: if not self.cnn_loss_top: loss = loss + obj_reps['cnn_regularization_loss'].mean( ) * self.config.NETWORK.CNN_LOSS_WEIGHT outputs['cnn_regularization_loss'] = obj_reps[ 'cnn_regularization_loss'] else: objects = objects.unsqueeze(1).repeat(1, num_choices, 1) box_mask = box_mask.unsqueeze(1).repeat(1, num_choices, 1) cnn_reg_logits = self.cnn_loss_reg( hidden_states_objects[box_mask]) cnn_reg_loss = F.cross_entropy(cnn_reg_logits, objects[box_mask].long()) loss = loss + cnn_reg_loss.mean( ) * self.config.NETWORK.CNN_LOSS_WEIGHT outputs['cnn_regularization_loss'] = cnn_reg_loss return outputs, loss def inference_forward(self, image, boxes, masks, question, question_align_matrix, answer_choices, answer_align_matrix, *args): if self.for_pretrain: answer_label, im_info, mask_position, mask_type = args else: assert len(args) == 1 im_info = args[0] ########################################### # visual feature extraction images = image objects = boxes[:, :, -1] segms = masks boxes = boxes[:, :, :4] box_mask = (boxes[:, :, -1] > -0.5) max_len = int(box_mask.sum(1).max().item()) objects = objects[:, :max_len] box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] segms = segms[:, :max_len] if self.config.NETWORK.BLIND: obj_reps = { 'obj_reps': boxes.new_zeros( (*boxes.shape[:-1], self.config.NETWORK.IMAGE_FINAL_DIM)) } else: obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=objects, segms=segms) num_choices = answer_choices.shape[1] question_ids = question[:, :, 0] question_tags = question[:, :, 1] question_tags = question_tags.repeat(1, num_choices).view( question_tags.shape[0], num_choices, -1) question_mask = (question[:, :, 0] > 0.5) answer_ids = answer_choices[:, :, :, 0] answer_tags = answer_choices[:, :, :, 1] answer_mask = (answer_choices[:, :, :, 0] > 0.5) ############################################ # prepare text if self.config.NETWORK.ANSWER_FIRST: if self.config.NETWORK.QA_ONE_SENT: raise NotImplemented else: text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text_from_aq( question_ids, question_tags, question_mask, answer_ids, answer_tags, answer_mask) else: if self.config.NETWORK.QA_ONE_SENT: text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text_from_qa_onesent( question_ids, question_tags, question_mask, answer_ids, answer_tags, answer_mask) else: text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text_from_qa( question_ids, question_tags, question_mask, answer_ids, answer_tags, answer_mask) if self.config.NETWORK.NO_GROUNDING: text_tags.zero_() text_visual_embeddings = self._collect_obj_reps( text_tags, obj_reps['obj_reps']) if self.config.NETWORK.BLIND: object_linguistic_embeddings = boxes.new_zeros( (*boxes.shape[:-1], self.config.NETWORK.VLBERT.hidden_size)) object_linguistic_embeddings = object_linguistic_embeddings.unsqueeze( 1).repeat(1, num_choices, 1, 1) else: if self.config.NETWORK.VLBERT.object_word_embed_mode in [1, 2]: object_linguistic_embeddings = self.object_linguistic_embeddings( objects.long().clamp(min=0, max=self.object_linguistic_embeddings. weight.data.shape[0] - 1)) object_linguistic_embeddings = object_linguistic_embeddings.unsqueeze( 1).repeat(1, num_choices, 1, 1) elif self.config.NETWORK.VLBERT.object_word_embed_mode == 3: cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) global_context_mask = text_mask & ( text_input_ids != cls_id) & (text_input_ids != sep_id) word_embedding = self.vlbert._module.word_embeddings( text_input_ids) word_embedding[global_context_mask == 0] = 0 object_linguistic_embeddings = word_embedding.sum( dim=2) / global_context_mask.sum( dim=2, keepdim=True).to(dtype=word_embedding.dtype) object_linguistic_embeddings = object_linguistic_embeddings.unsqueeze( 2).repeat((1, 1, max_len, 1)) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'].unsqueeze(1).repeat( 1, num_choices, 1, 1), object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT if self.config.NETWORK.NO_OBJ_ATTENTION or self.config.NETWORK.BLIND: box_mask.zero_() hidden_states_text, hidden_states_objects, pooled_rep = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask.unsqueeze(1).repeat(1, num_choices, 1), output_all_encoded_layers=False, output_text_and_object_separately=True) ########################################### # classifier logits = self.final_mlp(pooled_rep).squeeze(2) outputs = {'label_logits': logits} return outputs
class ResNetVLBERTForPretrainingMultitask(Module): def __init__(self, config): super(ResNetVLBERTForPretrainingMultitask, self).__init__(config) self.image_feature_extractor = FastRCNN(config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=False) self.object_linguistic_embeddings = nn.Embedding(1, config.NETWORK.VLBERT.hidden_size) if config.NETWORK.IMAGE_FEAT_PRECOMPUTED or (not config.NETWORK.MASK_RAW_PIXELS): self.object_mask_visual_embedding = nn.Embedding(1, 2048) if config.NETWORK.WITH_MVRC_LOSS: self.object_mask_word_embedding = nn.Embedding(1, config.NETWORK.VLBERT.hidden_size) self.aux_text_visual_embedding = nn.Embedding(1, config.NETWORK.VLBERT.hidden_size) self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained(config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format(config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path if language_pretrained_model_path is None: print("Warning: no pretrained language model found, training from scratch!!!") self.vlbert = VisualLinguisticBertForPretraining( config.NETWORK.VLBERT, language_pretrained_model_path=None if config.NETWORK.VLBERT.from_scratch else language_pretrained_model_path, with_rel_head=config.NETWORK.WITH_REL_LOSS, with_mlm_head=config.NETWORK.WITH_MLM_LOSS, with_mvrc_head=config.NETWORK.WITH_MVRC_LOSS, ) # init weights self.init_weight() self.fix_params() def init_weight(self): if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED or (not self.config.NETWORK.MASK_RAW_PIXELS): self.object_mask_visual_embedding.weight.data.fill_(0.0) if self.config.NETWORK.WITH_MVRC_LOSS: self.object_mask_word_embedding.weight.data.normal_(mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) self.aux_text_visual_embedding.weight.data.normal_(mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_(mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) def train(self, mode=True): super(ResNetVLBERTForPretrainingMultitask, self).train(mode) # turn some frozen layers to eval mode if self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): pass def _collect_obj_reps(self, span_tags, object_reps): """ Collect span-level object representations :param span_tags: [batch_size, ..leading_dims.., L] :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] :return: """ span_tags_fixed = torch.clamp(span_tags, min=0) # In case there were masked values here row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape) row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None] # Add extra diminsions to the row broadcaster so it matches row_id leading_dims = len(span_tags.shape) - 2 for i in range(leading_dims): row_id_broadcaster = row_id_broadcaster[..., None] row_id += row_id_broadcaster return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view(*span_tags_fixed.shape, -1) def forward(self, image, boxes, im_info, text, relationship_label, mlm_labels, mvrc_ops, mvrc_labels, *aux): # concat aux texts from different dataset assert len(aux) > 0 and len(aux) % 2 == 0 aux_text_list = aux[0::2] aux_text_mlm_labels_list = aux[1::2] num_aux_text = sum([_text.shape[0] for _text in aux_text_list]) max_aux_text_len = max([_text.shape[1] for _text in aux_text_list]) aux_text = aux_text_list[0].new_zeros((num_aux_text, max_aux_text_len)) aux_text_mlm_labels = aux_text_mlm_labels_list[0].new_zeros((num_aux_text, max_aux_text_len)).fill_(-1) _cur = 0 for _text, _mlm_labels in zip(aux_text_list, aux_text_mlm_labels_list): _num = _text.shape[0] aux_text[_cur:(_cur + _num), :_text.shape[1]] = _text aux_text_mlm_labels[_cur:(_cur + _num), :_text.shape[1]] = _mlm_labels _cur += _num ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) origin_len = boxes.shape[1] max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] mvrc_ops = mvrc_ops[:, :max_len] mvrc_labels = mvrc_labels[:, :max_len] if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED: box_features = boxes[:, :, 4:] box_features[mvrc_ops == 1] = self.object_mask_visual_embedding.weight[0] boxes[:, :, 4:] = box_features obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None, mvrc_ops=mvrc_ops, mask_visual_embed=self.object_mask_visual_embedding.weight[0] if (not self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED) and (not self.config.NETWORK.MASK_RAW_PIXELS) else None) ############################################ # prepare text text_input_ids = text text_tags = text.new_zeros(text.shape) text_visual_embeddings = self._collect_obj_reps(text_tags, obj_reps['obj_reps']) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long() ) if self.config.NETWORK.WITH_MVRC_LOSS: object_linguistic_embeddings[mvrc_ops == 1] = self.object_mask_word_embedding.weight[0] object_vl_embeddings = torch.cat((obj_reps['obj_reps'], object_linguistic_embeddings), -1) # add auxiliary text max_text_len = max(text_input_ids.shape[1], aux_text.shape[1]) text_input_ids_multi = text_input_ids.new_zeros((text_input_ids.shape[0] + aux_text.shape[0], max_text_len)) text_input_ids_multi[:text_input_ids.shape[0], :text_input_ids.shape[1]] = text_input_ids text_input_ids_multi[text_input_ids.shape[0]:, :aux_text.shape[1]] = aux_text text_token_type_ids_multi = text_input_ids_multi.new_zeros(text_input_ids_multi.shape) text_mask_multi = (text_input_ids_multi > 0) text_visual_embeddings_multi = text_visual_embeddings.new_zeros((text_input_ids.shape[0] + aux_text.shape[0], max_text_len, text_visual_embeddings.shape[-1])) text_visual_embeddings_multi[:text_visual_embeddings.shape[0], :text_visual_embeddings.shape[1]] \ = text_visual_embeddings text_visual_embeddings_multi[text_visual_embeddings.shape[0]:] = self.aux_text_visual_embedding.weight[0] object_vl_embeddings_multi = object_vl_embeddings.new_zeros((text_input_ids.shape[0] + aux_text.shape[0], *object_vl_embeddings.shape[1:])) object_vl_embeddings_multi[:object_vl_embeddings.shape[0]] = object_vl_embeddings box_mask_multi = box_mask.new_zeros((text_input_ids.shape[0] + aux_text.shape[0], *box_mask.shape[1:])) box_mask_multi[:box_mask.shape[0]] = box_mask ########################################### # Visual Linguistic BERT relationship_logits_multi, mlm_logits_multi, mvrc_logits_multi = self.vlbert(text_input_ids_multi, text_token_type_ids_multi, text_visual_embeddings_multi, text_mask_multi, object_vl_embeddings_multi, box_mask_multi) ########################################### outputs = {} # loss relationship_loss = im_info.new_zeros(()) mlm_loss = im_info.new_zeros(()) mvrc_loss = im_info.new_zeros(()) if self.config.NETWORK.WITH_REL_LOSS: relationship_logits = relationship_logits_multi[:text_input_ids.shape[0]] relationship_loss = F.cross_entropy(relationship_logits, relationship_label) if self.config.NETWORK.WITH_MLM_LOSS: mlm_labels_multi = mlm_labels.new_zeros((text_input_ids.shape[0] + aux_text.shape[0], max_text_len)).fill_( -1) mlm_labels_multi[:text_input_ids.shape[0], :mlm_labels.shape[1]] = mlm_labels mlm_labels_multi[text_input_ids.shape[0]:, :aux_text_mlm_labels.shape[1]] = aux_text_mlm_labels mlm_logits_multi_padded = \ mlm_logits_multi.new_zeros((*mlm_labels_multi.shape, mlm_logits_multi.shape[-1])).fill_(-10000.0) mlm_logits_multi_padded[:, :mlm_logits_multi.shape[1]] = mlm_logits_multi mlm_logits_multi = mlm_logits_multi_padded mlm_logits_wvc = mlm_logits_multi_padded[:text_input_ids.shape[0]] mlm_labels_wvc = mlm_labels_multi[:text_input_ids.shape[0]] mlm_logits_aux = mlm_logits_multi_padded[text_input_ids.shape[0]:] mlm_labels_aux = mlm_labels_multi[text_input_ids.shape[0]:] if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST: mlm_loss_wvc = F.cross_entropy(mlm_logits_wvc.transpose(1, 2), mlm_labels_wvc, ignore_index=-1, reduction='none') num_mlm_wvc = (mlm_labels_wvc != -1).sum(1, keepdim=True).to(dtype=mlm_loss_wvc.dtype) num_has_mlm_wvc = (num_mlm_wvc != 0).sum().to(dtype=mlm_loss_wvc.dtype) mlm_loss_wvc = (mlm_loss_wvc / (num_mlm_wvc + 1e-4)).sum() / (num_has_mlm_wvc + 1e-4) mlm_loss_aux = F.cross_entropy(mlm_logits_aux.transpose(1, 2), mlm_labels_aux, ignore_index=-1, reduction='none') num_mlm_aux = (mlm_labels_aux != -1).sum(1, keepdim=True).to(dtype=mlm_loss_aux.dtype) num_has_mlm_aux = (num_mlm_aux != 0).sum().to(dtype=mlm_loss_aux.dtype) mlm_loss_aux = (mlm_loss_aux / (num_mlm_aux + 1e-4)).sum() / (num_has_mlm_aux + 1e-4) else: # mlm_loss = F.cross_entropy(mlm_logits_multi_padded.view((-1, mlm_logits_multi_padded.shape[-1])), # mlm_labels_multi.view(-1), # ignore_index=-1) mlm_loss_wvc = F.cross_entropy( mlm_logits_wvc.view((-1, mlm_logits_multi_padded.shape[-1])), mlm_labels_wvc.view(-1), ignore_index=-1 ) mlm_loss_aux = F.cross_entropy( mlm_logits_aux.view((-1, mlm_logits_multi_padded.shape[-1])), mlm_labels_aux.view(-1), ignore_index=-1 ) # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), # mvrc_labels.contiguous().view(-1), # ignore_index=-1) if self.config.NETWORK.WITH_MVRC_LOSS: mvrc_logits = mvrc_logits_multi[:mvrc_labels.shape[0], :mvrc_labels.shape[1]] if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]), reduction='none').view(mvrc_logits.shape[:-1]) valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1 mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \ .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4) else: mvrc_loss = soft_cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1])) mvrc_logits_padded = mvrc_logits.new_zeros((mvrc_logits.shape[0], origin_len, mvrc_logits.shape[2])).fill_( -10000.0) mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits mvrc_logits = mvrc_logits_padded mvrc_labels_padded = mvrc_labels.new_zeros((mvrc_labels.shape[0], origin_len, mvrc_labels.shape[2])).fill_( 0.0) mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels mvrc_labels = mvrc_labels_padded outputs.update({ 'relationship_logits': relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None, 'relationship_label': relationship_label if self.config.NETWORK.WITH_REL_LOSS else None, 'mlm_logits_wvc': mlm_logits_wvc if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_label_wvc': mlm_labels_wvc if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_logits_aux': mlm_logits_aux if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_label_aux': mlm_labels_aux if self.config.NETWORK.WITH_MLM_LOSS else None, 'mvrc_logits': mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None, 'mvrc_label': mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None, 'relationship_loss': relationship_loss, 'mlm_loss_wvc': mlm_loss_wvc, 'mlm_loss_aux': mlm_loss_aux, 'mvrc_loss': mvrc_loss, }) loss = relationship_loss.mean() + mlm_loss_wvc.mean() + mlm_loss_aux.mean() + mvrc_loss.mean() return outputs, loss
class ResNetVLBERTForPretrainingNoVision(Module): def __init__(self, config): super(ResNetVLBERTForPretrainingNoVision, self).__init__(config) self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=False) self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) if config.NETWORK.IMAGE_FEAT_PRECOMPUTED: self.object_mask_visual_embedding = nn.Embedding(1, 2048) if config.NETWORK.WITH_MVRC_LOSS: self.object_mask_word_embedding = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBertForPretraining( config.NETWORK.VLBERT, language_pretrained_model_path=None if config.NETWORK.VLBERT.from_scratch else language_pretrained_model_path, with_rel_head=config.NETWORK.WITH_REL_LOSS, with_mlm_head=config.NETWORK.WITH_MLM_LOSS, with_mvrc_head=config.NETWORK.WITH_MVRC_LOSS, ) # init weights self.init_weight() self.fix_params() def init_weight(self): if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED: self.object_mask_visual_embedding.weight.data.fill_(0.0) if self.config.NETWORK.WITH_MVRC_LOSS: self.object_mask_word_embedding.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) def train(self, mode=True): super(ResNetVLBERTForPretrainingNoVision, self).train(mode) # turn some frozen layers to eval mode if self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): pass def _collect_obj_reps(self, span_tags, object_reps): """ Collect span-level object representations :param span_tags: [batch_size, ..leading_dims.., L] :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] :return: """ span_tags_fixed = torch.clamp( span_tags, min=0) # In case there were masked values here row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape) row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None] # Add extra diminsions to the row broadcaster so it matches row_id leading_dims = len(span_tags.shape) - 2 for i in range(leading_dims): row_id_broadcaster = row_id_broadcaster[..., None] row_id += row_id_broadcaster return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view( *span_tags_fixed.shape, -1) def forward(self, text, relationship_label, mlm_labels): ########################################### # Blank out visual feature extraction ############################################ # prepare text text_input_ids = text # creates a text_tags tensor of the same shape as text tensor text_tags = text.new_zeros(text.shape) # ***** FM edit: blank out visual embeddings for translation retrieval task text_visual_embeddings = text_input_ids.new_zeros( (text_input_ids.shape[0], text_input_ids.shape[1], 768), dtype=torch.float) # text_visual_embeddings[:] = self.aux_text_visual_embedding.weight[0] # ****** FM edit: blank visual embeddings (use known dimensions) object_vl_embeddings = text_input_ids.new_zeros( (text_input_ids.shape[0], 1, 1536), dtype=torch.float) # FM edit: No auxiliary text is used for text only # add auxiliary text - Concatenates the batches from the two dataloaders # The visual features for the text only corpus is just the embedding of the aux_visual_embedding (only one embedding) max_text_len = text_input_ids.shape[1] text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape) text_mask = (text_input_ids > 0) #FM: Edit: set to zero to ignore vision box_mask = text_input_ids.new_zeros((text_input_ids.shape[0], 1), dtype=torch.uint8) ########################################### # Visual Linguistic BERT relationship_logits, mlm_logits, mvrc_logits = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask) ########################################### outputs = {} # losses if self.config.NETWORK.WITH_REL_LOSS: relationship_loss = F.cross_entropy(relationship_logits, relationship_label) if self.config.NETWORK.WITH_MLM_LOSS: mlm_logits_padded = mlm_logits.new_zeros( (*mlm_labels.shape, mlm_logits.shape[-1])).fill_(-10000.0) mlm_logits_padded[:, :mlm_logits.shape[1]] = mlm_logits mlm_logits = mlm_logits_padded if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST: mlm_loss = F.cross_entropy(mlm_logits.transpose(1, 2), mlm_labels, ignore_index=-1, reduction='none') num_mlm = (mlm_labels != -1).sum( 1, keepdim=True).to(dtype=mlm_loss.dtype) num_has_mlm = (num_mlm != 0).sum().to(dtype=mlm_loss.dtype) mlm_loss = (mlm_loss / (num_mlm + 1e-4)).sum() / (num_has_mlm + 1e-4) else: mlm_loss = F.cross_entropy(mlm_logits.view( (-1, mlm_logits.shape[-1])), mlm_labels.view(-1), ignore_index=-1) if self.config.NETWORK.WITH_MVRC_LOSS: if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]), reduction='none').view(mvrc_logits.shape[:-1]) valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1 mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \ .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4) else: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1])) mvrc_logits_padded = mvrc_logits.new_zeros( (mvrc_logits.shape[0], origin_len, mvrc_logits.shape[2])).fill_(-10000.0) mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits mvrc_logits = mvrc_logits_padded mvrc_labels_padded = mvrc_labels.new_zeros( (mvrc_labels.shape[0], origin_len, mvrc_labels.shape[2])).fill_(0.0) mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels mvrc_labels = mvrc_labels_padded outputs.update({ 'relationship_logits': relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None, 'relationship_label': relationship_label if self.config.NETWORK.WITH_REL_LOSS else None, 'mlm_logits': mlm_logits if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_label': mlm_labels if self.config.NETWORK.WITH_MLM_LOSS else None, 'mvrc_logits': mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None, 'mvrc_label': mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None, 'mlm_loss': mlm_loss, }) loss = mlm_loss.mean() return outputs, loss
class ResNetVLBERTDistanceTranslationWithVision(Module): def __init__(self, config): super(ResNetVLBERTDistanceTranslationWithVision, self).__init__(config) # Constructs/initialises model elements self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=False) self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) if config.NETWORK.IMAGE_FEAT_PRECOMPUTED or ( not config.NETWORK.MASK_RAW_PIXELS): self.object_mask_visual_embedding = nn.Embedding(1, 2048) if config.NETWORK.WITH_MVRC_LOSS: self.object_mask_word_embedding = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.aux_text_visual_embedding = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) # Can specify pre-trained model or use the downloaded pretrained model specific in .yaml file language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': # language_pretrained_model_path = '{}-{:04d}.model'.format(config.NETWORK.BERT_PRETRAINED, # config.NETWORK.BERT_PRETRAINED_EPOCH) #FM edit: just use path of pretrained model language_pretrained_model_path = config.NETWORK.BERT_PRETRAINED elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBertForDistance( config.NETWORK.VLBERT, language_pretrained_model_path=None if config.NETWORK.VLBERT.from_scratch else language_pretrained_model_path, with_rel_head=config.NETWORK.WITH_REL_LOSS, with_mlm_head=config.NETWORK.WITH_MLM_LOSS, with_mvrc_head=config.NETWORK.WITH_MVRC_LOSS, ) # init weights self.init_weight() self.fix_params() def init_weight(self): if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED or ( not self.config.NETWORK.MASK_RAW_PIXELS): self.object_mask_visual_embedding.weight.data.fill_(0.0) if self.config.NETWORK.WITH_MVRC_LOSS: self.object_mask_word_embedding.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) self.aux_text_visual_embedding.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) def train(self, mode=True): super(ResNetVLBERTDistanceTranslationWithVision, self).train(mode) # turn some frozen layers to eval mode if self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): pass def _collect_obj_reps(self, span_tags, object_reps): """ Collect span-level object representations :param span_tags: [batch_size, ..leading_dims.., L] :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] :return: """ span_tags_fixed = torch.clamp( span_tags, min=0) # In case there were masked values here row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape) row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None] # Add extra diminsions to the row broadcaster so it matches row_id leading_dims = len(span_tags.shape) - 2 for i in range(leading_dims): row_id_broadcaster = row_id_broadcaster[..., None] row_id += row_id_broadcaster return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view( *span_tags_fixed.shape, -1) def forward(self, image, boxes, im_info, text, relationship_label, mlm_labels, mvrc_ops, mvrc_labels): ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) origin_len = boxes.shape[1] max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] mvrc_ops = mvrc_ops[:, :max_len] mvrc_labels = mvrc_labels[:, :max_len] if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED: box_features = boxes[:, :, 4:] box_features[mvrc_ops == 1] = self.object_mask_visual_embedding.weight[0] boxes[:, :, 4:] = box_features obj_reps = self.image_feature_extractor( images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None, mvrc_ops=mvrc_ops, mask_visual_embed=self.object_mask_visual_embedding.weight[0] if (not self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED) and (not self.config.NETWORK.MASK_RAW_PIXELS) else None) ############################################ # prepare text text_input_ids = text # creates a text_tags tensor of the same shape as text tensor text_tags = text.new_zeros(text.shape) text_visual_embeddings = self._collect_obj_reps( text_tags, obj_reps['obj_reps']) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) if self.config.NETWORK.WITH_MVRC_LOSS: object_linguistic_embeddings[ mvrc_ops == 1] = self.object_mask_word_embedding.weight[0] object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) # FM edit: No auxiliary text is used for text only # add auxiliary text - Concatenates the batches from the two dataloaders # The visual features for the text only corpus is just the embedding of the aux_visual_embedding (only one embedding) max_text_len = text_input_ids.shape[1] text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape) text_mask = (text_input_ids > 0) #FM: Edit: i have taken this out, not needed i think since defined above # box_mask = box_mask.new_zeros((text_input_ids.shape[0], *box_mask.shape[1:])) ########################################### # Visual Linguistic BERT relationship_logits_multi, mlm_logits_multi, mvrc_logits_multi, pooled_rep, text_out = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask) ########################################### outputs = {} # FM edit: removed other two losses that are not defined outputs.update({'cls_output': text_out[:, 0, :]}) # FM edit: removed addition of other losses which are not defined loss = 0 return outputs, loss
class ResNetVLBERTForPretrainingGenerate(Module): def __init__(self, config): super(ResNetVLBERTForPretrainingGenerate, self).__init__(config) self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=False) self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) if config.NETWORK.IMAGE_FEAT_PRECOMPUTED: self.object_mask_visual_embedding = nn.Embedding(1, 2048) if config.NETWORK.WITH_MVRC_LOSS: self.object_mask_word_embedding = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBertForPretraining( config.NETWORK.VLBERT, language_pretrained_model_path=None if config.NETWORK.VLBERT.from_scratch else language_pretrained_model_path, with_rel_head=config.NETWORK.WITH_REL_LOSS, with_mlm_head=config.NETWORK.WITH_MLM_LOSS, with_mvrc_head=config.NETWORK.WITH_MVRC_LOSS, ) # init weights self.init_weight() self.fix_params() def init_weight(self): if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED: self.object_mask_visual_embedding.weight.data.fill_(0.0) if self.config.NETWORK.WITH_MVRC_LOSS: self.object_mask_word_embedding.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) def train(self, mode=True): super(ResNetVLBERTForPretrainingGenerate, self).train(mode) # turn some frozen layers to eval mode if self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): pass def _collect_obj_reps(self, span_tags, object_reps): """ Collect span-level object representations :param span_tags: [batch_size, ..leading_dims.., L] :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] :return: """ span_tags_fixed = torch.clamp( span_tags, min=0) # In case there were masked values here row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape) row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None] # Add extra diminsions to the row broadcaster so it matches row_id leading_dims = len(span_tags.shape) - 2 for i in range(leading_dims): row_id_broadcaster = row_id_broadcaster[..., None] row_id += row_id_broadcaster return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view( *span_tags_fixed.shape, -1) def forward(self, image, boxes, im_info, text, relationship_label, mlm_labels, mvrc_ops, mvrc_labels): ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) origin_len = boxes.shape[1] max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] mvrc_ops = mvrc_ops[:, :max_len] mvrc_labels = mvrc_labels[:, :max_len] if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED: box_features = boxes[:, :, 4:] box_features[mvrc_ops == 1] = self.object_mask_visual_embedding.weight[0] boxes[:, :, 4:] = box_features obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None, mvrc_ops=mvrc_ops, mask_visual_embed=None) ############################################ # prepare text text_input_ids = text text_tags = text.new_zeros(text.shape) text_token_type_ids = text.new_zeros(text.shape) text_mask = (text_input_ids > 0) text_visual_embeddings = self._collect_obj_reps( text_tags, obj_reps['obj_reps']) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) if self.config.NETWORK.WITH_MVRC_LOSS: object_linguistic_embeddings[ mvrc_ops == 1] = self.object_mask_word_embedding.weight[0] object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT # #loop here for test mode: generated = [] stop = [False] * text.shape[0] curr_len = 0 max_len = 48 while not all(stop) and curr_len <= max_len: relationship_logits, mlm_logits, mvrc_logits = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask) answers = torch.topk(mlm_logits[mlm_labels == 103], k=1, dim=1) # Get size of each tensor position_tensor = torch.arange(mlm_labels.shape[1]) position_tensor = position_tensor.repeat(mlm_labels.shape[0]).view( mlm_labels.shape[0], -1) indeces = position_tensor[mlm_labels == 103] # 1. Update mlm_labels: mlm_labels_new = mlm_labels.new_zeros(mlm_labels.shape[0], mlm_labels.shape[1] + 1) mlm_labels_new = mlm_labels_new - 1 mlm_labels_new[torch.arange(mlm_labels.shape[0]), indeces + 1] = 103 mlm_labels = mlm_labels_new # 2. Update text_input_ids: text_input_ids_new = text_input_ids.new_zeros( text_input_ids.shape[0], text_input_ids.shape[1] + 1) text_input_ids_new[:, :-1] = text_input_ids text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces] = answers[1][:, 0] text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces + 1] = (self.tokenizer.convert_tokens_to_ids( ['[MASK]'])[0]) text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces + 2] = (self.tokenizer.convert_tokens_to_ids( ['[PAD]'])[0]) text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces + 3] = (self.tokenizer.convert_tokens_to_ids( ['[SEP]'])[0]) text_input_ids = text_input_ids_new # 3. Update text_token_type_ids: text_token_type_ids = text_token_type_ids.new_zeros( text_token_type_ids.shape[0], text_token_type_ids.shape[1] + 1) # 4. Update text_input_ids: text_visual_embeddings_new = text_visual_embeddings.new_zeros( text_visual_embeddings.shape[0], text_visual_embeddings.shape[1] + 1, text_visual_embeddings.shape[2]) text_visual_embeddings_new = text_visual_embeddings_new.transpose( 0, 1) text_visual_embeddings_new[:] = text_visual_embeddings[:, 0, :] text_visual_embeddings = text_visual_embeddings_new.transpose(0, 1) # 5. Update text_mask: text_mask = (text_input_ids > 0) # 6. Append generated words from each sentence in the batch to list - terminate if all [STOP] for nid, row in enumerate(answers[1]): if curr_len == 0: generated.append([]) for ele in row: # try: if not stop[nid]: if self.tokenizer.ids_to_tokens[ ele.item()] == '[STOP]': stop[nid] = True else: # print('generated: ', ele.item()) generated[nid].append( self.tokenizer.ids_to_tokens[ele.item()]) # except: # generated[nid].append(self.tokenizer.ids_to_tokens[100]) curr_len += 1 # Join in sentences generated_sentences = [] for sentence in generated: new_sentence = ' '.join(sentence) generated_sentences.append(new_sentence.replace(' ##', '')) # print(generated_sentences) # exit() ########################################### outputs = {} # loss relationship_loss = im_info.new_zeros(()) mlm_loss = im_info.new_zeros(()) mvrc_loss = im_info.new_zeros(()) if self.config.NETWORK.WITH_REL_LOSS: relationship_loss = F.cross_entropy(relationship_logits, relationship_label) if self.config.NETWORK.WITH_MLM_LOSS: mlm_logits_padded = mlm_logits.new_zeros( (*mlm_labels.shape, mlm_logits.shape[-1])).fill_(-10000.0) mlm_logits_padded[:, :mlm_logits.shape[1]] = mlm_logits mlm_logits = mlm_logits_padded if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST: mlm_loss = F.cross_entropy(mlm_logits.transpose(1, 2), mlm_labels, ignore_index=-1, reduction='none') num_mlm = (mlm_labels != -1).sum( 1, keepdim=True).to(dtype=mlm_loss.dtype) num_has_mlm = (num_mlm != 0).sum().to(dtype=mlm_loss.dtype) mlm_loss = (mlm_loss / (num_mlm + 1e-4)).sum() / (num_has_mlm + 1e-4) else: mlm_loss = F.cross_entropy(mlm_logits.view( (-1, mlm_logits.shape[-1])), mlm_labels.view(-1), ignore_index=-1) # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), # mvrc_labels.contiguous().view(-1), # ignore_index=-1) if self.config.NETWORK.WITH_MVRC_LOSS: if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]), reduction='none').view(mvrc_logits.shape[:-1]) valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1 mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \ .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4) else: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1])) mvrc_logits_padded = mvrc_logits.new_zeros( (mvrc_logits.shape[0], origin_len, mvrc_logits.shape[2])).fill_(-10000.0) mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits mvrc_logits = mvrc_logits_padded mvrc_labels_padded = mvrc_labels.new_zeros( (mvrc_labels.shape[0], origin_len, mvrc_labels.shape[2])).fill_(0.0) mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels mvrc_labels = mvrc_labels_padded outputs.update({ 'relationship_logits': relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None, 'relationship_label': relationship_label if self.config.NETWORK.WITH_REL_LOSS else None, 'mlm_logits': mlm_logits if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_label': mlm_labels if self.config.NETWORK.WITH_MLM_LOSS else None, 'mvrc_logits': mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None, 'mvrc_label': mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None, 'relationship_loss': relationship_loss, 'mlm_loss': mlm_loss, 'mvrc_loss': mvrc_loss, 'generated_sentences': generated_sentences }) loss = relationship_loss.mean() + mlm_loss.mean() + mvrc_loss.mean() return outputs, loss
class ResNetVLBERTForPretrainingMultitaskNoVision(Module): def __init__(self, config): super(ResNetVLBERTForPretrainingMultitaskNoVision, self).__init__(config) # Constructs/initialises model elements self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=False) self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) if config.NETWORK.IMAGE_FEAT_PRECOMPUTED or ( not config.NETWORK.MASK_RAW_PIXELS): self.object_mask_visual_embedding = nn.Embedding(1, 2048) if config.NETWORK.WITH_MVRC_LOSS: self.object_mask_word_embedding = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.aux_text_visual_embedding = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) # Can specify pre-trained model or use the downloaded pretrained model specific in .yaml file language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': # language_pretrained_model_path = '{}-{:04d}.model'.format(config.NETWORK.BERT_PRETRAINED, # config.NETWORK.BERT_PRETRAINED_EPOCH) #FM edit: just use path of pretrained model language_pretrained_model_path = config.NETWORK.BERT_PRETRAINED elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBertForPretraining( config.NETWORK.VLBERT, language_pretrained_model_path=None if config.NETWORK.VLBERT.from_scratch else language_pretrained_model_path, with_rel_head=config.NETWORK.WITH_REL_LOSS, with_mlm_head=config.NETWORK.WITH_MLM_LOSS, with_mvrc_head=config.NETWORK.WITH_MVRC_LOSS, with_MLT_head=config.NETWORK.WITH_MLT_LOSS) # init weights self.init_weight() self.fix_params() def init_weight(self): if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED or ( not self.config.NETWORK.MASK_RAW_PIXELS): self.object_mask_visual_embedding.weight.data.fill_(0.0) if self.config.NETWORK.WITH_MVRC_LOSS: self.object_mask_word_embedding.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) self.aux_text_visual_embedding.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) def train(self, mode=True): super(ResNetVLBERTForPretrainingMultitaskNoVision, self).train(mode) # turn some frozen layers to eval mode if self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): pass def _collect_obj_reps(self, span_tags, object_reps): """ Collect span-level object representations :param span_tags: [batch_size, ..leading_dims.., L] :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] :return: """ span_tags_fixed = torch.clamp( span_tags, min=0) # In case there were masked values here row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape) row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None] # Add extra diminsions to the row broadcaster so it matches row_id leading_dims = len(span_tags.shape) - 2 for i in range(leading_dims): row_id_broadcaster = row_id_broadcaster[..., None] row_id += row_id_broadcaster return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view( *span_tags_fixed.shape, -1) def forward(self, image, boxes, im_info, text, relationship_label, mlm_labels, mvrc_ops, mvrc_labels, word_de_ids): ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) origin_len = boxes.shape[1] max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] mvrc_ops = mvrc_ops[:, :max_len] mvrc_labels = mvrc_labels[:, :max_len] if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED: box_features = boxes[:, :, 4:] box_features[mvrc_ops == 1] = self.object_mask_visual_embedding.weight[0] boxes[:, :, 4:] = box_features obj_reps = self.image_feature_extractor( images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None, mvrc_ops=mvrc_ops, mask_visual_embed=self.object_mask_visual_embedding.weight[0] if (not self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED) and (not self.config.NETWORK.MASK_RAW_PIXELS) else None) ############################################ # prepare text text_input_ids = text # creates a text_tags tensor of the same shape as text tensor text_tags = text.new_zeros(text.shape) text_visual_embeddings = self._collect_obj_reps( text_tags, obj_reps['obj_reps']) # ***** FM edit: blank out visual embeddings for translation retrieval task text_visual_embeddings[:] = self.aux_text_visual_embedding.weight[0] object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) if self.config.NETWORK.WITH_MVRC_LOSS: object_linguistic_embeddings[ mvrc_ops == 1] = self.object_mask_word_embedding.weight[0] object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) # ****** FM edit: blank out all visual embeddings object_vl_embeddings = object_vl_embeddings.new_zeros( object_vl_embeddings.shape) # FM edit: No auxiliary text is used for text only # add auxiliary text - Concatenates the batches from the two dataloaders # The visual features for the text only corpus is just the embedding of the aux_visual_embedding (only one embedding) max_text_len = text_input_ids.shape[1] text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape) text_mask = (text_input_ids > 0) #FM: Edit: i have taken this out, not needed i think since defined above # box_mask = box_mask.new_zeros((text_input_ids.shape[0], *box_mask.shape[1:])) ########################################### # Visual Linguistic BERT relationship_logits_multi, mlm_logits_multi, mvrc_logits_multi, MLT_logits = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask) ########################################### outputs = {} # loss relationship_loss = im_info.new_zeros(()) mlm_loss = im_info.new_zeros(()) mvrc_loss = im_info.new_zeros(()) MLT_loss = im_info.new_zeros(()) if self.config.NETWORK.WITH_REL_LOSS: relationship_logits = relationship_logits_multi[:text_input_ids. shape[0]] # FM edit - change cross_entropy to bce/sigmoid relationship_loss = F.binary_cross_entropy( torch.sigmoid(relationship_logits), relationship_label.unsqueeze(1)) if self.config.NETWORK.WITH_MLM_LOSS: mlm_labels_multi = mlm_labels.new_zeros( (text_input_ids.shape[0] + aux_text.shape[0], max_text_len)).fill_(-1) mlm_labels_multi[:text_input_ids.shape[0], :mlm_labels. shape[1]] = mlm_labels mlm_labels_multi[text_input_ids.shape[0]:, :aux_text_mlm_labels. shape[1]] = aux_text_mlm_labels mlm_logits_multi_padded = \ mlm_logits_multi.new_zeros((*mlm_labels_multi.shape, mlm_logits_multi.shape[-1])).fill_(-10000.0) mlm_logits_multi_padded[:, :mlm_logits_multi. shape[1]] = mlm_logits_multi mlm_logits_multi = mlm_logits_multi_padded mlm_logits_wvc = mlm_logits_multi_padded[:text_input_ids.shape[0]] mlm_labels_wvc = mlm_labels_multi[:text_input_ids.shape[0]] mlm_logits_aux = mlm_logits_multi_padded[text_input_ids.shape[0]:] mlm_labels_aux = mlm_labels_multi[text_input_ids.shape[0]:] if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST: mlm_loss_wvc = F.cross_entropy(mlm_logits_wvc.transpose(1, 2), mlm_labels_wvc, ignore_index=-1, reduction='none') num_mlm_wvc = (mlm_labels_wvc != -1).sum( 1, keepdim=True).to(dtype=mlm_loss_wvc.dtype) num_has_mlm_wvc = (num_mlm_wvc != 0).sum().to( dtype=mlm_loss_wvc.dtype) mlm_loss_wvc = (mlm_loss_wvc / (num_mlm_wvc + 1e-4)).sum() / ( num_has_mlm_wvc + 1e-4) mlm_loss_aux = F.cross_entropy(mlm_logits_aux.transpose(1, 2), mlm_labels_aux, ignore_index=-1, reduction='none') num_mlm_aux = (mlm_labels_aux != -1).sum( 1, keepdim=True).to(dtype=mlm_loss_aux.dtype) num_has_mlm_aux = (num_mlm_aux != 0).sum().to( dtype=mlm_loss_aux.dtype) mlm_loss_aux = (mlm_loss_aux / (num_mlm_aux + 1e-4)).sum() / ( num_has_mlm_aux + 1e-4) else: # mlm_loss = F.cross_entropy(mlm_logits_multi_padded.view((-1, mlm_logits_multi_padded.shape[-1])), # mlm_labels_multi.view(-1), # ignore_index=-1) mlm_loss_wvc = F.cross_entropy(mlm_logits_wvc.view( (-1, mlm_logits_multi_padded.shape[-1])), mlm_labels_wvc.view(-1), ignore_index=-1) mlm_loss_aux = F.cross_entropy(mlm_logits_aux.view( (-1, mlm_logits_multi_padded.shape[-1])), mlm_labels_aux.view(-1), ignore_index=-1) # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), # mvrc_labels.contiguous().view(-1), # ignore_index=-1) if self.config.NETWORK.WITH_MVRC_LOSS: mvrc_logits = mvrc_logits_multi[:mvrc_labels. shape[0], :mvrc_labels.shape[1]] if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]), reduction='none').view(mvrc_logits.shape[:-1]) valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1 mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \ .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4) else: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1])) mvrc_logits_padded = mvrc_logits.new_zeros( (mvrc_logits.shape[0], origin_len, mvrc_logits.shape[2])).fill_(-10000.0) mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits mvrc_logits = mvrc_logits_padded mvrc_labels_padded = mvrc_labels.new_zeros( (mvrc_labels.shape[0], origin_len, mvrc_labels.shape[2])).fill_(0.0) mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels mvrc_labels = mvrc_labels_padded # MLT loss applied if self.config.NETWORK.WITH_MLT_LOSS: MLT_loss = F.cross_entropy(MLT_logits, word_de_ids) # FM edit: removed other two losses that are not defined outputs.update({ 'relationship_logits': relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None, 'relationship_label': relationship_label if self.config.NETWORK.WITH_REL_LOSS else None, 'mlm_logits_wvc': mlm_logits_wvc if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_label_wvc': mlm_labels_wvc if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_logits_aux': mlm_logits_aux if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_label_aux': mlm_labels_aux if self.config.NETWORK.WITH_MLM_LOSS else None, 'mvrc_logits': mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None, 'mvrc_label': mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None, 'MLT_logits': MLT_logits if self.config.NETWORK.WITH_MLT_LOSS else None, 'MLT_label': word_de_ids if self.config.NETWORK.WITH_MLT_LOSS else None, 'MLT_loss': MLT_loss, }) # FM edit: removed addition of other losses which are not defined loss = MLT_loss.mean() return outputs, loss
class ResNetVLBERTForAttentionVis(Module): def __init__(self, config): super(ResNetVLBERTForAttentionVis, self).__init__(config) self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=False) self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) if config.NETWORK.IMAGE_FEAT_PRECOMPUTED or ( not config.NETWORK.MASK_RAW_PIXELS): self.object_mask_visual_embedding = nn.Embedding(1, 2048) if config.NETWORK.WITH_MVRC_LOSS: self.object_mask_word_embedding = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.aux_text_visual_embedding = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=None if config.NETWORK.VLBERT.from_scratch else language_pretrained_model_path) # init weights self.init_weight() self.fix_params() def init_weight(self): if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED or ( not self.config.NETWORK.MASK_RAW_PIXELS): self.object_mask_visual_embedding.weight.data.fill_(0.0) if self.config.NETWORK.WITH_MVRC_LOSS: self.object_mask_word_embedding.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) self.aux_text_visual_embedding.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) def train(self, mode=True): super(ResNetVLBERTForAttentionVis, self).train(mode) # turn some frozen layers to eval mode if self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): pass def _collect_obj_reps(self, span_tags, object_reps): """ Collect span-level object representations :param span_tags: [batch_size, ..leading_dims.., L] :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] :return: """ span_tags_fixed = torch.clamp( span_tags, min=0) # In case there were masked values here row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape) row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None] # Add extra diminsions to the row broadcaster so it matches row_id leading_dims = len(span_tags.shape) - 2 for i in range(leading_dims): row_id_broadcaster = row_id_broadcaster[..., None] row_id += row_id_broadcaster return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view( *span_tags_fixed.shape, -1) def forward(self, image, boxes, im_info, text, relationship_label, mlm_labels, mvrc_ops, mvrc_labels, *aux): # concat aux texts from different dataset # assert len(aux) > 0 and len(aux) % 2 == 0 aux_text_list = aux[0::2] aux_text_mlm_labels_list = aux[1::2] num_aux_text = sum([_text.shape[0] for _text in aux_text_list]) max_aux_text_len = max([_text.shape[1] for _text in aux_text_list ]) if len(aux_text_list) > 0 else 0 aux_text = text.new_zeros((num_aux_text, max_aux_text_len)) aux_text_mlm_labels = mlm_labels.new_zeros( (num_aux_text, max_aux_text_len)).fill_(-1) _cur = 0 for _text, _mlm_labels in zip(aux_text_list, aux_text_mlm_labels_list): _num = _text.shape[0] aux_text[_cur:(_cur + _num), :_text.shape[1]] = _text aux_text_mlm_labels[_cur:(_cur + _num), :_text.shape[1]] = _mlm_labels _cur += _num ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) origin_len = boxes.shape[1] max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] mvrc_ops = mvrc_ops[:, :max_len] mvrc_labels = mvrc_labels[:, :max_len] if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED: box_features = boxes[:, :, 4:] box_features[mvrc_ops == 1] = self.object_mask_visual_embedding.weight[0] boxes[:, :, 4:] = box_features obj_reps = self.image_feature_extractor( images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None, mvrc_ops=mvrc_ops, mask_visual_embed=self.object_mask_visual_embedding.weight[0] if (not self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED) and (not self.config.NETWORK.MASK_RAW_PIXELS) else None) ############################################ # prepare text text_input_ids = text text_tags = text.new_zeros(text.shape) text_visual_embeddings = self._collect_obj_reps( text_tags, obj_reps['obj_reps']) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) if self.config.NETWORK.WITH_MVRC_LOSS: object_linguistic_embeddings[ mvrc_ops == 1] = self.object_mask_word_embedding.weight[0] object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) # add auxiliary text max_text_len = max(text_input_ids.shape[1], aux_text.shape[1]) text_input_ids_multi = text_input_ids.new_zeros( (text_input_ids.shape[0] + aux_text.shape[0], max_text_len)) text_input_ids_multi[:text_input_ids.shape[0], :text_input_ids. shape[1]] = text_input_ids text_input_ids_multi[ text_input_ids.shape[0]:, :aux_text.shape[1]] = aux_text text_token_type_ids_multi = text_input_ids_multi.new_zeros( text_input_ids_multi.shape) text_mask_multi = (text_input_ids_multi > 0) text_visual_embeddings_multi = text_visual_embeddings.new_zeros( (text_input_ids.shape[0] + aux_text.shape[0], max_text_len, text_visual_embeddings.shape[-1])) text_visual_embeddings_multi[:text_visual_embeddings.shape[0], :text_visual_embeddings.shape[1]] \ = text_visual_embeddings text_visual_embeddings_multi[ text_visual_embeddings. shape[0]:] = self.aux_text_visual_embedding.weight[0] object_vl_embeddings_multi = object_vl_embeddings.new_zeros( (text_input_ids.shape[0] + aux_text.shape[0], *object_vl_embeddings.shape[1:])) object_vl_embeddings_multi[:object_vl_embeddings. shape[0]] = object_vl_embeddings box_mask_multi = box_mask.new_zeros( (text_input_ids.shape[0] + aux_text.shape[0], *box_mask.shape[1:])) box_mask_multi[:box_mask.shape[0]] = box_mask ########################################### # Visual Linguistic BERT encoder_layers, _, attention_probs = self.vlbert( text_input_ids_multi, text_token_type_ids_multi, text_visual_embeddings_multi, text_mask_multi, object_vl_embeddings_multi, box_mask_multi, output_all_encoded_layers=True, output_attention_probs=True) hidden_states = torch.stack(encoder_layers, dim=0).transpose(0, 1).contiguous() attention_probs = torch.stack(attention_probs, dim=0).transpose(0, 1).contiguous() return { 'attention_probs': attention_probs, 'hidden_states': hidden_states }
class ResNetVLBERT(Module): def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS self.cnn_loss_top = config.NETWORK.CNN_LOSS_TOP self.align_caption_img = config.DATASET.ALIGN_CAPTION_IMG self.use_phrasal_paraphrases = config.DATASET.PHRASE_CLS self.supervise_attention = config.NETWORK.SUPERVISE_ATTENTION self.normalization = config.NETWORK.ATTENTION_NORM_METHOD self.ewc_reg = config.NETWORK.EWC_REG self.importance_hparam = 0. if config.NETWORK.EWC_REG: self.fisher = pickle.load(open(config.NETWORK.FISHER_PATH, "rb")) self.pretrain_param = torch.load(config.NETWORK.PARAM_PRETRAIN) self.importance_hparam = config.NETWORK.EWC_IMPORTANCE if not config.NETWORK.BLIND: self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=(self.enable_cnn_reg_loss and not self.cnn_loss_top)) if config.NETWORK.VLBERT.object_word_embed_mode == 1: self.object_linguistic_embeddings = nn.Embedding( 81, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 2: self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 3: self.object_linguistic_embeddings = None else: raise NotImplementedError self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN if 'roberta' in config.NETWORK.BERT_MODEL_NAME: self.tokenizer = RobertaTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) else: self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path) self.for_pretrain = False dim = config.NETWORK.VLBERT.hidden_size if self.align_caption_img: sentence_logits_shape = 3 else: sentence_logits_shape = 1 if config.NETWORK.SENTENCE.CLASSIFIER_TYPE == "2fc": self.sentence_cls = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear( dim, config.NETWORK.SENTENCE.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(config.NETWORK.SENTENCE.CLASSIFIER_HIDDEN_SIZE, sentence_logits_shape), ) elif config.NETWORK.SENTENCE.CLASSIFIER_TYPE == "1fc": self.sentence_cls = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, sentence_logits_shape)) else: raise ValueError("Classifier type: {} not supported!".format( config.NETWORK.SENTENCE.CLASSIFIER_TYPE)) if self.use_phrasal_paraphrases: if config.NETWORK.PHRASE.CLASSIFIER_TYPE == "2fc": self.phrasal_cls = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.PHRASE.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear( 4 * dim, config.NETWORK.PHRASE.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.PHRASE.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear( config.NETWORK.PHRASE.CLASSIFIER_HIDDEN_SIZE, 5), ) elif config.NETWORK.PHRASE.CLASSIFIER_TYPE == "1fc": self.phrasal_cls = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.PHRASE.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(4 * dim, 5)) else: raise ValueError("Classifier type: {} not supported!".format( config.NETWORK.PHRASE.CLASSIFIER_TYPE)) if self.supervise_attention == "indirect": if config.NETWORK.VG.CLASSIFIER_TYPE == "2fc": self.vg_cls = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.VG.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(2 * dim, config.NETWORK.VG.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.VG.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(config.NETWORK.VG.CLASSIFIER_HIDDEN_SIZE, 1), ) elif config.NETWORK.VG.CLASSIFIER_TYPE == "1fc": self.vg_cls = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.VG.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(2 * dim, 1)) else: raise ValueError("Classifier type: {} not supported!".format( config.NETWORK.PHRASE.CLASSIFIER_TYPE)) # init weights self.init_weight() self.fix_params() def init_weight(self): if not self.config.NETWORK.BLIND: self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_(mean=0.0, std=0.02) if not self.for_pretrain: for m in self.sentence_cls.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) def train(self, mode=True): super(ResNetVLBERT, self).train(mode) # turn some frozen layers to eval mode if (not self.config.NETWORK.BLIND) and self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): if self.config.NETWORK.BLIND: self.vlbert._module.visual_scale_text.requires_grad = False self.vlbert._module.visual_scale_object.requires_grad = False def _collect_obj_reps(self, span_tags, object_reps): """ Collect span-level object representations :param span_tags: [batch_size, ..leading_dims.., L] :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] :return: """ span_tags_fixed = torch.clamp( span_tags, min=0) # In case there were masked values here row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape) row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None] # Add extra dimensions to the row broadcaster so it matches row_id leading_dims = len(span_tags.shape) - 2 for i in range(leading_dims): row_id_broadcaster = row_id_broadcaster[..., None] row_id += row_id_broadcaster return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view( *span_tags_fixed.shape, -1) def prepare_text(self, sentence1, sentence2, mask1, mask2, sentence1_tags, sentence2_tags, phrase1_mask, phrase2_mask): batch_size, max_len1 = sentence1.shape _, max_len2 = sentence2.shape max_len = (mask1.sum(1) + mask2.sum(1)).max() + 3 cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) end_1 = 1 + mask1.sum(1, keepdim=True) end_2 = end_1 + 1 + mask2.sum(1, keepdim=True) input_ids = torch.zeros((batch_size, max_len), dtype=sentence1.dtype, device=sentence1.device) input_mask = torch.ones((batch_size, max_len), dtype=torch.uint8, device=sentence1.device) input_type_ids = torch.zeros((batch_size, max_len), dtype=sentence1.dtype, device=sentence1.device) text_tags = input_type_ids.new_zeros((batch_size, max_len)) phr_mask = None grid_i, grid_k = torch.meshgrid( torch.arange(batch_size, device=sentence1.device), torch.arange(max_len, device=sentence1.device)) input_mask[grid_k > end_2] = 0 input_type_ids[(grid_k > end_1) & (grid_k <= end_2)] = 1 input_mask1 = (grid_k > 0) & (grid_k < end_1) input_mask2 = (grid_k > end_1) & (grid_k < end_2) input_ids[:, 0] = cls_id input_ids[grid_k == end_1] = sep_id input_ids[grid_k == end_2] = sep_id input_ids[input_mask1] = sentence1[mask1] input_ids[input_mask2] = sentence2[mask2] text_tags[input_mask1] = sentence1_tags[mask1] text_tags[input_mask2] = sentence2_tags[mask2] if self.use_phrasal_paraphrases: phr_mask = phrase1_mask.new_zeros( (batch_size, max_len, phrase1_mask.size(-1))) phr_mask[input_mask1] = phrase1_mask[mask1] phr_mask[input_mask2] = phrase2_mask[mask2] # add offsets so that every pair of phrases gets a unique id in the batch no_phr_mask = (phr_mask == 0) n_phr = torch.max(phr_mask, dim=1)[0] offsets = phr_mask.new_zeros( (phr_mask.size(0) * phr_mask.size(-1))) offsets[1:] = torch.cumsum(n_phr.view(-1)[:-1], dim=0) offsets = offsets.view((phr_mask.size(0), phr_mask.size(-1))) phr_mask += offsets.unsqueeze(1) phr_mask[no_phr_mask] = 0 return input_ids, input_type_ids, text_tags, input_mask, phr_mask def train_forward(self, images, boxes, sentence1, sentence2, im_info, label): ########################################### # visual feature extraction box_mask = (boxes[:, :, -1] > -0.5) max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len].type(torch.float32) # segms = segms[:, :max_len] if self.config.NETWORK.BLIND: obj_reps = { 'obj_reps': boxes.new_zeros( (*boxes.shape[:-1], self.config.NETWORK.IMAGE_FINAL_DIM)) } else: obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) sentence1_ids = sentence1[:, :, 0] mask1 = (sentence1[:, :, 0] > 0.5) sentence1_tags = sentence1[:, :, 1] sentence2_ids = sentence2[:, :, 0] mask2 = (sentence2[:, :, 0] > 0.5) sentence2_tags = sentence2[:, :, 1] if self.use_phrasal_paraphrases: phrase1_mask = sentence1[:, :, 2:] phrase2_mask = sentence2[:, :, 2:] sentence_label = label[:, 0, 0].view(-1) phrase_labels = label[:, :, 1] else: phrase1_mask, phrase2_mask = None, None sentence_label = label.view(-1) ############################################ # prepare text text_input_ids, text_token_type_ids, text_tags, text_mask, phrase_mask = self.prepare_text( sentence1_ids, sentence2_ids, mask1, mask2, sentence1_tags, sentence2_tags, phrase1_mask, phrase2_mask) # Add visual feature to text elements if self.config.NETWORK.NO_GROUNDING: text_visual_embeddings = self._collect_obj_reps( text_tags.new_zeros(text_tags.size()), obj_reps['obj_reps']) else: text_visual_embeddings = self._collect_obj_reps( text_tags, obj_reps['obj_reps']) # Add textual feature to image element if self.config.NETWORK.BLIND: object_linguistic_embeddings = boxes.new_zeros( (*boxes.shape[:-1], self.config.NETWORK.VLBERT.hidden_size)) else: object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT if self.config.NETWORK.NO_OBJ_ATTENTION or self.config.NETWORK.BLIND: box_mask.zero_() if self.supervise_attention in ["direct", "semi-direct"]: hidden_states_text, hidden_states_objects, pooled_rep, attention_probs = \ self.vlbert(text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_all_encoded_layers=False, output_text_and_object_separately=True, output_attention_probs=True) else: hidden_states_text, hidden_states_objects, pooled_rep = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_all_encoded_layers=False, output_text_and_object_separately=True, output_attention_probs=False) ########################################### outputs = {} # sentence classification sentence_logits = self.sentence_cls(pooled_rep) if self.align_caption_img: sentence_logits = sentence_logits.view((-1, 3)) sentence_cls_loss = F.cross_entropy(sentence_logits, sentence_label) else: sentence_logits = sentence_logits.view(-1) sentence_cls_loss = F.binary_cross_entropy_with_logits( sentence_logits, sentence_label.type(torch.float32)) outputs.update({ 'sentence_label_logits': sentence_logits, 'sentence_label': sentence_label.long(), 'sentence_cls_loss': sentence_cls_loss }) # phrasal paraphrases classification phrase_cls_loss = sentence_logits.new_zeros(()) if self.use_phrasal_paraphrases: phrase_labels = phrase_labels.view((-1)) phrase_cls_logits = sentence_logits.new_zeros( (phrase_labels.size(0), 5)) outputs.update({ "phrase_label": phrase_labels, "phrase_label_logits": phrase_cls_logits, "phrase_cls_loss": phrase_cls_loss }) if phrase_mask.max() > 0: logits = self.get_phrase_cls(hidden_states_text, phrase_mask, text_token_type_ids) phrase_cls_loss = F.cross_entropy( logits, phrase_labels[phrase_labels > -1], reduction="mean") phrase_cls_logits[(phrase_labels > -1)] = logits outputs.update({ "phrase_label_logits": phrase_cls_logits, "phrase_cls_loss": phrase_cls_loss }) # Handle attention supervision, suffix 1 refers to text-to-roi attention and suffix 2 refers to roi-to-text attention_loss = 0. if self.supervise_attention in ["direct", "semi-direct"]: use_raw = self.supervise_attention == "direct" attention_loss_1, attention_loss_2 = get_attention_supervision_loss( attention_probs, text_tags, text_mask, box_mask, use_raw=use_raw, normalization=self.normalization) outputs.update({ "attention_loss_1": attention_loss_1, "attention_loss_2": attention_loss_2 }) attention_loss = attention_loss_1 + attention_loss_2 elif self.supervise_attention == "indirect": attention_loss = self.get_indirect_vg_loss(hidden_states_text, hidden_states_objects, text_tags, text_mask, box_mask) outputs.update({"vg_loss": attention_loss}) # EWC regularization loss against catastrophic forgetting ewc_loss = 0. if self.ewc_reg: for n, p in self.named_parameters(): name = "module." + n if name in self.fisher.keys(): ewc_loss += ( self.fisher[name].to(p.device) * (p - self.pretrain_param[name].to(p.device))**2).sum() outputs.update({"ewc_loss": ewc_loss}) loss = sentence_cls_loss.mean() + self.config.NETWORK.PHRASE_LOSS_WEIGHT * phrase_cls_loss + \ self.config.NETWORK.ATTENTION_LOSS_WEIGHT * attention_loss + self.importance_hparam * ewc_loss return outputs, loss def inference_forward(self, images, boxes, sentence1, sentence2, im_info): ########################################### # visual feature extraction # For now use all boxes box_mask = torch.ones(boxes[:, :, -1].size(), dtype=torch.uint8, device=boxes.device) max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len].type(torch.float32) if self.config.NETWORK.BLIND: obj_reps = { 'obj_reps': boxes.new_zeros( (*boxes.shape[:-1], self.config.NETWORK.IMAGE_FINAL_DIM)) } else: obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) # For now no tags sentence1_ids = sentence1[:, :, 0] mask1 = (sentence1[:, :, 0] > 0.5) sentence1_tags = sentence1[:, :, 1] sentence2_ids = sentence2[:, :, 0] mask2 = (sentence2[:, :, 0] > 0.5) sentence2_tags = sentence2[:, :, 1] if self.use_phrasal_paraphrases: phrase1_mask = sentence1[:, :, 2:] phrase2_mask = sentence2[:, :, 2:] else: phrase1_mask, phrase2_mask = None, None ############################################ # prepare text text_input_ids, text_token_type_ids, text_tags, text_mask, phrase_mask = self.prepare_text( sentence1_ids, sentence2_ids, mask1, mask2, sentence1_tags, sentence2_tags, phrase1_mask, phrase2_mask) # Add visual feature to text elements if self.config.NETWORK.NO_GROUNDING: text_tags.zero_() text_visual_embeddings = self._collect_obj_reps( text_tags, obj_reps['obj_reps']) # Add textual feature to image element if self.config.NETWORK.BLIND: object_linguistic_embeddings = boxes.new_zeros( (*boxes.shape[:-1], self.config.NETWORK.VLBERT.hidden_size)) else: object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT if self.config.NETWORK.NO_OBJ_ATTENTION or self.config.NETWORK.BLIND: box_mask.zero_() hidden_states_text, hidden_states_objects, pooled_rep = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_all_encoded_layers=False, output_text_and_object_separately=True) ########################################### outputs = {} # sentence classification sentence_logits = self.sentence_cls(pooled_rep) if self.align_caption_img: sentence_logits = sentence_logits.view((-1, 3)) else: sentence_logits = sentence_logits.view(-1) outputs.update({'sentence_label_logits': sentence_logits}) if self.use_phrasal_paraphrases: phrase_cls_logits = sentence_logits.new_zeros((1, 5)) + 100000 outputs.update({"phrase_label_logits": phrase_cls_logits}) if phrase_mask.max() > 0: phrase_cls_logits = self.get_phrase_cls( hidden_states_text, phrase_mask, text_token_type_ids) outputs.update({"phrase_label_logits": phrase_cls_logits}) return outputs def get_phrase_cls(self, encoded_rep, phr_mask, token_type): n_pairs = phr_mask.max().item() phr_reps = encoded_rep.new_zeros((n_pairs, 2, encoded_rep.size(-1))) for i in range(n_pairs): # max pool representation of first phrase shaped_phr_mask = (phr_mask == i + 1).any(2) phr_reps[i, 0] = encoded_rep[(token_type == 0) & shaped_phr_mask].max(dim=0)[0] # max pool representation of second phrase phr_reps[i, 1] = encoded_rep[(token_type == 1) & shaped_phr_mask].max(dim=0)[0] final_phrases_rep = torch.cat( (phr_reps[:, 0], phr_reps[:, 1], torch.abs(phr_reps[:, 0] - phr_reps[:, 1]), torch.mul(phr_reps[:, 0], phr_reps[:, 1])), dim=1) output_logits = self.phrasal_cls(final_phrases_rep) return output_logits def get_indirect_vg_loss(self, encoded_text, encoded_objects, text_tags, text_mask, box_mask): if text_tags.max() <= 0: return encoded_text.new_zeros((1)).sum() else: vg_inputs = [] vg_labels = [] indexes = find_phrases(text_tags) for i, k, length, tag in indexes: phrases_rep = encoded_text[i, k:k + length].max( dim=0)[0] # max pool encoding of the words in the phrase objects_reps = encoded_objects[i][box_mask[i]][1:] vg_inputs.append( torch.cat((phrases_rep.unsqueeze(0).repeat( len(objects_reps), 1), objects_reps), dim=1)) vg_lbl = text_tags.new_zeros((len(objects_reps))) vg_lbl[tag - 1] = 1 vg_labels.append(vg_lbl) vg_inputs = torch.cat(vg_inputs, dim=0) vg_labels = torch.cat(vg_labels, dim=0) vg_logits = self.vg_cls(vg_inputs).view(-1) vg_loss = F.binary_cross_entropy_with_logits( vg_logits, vg_labels.float()) return vg_loss
class ResNetVLBERTForPretrainingEncDec(Module): def __init__(self, config): super(ResNetVLBERTForPretrainingEncDec, self).__init__(config) self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=False) self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) if config.NETWORK.IMAGE_FEAT_PRECOMPUTED: self.object_mask_visual_embedding = nn.Embedding(1, 2048) if config.NETWORK.WITH_MVRC_LOSS: self.object_mask_word_embedding = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBertEncoder( config.NETWORK.VLBERT, language_pretrained_model_path=None if config.NETWORK.VLBERT.from_scratch else language_pretrained_model_path, with_rel_head=False, with_mlm_head=False, with_mvrc_head=False, ) # FM edit: add decoder self.decoder = VisualLinguisticBertForPretrainingDecoder( config.NETWORK.VLBERT, language_pretrained_model_path=None if config.NETWORK.VLBERT.from_scratch else language_pretrained_model_path, with_rel_head=config.NETWORK.WITH_REL_LOSS, with_mlm_head=config.NETWORK.WITH_MLM_LOSS, with_mvrc_head=config.NETWORK.WITH_MVRC_LOSS, ) # init weights self.init_weight() self.fix_params() def init_weight(self): if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED: self.object_mask_visual_embedding.weight.data.fill_(0.0) if self.config.NETWORK.WITH_MVRC_LOSS: self.object_mask_word_embedding.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) def train(self, mode=True): super(ResNetVLBERTForPretrainingEncDec, self).train(mode) # turn some frozen layers to eval mode if self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): pass def _collect_obj_reps(self, span_tags, object_reps): """ Collect span-level object representations :param span_tags: [batch_size, ..leading_dims.., L] :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] :return: """ span_tags_fixed = torch.clamp( span_tags, min=0) # In case there were masked values here row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape) row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None] # Add extra diminsions to the row broadcaster so it matches row_id leading_dims = len(span_tags.shape) - 2 for i in range(leading_dims): row_id_broadcaster = row_id_broadcaster[..., None] row_id += row_id_broadcaster return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view( *span_tags_fixed.shape, -1) def forward(self, image, boxes, im_info, text_en, text_de, relationship_label, mlm_labels_en, mlm_labels_de, mvrc_ops, mvrc_labels): ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) origin_len = boxes.shape[1] max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] mvrc_ops = mvrc_ops[:, :max_len] mvrc_labels = mvrc_labels[:, :max_len] if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED: box_features = boxes[:, :, 4:] box_features[mvrc_ops == 1] = self.object_mask_visual_embedding.weight[0] boxes[:, :, 4:] = box_features obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None, mvrc_ops=mvrc_ops, mask_visual_embed=None) ############################################ # prepare text - English text_input_ids = text_en text_tags = text_en.new_zeros(text_en.shape) text_token_type_ids = text_en.new_zeros(text_en.shape) text_mask = (text_input_ids > 0) text_visual_embeddings = self._collect_obj_reps( text_tags, obj_reps['obj_reps']) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) if self.config.NETWORK.WITH_MVRC_LOSS: object_linguistic_embeddings[ mvrc_ops == 1] = self.object_mask_word_embedding.weight[0] object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ############################################ # prepare text - German text_input_ids_de = text_de text_tags_de = text_de.new_zeros(text_de.shape) text_token_type_ids_de = text_de.new_zeros(text_de.shape) text_mask_de = (text_input_ids_de > 0) text_visual_embeddings_de = self._collect_obj_reps( text_tags_de, obj_reps['obj_reps']) object_linguistic_embeddings_de = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) if self.config.NETWORK.WITH_MVRC_LOSS: object_linguistic_embeddings_de[ mvrc_ops == 1] = self.object_mask_word_embedding_de.weight[0] object_vl_embeddings_de = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings_de), -1) ########################################### # Visual Linguistic BERT - Encoder relationship_logits_en, mlm_logits_en, mvrc_logits_en, encoder_hidden_states = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask) ########################################### # Visual Linguistic BERT - Decoder relationship_logits, mlm_logits, mvrc_logits = self.decoder( text_input_ids_de, text_token_type_ids_de, text_visual_embeddings_de, text_mask_de, object_vl_embeddings_de, box_mask, encoder_hidden_states) ########################################### outputs = {} # loss relationship_loss = im_info.new_zeros(()) mlm_loss = im_info.new_zeros(()) mvrc_loss = im_info.new_zeros(()) if self.config.NETWORK.WITH_REL_LOSS: relationship_loss = F.cross_entropy(relationship_logits, relationship_label) if self.config.NETWORK.WITH_MLM_LOSS: mlm_logits_padded = mlm_logits.new_zeros( (*mlm_labels_de.shape, mlm_logits.shape[-1])).fill_(-10000.0) mlm_logits_padded[:, :mlm_logits.shape[1]] = mlm_logits mlm_logits = mlm_logits_padded if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST: mlm_loss = F.cross_entropy(mlm_logits.transpose(1, 2), mlm_labels_de, ignore_index=-1, reduction='none') num_mlm = (mlm_labels_de != -1).sum( 1, keepdim=True).to(dtype=mlm_loss.dtype) num_has_mlm = (num_mlm != 0).sum().to(dtype=mlm_loss.dtype) mlm_loss = (mlm_loss / (num_mlm + 1e-4)).sum() / (num_has_mlm + 1e-4) else: mlm_loss = F.cross_entropy(mlm_logits.view( (-1, mlm_logits.shape[-1])), mlm_labels_de.view(-1), ignore_index=-1) # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), # mvrc_labels.contiguous().view(-1), # ignore_index=-1) if self.config.NETWORK.WITH_MVRC_LOSS: if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]), reduction='none').view(mvrc_logits.shape[:-1]) valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1 mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \ .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4) else: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1])) mvrc_logits_padded = mvrc_logits.new_zeros( (mvrc_logits.shape[0], origin_len, mvrc_logits.shape[2])).fill_(-10000.0) mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits mvrc_logits = mvrc_logits_padded mvrc_labels_padded = mvrc_labels.new_zeros( (mvrc_labels.shape[0], origin_len, mvrc_labels.shape[2])).fill_(0.0) mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels mvrc_labels = mvrc_labels_padded outputs.update({ 'relationship_logits': relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None, 'relationship_label': relationship_label if self.config.NETWORK.WITH_REL_LOSS else None, 'mlm_logits': mlm_logits if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_label': mlm_labels_de if self.config.NETWORK.WITH_MLM_LOSS else None, 'mvrc_logits': mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None, 'mvrc_label': mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None, 'relationship_loss': relationship_loss, 'mlm_loss': mlm_loss, 'mvrc_loss': mvrc_loss, }) loss = relationship_loss.mean() + mlm_loss.mean() + mvrc_loss.mean() return outputs, loss
class ResNetVLBERT(Module): def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.predict_on_cls = config.NETWORK.VLBERT.predict_on_cls # make prediction on [CLS]? self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS if not config.NETWORK.BLIND: self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=self.enable_cnn_reg_loss) if config.NETWORK.VLBERT.object_word_embed_mode == 1: self.object_linguistic_embeddings = nn.Embedding( 81, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 2: # default: class-agnostic self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 3: self.object_linguistic_embeddings = None else: raise NotImplementedError self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path self.language_pretrained_model_path = language_pretrained_model_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path) dim = config.NETWORK.VLBERT.hidden_size if config.NETWORK.CLASSIFIER_TYPE == "2fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, config.DATASET.ANSWER_VOCAB_SIZE), ) elif config.NETWORK.CLASSIFIER_TYPE == "1fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.DATASET.ANSWER_VOCAB_SIZE)) elif config.NETWORK.CLASSIFIER_TYPE == 'mlm': transform = BertPredictionHeadTransform(config.NETWORK.VLBERT) linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.DATASET.ANSWER_VOCAB_SIZE) self.final_mlp = nn.Sequential( transform, nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), linear) else: raise ValueError("Not support classifier type: {}!".format( config.NETWORK.CLASSIFIER_TYPE)) self.use_spatial_model = False if config.NETWORK.USE_SPATIAL_MODEL: self.use_spatial_model = True # self.simple_spatial_model = SimpleSpatialModel(4, config.NETWORK.VLBERT.hidden_size, 9, config) self.use_coord_vector = False if config.NETWORK.USE_COORD_VECTOR: self.use_coord_vector = True self.loc_fcs = nn.Sequential( nn.Linear(2 * 5 + 9, config.NETWORK.VLBERT.hidden_size), nn.ReLU(True), nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)) else: self.simple_spatial_model = SimpleSpatialModel( 4, config.NETWORK.VLBERT.hidden_size, 9) self.spa_add = True if config.NETWORK.SPA_ADD else False self.spa_concat = True if config.NETWORK.SPA_CONCAT else False if self.spa_add: self.spa_feat_weight = 0.5 if config.NETWORK.USE_SPA_WEIGHT: self.spa_feat_weight = config.NETWORK.SPA_FEAT_WEIGHT self.spa_fusion_linear = nn.Linear( config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) elif self.spa_concat: if self.use_coord_vector: self.spa_fusion_linear = nn.Linear( config.NETWORK.VLBERT.hidden_size + config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) else: self.spa_fusion_linear = nn.Linear( config.NETWORK.VLBERT.hidden_size * 2, config.NETWORK.VLBERT.hidden_size) self.spa_linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) self.dropout = nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT) self.spa_one_more_layer = config.NETWORK.SPA_ONE_MORE_LAYER if self.spa_one_more_layer: self.spa_linear_hidden = nn.Linear( config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) self.enhanced_img_feature = False if config.NETWORK.VLBERT.ENHANCED_IMG_FEATURE: self.enhanced_img_feature = True self.mask_weight = config.NETWORK.VLBERT.mask_weight self.mask_loss_sum = config.NETWORK.VLBERT.mask_loss_sum self.mask_loss_mse = config.NETWORK.VLBERT.mask_loss_mse self.no_predicate = config.NETWORK.VLBERT.NO_PREDICATE self.all_proposals_test = False if config.DATASET.ALL_PROPOSALS_TEST: self.all_proposals_test = True self.use_uvtranse = False if config.NETWORK.USE_UVTRANSE: self.use_uvtranse = True self.union_vec_fc = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) self.uvt_add = True if config.NETWORK.UVT_ADD else False self.uvt_concat = True if config.NETWORK.UVT_CONCAT else False if not (self.uvt_add ^ self.uvt_concat): assert False if self.uvt_add: self.uvt_feat_weight = config.NETWORK.UVT_FEAT_WEIGHT self.uvt_fusion_linear = nn.Linear( config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) elif self.uvt_concat: self.uvt_fusion_linear = nn.Linear( config.NETWORK.VLBERT.hidden_size * 2, config.NETWORK.VLBERT.hidden_size) self.uvt_linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) self.dropout_uvt = nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT) # init weights self.init_weight() def init_weight(self): # self.hm_out.weight.data.normal_(mean=0.0, std=0.02) # self.hm_out.bias.data.zero_() # self.hi_out.weight.data.normal_(mean=0.0, std=0.02) # self.hi_out.bias.data.zero_() self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_(mean=0.0, std=0.02) for m in self.final_mlp.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) if self.config.NETWORK.CLASSIFIER_TYPE == 'mlm': language_pretrained = torch.load( self.language_pretrained_model_path) mlm_transform_state_dict = {} pretrain_keys = [] for k, v in language_pretrained.items(): if k.startswith('cls.predictions.transform.'): pretrain_keys.append(k) k_ = k[len('cls.predictions.transform.'):] if 'gamma' in k_: k_ = k_.replace('gamma', 'weight') if 'beta' in k_: k_ = k_.replace('beta', 'bias') mlm_transform_state_dict[k_] = v print("loading pretrained classifier transform keys: {}.".format( pretrain_keys)) self.final_mlp[0].load_state_dict(mlm_transform_state_dict) def train(self, mode=True): super(ResNetVLBERT, self).train(mode) # turn some frozen layers to eval mode if self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def _collect_obj_reps(self, span_tags, object_reps, spo_len): """ Collect span-level object representations :param span_tags: [batch_size, ..leading_dims.., L] :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] :return: """ span_tags_fixed = torch.clamp( span_tags, min=0) # In case there were masked values here row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape) row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None] # Add extra diminsions to the row broadcaster so it matches row_id leading_dims = len(span_tags.shape) - 2 for i in range(leading_dims): row_id_broadcaster = row_id_broadcaster[..., None] row_id += row_id_broadcaster if self.enhanced_img_feature: # for i in range(span_tags_fixed.shape[0]): # span_tags_fixed[i, 1:1 + spo_len[i, 0]] = 1 # span_tags_fixed[i, 1 + spo_len[i, 0]:1 + spo_len[i, 0] + spo_len[i, 1]] = 2 # span_tags_fixed[i, 1 + spo_len[i, 0] + spo_len[i, 1]:1 + spo_len[i, 0] + spo_len[i, 1] + spo_len[i, 2]] = 3 pass text_visual_embeddings = object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view( *span_tags_fixed.shape, -1) return text_visual_embeddings def prepare_text_from_qa(self, question, question_tags, question_mask, answer, answer_tags, answer_mask): batch_size, max_q_len = question.shape _, max_a_len = answer.shape if self.predict_on_cls: answer_mask = answer_mask.new_zeros( answer_mask.shape) # remove answer_mask max_len = (question_mask.sum(1) + answer_mask.sum(1)).max() + 2 # [CLS] & 1*[SEP] else: max_len = (question_mask.sum(1) + answer_mask.sum(1)).max() + 3 # [CLS] & 2*[SEP] cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) q_end = 1 + question_mask.sum(1, keepdim=True) a_end = q_end if self.predict_on_cls else q_end + 1 + answer_mask.sum( 1, keepdim=True) input_ids = torch.zeros((batch_size, max_len), dtype=question.dtype, device=question.device) input_mask = torch.ones((batch_size, max_len), dtype=torch.uint8, device=question.device) input_type_ids = torch.zeros((batch_size, max_len), dtype=question.dtype, device=question.device) text_tags = input_type_ids.new_zeros((batch_size, max_len)) grid_i, grid_j = torch.meshgrid( torch.arange(batch_size, device=question.device), torch.arange(max_len, device=question.device)) input_mask[grid_j > a_end] = 0 if not self.predict_on_cls: input_type_ids[(grid_j > q_end) & (grid_j <= a_end)] = 1 q_input_mask = (grid_j > 0) & (grid_j < q_end) a_input_mask = (grid_j > q_end) & (grid_j < a_end) input_ids[:, 0] = cls_id input_ids[grid_j == q_end] = sep_id input_ids[grid_j == a_end] = sep_id input_ids[q_input_mask] = question[question_mask] input_ids[a_input_mask] = answer[answer_mask] text_tags[q_input_mask] = question_tags[question_mask] text_tags[a_input_mask] = answer_tags[answer_mask] ans_pos = a_end.new_zeros( a_end.shape).squeeze(1) if self.predict_on_cls else (a_end - 1).squeeze(1) return input_ids, input_type_ids, text_tags, input_mask, ans_pos def train_forward(self, img, im_info, boxes, labels, spo_ids, spo_lens, img_path): boxes, labels, spo_ids, spo_lens, im_info = boxes.squeeze( 0), labels.squeeze(0), spo_ids.squeeze(0), spo_lens.squeeze( 0), im_info.squeeze(0) images = torch.cat([ img for _ in range(boxes.shape[0]) ]) # (Pdb) images.shape = torch.Size([4, 3, 895, 899]) box_mask = (boxes[:, :, 0] > -1.5 ) # (Pdb) box_mask.shape = torch.Size([4, 54]) max_len = int(box_mask.sum(1).max().item()) # max_len = 54 box_mask = box_mask[:, :max_len] # doesn't seem to have effect boxes = boxes[:, :max_len] # doesn't seem to have effect boxes[boxes < 0] = 0 # rectify those coordinates < 0 to 0 obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None, copy_images=True) # obj_reps['obj_reps'].shape = torch.Size([4, 54, 768]) question_ids = spo_ids question_tags = spo_ids.new_zeros(question_ids.shape) question_mask = (spo_ids > 0.5) answer_ids = question_ids.new_zeros( (question_ids.shape[0], 1)).fill_(self.tokenizer.convert_tokens_to_ids(['[MASK]'])[0]) answer_mask = question_mask.new_zeros(answer_ids.shape).fill_(1) answer_tags = question_tags.new_zeros(answer_ids.shape) ############################################ # prepare text text_input_ids, text_token_type_ids, text_tags, text_mask, ans_pos = self.prepare_text_from_qa( question_ids, question_tags, question_mask, answer_ids, answer_tags, answer_mask) if self.config.NETWORK.NO_GROUNDING: # always False obj_rep_zeroed = obj_reps['obj_reps'].new_zeros( obj_reps['obj_reps'].shape) text_tags.zero_() text_visual_embeddings = self._collect_obj_reps( text_tags, obj_rep_zeroed) else: text_visual_embeddings = self._collect_obj_reps( text_tags, obj_reps['obj_reps'], spo_lens) assert self.config.NETWORK.VLBERT.object_word_embed_mode == 2 object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) # concatenation of obj visual & linguistic ########################################### # Visual Linguistic BERT hidden_states, hc, spo_fused_masks = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, object_visual_feat=obj_reps['obj_reps_rawraw'], spo_len=spo_lens, output_all_encoded_layers=False) _batch_inds = torch.arange(spo_ids.shape[0], device=spo_ids.device) hm = hidden_states[_batch_inds, ans_pos] if self.use_spatial_model: if self.use_coord_vector: # import pdb; pdb.set_trace() spa_feat = torch.zeros((boxes.shape[0], 5 * 2 + 9), dtype=boxes.dtype, layout=boxes.layout, device=boxes.device) for i in range(boxes.shape[0]): area_subj_ratio = (boxes[i, 1, 2] - boxes[i, 1, 0]) * ( boxes[i, 1, 3] - boxes[i, 1, 1]) / (im_info[i, 0] * im_info[i, 1]) subj = torch.tensor([ boxes[i, 1, 0] / im_info[i, 0], boxes[i, 1, 1] / im_info[i, 1], boxes[i, 1, 2] / im_info[i, 0], boxes[i, 1, 3] / im_info[i, 1], area_subj_ratio ]) area_pred_ratio = (boxes[i, 2, 2] - boxes[i, 2, 0]) * ( boxes[i, 2, 3] - boxes[i, 2, 1]) / (im_info[i, 0] * im_info[i, 1]) w_s = (boxes[i, 1, 2] - boxes[i, 1, 0]) h_s = (boxes[i, 1, 3] - boxes[i, 1, 1]) x_s = (boxes[i, 1, 2] + boxes[i, 1, 0]) / 2 y_s = (boxes[i, 1, 3] + boxes[i, 1, 1]) / 2 w_o = (boxes[i, 3, 2] - boxes[i, 3, 0]) h_o = (boxes[i, 3, 3] - boxes[i, 3, 1]) x_o = (boxes[i, 3, 2] + boxes[i, 3, 0]) / 2 y_o = (boxes[i, 3, 3] + boxes[i, 3, 1]) / 2 pred = torch.tensor([(x_s - x_o) / w_o, (y_s - y_o) / h_o, torch.log(w_s / w_o), torch.log(h_s / h_o), (x_o - x_s) / w_s, (y_o - y_s) / h_s, torch.log(w_o / w_s), torch.log(h_o / h_s), area_pred_ratio]) area_obj_ratio = (boxes[i, 3, 2] - boxes[i, 3, 0]) * ( boxes[i, 3, 3] - boxes[i, 3, 1]) / (im_info[i, 0] * im_info[i, 1]) obj = torch.tensor([ boxes[i, 3, 0] / im_info[i, 0], boxes[i, 3, 1] / im_info[i, 1], boxes[i, 3, 2] / im_info[i, 0], boxes[i, 3, 3] / im_info[i, 1], area_obj_ratio ]) spa_feat[0] = torch.cat((subj, pred, obj)).unsqueeze(0) spa_feat = self.loc_fcs(spa_feat) # assert self.spa_concat # Currently coord_vec only works with concatenation! else: for i in range(boxes.shape[0]): boxes[:, :, 0][i] /= im_info[:, 0][i] boxes[:, :, 1][i] /= im_info[:, 1][i] boxes[:, :, 2][i] /= im_info[:, 0][i] boxes[:, :, 3][i] /= im_info[:, 1][i] spa_feat = self.simple_spatial_model(boxes[:, 1], boxes[:, 3], labels) if self.spa_add: hm = hm * ( 1 - self.spa_feat_weight) + spa_feat * self.spa_feat_weight elif self.spa_concat: hm = torch.cat((hm, spa_feat), dim=1) hm = self.spa_fusion_linear(hm) hm = F.relu(hm) hm = self.dropout(hm) if self.spa_one_more_layer: # if no unfrozen VLBERT add one more layer and lower the dropout rate to 0.2 hm = self.spa_linear_hidden(hm) hm = F.relu(hm) hm = self.spa_linear(hm) if self.use_uvtranse: union_vec = obj_reps['obj_reps'][:, 2] - obj_reps[ 'obj_reps'][:, 1] - obj_reps['obj_reps'][:, 3] # pred - subj - obj union_vec = self.union_vec_fc(union_vec) union_vec = F.relu(union_vec) if self.uvt_add: hm = hm * (1 - self.uvt_feat_weight ) + union_vec * self.uvt_feat_weight elif self.uvt_concat: hm = torch.cat((hm, union_vec), dim=1) hm = self.uvt_fusion_linear(hm) hm = F.relu(hm) hm = self.dropout_uvt(hm) hm = self.uvt_linear(hm) # import pdb; pdb.set_trace() ########################################### outputs = {} # classifier logits = self.final_mlp(hm) # loss # import pdb; pdb.set_trace() ans_loss = F.cross_entropy(logits, labels.view(-1)) # * label.size(1) # Add sigmoid for binary prediction in spasen_metrics.py logits = F.softmax(logits, dim=1) # mask loss if spo_fused_masks is not None: nb_of_tokens = 2 if self.no_predicate else 3 spo_fused_masks = spo_fused_masks.view(-1, nb_of_tokens, 14, 14) # spo_fused_masks_norm = spo_fused_masks.new_zeros(size=spo_fused_masks.shape) boxes_mask = torch.zeros_like(spo_fused_masks) rounded_14x14_boxes = torch.round(boxes * 14).to(torch.int) for i in range(boxes.shape[0]): # for each sample for j in range(nb_of_tokens): # sub, pred, obj # Create a mask boxes_mask[i, j, rounded_14x14_boxes[ i, j + 1, 0].item():rounded_14x14_boxes[i, j + 1, 2].item(), rounded_14x14_boxes[ i, j + 1, 1].item():rounded_14x14_boxes[i, j + 1, 3].item()] = 1 if self.mask_loss_sum: mask_loss = F.binary_cross_entropy_with_logits( spo_fused_masks, boxes_mask, reduction='sum') / spo_fused_masks.shape[0] elif self.mask_loss_mse: mask_loss = F.mse_loss(spo_fused_masks, boxes_mask) else: mask_loss = F.binary_cross_entropy_with_logits( spo_fused_masks, boxes_mask) outputs.update({ 'label_logits': logits, 'label': labels, 'ans_loss': ans_loss, 'mask_loss': mask_loss }) if self.mask_weight < 0: loss = (ans_loss + mask_loss).mean() else: loss = (ans_loss * (1 - self.mask_weight) + mask_loss * self.mask_weight).mean() else: outputs.update({ 'label_logits': logits, 'label': labels, 'ans_loss': ans_loss }) loss = ans_loss.mean() return outputs, loss def inference_forward(self, img, im_info, boxes, labels, spo_ids, spo_lens, img_path, rels_cand, labels_so_ids, subj_obj_classes): boxes, labels, spo_ids, spo_lens, im_info, rels_cand, labels_so_ids, subj_obj_classes = boxes.squeeze( 0), labels.squeeze(0), spo_ids.squeeze(0), spo_lens.squeeze( 0), im_info.squeeze(0), rels_cand.squeeze( 0), labels_so_ids.squeeze(0), subj_obj_classes.squeeze(0) # visual feature extraction images = torch.cat([img for _ in range(boxes.shape[0])]) box_mask = (boxes[:, :, 0] > -1.5) max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None, copy_images=True) question_ids = spo_ids question_tags = spo_ids.new_zeros(question_ids.shape) question_mask = (spo_ids > 0.5) answer_ids = question_ids.new_zeros( (question_ids.shape[0], 1)).fill_(self.tokenizer.convert_tokens_to_ids(['[MASK]'])[0]) answer_mask = question_mask.new_zeros(answer_ids.shape).fill_(1) answer_tags = question_tags.new_zeros(answer_ids.shape) ############################################ # prepare text text_input_ids, text_token_type_ids, text_tags, text_mask, ans_pos = self.prepare_text_from_qa( question_ids, question_tags, question_mask, answer_ids, answer_tags, answer_mask) if self.config.NETWORK.NO_GROUNDING: obj_rep_zeroed = obj_reps['obj_reps'].new_zeros( obj_reps['obj_reps'].shape) text_tags.zero_() text_visual_embeddings = self._collect_obj_reps( text_tags, obj_rep_zeroed) else: text_visual_embeddings = self._collect_obj_reps( text_tags, obj_reps['obj_reps'], spo_lens) assert self.config.NETWORK.VLBERT.object_word_embed_mode == 2 object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT hidden_states, hc, spo_fused_masks = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, object_visual_feat=obj_reps['obj_reps_rawraw'], spo_len=spo_lens, output_all_encoded_layers=False) _batch_inds = torch.arange(spo_ids.shape[0], device=spo_ids.device) hm = hidden_states[_batch_inds, ans_pos] if self.use_spatial_model: if self.use_coord_vector: # import pdb; pdb.set_trace() spa_feat = torch.zeros((boxes.shape[0], 5 * 2 + 9), dtype=boxes.dtype, layout=boxes.layout, device=boxes.device) for i in range(boxes.shape[0]): area_subj_ratio = (boxes[i, 1, 2] - boxes[i, 1, 0]) * ( boxes[i, 1, 3] - boxes[i, 1, 1]) / (im_info[i, 0] * im_info[i, 1]) subj = torch.tensor([ boxes[i, 1, 0] / im_info[i, 0], boxes[i, 1, 1] / im_info[i, 1], boxes[i, 1, 2] / im_info[i, 0], boxes[i, 1, 3] / im_info[i, 1], area_subj_ratio ]) area_pred_ratio = (boxes[i, 2, 2] - boxes[i, 2, 0]) * ( boxes[i, 2, 3] - boxes[i, 2, 1]) / (im_info[i, 0] * im_info[i, 1]) w_s = (boxes[i, 1, 2] - boxes[i, 1, 0]) h_s = (boxes[i, 1, 3] - boxes[i, 1, 1]) x_s = (boxes[i, 1, 2] + boxes[i, 1, 0]) / 2 y_s = (boxes[i, 1, 3] + boxes[i, 1, 1]) / 2 w_o = (boxes[i, 3, 2] - boxes[i, 3, 0]) h_o = (boxes[i, 3, 3] - boxes[i, 3, 1]) x_o = (boxes[i, 3, 2] + boxes[i, 3, 0]) / 2 y_o = (boxes[i, 3, 3] + boxes[i, 3, 1]) / 2 pred = torch.tensor([(x_s - x_o) / w_o, (y_s - y_o) / h_o, torch.log(w_s / w_o), torch.log(h_s / h_o), (x_o - x_s) / w_s, (y_o - y_s) / h_s, torch.log(w_o / w_s), torch.log(h_o / h_s), area_pred_ratio]) area_obj_ratio = (boxes[i, 3, 2] - boxes[i, 3, 0]) * ( boxes[i, 3, 3] - boxes[i, 3, 1]) / (im_info[i, 0] * im_info[i, 1]) obj = torch.tensor([ boxes[i, 3, 0] / im_info[i, 0], boxes[i, 3, 1] / im_info[i, 1], boxes[i, 3, 2] / im_info[i, 0], boxes[i, 3, 3] / im_info[i, 1], area_obj_ratio ]) spa_feat[0] = torch.cat((subj, pred, obj)).unsqueeze(0) spa_feat = self.loc_fcs(spa_feat) # assert self.spa_concat # Currently coord_vec only works with concatenation! else: for i in range(boxes.shape[0]): boxes[:, :, 0][i] /= im_info[:, 0][i] boxes[:, :, 1][i] /= im_info[:, 1][i] boxes[:, :, 2][i] /= im_info[:, 0][i] boxes[:, :, 3][i] /= im_info[:, 1][i] spa_feat = self.simple_spatial_model(boxes[:, 1], boxes[:, 3], labels) if self.spa_add: hm = hm * ( 1 - self.spa_feat_weight) + spa_feat * self.spa_feat_weight elif self.spa_concat: hm = torch.cat((hm, spa_feat), dim=1) hm = self.spa_fusion_linear(hm) hm = F.relu(hm) hm = self.dropout(hm) if self.spa_one_more_layer: # if no unfrozen VLBERT add one more layer and lower the dropout rate to 0.2 hm = self.spa_linear_hidden(hm) hm = F.relu(hm) hm = self.spa_linear(hm) if self.use_uvtranse: union_vec = obj_reps['obj_reps'][:, 2] - obj_reps[ 'obj_reps'][:, 1] - obj_reps['obj_reps'][:, 3] # pred - subj - obj union_vec = self.union_vec_fc(union_vec) union_vec = F.relu(union_vec) if self.uvt_add: hm = hm * (1 - self.uvt_feat_weight ) + union_vec * self.uvt_feat_weight elif self.uvt_concat: hm = torch.cat((hm, union_vec), dim=1) hm = self.uvt_fusion_linear(hm) hm = F.relu(hm) hm = self.dropout_uvt(hm) hm = self.uvt_linear(hm) ########################################### outputs = {} # classifier logits = self.final_mlp(hm) logits = F.softmax(logits, dim=1) # mask loss if spo_fused_masks is not None: nb_of_tokens = 2 if self.no_predicate else 3 spo_fused_masks = spo_fused_masks.view(-1, nb_of_tokens, 14, 14) boxes_mask = boxes.new_zeros(size=(boxes.shape[0], nb_of_tokens, 14, 14)) rounded_14x14_boxes = torch.round(boxes * 14).to(torch.int) for i in range(boxes.shape[0]): # for each sample for j in range(nb_of_tokens): # sub, pred, obj # Create a mask boxes_mask[i, j, rounded_14x14_boxes[ i, j + 1, 0].item():rounded_14x14_boxes[i, j + 1, 2].item(), rounded_14x14_boxes[ i, j + 1, 1].item():rounded_14x14_boxes[i, j + 1, 3].item()] = 1 if self.mask_loss_sum: mask_loss = F.binary_cross_entropy_with_logits( spo_fused_masks, boxes_mask, reduction='sum') / spo_fused_masks.shape[0] elif self.mask_loss_mse: mask_loss = F.mse_loss(spo_fused_masks, boxes_mask) # import pdb; pdb.set_trace() # self.show_cam_on_image(spo_fused_masks, img_path) else: mask_loss = F.binary_cross_entropy_with_logits( spo_fused_masks, boxes_mask) outputs.update({ 'label_logits': logits, 'label': labels, 'labels_so_ids': labels_so_ids, 'rels_cand': rels_cand, 'mask_loss': mask_loss, 'img_path': img_path, 'spo_fused_masks': spo_fused_masks, 'subj_obj_classes': subj_obj_classes, 'prediction': logits.argmax(1) }) else: outputs.update({ 'label_logits': logits, 'label': labels, 'labels_so_ids': labels_so_ids, 'rels_cand': rels_cand, 'prediction': logits.argmax(1) }) return outputs
class ResNetVLBERTv4(Module): def __init__(self, config): super(ResNetVLBERTv4, self).__init__(config) self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS if not config.NETWORK.BLIND: self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=self.enable_cnn_reg_loss) if config.NETWORK.VLBERT.object_word_embed_mode == 1: self.object_linguistic_embeddings = nn.Embedding( 601, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 2: self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 3: self.object_linguistic_embeddings = None else: raise NotImplementedError self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path self.language_pretrained_model_path = language_pretrained_model_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path) # self.hm_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) # self.hi_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) dim = config.NETWORK.VLBERT.hidden_size if config.NETWORK.CLASSIFIER_TYPE == "2fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, config.NETWORK.CLASSIFIER_CLASS), ) elif config.NETWORK.CLASSIFIER_TYPE == "1fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_CLASS)) elif config.NETWORK.CLASSIFIER_TYPE == 'mlm': transform = BertPredictionHeadTransform(config.NETWORK.VLBERT) linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.CLASSIFIER_CLASS) self.final_mlp = nn.Sequential( transform, nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), linear) else: raise ValueError("Not support classifier type: {}!".format( config.NETWORK.CLASSIFIER_TYPE)) # init weights self.init_weight() self.fix_params() def init_weight(self): # self.hm_out.weight.data.normal_(mean=0.0, std=0.02) # self.hm_out.bias.data.zero_() # self.hi_out.weight.data.normal_(mean=0.0, std=0.02) # self.hi_out.bias.data.zero_() self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_(mean=0.0, std=0.02) for m in self.final_mlp.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) if self.config.NETWORK.CLASSIFIER_TYPE == 'mlm': language_pretrained = torch.load( self.language_pretrained_model_path) mlm_transform_state_dict = {} pretrain_keys = [] for k, v in language_pretrained.items(): if k.startswith('cls.predictions.transform.'): pretrain_keys.append(k) k_ = k[len('cls.predictions.transform.'):] if 'gamma' in k_: k_ = k_.replace('gamma', 'weight') if 'beta' in k_: k_ = k_.replace('beta', 'bias') mlm_transform_state_dict[k_] = v print("loading pretrained classifier transform keys: {}.".format( pretrain_keys)) self.final_mlp[0].load_state_dict(mlm_transform_state_dict) def train(self, mode=True): super(ResNetVLBERTv4, self).train(mode) # turn some frozen layers to eval mode if self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): pass def _collect_obj_reps(self, span_tags, object_reps): """ Collect span-level object representations :param span_tags: [batch_size, ..leading_dims.., L] :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] :return: """ span_tags_fixed = torch.clamp( span_tags, min=0) # In case there were masked values here row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape) row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None] # Add extra diminsions to the row broadcaster so it matches row_id leading_dims = len(span_tags.shape) - 2 for i in range(leading_dims): row_id_broadcaster = row_id_broadcaster[..., None] row_id += row_id_broadcaster object_select = object_reps[row_id.view(-1), span_tags_fixed.view(-1)] return object_select.view(*span_tags_fixed.shape, -1) def prepare_text(self, question, question_tags, question_mask): batch_size, max_q_len = question.shape max_len = question_mask.sum(1).max() + 2 cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) q_end = 1 + question_mask.sum(1, keepdim=True) input_ids = torch.zeros((batch_size, max_len), dtype=question.dtype, device=question.device) input_mask = torch.ones((batch_size, max_len), dtype=torch.bool, device=question.device) input_type_ids = torch.zeros((batch_size, max_len), dtype=question.dtype, device=question.device) text_tags = input_type_ids.new_zeros((batch_size, max_len)) grid_i, grid_j = torch.meshgrid( torch.arange(batch_size, device=question.device), torch.arange(max_len, device=question.device)) input_mask[grid_j > q_end] = 0 # input_type_ids[(grid_j > q_end) & (grid_j <= a_end)] = 1 q_input_mask = (grid_j > 0) & (grid_j < q_end) sep_idx = (question == sep_id).nonzero() for index in sep_idx: input_type_ids[index[0], index[1] + 1:] = self.config.NETWORK.VLBERT.visual_tag_type input_ids[:, 0] = cls_id input_ids[grid_j == q_end] = sep_id input_ids[q_input_mask] = question[question_mask] text_tags[q_input_mask] = question_tags[question_mask] return input_ids, input_type_ids, text_tags, input_mask def train_forward( self, image, boxes, im_info, text, img_boxes, text_tags, label: torch.Tensor, *sample_id_and_more, loss_fn=F.binary_cross_entropy_with_logits, ): ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) # NOTE: clip_pad_boxes(pad=-2) max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] objects = boxes[:, :max_len, 4] boxes = boxes[:, :max_len, :4] obj_and_img_boxes = torch.cat([img_boxes, boxes], axis=1) _box_mask = (obj_and_img_boxes[:, :, 0] > -1.5 ) # NOTE: clip_pad_boxes(pad=-2) obj_reps = self.image_feature_extractor(images=images, boxes=obj_and_img_boxes, box_mask=_box_mask, im_info=im_info, classes=None, segms=None) img_block_reps = obj_reps["obj_reps"][:, :img_boxes.shape[1], :] obj_reps["obj_reps"] = obj_reps["obj_reps"][:, img_boxes.shape[1]:, :] if self.config.NETWORK.IMAGE_FROZEN_BACKBONE_ALL: obj_reps = {k: v.detach() for k, v in obj_reps.items()} text_ids = text # text_tags = text.new_zeros(text_ids.shape) text_mask = (text > 0.5) ############################################ # prepare text text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text( text_ids, text_tags, text_mask) if self.config.NETWORK.NO_GROUNDING: obj_rep_zeroed = obj_reps['obj_reps'].new_zeros( obj_reps['obj_reps'].shape) text_tags.zero_() text_visual_embeddings = self._collect_obj_reps( text_tags, obj_rep_zeroed) else: text_visual_embeddings = self._collect_obj_reps( text_tags, torch.cat([img_block_reps, obj_reps['obj_reps']], dim=1)) assert self.config.NETWORK.VLBERT.object_word_embed_mode in [1, 2] if self.config.NETWORK.VLBERT.object_word_embed_mode == 1: object_linguistic_embeddings = self.object_linguistic_embeddings( objects.long().clamp( min=0, max=self.object_linguistic_embeddings.weight.data.shape[0] - 1)) else: object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT hidden_states, hc = self.vlbert(text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_all_encoded_layers=False) _batch_inds = torch.arange(text.shape[0], device=text.device) hm = hidden_states[_batch_inds, 0] # hm = F.tanh(self.hm_out(hidden_states[_batch_inds, ans_pos])) # hi = F.tanh(self.hi_out(hidden_states[_batch_inds, ans_pos + 2])) ########################################### outputs = {} # classifier # logits = self.final_mlp(hc * hm * hi) # logits = self.final_mlp(hc) logits = self.final_mlp(hm) if not self.config.NETWORK.CLASSIFIER_SIGMOID: if label.ndim == 2: label = label.squeeze(1) label = label.long() else: if label.ndim == 1: label = label.unsqueeze(1) # loss ans_loss = loss_fn(logits, label) outputs.update({ 'label_logits': logits, 'label': label, 'ans_loss': ans_loss }) loss = ans_loss.mean() return outputs, loss def _inference_forward(self, image, boxes, im_info, text, *args): ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] objects = boxes[:, :max_len, 4] boxes = boxes[:, :max_len, :4] obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) text_ids = text text_tags = text.new_zeros(text_ids.shape) text_mask = (text > 0.5) ############################################ # prepare text text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text( text_ids, text_tags, text_mask) if self.config.NETWORK.NO_GROUNDING: obj_rep_zeroed = obj_reps['obj_reps'].new_zeros( obj_reps['obj_reps'].shape) text_tags.zero_() text_visual_embeddings = self._collect_obj_reps( text_tags, obj_rep_zeroed) else: text_visual_embeddings = self._collect_obj_reps( text_tags, obj_reps['obj_reps']) assert self.config.NETWORK.VLBERT.object_word_embed_mode in [1, 2] if self.config.NETWORK.VLBERT.object_word_embed_mode == 1: object_linguistic_embeddings = self.object_linguistic_embeddings( objects.long().clamp( min=0, max=self.object_linguistic_embeddings.weight.data.shape[0] - 1)) else: object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT hidden_states, hc = self.vlbert(text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_all_encoded_layers=False) _batch_inds = torch.arange(text.shape[0], device=text.device) hm = hidden_states[_batch_inds, 0] # hm = F.tanh(self.hm_out(hidden_states[_batch_inds, ans_pos])) # hi = F.tanh(self.hi_out(hidden_states[_batch_inds, ans_pos + 2])) ########################################### outputs = {} # classifier # logits = self.final_mlp(hc * hm * hi) # logits = self.final_mlp(hc) logits = self.final_mlp(hm) outputs.update({'label_logits': logits}) return outputs def inference_forward(self, image, boxes, im_info, text, img_boxes, text_tags, *args): ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] objects = boxes[:, :max_len, 4] boxes = boxes[:, :max_len, :4] obj_and_img_boxes = torch.cat([img_boxes, boxes], axis=1) _box_mask = (obj_and_img_boxes[:, :, 0] > -1.5 ) # NOTE: clip_pad_boxes(pad=-2) obj_reps = self.image_feature_extractor(images=images, boxes=obj_and_img_boxes, box_mask=_box_mask, im_info=im_info, classes=None, segms=None) img_block_reps = obj_reps["obj_reps"][:, :img_boxes.shape[1], :] obj_reps["obj_reps"] = obj_reps["obj_reps"][:, img_boxes.shape[1]:, :] text_ids = text # text_tags = text.new_zeros(text_ids.shape) text_mask = (text > 0.5) ############################################ # prepare text text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text( text_ids, text_tags, text_mask) if self.config.NETWORK.NO_GROUNDING: obj_rep_zeroed = obj_reps['obj_reps'].new_zeros( obj_reps['obj_reps'].shape) text_tags.zero_() text_visual_embeddings = self._collect_obj_reps( text_tags, obj_rep_zeroed) else: text_visual_embeddings = self._collect_obj_reps( text_tags, torch.cat([img_block_reps, obj_reps['obj_reps']], dim=1)) assert self.config.NETWORK.VLBERT.object_word_embed_mode in [1, 2] if self.config.NETWORK.VLBERT.object_word_embed_mode == 1: object_linguistic_embeddings = self.object_linguistic_embeddings( objects.long().clamp( min=0, max=self.object_linguistic_embeddings.weight.data.shape[0] - 1)) else: object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT hidden_states, hc = self.vlbert(text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_all_encoded_layers=False) _batch_inds = torch.arange(text.shape[0], device=text.device) hm = hidden_states[_batch_inds, 0] # hm = F.tanh(self.hm_out(hidden_states[_batch_inds, ans_pos])) # hi = F.tanh(self.hi_out(hidden_states[_batch_inds, ans_pos + 2])) ########################################### outputs = {} # classifier # logits = self.final_mlp(hc * hm * hi) # logits = self.final_mlp(hc) logits = self.final_mlp(hm) outputs.update({'label_logits': logits}) return outputs
class ResNetVLBERT(Module): def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=False) self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path self.language_pretrained_model_path = language_pretrained_model_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path) transform = VisualLinguisticBertMVRCHeadTransform( config.NETWORK.VLBERT) # self.linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, 768) #331 1000 35 100 12003 lihui # self.OIM_loss = OIM_Module(331, 768) # config.NETWORK.VLBERT.hidden_size) self.OIM_loss = OIM_Module(12003, 768) self.linear = nn.Sequential( # transform, nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), nn.Linear(config.NETWORK.VLBERT.hidden_size, 768) #331 1000 35 100 12003 lihui ) linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, 1) self.final_mlp = nn.Sequential( transform, nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), linear) # init weights self.init_weight() self.fix_params() # self.embeddings_word = torch.nn.Conv1d(in_channels=40, out_channels=1, kernel_size=1) # self.embeddings_box = torch.nn.Conv1d(in_channels=6, out_channels=1, kernel_size=1) # self.line_cls = nn.utils.weight_norm(nn.Linear(config.NETWORK.VLBERT.hidden_size, 1000), name='weight') #12003 def init_weight(self): self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) for m in self.final_mlp.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) def train(self, mode=True): super(ResNetVLBERT, self).train(mode) # turn some frozen layers to eval mode if self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): pass def train_forward( self, image, boxes, im_info, expression, label, ): ########################################### # visual feature extraction batch_size = image.size(0) num_options = image.size(1) image = image.view(-1, image.size(2), image.size(3), image.size(4)) boxes = boxes.view(-1, boxes.size(2), boxes.size(3)) #boxes = boxes im_info = im_info.view(-1, im_info.size(2)) expression = expression.view(-1, expression.size(2)) images = image box_mask = (boxes[:, :, 0] > -1.5) # max_len = int(box_mask.sum(1).max().item()) # origin_len = boxes.shape[1] # box_mask = box_mask[:, :max_len] # boxes = boxes[:, :max_len] # label = label[:, :max_len] obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) ############################################ # prepare text cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) text_input_ids = expression.new_zeros( (expression.shape[0], expression.shape[1] + 2)) text_input_ids[:, 0] = cls_id text_input_ids[:, 1:-1] = expression _sep_pos = (text_input_ids > 0).sum(1) _batch_inds = torch.arange(expression.shape[0], device=expression.device) text_input_ids[_batch_inds, _sep_pos] = sep_id text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape) text_mask = text_input_ids > 0 text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze( 1).repeat((1, text_input_ids.shape[1], 1)) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT _, pooled_output = self.vlbert(text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_all_encoded_layers=False, output_text_and_object_separately=False) ########################################### outputs = {} # classifier #logits = self.final_mlp(pooled_output) ''' logits = self.linear(pooled_output) # vil_logit = logits.view(batch_size, num_options) score_OIM = self.OIM_loss(logits, label.view(-1)) loss_c = nn.CrossEntropyLoss(ignore_index=-1) cmpc_loss = loss_c(F.softmax(score_OIM, dim=1)*10, label.view(-1))# + criterion(text_logits, label_text) # cmpc_loss = loss_c(logits, label.view(-1)) cls_pred = torch.argmax(score_OIM, dim=1) cls_precision = torch.mean((cls_pred[label.view(-1) != -1] == label.view(-1)[label.view(-1) != -1]).float()) return cls_precision, cmpc_loss ''' # loss logits = self.final_mlp(pooled_output) vil_logit = logits.view(batch_size, num_options) loss = nn.CrossEntropyLoss(ignore_index=-1) cls_loss = loss(vil_logit, torch.zeros(batch_size).long().cuda()) _, preds = torch.max(vil_logit, 1) batch_score = float((preds == torch.zeros(batch_size).long().cuda() ).sum()) / float(batch_size) return batch_score, cls_loss def inference_forward(self, image, boxes, im_info, expression, label, feat=None): ########################################### # visual feature extraction batch_size = boxes.size(0) num_options = boxes.size(1) if feat is None: image = image.view(-1, image.size(2), image.size(3), image.size(4)) boxes = boxes.view(-1, boxes.size(2), boxes.size(3)) im_info = im_info.view(-1, im_info.size(2)) images = image box_mask = (boxes[:, :, 0] > -1.5) # max_len = int(box_mask.sum(1).max().item()) # origin_len = boxes.shape[1] # box_mask = box_mask[:, :max_len] # boxes = boxes[:, :max_len] obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) else: boxes = boxes.view(-1, boxes.size(2), boxes.size(3)) box_mask = (boxes[:, :, 0] > -1.5) # obj_reps = feat ############################################ # prepare text expression = expression.view(-1, expression.size(2)) cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) text_input_ids = expression.new_zeros( (expression.shape[0], expression.shape[1] + 2)) text_input_ids[:, 0] = cls_id text_input_ids[:, 1:-1] = expression _sep_pos = (text_input_ids > 0).sum(1) _batch_inds = torch.arange(expression.shape[0], device=expression.device) text_input_ids[_batch_inds, _sep_pos] = sep_id text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape) text_mask = text_input_ids > 0 if feat is None: text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze( 1).repeat((1, text_input_ids.shape[1], 1)) #text_visual_embeddings = feat[:, 0].unsqueeze(1).repeat((1, text_input_ids.shape[1], 1)) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) #object_vl_embeddings = torch.cat((feat, object_linguistic_embeddings), -1) else: # text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze(1).repeat((1, text_input_ids.shape[1], 1)) text_visual_embeddings = feat[:, 0].unsqueeze(1).repeat( (1, text_input_ids.shape[1], 1)) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) # object_vl_embeddings = torch.cat((obj_reps['obj_reps'], object_linguistic_embeddings), -1) object_vl_embeddings = torch.cat( (feat, object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT encoded_layers, pooled_output, att = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_all_encoded_layers=False, output_text_and_object_separately=False, output_attention_probs=True) ########################################### outputs = {} # classifier logits = self.final_mlp(pooled_output) #.squeeze(-1) # loss vil_logit = logits.view(batch_size, num_options) _, preds = torch.max(vil_logit, 1) return att, logits def compute_cmpc_loss(self, image_embeddings, text_embeddings, labels): """ Cross-Modal Projection Classfication loss(CMPC) :param image_embeddings: Tensor with dtype torch.float32 :param text_embeddings: Tensor with dtype torch.float32 :param labels: Tensor with dtype torch.int32 :return: """ criterion = nn.CrossEntropyLoss() # labels_onehot = one_hot_coding(labels, self.num_classes).float() # image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True) # text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True) # image_proj_text = torch.sum(image_embeddings * text_norm, dim=1, keepdim=True) * text_norm # text_proj_image = torch.sum(text_embeddings * image_norm, dim=1, keepdim=True) * image_norm image_logits = image_embeddings #self.line_cls(image_embeddings) text_logits = text_embeddings #self.line_cls(text_embeddings) label_img = labels[:, 1, :].contiguous().view(-1) label_text = labels[:, 0, :].contiguous().view(-1) cmpc_loss = criterion( image_logits, label_img) # + criterion(text_logits, label_text) # cmpc_loss = - (F.log_softmax(image_logits, dim=1) + F.log_softmax(text_logits, dim=1)) * labels_onehot # cmpc_loss = torch.mean(torch.sum(cmpc_loss, dim=1)) # classification accuracy for observation image_pred = torch.argmax(image_logits, dim=1) text_pred = torch.argmax(text_logits, dim=1) image_precision = torch.mean((image_pred == label_img).float()) text_precision = torch.mean((text_pred == label_text).float()) return cmpc_loss, image_precision, text_precision
class ResNetVLBERT(Module): def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=False) self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path self.language_pretrained_model_path = language_pretrained_model_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path) transform = VisualLinguisticBertMVRCHeadTransform( config.NETWORK.VLBERT) linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, 1) self.final_mlp = nn.Sequential( transform, nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), linear) # init weights self.init_weight() self.fix_params() def init_weight(self): self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) for m in self.final_mlp.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) def train(self, mode=True): super(ResNetVLBERT, self).train(mode) # turn some frozen layers to eval mode if self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): pass def train_forward( self, image, boxes, im_info, expression, label, ): ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) max_len = int(box_mask.sum(1).max().item()) origin_len = boxes.shape[1] box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] label = label[:, :max_len] obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) ############################################ # prepare text cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) text_input_ids = expression.new_zeros( (expression.shape[0], expression.shape[1] + 2)) text_input_ids[:, 0] = cls_id text_input_ids[:, 1:-1] = expression _sep_pos = (text_input_ids > 0).sum(1) _batch_inds = torch.arange(expression.shape[0], device=expression.device) text_input_ids[_batch_inds, _sep_pos] = sep_id text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape) text_mask = text_input_ids > 0 text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze( 1).repeat((1, text_input_ids.shape[1], 1)) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT hidden_states_text, hidden_states_regions, _ = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_all_encoded_layers=False, output_text_and_object_separately=True) ########################################### outputs = {} # classifier logits = self.final_mlp(hidden_states_regions).squeeze(-1) # loss cls_loss = F.binary_cross_entropy_with_logits(logits[box_mask], label[box_mask]) # pad back to origin len for compatibility with DataParallel logits_ = logits.new_zeros( (logits.shape[0], origin_len)).fill_(-10000.0) logits_[:, :logits.shape[1]] = logits logits = logits_ label_ = label.new_zeros((logits.shape[0], origin_len)).fill_(-1) label_[:, :label.shape[1]] = label label = label_ outputs.update({ 'label_logits': logits, 'label': label, 'cls_loss': cls_loss }) loss = cls_loss.mean() return outputs, loss def inference_forward(self, image, boxes, im_info, expression): ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) max_len = int(box_mask.sum(1).max().item()) origin_len = boxes.shape[1] box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) ############################################ # prepare text cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) text_input_ids = expression.new_zeros( (expression.shape[0], expression.shape[1] + 2)) text_input_ids[:, 0] = cls_id text_input_ids[:, 1:-1] = expression _sep_pos = (text_input_ids > 0).sum(1) _batch_inds = torch.arange(expression.shape[0], device=expression.device) text_input_ids[_batch_inds, _sep_pos] = sep_id text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape) text_mask = text_input_ids > 0 text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze( 1).repeat((1, text_input_ids.shape[1], 1)) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT hidden_states_text, hidden_states_regions, _ = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_all_encoded_layers=False, output_text_and_object_separately=True) ########################################### outputs = {} # classifier logits = self.final_mlp(hidden_states_regions).squeeze(-1) # pad back to origin len for compatibility with DataParallel logits_ = logits.new_zeros( (logits.shape[0], origin_len)).fill_(-10000.0) logits_[:, :logits.shape[1]] = logits logits = logits_ w_ratio = im_info[:, 2] h_ratio = im_info[:, 3] pred_boxes = boxes[_batch_inds, logits.argmax(1), :4] pred_boxes[:, [0, 2]] /= w_ratio.unsqueeze(1) pred_boxes[:, [1, 3]] /= h_ratio.unsqueeze(1) outputs.update({'label_logits': logits, 'pred_boxes': pred_boxes}) return outputs
class ResNetVLBERT(Module): def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=False) self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path self.language_pretrained_model_path = language_pretrained_model_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path) self.task1_head = Task1Head(config.NETWORK.VLBERT) self.task2_head = Task2Head(config.NETWORK.VLBERT) self.task3_head = Task3Head(config.NETWORK.VLBERT) # init weights self.init_weight() self.fix_params() def init_weight(self): self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) for m in self.task1_head.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) for m in self.task2_head.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) for m in self.task3_head.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) def train(self, mode=True): super(ResNetVLBERT, self).train(mode) # turn some frozen layers to eval mode if self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): for param in self.image_feature_extractor.parameters(): param.requires_grad = False for param in self.vlbert.parameters(): param.requires_grad = False for param in self.object_linguistic_embeddings.parameters(): param.requires_grad def train_forward(self, image, boxes, im_info, expression, label, pos, target, mask): ########################################### if self.vlbert.training: self.vlbert.eval() if self.image_feature_extractor.training: self.image_feature_extractor.eval() # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) max_len = int(box_mask.sum(1).max().item()) origin_len = boxes.shape[1] box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] #our labels for foil are binary and 1 dimension #label = label[:, :max_len] obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) ############################################ # prepare text cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) text_input_ids = expression.new_zeros( (expression.shape[0], expression.shape[1] + 2)) text_input_ids[:, 0] = cls_id text_input_ids[:, 1:-1] = expression _sep_pos = (text_input_ids > 0).sum(1) _batch_inds = torch.arange(expression.shape[0], device=expression.device) text_input_ids[_batch_inds, _sep_pos] = sep_id text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape) text_mask = text_input_ids > 0 text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze( 1).repeat((1, text_input_ids.shape[1], 1)) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT sequence_reps, object_reps, _ = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_text_and_object_separately=True, output_all_encoded_layers=False) cls_rep = sequence_reps[:, 0, :].squeeze(1) sequence_reps = sequence_reps[:, 1:, :] cls_log_probs = self.task1_head(cls_rep) pos_log_probs = self.task2_head(sequence_reps, text_mask[:, 1:]) zeros = torch.zeros_like(text_input_ids) mask_len = mask.shape[1] error_cor_mask = zeros error_cor_mask[:, 1:mask_len + 1] = mask error_cor_mask = error_cor_mask.bool() text_input_ids[error_cor_mask] = self.tokenizer.convert_tokens_to_ids( ['[MASK]'])[0] sequence_reps, _, _ = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_text_and_object_separately=True, output_all_encoded_layers=False) sequence_reps = sequence_reps[:, 1:, :] select_index = pos.view(-1, 1).unsqueeze(2).repeat(1, 1, 768) select_index[select_index < 0] = 0 masked_reps = torch.gather(sequence_reps, 1, select_index).squeeze(1) cor_log_probs = self.task3_head(masked_reps) loss_mask = label.view(-1).float() cls_loss = F.binary_cross_entropy(cls_log_probs.view(-1), label.view(-1).float(), reduction="none") pos_loss = F.nll_loss( pos_log_probs, pos, ignore_index=-1, reduction="none").view(-1) * loss_mask cor_loss = F.nll_loss(cor_log_probs.view(-1, cor_log_probs.shape[-1]), target, ignore_index=0, reduction="none").view(-1) * loss_mask loss = cls_loss.mean() + pos_loss.mean() + cor_loss.mean() outputs = { "cls_logits": cls_log_probs, "pos_logits": pos_log_probs, "cor_logits": cor_log_probs, "cls_label": label, "pos_label": pos, "cor_label": target } return outputs, loss def inference_forward(self, image, boxes, im_info, expression, label, pos, target): # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) max_len = int(box_mask.sum(1).max().item()) origin_len = boxes.shape[1] box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] #our labels for foil are binary and 1 dimension #label = label[:, :max_len] with torch.no_grad(): obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) ############################################ # prepare text cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) text_input_ids = expression.new_zeros( (expression.shape[0], expression.shape[1] + 2)) text_input_ids[:, 0] = cls_id text_input_ids[:, 1:-1] = expression _sep_pos = (text_input_ids > 0).sum(1) _batch_inds = torch.arange(expression.shape[0], device=expression.device) text_input_ids[_batch_inds, _sep_pos] = sep_id text_token_type_ids = text_input_ids.new_zeros( text_input_ids.shape) text_mask = text_input_ids > 0 text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze( 1).repeat((1, text_input_ids.shape[1], 1)) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT sequence_reps, object_reps, _ = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_text_and_object_separately=True, output_all_encoded_layers=False) cls_rep = sequence_reps[:, 0, :].squeeze(1) sequence_reps = sequence_reps[:, 1:, :] cls_log_probs = self.task1_head(cls_rep) pos_log_probs = self.task2_head(sequence_reps, text_mask[:, 1:]) zeros = torch.zeros_like(text_input_ids) mask_len = mask.shape[1] error_cor_mask = zeros error_cor_mask[:, 1:mask_len + 1] = mask error_cor_mask = error_cor_mask.bool() text_input_ids[error_cor_mask] = self.tokenizer.convert_tokens_to_ids( ['[MASK]'])[0] with torch.no_grad(): sequence_reps, _, _ = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_text_and_object_separately=True, output_all_encoded_layers=False) sequence_reps = sequence_reps[:, 1:, :] select_index = pos.view(-1, 1).unsqueeze(2).repeat(1, 1, 768) select_index[select_index < 0] = 0 masked_reps = torch.gather(sequence_reps, 1, select_index).squeeze(1) cor_log_probs = self.task3_head(masked_reps) loss_mask = label.view(-1).float() cls_loss = F.binary_cross_entropy(cls_log_probs.view(-1), label.view(-1).float(), reduction="none") pos_loss = F.nll_loss( pos_log_probs, pos, ignore_index=-1, reduction="none").view(-1) * loss_mask cor_loss = F.nll_loss(cor_log_probs.view(-1, cor_log_probs.shape[-1]), target, ignore_index=0, reduction="none").view(-1) * loss_mask loss = cls_loss.mean() + pos_loss.mean() + cor_loss.mean() outputs = { "cls_logits": cls_log_probs, "pos_logits": pos_log_probs, "cor_logits": cor_log_probs, "cls_label": label, "pos_label": pos, "cor_label": target } return outputs, loss
class ResNetVLBERT(Module): def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS self.cnn_loss_top = config.NETWORK.CNN_LOSS_TOP if not config.NETWORK.BLIND: self.image_feature_extractor = FastRCNN(config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=(self.enable_cnn_reg_loss and not self.cnn_loss_top)) if config.NETWORK.VLBERT.object_word_embed_mode == 1: self.object_linguistic_embeddings = nn.Embedding(81, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 2: self.object_linguistic_embeddings = nn.Embedding(1, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 3: self.object_linguistic_embeddings = None else: raise NotImplementedError self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN if 'roberta' in config.NETWORK.BERT_MODEL_NAME: self.tokenizer = RobertaTokenizer.from_pretrained(config.NETWORK.BERT_MODEL_NAME) else: self.tokenizer = BertTokenizer.from_pretrained(config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format(config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path if language_pretrained_model_path is None: print("Warning: no pretrained language model found, training from scratch!!!") self.vlbert = VisualLinguisticBert(config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path) self.for_pretrain = False dim = config.NETWORK.VLBERT.hidden_size if config.NETWORK.SENTENCE.CLASSIFIER_TYPE == "2fc": self.sentence_cls = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.NETWORK.SENTENCE.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(config.NETWORK.SENTENCE.CLASSIFIER_HIDDEN_SIZE, 3), ) elif config.NETWORK.SENTENCE.CLASSIFIER_TYPE == "1fc": self.sentence_cls = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, 3) ) else: raise ValueError("Classifier type: {} not supported!".format(config.NETWORK.SENTENCE.CLASSIFIER_TYPE)) # init weights self.init_weight() self.fix_params() def init_weight(self): if not self.config.NETWORK.BLIND: self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_(mean=0.0, std=0.02) if not self.for_pretrain: for m in self.sentence_cls.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) def train(self, mode=True): super(ResNetVLBERT, self).train(mode) # turn some frozen layers to eval mode if (not self.config.NETWORK.BLIND) and self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): if self.config.NETWORK.BLIND: self.vlbert._module.visual_scale_text.requires_grad = False self.vlbert._module.visual_scale_object.requires_grad = False def _collect_obj_reps(self, span_tags, object_reps): """ Collect span-level object representations :param span_tags: [batch_size, ..leading_dims.., L] :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] :return: """ span_tags_fixed = torch.clamp(span_tags, min=0) # In case there were masked values here row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape) row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None] # Add extra dimensions to the row broadcaster so it matches row_id leading_dims = len(span_tags.shape) - 2 for i in range(leading_dims): row_id_broadcaster = row_id_broadcaster[..., None] row_id += row_id_broadcaster return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view(*span_tags_fixed.shape, -1) def prepare_text(self, sentence, mask): batch_size, max_len = sentence.shape cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(['[CLS]', '[SEP]']) sep_pos = 1 + mask.sum(1, keepdim=True) input_ids = torch.zeros((batch_size, max_len + 2), dtype=sentence.dtype, device=sentence.device) input_ids[:, 0] = cls_id _batch_inds = torch.arange(sentence.shape[0], device=sentence.device) input_ids[_batch_inds, sep_pos] = sep_id input_ids[:, 1:-1] = sentence input_mask = input_ids > 0 return input_ids, input_mask def train_forward(self, images, boxes, hypothesis, im_info, label): ########################################### # visual feature extraction # Don't know what segments are for # segms = masks box_mask = (boxes[:, :, -1] > - 0.5) max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len].type(torch.float32) # segms = segms[:, :max_len] if self.config.NETWORK.BLIND: obj_reps = {'obj_reps': boxes.new_zeros((*boxes.shape[:-1], self.config.NETWORK.IMAGE_FINAL_DIM))} else: obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) # For now no tags mask = (hypothesis > 0.5) sentence_label = label.view(-1) ############################################ # prepare text text_input_ids, text_mask = self.prepare_text(hypothesis, mask) text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape) # Add visual feature to text elements text_visual_embeddings = self._collect_obj_reps(text_input_ids.new_zeros(text_input_ids.size()), obj_reps['obj_reps']) # Add textual feature to image element if self.config.NETWORK.BLIND: object_linguistic_embeddings = boxes.new_zeros((*boxes.shape[:-1], self.config.NETWORK.VLBERT.hidden_size)) else: object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat((obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT if self.config.NETWORK.NO_OBJ_ATTENTION or self.config.NETWORK.BLIND: box_mask.zero_() _, pooled_rep = self.vlbert(text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_all_encoded_layers=False, output_text_and_object_separately=False, output_attention_probs=False) ########################################### outputs = {} # sentence classification sentence_logits = self.sentence_cls(pooled_rep).view((-1, 3)) sentence_cls_loss = F.cross_entropy(sentence_logits, sentence_label) outputs.update({'sentence_label_logits': sentence_logits, 'sentence_label': sentence_label.long(), 'sentence_cls_loss': sentence_cls_loss}) loss = sentence_cls_loss.mean() return outputs, loss def inference_forward(self, images, boxes, hypothesis, im_info): ########################################### # visual feature extraction # Don't know what segments are for # segms = masks box_mask = (boxes[:, :, -1] > - 0.5) max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len].type(torch.float32) # segms = segms[:, :max_len] if self.config.NETWORK.BLIND: obj_reps = {'obj_reps': boxes.new_zeros((*boxes.shape[:-1], self.config.NETWORK.IMAGE_FINAL_DIM))} else: obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) # For now no tags mask = (hypothesis > 0.5) ############################################ # prepare text text_input_ids, text_mask = self.prepare_text(hypothesis, mask) text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape) # Add visual feature to text elements text_visual_embeddings = self._collect_obj_reps(text_input_ids.new_zeros(text_input_ids.size()), obj_reps['obj_reps']) # Add textual feature to image element if self.config.NETWORK.BLIND: object_linguistic_embeddings = boxes.new_zeros((*boxes.shape[:-1], self.config.NETWORK.VLBERT.hidden_size)) else: object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat((obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT if self.config.NETWORK.NO_OBJ_ATTENTION or self.config.NETWORK.BLIND: box_mask.zero_() _, pooled_rep = self.vlbert(text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_all_encoded_layers=False, output_text_and_object_separately=False, output_attention_probs=False) ########################################### outputs = {} # sentence classification sentence_logits = self.sentence_cls(pooled_rep).view((-1, 3)) outputs.update({'sentence_label_logits': sentence_logits}) return outputs