def forward(self, text, relationship_label, mlm_labels):
        ###########################################

        # Blank out visual feature extraction

        ############################################

        # prepare text
        text_input_ids = text
        # creates a text_tags tensor of the same shape as text tensor
        text_tags = text.new_zeros(text.shape)
        # ***** FM edit: blank out visual embeddings for translation retrieval task
        text_visual_embeddings = text_input_ids.new_zeros(
            (text_input_ids.shape[0], text_input_ids.shape[1], 768),
            dtype=torch.float)
        # text_visual_embeddings[:] = self.aux_text_visual_embedding.weight[0]

        # ****** FM edit: blank visual embeddings (use known dimensions)
        object_vl_embeddings = text_input_ids.new_zeros(
            (text_input_ids.shape[0], 1, 1536), dtype=torch.float)

        # FM edit: No auxiliary text is used for text only
        # add auxiliary text - Concatenates the batches from the two dataloaders
        # The visual features for the text only corpus is just the embedding of the aux_visual_embedding (only one embedding)
        max_text_len = text_input_ids.shape[1]
        text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape)
        text_mask = (text_input_ids > 0)
        #FM: Edit: set to zero to ignore vision
        box_mask = text_input_ids.new_zeros((text_input_ids.shape[0], 1),
                                            dtype=torch.uint8)

        ###########################################

        # Visual Linguistic BERT
        relationship_logits, mlm_logits, mvrc_logits = self.vlbert(
            text_input_ids, text_token_type_ids, text_visual_embeddings,
            text_mask, object_vl_embeddings, box_mask)

        ###########################################
        outputs = {}

        # losses
        if self.config.NETWORK.WITH_REL_LOSS:
            relationship_loss = F.cross_entropy(relationship_logits,
                                                relationship_label)
        if self.config.NETWORK.WITH_MLM_LOSS:
            mlm_logits_padded = mlm_logits.new_zeros(
                (*mlm_labels.shape, mlm_logits.shape[-1])).fill_(-10000.0)
            mlm_logits_padded[:, :mlm_logits.shape[1]] = mlm_logits
            mlm_logits = mlm_logits_padded
            if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST:
                mlm_loss = F.cross_entropy(mlm_logits.transpose(1, 2),
                                           mlm_labels,
                                           ignore_index=-1,
                                           reduction='none')
                num_mlm = (mlm_labels != -1).sum(
                    1, keepdim=True).to(dtype=mlm_loss.dtype)
                num_has_mlm = (num_mlm != 0).sum().to(dtype=mlm_loss.dtype)
                mlm_loss = (mlm_loss /
                            (num_mlm + 1e-4)).sum() / (num_has_mlm + 1e-4)
            else:
                mlm_loss = F.cross_entropy(mlm_logits.view(
                    (-1, mlm_logits.shape[-1])),
                                           mlm_labels.view(-1),
                                           ignore_index=-1)

        if self.config.NETWORK.WITH_MVRC_LOSS:
            if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]),
                    reduction='none').view(mvrc_logits.shape[:-1])
                valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1
                mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \
                                .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4)
            else:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]))

            mvrc_logits_padded = mvrc_logits.new_zeros(
                (mvrc_logits.shape[0], origin_len,
                 mvrc_logits.shape[2])).fill_(-10000.0)
            mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits
            mvrc_logits = mvrc_logits_padded
            mvrc_labels_padded = mvrc_labels.new_zeros(
                (mvrc_labels.shape[0], origin_len,
                 mvrc_labels.shape[2])).fill_(0.0)
            mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels
            mvrc_labels = mvrc_labels_padded

        outputs.update({
            'relationship_logits':
            relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None,
            'relationship_label':
            relationship_label if self.config.NETWORK.WITH_REL_LOSS else None,
            'mlm_logits':
            mlm_logits if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_label':
            mlm_labels if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mvrc_logits':
            mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'mvrc_label':
            mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'mlm_loss':
            mlm_loss,
        })

        loss = mlm_loss.mean()

        return outputs, loss
    def forward(self,
                image,
                boxes,
                im_info,
                text,
                relationship_label,
                mlm_labels,
                mvrc_ops,
                mvrc_labels,
                *aux):

        # concat aux texts from different dataset
        assert len(aux) > 0 and len(aux) % 2 == 0
        aux_text_list = aux[0::2]
        aux_text_mlm_labels_list = aux[1::2]
        num_aux_text = sum([_text.shape[0] for _text in aux_text_list])
        max_aux_text_len = max([_text.shape[1] for _text in aux_text_list])
        aux_text = aux_text_list[0].new_zeros((num_aux_text, max_aux_text_len))
        aux_text_mlm_labels = aux_text_mlm_labels_list[0].new_zeros((num_aux_text, max_aux_text_len)).fill_(-1)
        _cur = 0
        for _text, _mlm_labels in zip(aux_text_list, aux_text_mlm_labels_list):
            _num = _text.shape[0]
            aux_text[_cur:(_cur + _num), :_text.shape[1]] = _text
            aux_text_mlm_labels[_cur:(_cur + _num), :_text.shape[1]] = _mlm_labels
            _cur += _num

        ###########################################

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        origin_len = boxes.shape[1]
        max_len = int(box_mask.sum(1).max().item())
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        mvrc_ops = mvrc_ops[:, :max_len]
        mvrc_labels = mvrc_labels[:, :max_len]

        if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED:
            box_features = boxes[:, :, 4:]
            box_features[mvrc_ops == 1] = self.object_mask_visual_embedding.weight[0]
            boxes[:, :, 4:] = box_features

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=boxes,
                                                box_mask=box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None,
                                                mvrc_ops=mvrc_ops,
                                                mask_visual_embed=self.object_mask_visual_embedding.weight[0]
                                                if (not self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED)
                                                   and (not self.config.NETWORK.MASK_RAW_PIXELS)
                                                else None)

        ############################################

        # prepare text
        text_input_ids = text
        text_tags = text.new_zeros(text.shape)
        text_visual_embeddings = self._collect_obj_reps(text_tags, obj_reps['obj_reps'])

        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()
        )
        if self.config.NETWORK.WITH_MVRC_LOSS:
            object_linguistic_embeddings[mvrc_ops == 1] = self.object_mask_word_embedding.weight[0]
        object_vl_embeddings = torch.cat((obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        # add auxiliary text
        max_text_len = max(text_input_ids.shape[1], aux_text.shape[1])
        text_input_ids_multi = text_input_ids.new_zeros((text_input_ids.shape[0] + aux_text.shape[0], max_text_len))
        text_input_ids_multi[:text_input_ids.shape[0], :text_input_ids.shape[1]] = text_input_ids
        text_input_ids_multi[text_input_ids.shape[0]:, :aux_text.shape[1]] = aux_text
        text_token_type_ids_multi = text_input_ids_multi.new_zeros(text_input_ids_multi.shape)
        text_mask_multi = (text_input_ids_multi > 0)
        text_visual_embeddings_multi = text_visual_embeddings.new_zeros((text_input_ids.shape[0] + aux_text.shape[0],
                                                                         max_text_len,
                                                                         text_visual_embeddings.shape[-1]))
        text_visual_embeddings_multi[:text_visual_embeddings.shape[0], :text_visual_embeddings.shape[1]] \
            = text_visual_embeddings
        text_visual_embeddings_multi[text_visual_embeddings.shape[0]:] = self.aux_text_visual_embedding.weight[0]
        object_vl_embeddings_multi = object_vl_embeddings.new_zeros((text_input_ids.shape[0] + aux_text.shape[0],
                                                                     *object_vl_embeddings.shape[1:]))
        object_vl_embeddings_multi[:object_vl_embeddings.shape[0]] = object_vl_embeddings
        box_mask_multi = box_mask.new_zeros((text_input_ids.shape[0] + aux_text.shape[0], *box_mask.shape[1:]))
        box_mask_multi[:box_mask.shape[0]] = box_mask

        ###########################################

        # Visual Linguistic BERT

        relationship_logits_multi, mlm_logits_multi, mvrc_logits_multi = self.vlbert(text_input_ids_multi,
                                                                                     text_token_type_ids_multi,
                                                                                     text_visual_embeddings_multi,
                                                                                     text_mask_multi,
                                                                                     object_vl_embeddings_multi,
                                                                                     box_mask_multi)

        ###########################################
        outputs = {}

        # loss
        relationship_loss = im_info.new_zeros(())
        mlm_loss = im_info.new_zeros(())
        mvrc_loss = im_info.new_zeros(())
        if self.config.NETWORK.WITH_REL_LOSS:
            relationship_logits = relationship_logits_multi[:text_input_ids.shape[0]]
            relationship_loss = F.cross_entropy(relationship_logits, relationship_label)
        if self.config.NETWORK.WITH_MLM_LOSS:
            mlm_labels_multi = mlm_labels.new_zeros((text_input_ids.shape[0] + aux_text.shape[0], max_text_len)).fill_(
                -1)
            mlm_labels_multi[:text_input_ids.shape[0], :mlm_labels.shape[1]] = mlm_labels
            mlm_labels_multi[text_input_ids.shape[0]:, :aux_text_mlm_labels.shape[1]] = aux_text_mlm_labels

            mlm_logits_multi_padded = \
                mlm_logits_multi.new_zeros((*mlm_labels_multi.shape, mlm_logits_multi.shape[-1])).fill_(-10000.0)
            mlm_logits_multi_padded[:, :mlm_logits_multi.shape[1]] = mlm_logits_multi
            mlm_logits_multi = mlm_logits_multi_padded
            mlm_logits_wvc = mlm_logits_multi_padded[:text_input_ids.shape[0]]
            mlm_labels_wvc = mlm_labels_multi[:text_input_ids.shape[0]]
            mlm_logits_aux = mlm_logits_multi_padded[text_input_ids.shape[0]:]
            mlm_labels_aux = mlm_labels_multi[text_input_ids.shape[0]:]
            if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST:
                mlm_loss_wvc = F.cross_entropy(mlm_logits_wvc.transpose(1, 2),
                                               mlm_labels_wvc,
                                               ignore_index=-1, reduction='none')
                num_mlm_wvc = (mlm_labels_wvc != -1).sum(1, keepdim=True).to(dtype=mlm_loss_wvc.dtype)
                num_has_mlm_wvc = (num_mlm_wvc != 0).sum().to(dtype=mlm_loss_wvc.dtype)
                mlm_loss_wvc = (mlm_loss_wvc / (num_mlm_wvc + 1e-4)).sum() / (num_has_mlm_wvc + 1e-4)
                mlm_loss_aux = F.cross_entropy(mlm_logits_aux.transpose(1, 2),
                                               mlm_labels_aux,
                                               ignore_index=-1, reduction='none')
                num_mlm_aux = (mlm_labels_aux != -1).sum(1, keepdim=True).to(dtype=mlm_loss_aux.dtype)
                num_has_mlm_aux = (num_mlm_aux != 0).sum().to(dtype=mlm_loss_aux.dtype)
                mlm_loss_aux = (mlm_loss_aux / (num_mlm_aux + 1e-4)).sum() / (num_has_mlm_aux + 1e-4)
            else:
                # mlm_loss = F.cross_entropy(mlm_logits_multi_padded.view((-1, mlm_logits_multi_padded.shape[-1])),
                #                            mlm_labels_multi.view(-1),
                #                            ignore_index=-1)
                mlm_loss_wvc = F.cross_entropy(
                    mlm_logits_wvc.view((-1, mlm_logits_multi_padded.shape[-1])),
                    mlm_labels_wvc.view(-1),
                    ignore_index=-1
                )
                mlm_loss_aux = F.cross_entropy(
                    mlm_logits_aux.view((-1, mlm_logits_multi_padded.shape[-1])),
                    mlm_labels_aux.view(-1),
                    ignore_index=-1
                )

        # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
        #                             mvrc_labels.contiguous().view(-1),
        #                             ignore_index=-1)
        if self.config.NETWORK.WITH_MVRC_LOSS:
            mvrc_logits = mvrc_logits_multi[:mvrc_labels.shape[0], :mvrc_labels.shape[1]]
            if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]),
                    reduction='none').view(mvrc_logits.shape[:-1])
                valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1
                mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \
                                .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4)
            else:
                mvrc_loss = soft_cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                                               mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]))

            mvrc_logits_padded = mvrc_logits.new_zeros((mvrc_logits.shape[0], origin_len, mvrc_logits.shape[2])).fill_(
                -10000.0)
            mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits
            mvrc_logits = mvrc_logits_padded
            mvrc_labels_padded = mvrc_labels.new_zeros((mvrc_labels.shape[0], origin_len, mvrc_labels.shape[2])).fill_(
                0.0)
            mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels
            mvrc_labels = mvrc_labels_padded

        outputs.update({
            'relationship_logits': relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None,
            'relationship_label': relationship_label if self.config.NETWORK.WITH_REL_LOSS else None,
            'mlm_logits_wvc': mlm_logits_wvc if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_label_wvc': mlm_labels_wvc if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_logits_aux': mlm_logits_aux if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_label_aux': mlm_labels_aux if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mvrc_logits': mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'mvrc_label': mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'relationship_loss': relationship_loss,
            'mlm_loss_wvc': mlm_loss_wvc,
            'mlm_loss_aux': mlm_loss_aux,
            'mvrc_loss': mvrc_loss,
        })

        loss = relationship_loss.mean() + mlm_loss_wvc.mean() + mlm_loss_aux.mean() + mvrc_loss.mean()

        return outputs, loss
