def forward(self, text, relationship_label, mlm_labels): ########################################### # Blank out visual feature extraction ############################################ # prepare text text_input_ids = text # creates a text_tags tensor of the same shape as text tensor text_tags = text.new_zeros(text.shape) # ***** FM edit: blank out visual embeddings for translation retrieval task text_visual_embeddings = text_input_ids.new_zeros( (text_input_ids.shape[0], text_input_ids.shape[1], 768), dtype=torch.float) # text_visual_embeddings[:] = self.aux_text_visual_embedding.weight[0] # ****** FM edit: blank visual embeddings (use known dimensions) object_vl_embeddings = text_input_ids.new_zeros( (text_input_ids.shape[0], 1, 1536), dtype=torch.float) # FM edit: No auxiliary text is used for text only # add auxiliary text - Concatenates the batches from the two dataloaders # The visual features for the text only corpus is just the embedding of the aux_visual_embedding (only one embedding) max_text_len = text_input_ids.shape[1] text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape) text_mask = (text_input_ids > 0) #FM: Edit: set to zero to ignore vision box_mask = text_input_ids.new_zeros((text_input_ids.shape[0], 1), dtype=torch.uint8) ########################################### # Visual Linguistic BERT relationship_logits, mlm_logits, mvrc_logits = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask) ########################################### outputs = {} # losses if self.config.NETWORK.WITH_REL_LOSS: relationship_loss = F.cross_entropy(relationship_logits, relationship_label) if self.config.NETWORK.WITH_MLM_LOSS: mlm_logits_padded = mlm_logits.new_zeros( (*mlm_labels.shape, mlm_logits.shape[-1])).fill_(-10000.0) mlm_logits_padded[:, :mlm_logits.shape[1]] = mlm_logits mlm_logits = mlm_logits_padded if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST: mlm_loss = F.cross_entropy(mlm_logits.transpose(1, 2), mlm_labels, ignore_index=-1, reduction='none') num_mlm = (mlm_labels != -1).sum( 1, keepdim=True).to(dtype=mlm_loss.dtype) num_has_mlm = (num_mlm != 0).sum().to(dtype=mlm_loss.dtype) mlm_loss = (mlm_loss / (num_mlm + 1e-4)).sum() / (num_has_mlm + 1e-4) else: mlm_loss = F.cross_entropy(mlm_logits.view( (-1, mlm_logits.shape[-1])), mlm_labels.view(-1), ignore_index=-1) if self.config.NETWORK.WITH_MVRC_LOSS: if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]), reduction='none').view(mvrc_logits.shape[:-1]) valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1 mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \ .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4) else: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1])) mvrc_logits_padded = mvrc_logits.new_zeros( (mvrc_logits.shape[0], origin_len, mvrc_logits.shape[2])).fill_(-10000.0) mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits mvrc_logits = mvrc_logits_padded mvrc_labels_padded = mvrc_labels.new_zeros( (mvrc_labels.shape[0], origin_len, mvrc_labels.shape[2])).fill_(0.0) mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels mvrc_labels = mvrc_labels_padded outputs.update({ 'relationship_logits': relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None, 'relationship_label': relationship_label if self.config.NETWORK.WITH_REL_LOSS else None, 'mlm_logits': mlm_logits if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_label': mlm_labels if self.config.NETWORK.WITH_MLM_LOSS else None, 'mvrc_logits': mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None, 'mvrc_label': mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None, 'mlm_loss': mlm_loss, }) loss = mlm_loss.mean() return outputs, loss
def forward(self, image, boxes, im_info, text, relationship_label, mlm_labels, mvrc_ops, mvrc_labels, *aux): # concat aux texts from different dataset assert len(aux) > 0 and len(aux) % 2 == 0 aux_text_list = aux[0::2] aux_text_mlm_labels_list = aux[1::2] num_aux_text = sum([_text.shape[0] for _text in aux_text_list]) max_aux_text_len = max([_text.shape[1] for _text in aux_text_list]) aux_text = aux_text_list[0].new_zeros((num_aux_text, max_aux_text_len)) aux_text_mlm_labels = aux_text_mlm_labels_list[0].new_zeros((num_aux_text, max_aux_text_len)).fill_(-1) _cur = 0 for _text, _mlm_labels in zip(aux_text_list, aux_text_mlm_labels_list): _num = _text.shape[0] aux_text[_cur:(_cur + _num), :_text.shape[1]] = _text aux_text_mlm_labels[_cur:(_cur + _num), :_text.shape[1]] = _mlm_labels _cur += _num ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) origin_len = boxes.shape[1] max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] mvrc_ops = mvrc_ops[:, :max_len] mvrc_labels = mvrc_labels[:, :max_len] if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED: box_features = boxes[:, :, 4:] box_features[mvrc_ops == 1] = self.object_mask_visual_embedding.weight[0] boxes[:, :, 4:] = box_features obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None, mvrc_ops=mvrc_ops, mask_visual_embed=self.object_mask_visual_embedding.weight[0] if (not self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED) and (not self.config.NETWORK.MASK_RAW_PIXELS) else None) ############################################ # prepare text text_input_ids = text text_tags = text.new_zeros(text.shape) text_visual_embeddings = self._collect_obj_reps(text_tags, obj_reps['obj_reps']) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long() ) if self.config.NETWORK.WITH_MVRC_LOSS: object_linguistic_embeddings[mvrc_ops == 1] = self.object_mask_word_embedding.weight[0] object_vl_embeddings = torch.cat((obj_reps['obj_reps'], object_linguistic_embeddings), -1) # add auxiliary text max_text_len = max(text_input_ids.shape[1], aux_text.shape[1]) text_input_ids_multi = text_input_ids.new_zeros((text_input_ids.shape[0] + aux_text.shape[0], max_text_len)) text_input_ids_multi[:text_input_ids.shape[0], :text_input_ids.shape[1]] = text_input_ids text_input_ids_multi[text_input_ids.shape[0]:, :aux_text.shape[1]] = aux_text text_token_type_ids_multi = text_input_ids_multi.new_zeros(text_input_ids_multi.shape) text_mask_multi = (text_input_ids_multi > 0) text_visual_embeddings_multi = text_visual_embeddings.new_zeros((text_input_ids.shape[0] + aux_text.shape[0], max_text_len, text_visual_embeddings.shape[-1])) text_visual_embeddings_multi[:text_visual_embeddings.shape[0], :text_visual_embeddings.shape[1]] \ = text_visual_embeddings text_visual_embeddings_multi[text_visual_embeddings.shape[0]:] = self.aux_text_visual_embedding.weight[0] object_vl_embeddings_multi = object_vl_embeddings.new_zeros((text_input_ids.shape[0] + aux_text.shape[0], *object_vl_embeddings.shape[1:])) object_vl_embeddings_multi[:object_vl_embeddings.shape[0]] = object_vl_embeddings box_mask_multi = box_mask.new_zeros((text_input_ids.shape[0] + aux_text.shape[0], *box_mask.shape[1:])) box_mask_multi[:box_mask.shape[0]] = box_mask ########################################### # Visual Linguistic BERT relationship_logits_multi, mlm_logits_multi, mvrc_logits_multi = self.vlbert(text_input_ids_multi, text_token_type_ids_multi, text_visual_embeddings_multi, text_mask_multi, object_vl_embeddings_multi, box_mask_multi) ########################################### outputs = {} # loss relationship_loss = im_info.new_zeros(()) mlm_loss = im_info.new_zeros(()) mvrc_loss = im_info.new_zeros(()) if self.config.NETWORK.WITH_REL_LOSS: relationship_logits = relationship_logits_multi[:text_input_ids.shape[0]] relationship_loss = F.cross_entropy(relationship_logits, relationship_label) if self.config.NETWORK.WITH_MLM_LOSS: mlm_labels_multi = mlm_labels.new_zeros((text_input_ids.shape[0] + aux_text.shape[0], max_text_len)).fill_( -1) mlm_labels_multi[:text_input_ids.shape[0], :mlm_labels.shape[1]] = mlm_labels mlm_labels_multi[text_input_ids.shape[0]:, :aux_text_mlm_labels.shape[1]] = aux_text_mlm_labels mlm_logits_multi_padded = \ mlm_logits_multi.new_zeros((*mlm_labels_multi.shape, mlm_logits_multi.shape[-1])).fill_(-10000.0) mlm_logits_multi_padded[:, :mlm_logits_multi.shape[1]] = mlm_logits_multi mlm_logits_multi = mlm_logits_multi_padded mlm_logits_wvc = mlm_logits_multi_padded[:text_input_ids.shape[0]] mlm_labels_wvc = mlm_labels_multi[:text_input_ids.shape[0]] mlm_logits_aux = mlm_logits_multi_padded[text_input_ids.shape[0]:] mlm_labels_aux = mlm_labels_multi[text_input_ids.shape[0]:] if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST: mlm_loss_wvc = F.cross_entropy(mlm_logits_wvc.transpose(1, 2), mlm_labels_wvc, ignore_index=-1, reduction='none') num_mlm_wvc = (mlm_labels_wvc != -1).sum(1, keepdim=True).to(dtype=mlm_loss_wvc.dtype) num_has_mlm_wvc = (num_mlm_wvc != 0).sum().to(dtype=mlm_loss_wvc.dtype) mlm_loss_wvc = (mlm_loss_wvc / (num_mlm_wvc + 1e-4)).sum() / (num_has_mlm_wvc + 1e-4) mlm_loss_aux = F.cross_entropy(mlm_logits_aux.transpose(1, 2), mlm_labels_aux, ignore_index=-1, reduction='none') num_mlm_aux = (mlm_labels_aux != -1).sum(1, keepdim=True).to(dtype=mlm_loss_aux.dtype) num_has_mlm_aux = (num_mlm_aux != 0).sum().to(dtype=mlm_loss_aux.dtype) mlm_loss_aux = (mlm_loss_aux / (num_mlm_aux + 1e-4)).sum() / (num_has_mlm_aux + 1e-4) else: # mlm_loss = F.cross_entropy(mlm_logits_multi_padded.view((-1, mlm_logits_multi_padded.shape[-1])), # mlm_labels_multi.view(-1), # ignore_index=-1) mlm_loss_wvc = F.cross_entropy( mlm_logits_wvc.view((-1, mlm_logits_multi_padded.shape[-1])), mlm_labels_wvc.view(-1), ignore_index=-1 ) mlm_loss_aux = F.cross_entropy( mlm_logits_aux.view((-1, mlm_logits_multi_padded.shape[-1])), mlm_labels_aux.view(-1), ignore_index=-1 ) # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), # mvrc_labels.contiguous().view(-1), # ignore_index=-1) if self.config.NETWORK.WITH_MVRC_LOSS: mvrc_logits = mvrc_logits_multi[:mvrc_labels.shape[0], :mvrc_labels.shape[1]] if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]), reduction='none').view(mvrc_logits.shape[:-1]) valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1 mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \ .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4) else: mvrc_loss = soft_cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1])) mvrc_logits_padded = mvrc_logits.new_zeros((mvrc_logits.shape[0], origin_len, mvrc_logits.shape[2])).fill_( -10000.0) mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits mvrc_logits = mvrc_logits_padded mvrc_labels_padded = mvrc_labels.new_zeros((mvrc_labels.shape[0], origin_len, mvrc_labels.shape[2])).fill_( 0.0) mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels mvrc_labels = mvrc_labels_padded outputs.update({ 'relationship_logits': relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None, 'relationship_label': relationship_label if self.config.NETWORK.WITH_REL_LOSS else None, 'mlm_logits_wvc': mlm_logits_wvc if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_label_wvc': mlm_labels_wvc if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_logits_aux': mlm_logits_aux if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_label_aux': mlm_labels_aux if self.config.NETWORK.WITH_MLM_LOSS else None, 'mvrc_logits': mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None, 'mvrc_label': mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None, 'relationship_loss': relationship_loss, 'mlm_loss_wvc': mlm_loss_wvc, 'mlm_loss_aux': mlm_loss_aux, 'mvrc_loss': mvrc_loss, }) loss = relationship_loss.mean() + mlm_loss_wvc.mean() + mlm_loss_aux.mean() + mvrc_loss.mean() return outputs, loss
def forward(self, image, boxes, im_info, text, relationship_label, mlm_labels, mvrc_ops, mvrc_labels, word_de_ids): ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) origin_len = boxes.shape[1] max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] mvrc_ops = mvrc_ops[:, :max_len] mvrc_labels = mvrc_labels[:, :max_len] if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED: box_features = boxes[:, :, 4:] box_features[mvrc_ops == 1] = self.object_mask_visual_embedding.weight[0] boxes[:, :, 4:] = box_features obj_reps = self.image_feature_extractor( images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None, mvrc_ops=mvrc_ops, mask_visual_embed=self.object_mask_visual_embedding.weight[0] if (not self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED) and (not self.config.NETWORK.MASK_RAW_PIXELS) else None) ############################################ # prepare text text_input_ids = text # creates a text_tags tensor of the same shape as text tensor text_tags = text.new_zeros(text.shape) text_visual_embeddings = self._collect_obj_reps( text_tags, obj_reps['obj_reps']) # ***** FM edit: blank out visual embeddings for translation retrieval task text_visual_embeddings[:] = self.aux_text_visual_embedding.weight[0] object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) if self.config.NETWORK.WITH_MVRC_LOSS: object_linguistic_embeddings[ mvrc_ops == 1] = self.object_mask_word_embedding.weight[0] object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) # ****** FM edit: blank out all visual embeddings object_vl_embeddings = object_vl_embeddings.new_zeros( object_vl_embeddings.shape) # FM edit: No auxiliary text is used for text only # add auxiliary text - Concatenates the batches from the two dataloaders # The visual features for the text only corpus is just the embedding of the aux_visual_embedding (only one embedding) max_text_len = text_input_ids.shape[1] text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape) text_mask = (text_input_ids > 0) #FM: Edit: i have taken this out, not needed i think since defined above # box_mask = box_mask.new_zeros((text_input_ids.shape[0], *box_mask.shape[1:])) ########################################### # Visual Linguistic BERT relationship_logits_multi, mlm_logits_multi, mvrc_logits_multi, MLT_logits = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask) ########################################### outputs = {} # loss relationship_loss = im_info.new_zeros(()) mlm_loss = im_info.new_zeros(()) mvrc_loss = im_info.new_zeros(()) MLT_loss = im_info.new_zeros(()) if self.config.NETWORK.WITH_REL_LOSS: relationship_logits = relationship_logits_multi[:text_input_ids. shape[0]] # FM edit - change cross_entropy to bce/sigmoid relationship_loss = F.binary_cross_entropy( torch.sigmoid(relationship_logits), relationship_label.unsqueeze(1)) if self.config.NETWORK.WITH_MLM_LOSS: mlm_labels_multi = mlm_labels.new_zeros( (text_input_ids.shape[0] + aux_text.shape[0], max_text_len)).fill_(-1) mlm_labels_multi[:text_input_ids.shape[0], :mlm_labels. shape[1]] = mlm_labels mlm_labels_multi[text_input_ids.shape[0]:, :aux_text_mlm_labels. shape[1]] = aux_text_mlm_labels mlm_logits_multi_padded = \ mlm_logits_multi.new_zeros((*mlm_labels_multi.shape, mlm_logits_multi.shape[-1])).fill_(-10000.0) mlm_logits_multi_padded[:, :mlm_logits_multi. shape[1]] = mlm_logits_multi mlm_logits_multi = mlm_logits_multi_padded mlm_logits_wvc = mlm_logits_multi_padded[:text_input_ids.shape[0]] mlm_labels_wvc = mlm_labels_multi[:text_input_ids.shape[0]] mlm_logits_aux = mlm_logits_multi_padded[text_input_ids.shape[0]:] mlm_labels_aux = mlm_labels_multi[text_input_ids.shape[0]:] if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST: mlm_loss_wvc = F.cross_entropy(mlm_logits_wvc.transpose(1, 2), mlm_labels_wvc, ignore_index=-1, reduction='none') num_mlm_wvc = (mlm_labels_wvc != -1).sum( 1, keepdim=True).to(dtype=mlm_loss_wvc.dtype) num_has_mlm_wvc = (num_mlm_wvc != 0).sum().to( dtype=mlm_loss_wvc.dtype) mlm_loss_wvc = (mlm_loss_wvc / (num_mlm_wvc + 1e-4)).sum() / ( num_has_mlm_wvc + 1e-4) mlm_loss_aux = F.cross_entropy(mlm_logits_aux.transpose(1, 2), mlm_labels_aux, ignore_index=-1, reduction='none') num_mlm_aux = (mlm_labels_aux != -1).sum( 1, keepdim=True).to(dtype=mlm_loss_aux.dtype) num_has_mlm_aux = (num_mlm_aux != 0).sum().to( dtype=mlm_loss_aux.dtype) mlm_loss_aux = (mlm_loss_aux / (num_mlm_aux + 1e-4)).sum() / ( num_has_mlm_aux + 1e-4) else: # mlm_loss = F.cross_entropy(mlm_logits_multi_padded.view((-1, mlm_logits_multi_padded.shape[-1])), # mlm_labels_multi.view(-1), # ignore_index=-1) mlm_loss_wvc = F.cross_entropy(mlm_logits_wvc.view( (-1, mlm_logits_multi_padded.shape[-1])), mlm_labels_wvc.view(-1), ignore_index=-1) mlm_loss_aux = F.cross_entropy(mlm_logits_aux.view( (-1, mlm_logits_multi_padded.shape[-1])), mlm_labels_aux.view(-1), ignore_index=-1) # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), # mvrc_labels.contiguous().view(-1), # ignore_index=-1) if self.config.NETWORK.WITH_MVRC_LOSS: mvrc_logits = mvrc_logits_multi[:mvrc_labels. shape[0], :mvrc_labels.shape[1]] if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]), reduction='none').view(mvrc_logits.shape[:-1]) valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1 mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \ .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4) else: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1])) mvrc_logits_padded = mvrc_logits.new_zeros( (mvrc_logits.shape[0], origin_len, mvrc_logits.shape[2])).fill_(-10000.0) mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits mvrc_logits = mvrc_logits_padded mvrc_labels_padded = mvrc_labels.new_zeros( (mvrc_labels.shape[0], origin_len, mvrc_labels.shape[2])).fill_(0.0) mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels mvrc_labels = mvrc_labels_padded # MLT loss applied if self.config.NETWORK.WITH_MLT_LOSS: MLT_loss = F.cross_entropy(MLT_logits, word_de_ids) # FM edit: removed other two losses that are not defined outputs.update({ 'relationship_logits': relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None, 'relationship_label': relationship_label if self.config.NETWORK.WITH_REL_LOSS else None, 'mlm_logits_wvc': mlm_logits_wvc if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_label_wvc': mlm_labels_wvc if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_logits_aux': mlm_logits_aux if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_label_aux': mlm_labels_aux if self.config.NETWORK.WITH_MLM_LOSS else None, 'mvrc_logits': mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None, 'mvrc_label': mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None, 'MLT_logits': MLT_logits if self.config.NETWORK.WITH_MLT_LOSS else None, 'MLT_label': word_de_ids if self.config.NETWORK.WITH_MLT_LOSS else None, 'MLT_loss': MLT_loss, }) # FM edit: removed addition of other losses which are not defined loss = MLT_loss.mean() return outputs, loss
def forward(self, image, boxes, im_info, text, relationship_label, mlm_labels, mvrc_ops, mvrc_labels): ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) origin_len = boxes.shape[1] max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] mvrc_ops = mvrc_ops[:, :max_len] mvrc_labels = mvrc_labels[:, :max_len] if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED: box_features = boxes[:, :, 4:] box_features[mvrc_ops == 1] = self.object_mask_visual_embedding.weight[0] boxes[:, :, 4:] = box_features obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None, mvrc_ops=mvrc_ops, mask_visual_embed=None) ############################################ # prepare text text_input_ids = text text_tags = text.new_zeros(text.shape) text_token_type_ids = text.new_zeros(text.shape) text_mask = (text_input_ids > 0) text_visual_embeddings = self._collect_obj_reps(text_tags, obj_reps['obj_reps']) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long() ) if self.config.NETWORK.WITH_MVRC_LOSS: object_linguistic_embeddings[mvrc_ops == 1] = self.object_mask_word_embedding.weight[0] object_vl_embeddings = torch.cat((obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT relationship_logits, mlm_logits, mvrc_logits = self.vlbert(text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask) ########################################### outputs = {} # loss relationship_loss = im_info.new_zeros(()) mlm_loss = im_info.new_zeros(()) mvrc_loss = im_info.new_zeros(()) if self.config.NETWORK.WITH_REL_LOSS: relationship_loss = F.cross_entropy(relationship_logits, relationship_label) if self.config.NETWORK.WITH_MLM_LOSS: mlm_logits_padded = mlm_logits.new_zeros((*mlm_labels.shape, mlm_logits.shape[-1])).fill_(-10000.0) mlm_logits_padded[:, :mlm_logits.shape[1]] = mlm_logits mlm_logits = mlm_logits_padded if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST: mlm_loss = F.cross_entropy(mlm_logits.transpose(1, 2), mlm_labels, ignore_index=-1, reduction='none') num_mlm = (mlm_labels != -1).sum(1, keepdim=True).to(dtype=mlm_loss.dtype) num_has_mlm = (num_mlm != 0).sum().to(dtype=mlm_loss.dtype) mlm_loss = (mlm_loss / (num_mlm + 1e-4)).sum() / (num_has_mlm + 1e-4) else: mlm_loss = F.cross_entropy(mlm_logits.view((-1, mlm_logits.shape[-1])), mlm_labels.view(-1), ignore_index=-1) # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), # mvrc_labels.contiguous().view(-1), # ignore_index=-1) if self.config.NETWORK.WITH_MVRC_LOSS: if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]), reduction='none').view(mvrc_logits.shape[:-1]) valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1 mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \ .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4) else: mvrc_loss = soft_cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1])) mvrc_logits_padded = mvrc_logits.new_zeros((mvrc_logits.shape[0], origin_len, mvrc_logits.shape[2])).fill_(-10000.0) mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits mvrc_logits = mvrc_logits_padded mvrc_labels_padded = mvrc_labels.new_zeros((mvrc_labels.shape[0], origin_len, mvrc_labels.shape[2])).fill_(0.0) mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels mvrc_labels = mvrc_labels_padded outputs.update({ 'relationship_logits': relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None, 'relationship_label': relationship_label if self.config.NETWORK.WITH_REL_LOSS else None, 'mlm_logits': mlm_logits if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_label': mlm_labels if self.config.NETWORK.WITH_MLM_LOSS else None, 'mvrc_logits': mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None, 'mvrc_label': mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None, 'relationship_loss': relationship_loss, 'mlm_loss': mlm_loss, 'mvrc_loss': mvrc_loss, }) loss = relationship_loss.mean() + mlm_loss.mean() + mvrc_loss.mean() return outputs, loss
def forward(self, image, boxes, im_info, text, relationship_label, mlm_labels, mvrc_ops, mvrc_labels): ########################################### # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) origin_len = boxes.shape[1] max_len = int(box_mask.sum(1).max().item()) box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] mvrc_ops = mvrc_ops[:, :max_len] mvrc_labels = mvrc_labels[:, :max_len] if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED: box_features = boxes[:, :, 4:] box_features[mvrc_ops == 1] = self.object_mask_visual_embedding.weight[0] boxes[:, :, 4:] = box_features obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None, mvrc_ops=mvrc_ops, mask_visual_embed=None) ############################################ # prepare text text_input_ids = text text_tags = text.new_zeros(text.shape) text_token_type_ids = text.new_zeros(text.shape) text_mask = (text_input_ids > 0) text_visual_embeddings = self._collect_obj_reps( text_tags, obj_reps['obj_reps']) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) if self.config.NETWORK.WITH_MVRC_LOSS: object_linguistic_embeddings[ mvrc_ops == 1] = self.object_mask_word_embedding.weight[0] object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT # #loop here for test mode: generated = [] stop = [False] * text.shape[0] curr_len = 0 max_len = 48 while not all(stop) and curr_len <= max_len: relationship_logits, mlm_logits, mvrc_logits = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask) answers = torch.topk(mlm_logits[mlm_labels == 103], k=1, dim=1) # Get size of each tensor position_tensor = torch.arange(mlm_labels.shape[1]) position_tensor = position_tensor.repeat(mlm_labels.shape[0]).view( mlm_labels.shape[0], -1) indeces = position_tensor[mlm_labels == 103] # 1. Update mlm_labels: mlm_labels_new = mlm_labels.new_zeros(mlm_labels.shape[0], mlm_labels.shape[1] + 1) mlm_labels_new = mlm_labels_new - 1 mlm_labels_new[torch.arange(mlm_labels.shape[0]), indeces + 1] = 103 mlm_labels = mlm_labels_new # 2. Update text_input_ids: text_input_ids_new = text_input_ids.new_zeros( text_input_ids.shape[0], text_input_ids.shape[1] + 1) text_input_ids_new[:, :-1] = text_input_ids text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces] = answers[1][:, 0] text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces + 1] = (self.tokenizer.convert_tokens_to_ids( ['[MASK]'])[0]) text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces + 2] = (self.tokenizer.convert_tokens_to_ids( ['[PAD]'])[0]) text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces + 3] = (self.tokenizer.convert_tokens_to_ids( ['[SEP]'])[0]) text_input_ids = text_input_ids_new # 3. Update text_token_type_ids: text_token_type_ids = text_token_type_ids.new_zeros( text_token_type_ids.shape[0], text_token_type_ids.shape[1] + 1) # 4. Update text_input_ids: text_visual_embeddings_new = text_visual_embeddings.new_zeros( text_visual_embeddings.shape[0], text_visual_embeddings.shape[1] + 1, text_visual_embeddings.shape[2]) text_visual_embeddings_new = text_visual_embeddings_new.transpose( 0, 1) text_visual_embeddings_new[:] = text_visual_embeddings[:, 0, :] text_visual_embeddings = text_visual_embeddings_new.transpose(0, 1) # 5. Update text_mask: text_mask = (text_input_ids > 0) # 6. Append generated words from each sentence in the batch to list - terminate if all [STOP] for nid, row in enumerate(answers[1]): if curr_len == 0: generated.append([]) for ele in row: # try: if not stop[nid]: if self.tokenizer.ids_to_tokens[ ele.item()] == '[STOP]': stop[nid] = True else: # print('generated: ', ele.item()) generated[nid].append( self.tokenizer.ids_to_tokens[ele.item()]) # except: # generated[nid].append(self.tokenizer.ids_to_tokens[100]) curr_len += 1 # Join in sentences generated_sentences = [] for sentence in generated: new_sentence = ' '.join(sentence) generated_sentences.append(new_sentence.replace(' ##', '')) # print(generated_sentences) # exit() ########################################### outputs = {} # loss relationship_loss = im_info.new_zeros(()) mlm_loss = im_info.new_zeros(()) mvrc_loss = im_info.new_zeros(()) if self.config.NETWORK.WITH_REL_LOSS: relationship_loss = F.cross_entropy(relationship_logits, relationship_label) if self.config.NETWORK.WITH_MLM_LOSS: mlm_logits_padded = mlm_logits.new_zeros( (*mlm_labels.shape, mlm_logits.shape[-1])).fill_(-10000.0) mlm_logits_padded[:, :mlm_logits.shape[1]] = mlm_logits mlm_logits = mlm_logits_padded if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST: mlm_loss = F.cross_entropy(mlm_logits.transpose(1, 2), mlm_labels, ignore_index=-1, reduction='none') num_mlm = (mlm_labels != -1).sum( 1, keepdim=True).to(dtype=mlm_loss.dtype) num_has_mlm = (num_mlm != 0).sum().to(dtype=mlm_loss.dtype) mlm_loss = (mlm_loss / (num_mlm + 1e-4)).sum() / (num_has_mlm + 1e-4) else: mlm_loss = F.cross_entropy(mlm_logits.view( (-1, mlm_logits.shape[-1])), mlm_labels.view(-1), ignore_index=-1) # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), # mvrc_labels.contiguous().view(-1), # ignore_index=-1) if self.config.NETWORK.WITH_MVRC_LOSS: if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]), reduction='none').view(mvrc_logits.shape[:-1]) valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1 mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \ .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4) else: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1])) mvrc_logits_padded = mvrc_logits.new_zeros( (mvrc_logits.shape[0], origin_len, mvrc_logits.shape[2])).fill_(-10000.0) mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits mvrc_logits = mvrc_logits_padded mvrc_labels_padded = mvrc_labels.new_zeros( (mvrc_labels.shape[0], origin_len, mvrc_labels.shape[2])).fill_(0.0) mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels mvrc_labels = mvrc_labels_padded outputs.update({ 'relationship_logits': relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None, 'relationship_label': relationship_label if self.config.NETWORK.WITH_REL_LOSS else None, 'mlm_logits': mlm_logits if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_label': mlm_labels if self.config.NETWORK.WITH_MLM_LOSS else None, 'mvrc_logits': mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None, 'mvrc_label': mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None, 'relationship_loss': relationship_loss, 'mlm_loss': mlm_loss, 'mvrc_loss': mvrc_loss, 'generated_sentences': generated_sentences }) loss = relationship_loss.mean() + mlm_loss.mean() + mvrc_loss.mean() return outputs, loss
def forward(self, text, relationship_label, mlm_labels): ########################################### ############################################ # prepare text text_input_ids = text text_tags = text.new_zeros(text.shape) text_token_type_ids = text.new_zeros(text.shape) # ***** FM edit: blank out visual embeddings for translation retrieval task text_visual_embeddings = text_input_ids.new_zeros( (text_input_ids.shape[0], text_input_ids.shape[1], 768), dtype=torch.float) # text_visual_embeddings[:] = self.aux_text_visual_embedding.weight[0] # ****** FM edit: blank visual embeddings (use known dimensions) object_vl_embeddings = text_input_ids.new_zeros( (text_input_ids.shape[0], 1, 1536), dtype=torch.float) # FM edit: No auxiliary text is used for text only # add auxiliary text - Concatenates the batches from the two dataloaders # The visual features for the text only corpus is just the embedding of the aux_visual_embedding (only one embedding) max_text_len = text_input_ids.shape[1] text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape) text_mask = (text_input_ids > 0) #FM: Edit: set to zero to ignore vision box_mask = text_input_ids.new_zeros((text_input_ids.shape[0], 1), dtype=torch.uint8) ########################################### # Visual Linguistic BERT # #loop here for test mode: generated = [] stop = [False] * text.shape[0] curr_len = 0 max_len = 48 while not all(stop) and curr_len <= max_len: relationship_logits, mlm_logits, mvrc_logits = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask) answers = torch.topk(mlm_logits[mlm_labels == 103], k=1, dim=1) # Get size of each tensor position_tensor = torch.arange(mlm_labels.shape[1]) position_tensor = position_tensor.repeat(mlm_labels.shape[0]).view( mlm_labels.shape[0], -1) indeces = position_tensor[mlm_labels == 103] # 1. Update mlm_labels: mlm_labels_new = mlm_labels.new_zeros(mlm_labels.shape[0], mlm_labels.shape[1] + 1) mlm_labels_new = mlm_labels_new - 1 mlm_labels_new[torch.arange(mlm_labels.shape[0]), indeces + 1] = 103 mlm_labels = mlm_labels_new # 2. Update text_input_ids: text_input_ids_new = text_input_ids.new_zeros( text_input_ids.shape[0], text_input_ids.shape[1] + 1) text_input_ids_new[:, :-1] = text_input_ids text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces] = answers[1][:, 0] text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces + 1] = (self.tokenizer.convert_tokens_to_ids( ['[MASK]'])[0]) text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces + 2] = (self.tokenizer.convert_tokens_to_ids( ['[PAD]'])[0]) text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces + 3] = (self.tokenizer.convert_tokens_to_ids( ['[SEP]'])[0]) text_input_ids = text_input_ids_new # 3. Update text_token_type_ids: text_token_type_ids = text_token_type_ids.new_zeros( text_token_type_ids.shape[0], text_token_type_ids.shape[1] + 1) # 4. Update text_input_ids: text_visual_embeddings_new = text_visual_embeddings.new_zeros( text_visual_embeddings.shape[0], text_visual_embeddings.shape[1] + 1, text_visual_embeddings.shape[2]) text_visual_embeddings_new = text_visual_embeddings_new.transpose( 0, 1) text_visual_embeddings_new[:] = text_visual_embeddings[:, 0, :] text_visual_embeddings = text_visual_embeddings_new.transpose(0, 1) # 5. Update text_mask: text_mask = (text_input_ids > 0) # 6. Add to generated for nid, row in enumerate(answers[1]): if curr_len == 0: generated.append([]) for ele in row: # try: if not stop[nid]: if self.tokenizer.ids_to_tokens[ ele.item()] == '[STOP]': stop[nid] = True else: generated[nid].append( self.tokenizer.ids_to_tokens[ele.item()]) # except: # generated[nid].append(self.tokenizer.ids_to_tokens[100]) curr_len += 1 # Join in sentences generated_sentences = [] for sentence in generated: new_sentence = ' '.join(sentence) generated_sentences.append(new_sentence.replace(' ##', '')) ########################################### outputs = {} if self.config.NETWORK.WITH_REL_LOSS: relationship_loss = F.cross_entropy(relationship_logits, relationship_label) if self.config.NETWORK.WITH_MLM_LOSS: mlm_logits_padded = mlm_logits.new_zeros( (*mlm_labels.shape, mlm_logits.shape[-1])).fill_(-10000.0) mlm_logits_padded[:, :mlm_logits.shape[1]] = mlm_logits mlm_logits = mlm_logits_padded if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST: mlm_loss = F.cross_entropy(mlm_logits.transpose(1, 2), mlm_labels, ignore_index=-1, reduction='none') num_mlm = (mlm_labels != -1).sum( 1, keepdim=True).to(dtype=mlm_loss.dtype) num_has_mlm = (num_mlm != 0).sum().to(dtype=mlm_loss.dtype) mlm_loss = (mlm_loss / (num_mlm + 1e-4)).sum() / (num_has_mlm + 1e-4) else: mlm_loss = F.cross_entropy(mlm_logits.view( (-1, mlm_logits.shape[-1])), mlm_labels.view(-1), ignore_index=-1) if self.config.NETWORK.WITH_MVRC_LOSS: if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]), reduction='none').view(mvrc_logits.shape[:-1]) valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1 mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \ .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4) else: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1])) mvrc_logits_padded = mvrc_logits.new_zeros( (mvrc_logits.shape[0], origin_len, mvrc_logits.shape[2])).fill_(-10000.0) mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits mvrc_logits = mvrc_logits_padded mvrc_labels_padded = mvrc_labels.new_zeros( (mvrc_labels.shape[0], origin_len, mvrc_labels.shape[2])).fill_(0.0) mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels mvrc_labels = mvrc_labels_padded outputs.update({ 'relationship_logits': relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None, 'relationship_label': relationship_label if self.config.NETWORK.WITH_REL_LOSS else None, 'mlm_logits': mlm_logits if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_label': mlm_labels if self.config.NETWORK.WITH_MLM_LOSS else None, 'mvrc_logits': mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None, 'mvrc_label': mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None, 'mlm_loss': mlm_loss, 'generated_sentences': generated_sentences }) loss = mlm_loss.mean() return outputs, loss
def forward(self, text, relationship_label, mlm_labels): ########################################### # FM edit: remove visual feature extraction ############################################ # prepare text text_input_ids = text text_tags = text.new_zeros(text.shape) text_token_type_ids = text.new_zeros(text.shape) text_mask = (text_input_ids > 0) # text_visual_embeddings = self._collect_obj_reps(text_tags, obj_reps['obj_reps']) # ***** FM edit: blank out visual embeddings for translation retrieval task text_visual_embeddings = text_input_ids.new_zeros( (text_input_ids.shape[0], text_input_ids.shape[1], 768), dtype=torch.float) # ****** FM edit: blank visual embeddings (use known dimensions) object_vl_embeddings = text_input_ids.new_zeros( (text_input_ids.shape[0], 1, 1536), dtype=torch.float) #FM: Edit: set to zero to ignore vision box_mask = text_input_ids.new_zeros((text_input_ids.shape[0], 1), dtype=torch.uint8) ########################################### # Visual Linguistic BERT relationship_logits, mlm_logits, mvrc_logits = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask) ########################################### outputs = {} # loss # relationship_loss = im_info.new_zeros(()) # mlm_loss = im_info.new_zeros(()) # mvrc_loss = im_info.new_zeros(()) if self.config.NETWORK.WITH_REL_LOSS: relationship_loss = F.cross_entropy(relationship_logits, relationship_label) if self.config.NETWORK.WITH_MLM_LOSS: mlm_logits_padded = mlm_logits.new_zeros( (*mlm_labels.shape, mlm_logits.shape[-1])).fill_(-10000.0) mlm_logits_padded[:, :mlm_logits.shape[1]] = mlm_logits mlm_logits = mlm_logits_padded if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST: mlm_loss = F.cross_entropy(mlm_logits.transpose(1, 2), mlm_labels, ignore_index=-1, reduction='none') num_mlm = (mlm_labels != -1).sum( 1, keepdim=True).to(dtype=mlm_loss.dtype) num_has_mlm = (num_mlm != 0).sum().to(dtype=mlm_loss.dtype) mlm_loss = (mlm_loss / (num_mlm + 1e-4)).sum() / (num_has_mlm + 1e-4) else: mlm_loss = F.cross_entropy(mlm_logits.view( (-1, mlm_logits.shape[-1])), mlm_labels.view(-1), ignore_index=-1) # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), # mvrc_labels.contiguous().view(-1), # ignore_index=-1) if self.config.NETWORK.WITH_MVRC_LOSS: if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]), reduction='none').view(mvrc_logits.shape[:-1]) valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1 mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \ .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4) else: mvrc_loss = soft_cross_entropy( mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]), mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1])) mvrc_logits_padded = mvrc_logits.new_zeros( (mvrc_logits.shape[0], origin_len, mvrc_logits.shape[2])).fill_(-10000.0) mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits mvrc_logits = mvrc_logits_padded mvrc_labels_padded = mvrc_labels.new_zeros( (mvrc_labels.shape[0], origin_len, mvrc_labels.shape[2])).fill_(0.0) mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels mvrc_labels = mvrc_labels_padded outputs.update({ 'relationship_logits': relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None, 'relationship_label': relationship_label if self.config.NETWORK.WITH_REL_LOSS else None, 'mlm_logits': mlm_logits if self.config.NETWORK.WITH_MLM_LOSS else None, 'mlm_label': mlm_labels if self.config.NETWORK.WITH_MLM_LOSS else None, 'mvrc_logits': mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None, 'mvrc_label': mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None, 'mlm_loss': mlm_loss, }) loss = mlm_loss.mean() return outputs, loss