Esempio n. 3
0
    def forward(self, image, boxes, im_info, text, relationship_label,
                mlm_labels, mvrc_ops, mvrc_labels, word_de_ids):

        ###########################################

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        origin_len = boxes.shape[1]
        max_len = int(box_mask.sum(1).max().item())
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        mvrc_ops = mvrc_ops[:, :max_len]
        mvrc_labels = mvrc_labels[:, :max_len]

        if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED:
            box_features = boxes[:, :, 4:]
            box_features[mvrc_ops ==
                         1] = self.object_mask_visual_embedding.weight[0]
            boxes[:, :, 4:] = box_features

        obj_reps = self.image_feature_extractor(
            images=images,
            boxes=boxes,
            box_mask=box_mask,
            im_info=im_info,
            classes=None,
            segms=None,
            mvrc_ops=mvrc_ops,
            mask_visual_embed=self.object_mask_visual_embedding.weight[0] if
            (not self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED) and
            (not self.config.NETWORK.MASK_RAW_PIXELS) else None)

        ############################################

        # prepare text
        text_input_ids = text
        # creates a text_tags tensor of the same shape as text tensor
        text_tags = text.new_zeros(text.shape)
        text_visual_embeddings = self._collect_obj_reps(
            text_tags, obj_reps['obj_reps'])
        # ***** FM edit: blank out visual embeddings for translation retrieval task
        text_visual_embeddings[:] = self.aux_text_visual_embedding.weight[0]

        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        if self.config.NETWORK.WITH_MVRC_LOSS:
            object_linguistic_embeddings[
                mvrc_ops == 1] = self.object_mask_word_embedding.weight[0]
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)
        # ****** FM edit: blank out all visual embeddings
        object_vl_embeddings = object_vl_embeddings.new_zeros(
            object_vl_embeddings.shape)

        # FM edit: No auxiliary text is used for text only
        # add auxiliary text - Concatenates the batches from the two dataloaders
        # The visual features for the text only corpus is just the embedding of the aux_visual_embedding (only one embedding)
        max_text_len = text_input_ids.shape[1]
        text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape)
        text_mask = (text_input_ids > 0)
        #FM: Edit: i have taken this out, not needed i think since defined above
        # box_mask = box_mask.new_zeros((text_input_ids.shape[0], *box_mask.shape[1:]))

        ###########################################

        # Visual Linguistic BERT

        relationship_logits_multi, mlm_logits_multi, mvrc_logits_multi, MLT_logits = self.vlbert(
            text_input_ids, text_token_type_ids, text_visual_embeddings,
            text_mask, object_vl_embeddings, box_mask)

        ###########################################
        outputs = {}

        # loss
        relationship_loss = im_info.new_zeros(())
        mlm_loss = im_info.new_zeros(())
        mvrc_loss = im_info.new_zeros(())
        MLT_loss = im_info.new_zeros(())
        if self.config.NETWORK.WITH_REL_LOSS:
            relationship_logits = relationship_logits_multi[:text_input_ids.
                                                            shape[0]]
            # FM edit - change cross_entropy to bce/sigmoid
            relationship_loss = F.binary_cross_entropy(
                torch.sigmoid(relationship_logits),
                relationship_label.unsqueeze(1))
        if self.config.NETWORK.WITH_MLM_LOSS:
            mlm_labels_multi = mlm_labels.new_zeros(
                (text_input_ids.shape[0] + aux_text.shape[0],
                 max_text_len)).fill_(-1)
            mlm_labels_multi[:text_input_ids.shape[0], :mlm_labels.
                             shape[1]] = mlm_labels
            mlm_labels_multi[text_input_ids.shape[0]:, :aux_text_mlm_labels.
                             shape[1]] = aux_text_mlm_labels

            mlm_logits_multi_padded = \
                mlm_logits_multi.new_zeros((*mlm_labels_multi.shape, mlm_logits_multi.shape[-1])).fill_(-10000.0)
            mlm_logits_multi_padded[:, :mlm_logits_multi.
                                    shape[1]] = mlm_logits_multi
            mlm_logits_multi = mlm_logits_multi_padded
            mlm_logits_wvc = mlm_logits_multi_padded[:text_input_ids.shape[0]]
            mlm_labels_wvc = mlm_labels_multi[:text_input_ids.shape[0]]
            mlm_logits_aux = mlm_logits_multi_padded[text_input_ids.shape[0]:]
            mlm_labels_aux = mlm_labels_multi[text_input_ids.shape[0]:]
            if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST:
                mlm_loss_wvc = F.cross_entropy(mlm_logits_wvc.transpose(1, 2),
                                               mlm_labels_wvc,
                                               ignore_index=-1,
                                               reduction='none')
                num_mlm_wvc = (mlm_labels_wvc != -1).sum(
                    1, keepdim=True).to(dtype=mlm_loss_wvc.dtype)
                num_has_mlm_wvc = (num_mlm_wvc != 0).sum().to(
                    dtype=mlm_loss_wvc.dtype)
                mlm_loss_wvc = (mlm_loss_wvc / (num_mlm_wvc + 1e-4)).sum() / (
                    num_has_mlm_wvc + 1e-4)
                mlm_loss_aux = F.cross_entropy(mlm_logits_aux.transpose(1, 2),
                                               mlm_labels_aux,
                                               ignore_index=-1,
                                               reduction='none')
                num_mlm_aux = (mlm_labels_aux != -1).sum(
                    1, keepdim=True).to(dtype=mlm_loss_aux.dtype)
                num_has_mlm_aux = (num_mlm_aux != 0).sum().to(
                    dtype=mlm_loss_aux.dtype)
                mlm_loss_aux = (mlm_loss_aux / (num_mlm_aux + 1e-4)).sum() / (
                    num_has_mlm_aux + 1e-4)
            else:
                # mlm_loss = F.cross_entropy(mlm_logits_multi_padded.view((-1, mlm_logits_multi_padded.shape[-1])),
                #                            mlm_labels_multi.view(-1),
                #                            ignore_index=-1)
                mlm_loss_wvc = F.cross_entropy(mlm_logits_wvc.view(
                    (-1, mlm_logits_multi_padded.shape[-1])),
                                               mlm_labels_wvc.view(-1),
                                               ignore_index=-1)
                mlm_loss_aux = F.cross_entropy(mlm_logits_aux.view(
                    (-1, mlm_logits_multi_padded.shape[-1])),
                                               mlm_labels_aux.view(-1),
                                               ignore_index=-1)

        # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
        #                             mvrc_labels.contiguous().view(-1),
        #                             ignore_index=-1)
        if self.config.NETWORK.WITH_MVRC_LOSS:
            mvrc_logits = mvrc_logits_multi[:mvrc_labels.
                                            shape[0], :mvrc_labels.shape[1]]
            if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]),
                    reduction='none').view(mvrc_logits.shape[:-1])
                valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1
                mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \
                                .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4)
            else:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]))

            mvrc_logits_padded = mvrc_logits.new_zeros(
                (mvrc_logits.shape[0], origin_len,
                 mvrc_logits.shape[2])).fill_(-10000.0)
            mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits
            mvrc_logits = mvrc_logits_padded
            mvrc_labels_padded = mvrc_labels.new_zeros(
                (mvrc_labels.shape[0], origin_len,
                 mvrc_labels.shape[2])).fill_(0.0)
            mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels
            mvrc_labels = mvrc_labels_padded

        # MLT loss applied
        if self.config.NETWORK.WITH_MLT_LOSS:
            MLT_loss = F.cross_entropy(MLT_logits, word_de_ids)

        # FM edit: removed other two losses that are not defined
        outputs.update({
            'relationship_logits':
            relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None,
            'relationship_label':
            relationship_label if self.config.NETWORK.WITH_REL_LOSS else None,
            'mlm_logits_wvc':
            mlm_logits_wvc if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_label_wvc':
            mlm_labels_wvc if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_logits_aux':
            mlm_logits_aux if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_label_aux':
            mlm_labels_aux if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mvrc_logits':
            mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'mvrc_label':
            mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'MLT_logits':
            MLT_logits if self.config.NETWORK.WITH_MLT_LOSS else None,
            'MLT_label':
            word_de_ids if self.config.NETWORK.WITH_MLT_LOSS else None,
            'MLT_loss':
            MLT_loss,
        })

        # FM edit: removed addition of other losses which are not defined
        loss = MLT_loss.mean()

        return outputs, loss
    def forward(self,
                image,
                boxes,
                im_info,
                text,
                relationship_label,
                mlm_labels,
                mvrc_ops,
                mvrc_labels):
        ###########################################

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        origin_len = boxes.shape[1]
        max_len = int(box_mask.sum(1).max().item())
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        mvrc_ops = mvrc_ops[:, :max_len]
        mvrc_labels = mvrc_labels[:, :max_len]

        if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED:
            box_features = boxes[:, :, 4:]
            box_features[mvrc_ops == 1] = self.object_mask_visual_embedding.weight[0]
            boxes[:, :, 4:] = box_features

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=boxes,
                                                box_mask=box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None,
                                                mvrc_ops=mvrc_ops,
                                                mask_visual_embed=None)

        ############################################

        # prepare text
        text_input_ids = text
        text_tags = text.new_zeros(text.shape)
        text_token_type_ids = text.new_zeros(text.shape)
        text_mask = (text_input_ids > 0)
        text_visual_embeddings = self._collect_obj_reps(text_tags, obj_reps['obj_reps'])

        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()
        )
        if self.config.NETWORK.WITH_MVRC_LOSS:
            object_linguistic_embeddings[mvrc_ops == 1] = self.object_mask_word_embedding.weight[0]
        object_vl_embeddings = torch.cat((obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT

        relationship_logits, mlm_logits, mvrc_logits = self.vlbert(text_input_ids,
                                                                   text_token_type_ids,
                                                                   text_visual_embeddings,
                                                                   text_mask,
                                                                   object_vl_embeddings,
                                                                   box_mask)

        ###########################################
        outputs = {}

        # loss
        relationship_loss = im_info.new_zeros(())
        mlm_loss = im_info.new_zeros(())
        mvrc_loss = im_info.new_zeros(())
        if self.config.NETWORK.WITH_REL_LOSS:
            relationship_loss = F.cross_entropy(relationship_logits, relationship_label)
        if self.config.NETWORK.WITH_MLM_LOSS:
            mlm_logits_padded = mlm_logits.new_zeros((*mlm_labels.shape, mlm_logits.shape[-1])).fill_(-10000.0)
            mlm_logits_padded[:, :mlm_logits.shape[1]] = mlm_logits
            mlm_logits = mlm_logits_padded
            if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST:
                mlm_loss = F.cross_entropy(mlm_logits.transpose(1, 2),
                                           mlm_labels,
                                           ignore_index=-1, reduction='none')
                num_mlm = (mlm_labels != -1).sum(1, keepdim=True).to(dtype=mlm_loss.dtype)
                num_has_mlm = (num_mlm != 0).sum().to(dtype=mlm_loss.dtype)
                mlm_loss = (mlm_loss / (num_mlm + 1e-4)).sum() / (num_has_mlm + 1e-4)
            else:
                mlm_loss = F.cross_entropy(mlm_logits.view((-1, mlm_logits.shape[-1])),
                                           mlm_labels.view(-1),
                                           ignore_index=-1)
        # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
        #                             mvrc_labels.contiguous().view(-1),
        #                             ignore_index=-1)
        if self.config.NETWORK.WITH_MVRC_LOSS:
            if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]),
                    reduction='none').view(mvrc_logits.shape[:-1])
                valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1
                mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \
                                .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4)
            else:
                mvrc_loss = soft_cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                                               mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]))

            mvrc_logits_padded = mvrc_logits.new_zeros((mvrc_logits.shape[0], origin_len, mvrc_logits.shape[2])).fill_(-10000.0)
            mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits
            mvrc_logits = mvrc_logits_padded
            mvrc_labels_padded = mvrc_labels.new_zeros((mvrc_labels.shape[0], origin_len, mvrc_labels.shape[2])).fill_(0.0)
            mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels
            mvrc_labels = mvrc_labels_padded

        outputs.update({
            'relationship_logits': relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None,
            'relationship_label': relationship_label if self.config.NETWORK.WITH_REL_LOSS else None,
            'mlm_logits': mlm_logits if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_label': mlm_labels if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mvrc_logits': mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'mvrc_label': mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'relationship_loss': relationship_loss,
            'mlm_loss': mlm_loss,
            'mvrc_loss': mvrc_loss,
        })

        loss = relationship_loss.mean() + mlm_loss.mean() + mvrc_loss.mean()

        return outputs, loss
    def forward(self, image, boxes, im_info, text, relationship_label,
                mlm_labels, mvrc_ops, mvrc_labels):
        ###########################################

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        origin_len = boxes.shape[1]
        max_len = int(box_mask.sum(1).max().item())
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        mvrc_ops = mvrc_ops[:, :max_len]
        mvrc_labels = mvrc_labels[:, :max_len]

        if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED:
            box_features = boxes[:, :, 4:]
            box_features[mvrc_ops ==
                         1] = self.object_mask_visual_embedding.weight[0]
            boxes[:, :, 4:] = box_features

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=boxes,
                                                box_mask=box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None,
                                                mvrc_ops=mvrc_ops,
                                                mask_visual_embed=None)

        ############################################

        # prepare text
        text_input_ids = text
        text_tags = text.new_zeros(text.shape)
        text_token_type_ids = text.new_zeros(text.shape)
        text_mask = (text_input_ids > 0)
        text_visual_embeddings = self._collect_obj_reps(
            text_tags, obj_reps['obj_reps'])

        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        if self.config.NETWORK.WITH_MVRC_LOSS:
            object_linguistic_embeddings[
                mvrc_ops == 1] = self.object_mask_word_embedding.weight[0]
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################
        # Visual Linguistic BERT
        # #loop here for test mode:
        generated = []
        stop = [False] * text.shape[0]
        curr_len = 0
        max_len = 48
        while not all(stop) and curr_len <= max_len:
            relationship_logits, mlm_logits, mvrc_logits = self.vlbert(
                text_input_ids, text_token_type_ids, text_visual_embeddings,
                text_mask, object_vl_embeddings, box_mask)
            answers = torch.topk(mlm_logits[mlm_labels == 103], k=1, dim=1)

            # Get size of each tensor
            position_tensor = torch.arange(mlm_labels.shape[1])
            position_tensor = position_tensor.repeat(mlm_labels.shape[0]).view(
                mlm_labels.shape[0], -1)
            indeces = position_tensor[mlm_labels == 103]

            # 1. Update mlm_labels:
            mlm_labels_new = mlm_labels.new_zeros(mlm_labels.shape[0],
                                                  mlm_labels.shape[1] + 1)
            mlm_labels_new = mlm_labels_new - 1
            mlm_labels_new[torch.arange(mlm_labels.shape[0]),
                           indeces + 1] = 103
            mlm_labels = mlm_labels_new

            # 2. Update text_input_ids:
            text_input_ids_new = text_input_ids.new_zeros(
                text_input_ids.shape[0], text_input_ids.shape[1] + 1)
            text_input_ids_new[:, :-1] = text_input_ids
            text_input_ids_new[torch.arange(text_input_ids.shape[0]),
                               indeces] = answers[1][:, 0]
            text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces +
                               1] = (self.tokenizer.convert_tokens_to_ids(
                                   ['[MASK]'])[0])
            text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces +
                               2] = (self.tokenizer.convert_tokens_to_ids(
                                   ['[PAD]'])[0])
            text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces +
                               3] = (self.tokenizer.convert_tokens_to_ids(
                                   ['[SEP]'])[0])
            text_input_ids = text_input_ids_new

            # 3. Update text_token_type_ids:
            text_token_type_ids = text_token_type_ids.new_zeros(
                text_token_type_ids.shape[0], text_token_type_ids.shape[1] + 1)

            # 4. Update text_input_ids:
            text_visual_embeddings_new = text_visual_embeddings.new_zeros(
                text_visual_embeddings.shape[0],
                text_visual_embeddings.shape[1] + 1,
                text_visual_embeddings.shape[2])
            text_visual_embeddings_new = text_visual_embeddings_new.transpose(
                0, 1)
            text_visual_embeddings_new[:] = text_visual_embeddings[:, 0, :]
            text_visual_embeddings = text_visual_embeddings_new.transpose(0, 1)

            # 5. Update text_mask:
            text_mask = (text_input_ids > 0)

            # 6. Append generated words from each sentence in the batch to list - terminate if all [STOP]
            for nid, row in enumerate(answers[1]):
                if curr_len == 0:
                    generated.append([])
                for ele in row:
                    # try:
                    if not stop[nid]:
                        if self.tokenizer.ids_to_tokens[
                                ele.item()] == '[STOP]':
                            stop[nid] = True
                        else:
                            # print('generated: ', ele.item())
                            generated[nid].append(
                                self.tokenizer.ids_to_tokens[ele.item()])
                    # except:
                    #     generated[nid].append(self.tokenizer.ids_to_tokens[100])
            curr_len += 1

        # Join in sentences
        generated_sentences = []
        for sentence in generated:
            new_sentence = ' '.join(sentence)
            generated_sentences.append(new_sentence.replace(' ##', ''))
        # print(generated_sentences)
        # exit()

        ###########################################
        outputs = {}

        # loss
        relationship_loss = im_info.new_zeros(())
        mlm_loss = im_info.new_zeros(())
        mvrc_loss = im_info.new_zeros(())
        if self.config.NETWORK.WITH_REL_LOSS:
            relationship_loss = F.cross_entropy(relationship_logits,
                                                relationship_label)
        if self.config.NETWORK.WITH_MLM_LOSS:
            mlm_logits_padded = mlm_logits.new_zeros(
                (*mlm_labels.shape, mlm_logits.shape[-1])).fill_(-10000.0)
            mlm_logits_padded[:, :mlm_logits.shape[1]] = mlm_logits
            mlm_logits = mlm_logits_padded
            if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST:
                mlm_loss = F.cross_entropy(mlm_logits.transpose(1, 2),
                                           mlm_labels,
                                           ignore_index=-1,
                                           reduction='none')
                num_mlm = (mlm_labels != -1).sum(
                    1, keepdim=True).to(dtype=mlm_loss.dtype)
                num_has_mlm = (num_mlm != 0).sum().to(dtype=mlm_loss.dtype)
                mlm_loss = (mlm_loss /
                            (num_mlm + 1e-4)).sum() / (num_has_mlm + 1e-4)
            else:
                mlm_loss = F.cross_entropy(mlm_logits.view(
                    (-1, mlm_logits.shape[-1])),
                                           mlm_labels.view(-1),
                                           ignore_index=-1)
        # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
        #                             mvrc_labels.contiguous().view(-1),
        #                             ignore_index=-1)
        if self.config.NETWORK.WITH_MVRC_LOSS:
            if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]),
                    reduction='none').view(mvrc_logits.shape[:-1])
                valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1
                mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \
                                .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4)
            else:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]))

            mvrc_logits_padded = mvrc_logits.new_zeros(
                (mvrc_logits.shape[0], origin_len,
                 mvrc_logits.shape[2])).fill_(-10000.0)
            mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits
            mvrc_logits = mvrc_logits_padded
            mvrc_labels_padded = mvrc_labels.new_zeros(
                (mvrc_labels.shape[0], origin_len,
                 mvrc_labels.shape[2])).fill_(0.0)
            mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels
            mvrc_labels = mvrc_labels_padded

        outputs.update({
            'relationship_logits':
            relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None,
            'relationship_label':
            relationship_label if self.config.NETWORK.WITH_REL_LOSS else None,
            'mlm_logits':
            mlm_logits if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_label':
            mlm_labels if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mvrc_logits':
            mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'mvrc_label':
            mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'relationship_loss':
            relationship_loss,
            'mlm_loss':
            mlm_loss,
            'mvrc_loss':
            mvrc_loss,
            'generated_sentences':
            generated_sentences
        })

        loss = relationship_loss.mean() + mlm_loss.mean() + mvrc_loss.mean()

        return outputs, loss
    def forward(self, text, relationship_label, mlm_labels):
        ###########################################

        ############################################

        # prepare text
        text_input_ids = text
        text_tags = text.new_zeros(text.shape)
        text_token_type_ids = text.new_zeros(text.shape)

        # ***** FM edit: blank out visual embeddings for translation retrieval task
        text_visual_embeddings = text_input_ids.new_zeros(
            (text_input_ids.shape[0], text_input_ids.shape[1], 768),
            dtype=torch.float)
        # text_visual_embeddings[:] = self.aux_text_visual_embedding.weight[0]

        # ****** FM edit: blank visual embeddings (use known dimensions)
        object_vl_embeddings = text_input_ids.new_zeros(
            (text_input_ids.shape[0], 1, 1536), dtype=torch.float)

        # FM edit: No auxiliary text is used for text only
        # add auxiliary text - Concatenates the batches from the two dataloaders
        # The visual features for the text only corpus is just the embedding of the aux_visual_embedding (only one embedding)
        max_text_len = text_input_ids.shape[1]
        text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape)
        text_mask = (text_input_ids > 0)
        #FM: Edit: set to zero to ignore vision
        box_mask = text_input_ids.new_zeros((text_input_ids.shape[0], 1),
                                            dtype=torch.uint8)

        ###########################################
        # Visual Linguistic BERT
        # #loop here for test mode:
        generated = []
        stop = [False] * text.shape[0]
        curr_len = 0
        max_len = 48
        while not all(stop) and curr_len <= max_len:
            relationship_logits, mlm_logits, mvrc_logits = self.vlbert(
                text_input_ids, text_token_type_ids, text_visual_embeddings,
                text_mask, object_vl_embeddings, box_mask)
            answers = torch.topk(mlm_logits[mlm_labels == 103], k=1, dim=1)

            # Get size of each tensor
            position_tensor = torch.arange(mlm_labels.shape[1])
            position_tensor = position_tensor.repeat(mlm_labels.shape[0]).view(
                mlm_labels.shape[0], -1)
            indeces = position_tensor[mlm_labels == 103]

            # 1. Update mlm_labels:
            mlm_labels_new = mlm_labels.new_zeros(mlm_labels.shape[0],
                                                  mlm_labels.shape[1] + 1)
            mlm_labels_new = mlm_labels_new - 1
            mlm_labels_new[torch.arange(mlm_labels.shape[0]),
                           indeces + 1] = 103
            mlm_labels = mlm_labels_new

            # 2. Update text_input_ids:
            text_input_ids_new = text_input_ids.new_zeros(
                text_input_ids.shape[0], text_input_ids.shape[1] + 1)
            text_input_ids_new[:, :-1] = text_input_ids
            text_input_ids_new[torch.arange(text_input_ids.shape[0]),
                               indeces] = answers[1][:, 0]
            text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces +
                               1] = (self.tokenizer.convert_tokens_to_ids(
                                   ['[MASK]'])[0])
            text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces +
                               2] = (self.tokenizer.convert_tokens_to_ids(
                                   ['[PAD]'])[0])
            text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces +
                               3] = (self.tokenizer.convert_tokens_to_ids(
                                   ['[SEP]'])[0])
            text_input_ids = text_input_ids_new

            # 3. Update text_token_type_ids:
            text_token_type_ids = text_token_type_ids.new_zeros(
                text_token_type_ids.shape[0], text_token_type_ids.shape[1] + 1)

            # 4. Update text_input_ids:
            text_visual_embeddings_new = text_visual_embeddings.new_zeros(
                text_visual_embeddings.shape[0],
                text_visual_embeddings.shape[1] + 1,
                text_visual_embeddings.shape[2])
            text_visual_embeddings_new = text_visual_embeddings_new.transpose(
                0, 1)
            text_visual_embeddings_new[:] = text_visual_embeddings[:, 0, :]
            text_visual_embeddings = text_visual_embeddings_new.transpose(0, 1)

            # 5. Update text_mask:
            text_mask = (text_input_ids > 0)

            # 6. Add to generated
            for nid, row in enumerate(answers[1]):
                if curr_len == 0:
                    generated.append([])
                for ele in row:
                    # try:
                    if not stop[nid]:
                        if self.tokenizer.ids_to_tokens[
                                ele.item()] == '[STOP]':
                            stop[nid] = True
                        else:
                            generated[nid].append(
                                self.tokenizer.ids_to_tokens[ele.item()])
                    # except:
                    #     generated[nid].append(self.tokenizer.ids_to_tokens[100])
            curr_len += 1

        # Join in sentences
        generated_sentences = []
        for sentence in generated:
            new_sentence = ' '.join(sentence)
            generated_sentences.append(new_sentence.replace(' ##', ''))

        ###########################################
        outputs = {}

        if self.config.NETWORK.WITH_REL_LOSS:
            relationship_loss = F.cross_entropy(relationship_logits,
                                                relationship_label)
        if self.config.NETWORK.WITH_MLM_LOSS:
            mlm_logits_padded = mlm_logits.new_zeros(
                (*mlm_labels.shape, mlm_logits.shape[-1])).fill_(-10000.0)
            mlm_logits_padded[:, :mlm_logits.shape[1]] = mlm_logits
            mlm_logits = mlm_logits_padded
            if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST:
                mlm_loss = F.cross_entropy(mlm_logits.transpose(1, 2),
                                           mlm_labels,
                                           ignore_index=-1,
                                           reduction='none')
                num_mlm = (mlm_labels != -1).sum(
                    1, keepdim=True).to(dtype=mlm_loss.dtype)
                num_has_mlm = (num_mlm != 0).sum().to(dtype=mlm_loss.dtype)
                mlm_loss = (mlm_loss /
                            (num_mlm + 1e-4)).sum() / (num_has_mlm + 1e-4)
            else:
                mlm_loss = F.cross_entropy(mlm_logits.view(
                    (-1, mlm_logits.shape[-1])),
                                           mlm_labels.view(-1),
                                           ignore_index=-1)

        if self.config.NETWORK.WITH_MVRC_LOSS:
            if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]),
                    reduction='none').view(mvrc_logits.shape[:-1])
                valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1
                mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \
                                .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4)
            else:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]))

            mvrc_logits_padded = mvrc_logits.new_zeros(
                (mvrc_logits.shape[0], origin_len,
                 mvrc_logits.shape[2])).fill_(-10000.0)
            mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits
            mvrc_logits = mvrc_logits_padded
            mvrc_labels_padded = mvrc_labels.new_zeros(
                (mvrc_labels.shape[0], origin_len,
                 mvrc_labels.shape[2])).fill_(0.0)
            mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels
            mvrc_labels = mvrc_labels_padded

        outputs.update({
            'relationship_logits':
            relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None,
            'relationship_label':
            relationship_label if self.config.NETWORK.WITH_REL_LOSS else None,
            'mlm_logits':
            mlm_logits if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_label':
            mlm_labels if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mvrc_logits':
            mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'mvrc_label':
            mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'mlm_loss':
            mlm_loss,
            'generated_sentences':
            generated_sentences
        })

        loss = mlm_loss.mean()

        return outputs, loss
    def forward(self, text, relationship_label, mlm_labels):
        ###########################################

        # FM edit: remove visual feature extraction

        ############################################

        # prepare text
        text_input_ids = text
        text_tags = text.new_zeros(text.shape)
        text_token_type_ids = text.new_zeros(text.shape)
        text_mask = (text_input_ids > 0)
        # text_visual_embeddings = self._collect_obj_reps(text_tags, obj_reps['obj_reps'])
        # ***** FM edit: blank out visual embeddings for translation retrieval task
        text_visual_embeddings = text_input_ids.new_zeros(
            (text_input_ids.shape[0], text_input_ids.shape[1], 768),
            dtype=torch.float)

        # ****** FM edit: blank visual embeddings (use known dimensions)
        object_vl_embeddings = text_input_ids.new_zeros(
            (text_input_ids.shape[0], 1, 1536), dtype=torch.float)
        #FM: Edit: set to zero to ignore vision
        box_mask = text_input_ids.new_zeros((text_input_ids.shape[0], 1),
                                            dtype=torch.uint8)

        ###########################################

        # Visual Linguistic BERT

        relationship_logits, mlm_logits, mvrc_logits = self.vlbert(
            text_input_ids, text_token_type_ids, text_visual_embeddings,
            text_mask, object_vl_embeddings, box_mask)

        ###########################################
        outputs = {}

        # loss
        # relationship_loss = im_info.new_zeros(())
        # mlm_loss = im_info.new_zeros(())
        # mvrc_loss = im_info.new_zeros(())
        if self.config.NETWORK.WITH_REL_LOSS:
            relationship_loss = F.cross_entropy(relationship_logits,
                                                relationship_label)
        if self.config.NETWORK.WITH_MLM_LOSS:
            mlm_logits_padded = mlm_logits.new_zeros(
                (*mlm_labels.shape, mlm_logits.shape[-1])).fill_(-10000.0)
            mlm_logits_padded[:, :mlm_logits.shape[1]] = mlm_logits
            mlm_logits = mlm_logits_padded
            if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST:
                mlm_loss = F.cross_entropy(mlm_logits.transpose(1, 2),
                                           mlm_labels,
                                           ignore_index=-1,
                                           reduction='none')
                num_mlm = (mlm_labels != -1).sum(
                    1, keepdim=True).to(dtype=mlm_loss.dtype)
                num_has_mlm = (num_mlm != 0).sum().to(dtype=mlm_loss.dtype)
                mlm_loss = (mlm_loss /
                            (num_mlm + 1e-4)).sum() / (num_has_mlm + 1e-4)
            else:
                mlm_loss = F.cross_entropy(mlm_logits.view(
                    (-1, mlm_logits.shape[-1])),
                                           mlm_labels.view(-1),
                                           ignore_index=-1)
        # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
        #                             mvrc_labels.contiguous().view(-1),
        #                             ignore_index=-1)
        if self.config.NETWORK.WITH_MVRC_LOSS:
            if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]),
                    reduction='none').view(mvrc_logits.shape[:-1])
                valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1
                mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \
                                .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4)
            else:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]))

            mvrc_logits_padded = mvrc_logits.new_zeros(
                (mvrc_logits.shape[0], origin_len,
                 mvrc_logits.shape[2])).fill_(-10000.0)
            mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits
            mvrc_logits = mvrc_logits_padded
            mvrc_labels_padded = mvrc_labels.new_zeros(
                (mvrc_labels.shape[0], origin_len,
                 mvrc_labels.shape[2])).fill_(0.0)
            mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels
            mvrc_labels = mvrc_labels_padded

        outputs.update({
            'relationship_logits':
            relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None,
            'relationship_label':
            relationship_label if self.config.NETWORK.WITH_REL_LOSS else None,
            'mlm_logits':
            mlm_logits if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_label':
            mlm_labels if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mvrc_logits':
            mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'mvrc_label':
            mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'mlm_loss':
            mlm_loss,
        })

        loss = mlm_loss.mean()

        return outputs, loss