def __init__(self, config):

        super(ResNetVLBERTForPretrainingEncDecGenerate, self).__init__(config)

        self.image_feature_extractor = FastRCNN(
            config,
            average_pool=True,
            final_dim=config.NETWORK.IMAGE_FINAL_DIM,
            enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        if config.NETWORK.IMAGE_FEAT_PRECOMPUTED:
            self.object_mask_visual_embedding = nn.Embedding(1, 2048)
        if config.NETWORK.WITH_MVRC_LOSS:
            self.object_mask_word_embedding = nn.Embedding(
                1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)
        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path

        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBertEncoder(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=None
            if config.NETWORK.VLBERT.from_scratch else
            language_pretrained_model_path,
            with_rel_head=False,
            with_mlm_head=False,
            with_mvrc_head=False,
        )

        # FM addition: add decoder
        self.decoder = VisualLinguisticBertForPretrainingDecoder(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=None
            if config.NETWORK.VLBERT.from_scratch else
            language_pretrained_model_path,
            with_rel_head=config.NETWORK.WITH_REL_LOSS,
            with_mlm_head=config.NETWORK.WITH_MLM_LOSS,
            with_mvrc_head=config.NETWORK.WITH_MVRC_LOSS,
        )

        # init weights
        self.init_weight()

        self.fix_params()
Ejemplo n.º 2
0
    def __init__(self, config):

        super(ResNetVLBERTForPretrainingMultitaskNoVision,
              self).__init__(config)

        # Constructs/initialises model elements
        self.image_feature_extractor = FastRCNN(
            config,
            average_pool=True,
            final_dim=config.NETWORK.IMAGE_FINAL_DIM,
            enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        if config.NETWORK.IMAGE_FEAT_PRECOMPUTED or (
                not config.NETWORK.MASK_RAW_PIXELS):
            self.object_mask_visual_embedding = nn.Embedding(1, 2048)
        if config.NETWORK.WITH_MVRC_LOSS:
            self.object_mask_word_embedding = nn.Embedding(
                1, config.NETWORK.VLBERT.hidden_size)
        self.aux_text_visual_embedding = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        # Can specify pre-trained model or use the downloaded pretrained model specific in .yaml file
        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            # language_pretrained_model_path = '{}-{:04d}.model'.format(config.NETWORK.BERT_PRETRAINED,
            #                                                           config.NETWORK.BERT_PRETRAINED_EPOCH)
            #FM edit: just use path of pretrained model
            language_pretrained_model_path = config.NETWORK.BERT_PRETRAINED
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path

        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBertForPretraining(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=None
            if config.NETWORK.VLBERT.from_scratch else
            language_pretrained_model_path,
            with_rel_head=config.NETWORK.WITH_REL_LOSS,
            with_mlm_head=config.NETWORK.WITH_MLM_LOSS,
            with_mvrc_head=config.NETWORK.WITH_MVRC_LOSS,
            with_MLT_head=config.NETWORK.WITH_MLT_LOSS)

        # init weights
        self.init_weight()

        self.fix_params()
Ejemplo n.º 3
0
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.image_feature_extractor = FastRCNN(
            config,
            average_pool=True,
            final_dim=config.NETWORK.IMAGE_FINAL_DIM,
            enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        transform = VisualLinguisticBertMVRCHeadTransform(
            config.NETWORK.VLBERT)
        # self.linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, 768) #331 1000 35 100 12003 lihui
        # self.OIM_loss = OIM_Module(331, 768)  # config.NETWORK.VLBERT.hidden_size)
        self.OIM_loss = OIM_Module(12003, 768)
        self.linear = nn.Sequential(
            # transform,
            nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
            nn.Linear(config.NETWORK.VLBERT.hidden_size,
                      768)  #331 1000 35 100 12003 lihui
        )

        linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, 1)
        self.final_mlp = nn.Sequential(
            transform,
            nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
            linear)

        # init weights
        self.init_weight()

        self.fix_params()
Ejemplo n.º 4
0
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.image_feature_extractor = FastRCNN(
            config,
            average_pool=True,
            final_dim=config.NETWORK.IMAGE_FINAL_DIM,
            enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        self.task1_head = Task1Head(config.NETWORK.VLBERT)
        self.task2_head = Task2Head(config.NETWORK.VLBERT)
        self.task3_head = Task3Head(config.NETWORK.VLBERT)

        # init weights
        self.init_weight()

        self.fix_params()
Ejemplo n.º 5
0
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)
        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        self.cnn_loss_top = config.NETWORK.CNN_LOSS_TOP
        self.align_caption_img = config.DATASET.ALIGN_CAPTION_IMG
        self.use_phrasal_paraphrases = config.DATASET.PHRASE_CLS
        self.supervise_attention = config.NETWORK.SUPERVISE_ATTENTION
        self.normalization = config.NETWORK.ATTENTION_NORM_METHOD
        self.ewc_reg = config.NETWORK.EWC_REG
        self.importance_hparam = 0.
        if config.NETWORK.EWC_REG:
            self.fisher = pickle.load(open(config.NETWORK.FISHER_PATH, "rb"))
            self.pretrain_param = torch.load(config.NETWORK.PARAM_PRETRAIN)
            self.importance_hparam = config.NETWORK.EWC_IMPORTANCE
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(
                config,
                average_pool=True,
                final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                enable_cnn_reg_loss=(self.enable_cnn_reg_loss
                                     and not self.cnn_loss_top))
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(
                    81, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:
                self.object_linguistic_embeddings = nn.Embedding(
                    1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError

        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        if 'roberta' in config.NETWORK.BERT_MODEL_NAME:
            self.tokenizer = RobertaTokenizer.from_pretrained(
                config.NETWORK.BERT_MODEL_NAME)
        else:
            self.tokenizer = BertTokenizer.from_pretrained(
                config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path

        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        self.for_pretrain = False
        dim = config.NETWORK.VLBERT.hidden_size
        if self.align_caption_img:
            sentence_logits_shape = 3
        else:
            sentence_logits_shape = 1
        if config.NETWORK.SENTENCE.CLASSIFIER_TYPE == "2fc":
            self.sentence_cls = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(
                    dim, config.NETWORK.SENTENCE.CLASSIFIER_HIDDEN_SIZE),
                torch.nn.ReLU(inplace=True),
                torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(config.NETWORK.SENTENCE.CLASSIFIER_HIDDEN_SIZE,
                                sentence_logits_shape),
            )
        elif config.NETWORK.SENTENCE.CLASSIFIER_TYPE == "1fc":
            self.sentence_cls = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, sentence_logits_shape))
        else:
            raise ValueError("Classifier type: {} not supported!".format(
                config.NETWORK.SENTENCE.CLASSIFIER_TYPE))

        if self.use_phrasal_paraphrases:
            if config.NETWORK.PHRASE.CLASSIFIER_TYPE == "2fc":
                self.phrasal_cls = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.PHRASE.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(
                        4 * dim, config.NETWORK.PHRASE.CLASSIFIER_HIDDEN_SIZE),
                    torch.nn.ReLU(inplace=True),
                    torch.nn.Dropout(config.NETWORK.PHRASE.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(
                        config.NETWORK.PHRASE.CLASSIFIER_HIDDEN_SIZE, 5),
                )
            elif config.NETWORK.PHRASE.CLASSIFIER_TYPE == "1fc":
                self.phrasal_cls = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.PHRASE.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(4 * dim, 5))
            else:
                raise ValueError("Classifier type: {} not supported!".format(
                    config.NETWORK.PHRASE.CLASSIFIER_TYPE))

        if self.supervise_attention == "indirect":
            if config.NETWORK.VG.CLASSIFIER_TYPE == "2fc":
                self.vg_cls = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.VG.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(2 * dim,
                                    config.NETWORK.VG.CLASSIFIER_HIDDEN_SIZE),
                    torch.nn.ReLU(inplace=True),
                    torch.nn.Dropout(config.NETWORK.VG.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(config.NETWORK.VG.CLASSIFIER_HIDDEN_SIZE,
                                    1),
                )
            elif config.NETWORK.VG.CLASSIFIER_TYPE == "1fc":
                self.vg_cls = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.VG.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(2 * dim, 1))
            else:
                raise ValueError("Classifier type: {} not supported!".format(
                    config.NETWORK.PHRASE.CLASSIFIER_TYPE))

        # init weights
        self.init_weight()

        self.fix_params()
Ejemplo n.º 6
0
class ResNetVLBERT(Module):
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)
        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        self.cnn_loss_top = config.NETWORK.CNN_LOSS_TOP
        self.align_caption_img = config.DATASET.ALIGN_CAPTION_IMG
        self.use_phrasal_paraphrases = config.DATASET.PHRASE_CLS
        self.supervise_attention = config.NETWORK.SUPERVISE_ATTENTION
        self.normalization = config.NETWORK.ATTENTION_NORM_METHOD
        self.ewc_reg = config.NETWORK.EWC_REG
        self.importance_hparam = 0.
        if config.NETWORK.EWC_REG:
            self.fisher = pickle.load(open(config.NETWORK.FISHER_PATH, "rb"))
            self.pretrain_param = torch.load(config.NETWORK.PARAM_PRETRAIN)
            self.importance_hparam = config.NETWORK.EWC_IMPORTANCE
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(
                config,
                average_pool=True,
                final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                enable_cnn_reg_loss=(self.enable_cnn_reg_loss
                                     and not self.cnn_loss_top))
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(
                    81, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:
                self.object_linguistic_embeddings = nn.Embedding(
                    1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError

        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        if 'roberta' in config.NETWORK.BERT_MODEL_NAME:
            self.tokenizer = RobertaTokenizer.from_pretrained(
                config.NETWORK.BERT_MODEL_NAME)
        else:
            self.tokenizer = BertTokenizer.from_pretrained(
                config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path

        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        self.for_pretrain = False
        dim = config.NETWORK.VLBERT.hidden_size
        if self.align_caption_img:
            sentence_logits_shape = 3
        else:
            sentence_logits_shape = 1
        if config.NETWORK.SENTENCE.CLASSIFIER_TYPE == "2fc":
            self.sentence_cls = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(
                    dim, config.NETWORK.SENTENCE.CLASSIFIER_HIDDEN_SIZE),
                torch.nn.ReLU(inplace=True),
                torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(config.NETWORK.SENTENCE.CLASSIFIER_HIDDEN_SIZE,
                                sentence_logits_shape),
            )
        elif config.NETWORK.SENTENCE.CLASSIFIER_TYPE == "1fc":
            self.sentence_cls = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, sentence_logits_shape))
        else:
            raise ValueError("Classifier type: {} not supported!".format(
                config.NETWORK.SENTENCE.CLASSIFIER_TYPE))

        if self.use_phrasal_paraphrases:
            if config.NETWORK.PHRASE.CLASSIFIER_TYPE == "2fc":
                self.phrasal_cls = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.PHRASE.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(
                        4 * dim, config.NETWORK.PHRASE.CLASSIFIER_HIDDEN_SIZE),
                    torch.nn.ReLU(inplace=True),
                    torch.nn.Dropout(config.NETWORK.PHRASE.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(
                        config.NETWORK.PHRASE.CLASSIFIER_HIDDEN_SIZE, 5),
                )
            elif config.NETWORK.PHRASE.CLASSIFIER_TYPE == "1fc":
                self.phrasal_cls = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.PHRASE.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(4 * dim, 5))
            else:
                raise ValueError("Classifier type: {} not supported!".format(
                    config.NETWORK.PHRASE.CLASSIFIER_TYPE))

        if self.supervise_attention == "indirect":
            if config.NETWORK.VG.CLASSIFIER_TYPE == "2fc":
                self.vg_cls = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.VG.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(2 * dim,
                                    config.NETWORK.VG.CLASSIFIER_HIDDEN_SIZE),
                    torch.nn.ReLU(inplace=True),
                    torch.nn.Dropout(config.NETWORK.VG.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(config.NETWORK.VG.CLASSIFIER_HIDDEN_SIZE,
                                    1),
                )
            elif config.NETWORK.VG.CLASSIFIER_TYPE == "1fc":
                self.vg_cls = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.VG.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(2 * dim, 1))
            else:
                raise ValueError("Classifier type: {} not supported!".format(
                    config.NETWORK.PHRASE.CLASSIFIER_TYPE))

        # init weights
        self.init_weight()

        self.fix_params()

    def init_weight(self):
        if not self.config.NETWORK.BLIND:
            self.image_feature_extractor.init_weight()
            if self.object_linguistic_embeddings is not None:
                self.object_linguistic_embeddings.weight.data.normal_(mean=0.0,
                                                                      std=0.02)

        if not self.for_pretrain:
            for m in self.sentence_cls.modules():
                if isinstance(m, torch.nn.Linear):
                    torch.nn.init.xavier_uniform_(m.weight)
                    torch.nn.init.constant_(m.bias, 0)

    def train(self, mode=True):
        super(ResNetVLBERT, self).train(mode)
        # turn some frozen layers to eval mode
        if (not self.config.NETWORK.BLIND) and self.image_feature_bn_eval:
            self.image_feature_extractor.bn_eval()

    def fix_params(self):
        if self.config.NETWORK.BLIND:
            self.vlbert._module.visual_scale_text.requires_grad = False
            self.vlbert._module.visual_scale_object.requires_grad = False

    def _collect_obj_reps(self, span_tags, object_reps):
        """
        Collect span-level object representations
        :param span_tags: [batch_size, ..leading_dims.., L]
        :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim]
        :return:
        """

        span_tags_fixed = torch.clamp(
            span_tags, min=0)  # In case there were masked values here
        row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape)
        row_id_broadcaster = torch.arange(0,
                                          row_id.shape[0],
                                          step=1,
                                          device=row_id.device)[:, None]

        # Add extra dimensions to the row broadcaster so it matches row_id
        leading_dims = len(span_tags.shape) - 2
        for i in range(leading_dims):
            row_id_broadcaster = row_id_broadcaster[..., None]
        row_id += row_id_broadcaster
        return object_reps[row_id.view(-1),
                           span_tags_fixed.view(-1)].view(
                               *span_tags_fixed.shape, -1)

    def prepare_text(self, sentence1, sentence2, mask1, mask2, sentence1_tags,
                     sentence2_tags, phrase1_mask, phrase2_mask):
        batch_size, max_len1 = sentence1.shape
        _, max_len2 = sentence2.shape
        max_len = (mask1.sum(1) + mask2.sum(1)).max() + 3
        cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(
            ['[CLS]', '[SEP]'])
        end_1 = 1 + mask1.sum(1, keepdim=True)
        end_2 = end_1 + 1 + mask2.sum(1, keepdim=True)
        input_ids = torch.zeros((batch_size, max_len),
                                dtype=sentence1.dtype,
                                device=sentence1.device)
        input_mask = torch.ones((batch_size, max_len),
                                dtype=torch.uint8,
                                device=sentence1.device)
        input_type_ids = torch.zeros((batch_size, max_len),
                                     dtype=sentence1.dtype,
                                     device=sentence1.device)
        text_tags = input_type_ids.new_zeros((batch_size, max_len))
        phr_mask = None
        grid_i, grid_k = torch.meshgrid(
            torch.arange(batch_size, device=sentence1.device),
            torch.arange(max_len, device=sentence1.device))

        input_mask[grid_k > end_2] = 0
        input_type_ids[(grid_k > end_1) & (grid_k <= end_2)] = 1
        input_mask1 = (grid_k > 0) & (grid_k < end_1)
        input_mask2 = (grid_k > end_1) & (grid_k < end_2)
        input_ids[:, 0] = cls_id
        input_ids[grid_k == end_1] = sep_id
        input_ids[grid_k == end_2] = sep_id
        input_ids[input_mask1] = sentence1[mask1]
        input_ids[input_mask2] = sentence2[mask2]
        text_tags[input_mask1] = sentence1_tags[mask1]
        text_tags[input_mask2] = sentence2_tags[mask2]
        if self.use_phrasal_paraphrases:
            phr_mask = phrase1_mask.new_zeros(
                (batch_size, max_len, phrase1_mask.size(-1)))
            phr_mask[input_mask1] = phrase1_mask[mask1]
            phr_mask[input_mask2] = phrase2_mask[mask2]

            # add offsets so that every pair of phrases gets a unique id in the batch
            no_phr_mask = (phr_mask == 0)
            n_phr = torch.max(phr_mask, dim=1)[0]
            offsets = phr_mask.new_zeros(
                (phr_mask.size(0) * phr_mask.size(-1)))
            offsets[1:] = torch.cumsum(n_phr.view(-1)[:-1], dim=0)
            offsets = offsets.view((phr_mask.size(0), phr_mask.size(-1)))
            phr_mask += offsets.unsqueeze(1)
            phr_mask[no_phr_mask] = 0

        return input_ids, input_type_ids, text_tags, input_mask, phr_mask

    def train_forward(self, images, boxes, sentence1, sentence2, im_info,
                      label):
        ###########################################
        # visual feature extraction

        box_mask = (boxes[:, :, -1] > -0.5)
        max_len = int(box_mask.sum(1).max().item())

        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len].type(torch.float32)

        # segms = segms[:, :max_len]
        if self.config.NETWORK.BLIND:
            obj_reps = {
                'obj_reps':
                boxes.new_zeros(
                    (*boxes.shape[:-1], self.config.NETWORK.IMAGE_FINAL_DIM))
            }
        else:
            obj_reps = self.image_feature_extractor(images=images,
                                                    boxes=boxes,
                                                    box_mask=box_mask,
                                                    im_info=im_info,
                                                    classes=None,
                                                    segms=None)

        sentence1_ids = sentence1[:, :, 0]
        mask1 = (sentence1[:, :, 0] > 0.5)
        sentence1_tags = sentence1[:, :, 1]

        sentence2_ids = sentence2[:, :, 0]
        mask2 = (sentence2[:, :, 0] > 0.5)
        sentence2_tags = sentence2[:, :, 1]

        if self.use_phrasal_paraphrases:
            phrase1_mask = sentence1[:, :, 2:]
            phrase2_mask = sentence2[:, :, 2:]
            sentence_label = label[:, 0, 0].view(-1)
            phrase_labels = label[:, :, 1]
        else:
            phrase1_mask, phrase2_mask = None, None
            sentence_label = label.view(-1)

        ############################################

        # prepare text
        text_input_ids, text_token_type_ids, text_tags, text_mask, phrase_mask = self.prepare_text(
            sentence1_ids, sentence2_ids, mask1, mask2, sentence1_tags,
            sentence2_tags, phrase1_mask, phrase2_mask)

        # Add visual feature to text elements
        if self.config.NETWORK.NO_GROUNDING:
            text_visual_embeddings = self._collect_obj_reps(
                text_tags.new_zeros(text_tags.size()), obj_reps['obj_reps'])
        else:
            text_visual_embeddings = self._collect_obj_reps(
                text_tags, obj_reps['obj_reps'])
        # Add textual feature to image element
        if self.config.NETWORK.BLIND:
            object_linguistic_embeddings = boxes.new_zeros(
                (*boxes.shape[:-1], self.config.NETWORK.VLBERT.hidden_size))
        else:
            object_linguistic_embeddings = self.object_linguistic_embeddings(
                boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT
        if self.config.NETWORK.NO_OBJ_ATTENTION or self.config.NETWORK.BLIND:
            box_mask.zero_()

        if self.supervise_attention in ["direct", "semi-direct"]:
            hidden_states_text, hidden_states_objects, pooled_rep, attention_probs = \
                self.vlbert(text_input_ids,
                            text_token_type_ids,
                            text_visual_embeddings,
                            text_mask,
                            object_vl_embeddings,
                            box_mask,
                            output_all_encoded_layers=False,
                            output_text_and_object_separately=True,
                            output_attention_probs=True)
        else:
            hidden_states_text, hidden_states_objects, pooled_rep = self.vlbert(
                text_input_ids,
                text_token_type_ids,
                text_visual_embeddings,
                text_mask,
                object_vl_embeddings,
                box_mask,
                output_all_encoded_layers=False,
                output_text_and_object_separately=True,
                output_attention_probs=False)

        ###########################################
        outputs = {}

        # sentence classification
        sentence_logits = self.sentence_cls(pooled_rep)
        if self.align_caption_img:
            sentence_logits = sentence_logits.view((-1, 3))
            sentence_cls_loss = F.cross_entropy(sentence_logits,
                                                sentence_label)
        else:
            sentence_logits = sentence_logits.view(-1)
            sentence_cls_loss = F.binary_cross_entropy_with_logits(
                sentence_logits, sentence_label.type(torch.float32))

        outputs.update({
            'sentence_label_logits': sentence_logits,
            'sentence_label': sentence_label.long(),
            'sentence_cls_loss': sentence_cls_loss
        })

        # phrasal paraphrases classification
        phrase_cls_loss = sentence_logits.new_zeros(())
        if self.use_phrasal_paraphrases:
            phrase_labels = phrase_labels.view((-1))
            phrase_cls_logits = sentence_logits.new_zeros(
                (phrase_labels.size(0), 5))
            outputs.update({
                "phrase_label": phrase_labels,
                "phrase_label_logits": phrase_cls_logits,
                "phrase_cls_loss": phrase_cls_loss
            })
            if phrase_mask.max() > 0:
                logits = self.get_phrase_cls(hidden_states_text, phrase_mask,
                                             text_token_type_ids)
                phrase_cls_loss = F.cross_entropy(
                    logits,
                    phrase_labels[phrase_labels > -1],
                    reduction="mean")
                phrase_cls_logits[(phrase_labels > -1)] = logits
                outputs.update({
                    "phrase_label_logits": phrase_cls_logits,
                    "phrase_cls_loss": phrase_cls_loss
                })

        # Handle attention supervision, suffix 1 refers to text-to-roi attention and suffix 2 refers to roi-to-text
        attention_loss = 0.
        if self.supervise_attention in ["direct", "semi-direct"]:
            use_raw = self.supervise_attention == "direct"
            attention_loss_1, attention_loss_2 = get_attention_supervision_loss(
                attention_probs,
                text_tags,
                text_mask,
                box_mask,
                use_raw=use_raw,
                normalization=self.normalization)
            outputs.update({
                "attention_loss_1": attention_loss_1,
                "attention_loss_2": attention_loss_2
            })
            attention_loss = attention_loss_1 + attention_loss_2

        elif self.supervise_attention == "indirect":
            attention_loss = self.get_indirect_vg_loss(hidden_states_text,
                                                       hidden_states_objects,
                                                       text_tags, text_mask,
                                                       box_mask)
            outputs.update({"vg_loss": attention_loss})

        # EWC regularization loss against catastrophic forgetting
        ewc_loss = 0.
        if self.ewc_reg:
            for n, p in self.named_parameters():
                name = "module." + n
                if name in self.fisher.keys():
                    ewc_loss += (
                        self.fisher[name].to(p.device) *
                        (p - self.pretrain_param[name].to(p.device))**2).sum()
            outputs.update({"ewc_loss": ewc_loss})

        loss = sentence_cls_loss.mean() + self.config.NETWORK.PHRASE_LOSS_WEIGHT * phrase_cls_loss + \
               self.config.NETWORK.ATTENTION_LOSS_WEIGHT * attention_loss + self.importance_hparam * ewc_loss

        return outputs, loss

    def inference_forward(self, images, boxes, sentence1, sentence2, im_info):
        ###########################################
        # visual feature extraction

        # For now use all boxes
        box_mask = torch.ones(boxes[:, :, -1].size(),
                              dtype=torch.uint8,
                              device=boxes.device)

        max_len = int(box_mask.sum(1).max().item())
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len].type(torch.float32)

        if self.config.NETWORK.BLIND:
            obj_reps = {
                'obj_reps':
                boxes.new_zeros(
                    (*boxes.shape[:-1], self.config.NETWORK.IMAGE_FINAL_DIM))
            }
        else:
            obj_reps = self.image_feature_extractor(images=images,
                                                    boxes=boxes,
                                                    box_mask=box_mask,
                                                    im_info=im_info,
                                                    classes=None,
                                                    segms=None)

        # For now no tags
        sentence1_ids = sentence1[:, :, 0]
        mask1 = (sentence1[:, :, 0] > 0.5)
        sentence1_tags = sentence1[:, :, 1]
        sentence2_ids = sentence2[:, :, 0]
        mask2 = (sentence2[:, :, 0] > 0.5)
        sentence2_tags = sentence2[:, :, 1]

        if self.use_phrasal_paraphrases:
            phrase1_mask = sentence1[:, :, 2:]
            phrase2_mask = sentence2[:, :, 2:]
        else:
            phrase1_mask, phrase2_mask = None, None

        ############################################

        # prepare text
        text_input_ids, text_token_type_ids, text_tags, text_mask, phrase_mask = self.prepare_text(
            sentence1_ids, sentence2_ids, mask1, mask2, sentence1_tags,
            sentence2_tags, phrase1_mask, phrase2_mask)

        # Add visual feature to text elements
        if self.config.NETWORK.NO_GROUNDING:
            text_tags.zero_()
        text_visual_embeddings = self._collect_obj_reps(
            text_tags, obj_reps['obj_reps'])
        # Add textual feature to image element
        if self.config.NETWORK.BLIND:
            object_linguistic_embeddings = boxes.new_zeros(
                (*boxes.shape[:-1], self.config.NETWORK.VLBERT.hidden_size))
        else:
            object_linguistic_embeddings = self.object_linguistic_embeddings(
                boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT
        if self.config.NETWORK.NO_OBJ_ATTENTION or self.config.NETWORK.BLIND:
            box_mask.zero_()
        hidden_states_text, hidden_states_objects, pooled_rep = self.vlbert(
            text_input_ids,
            text_token_type_ids,
            text_visual_embeddings,
            text_mask,
            object_vl_embeddings,
            box_mask,
            output_all_encoded_layers=False,
            output_text_and_object_separately=True)

        ###########################################
        outputs = {}
        # sentence classification
        sentence_logits = self.sentence_cls(pooled_rep)
        if self.align_caption_img:
            sentence_logits = sentence_logits.view((-1, 3))
        else:
            sentence_logits = sentence_logits.view(-1)
        outputs.update({'sentence_label_logits': sentence_logits})

        if self.use_phrasal_paraphrases:
            phrase_cls_logits = sentence_logits.new_zeros((1, 5)) + 100000
            outputs.update({"phrase_label_logits": phrase_cls_logits})
            if phrase_mask.max() > 0:
                phrase_cls_logits = self.get_phrase_cls(
                    hidden_states_text, phrase_mask, text_token_type_ids)
                outputs.update({"phrase_label_logits": phrase_cls_logits})

        return outputs

    def get_phrase_cls(self, encoded_rep, phr_mask, token_type):
        n_pairs = phr_mask.max().item()
        phr_reps = encoded_rep.new_zeros((n_pairs, 2, encoded_rep.size(-1)))
        for i in range(n_pairs):
            # max pool representation of first phrase
            shaped_phr_mask = (phr_mask == i + 1).any(2)
            phr_reps[i, 0] = encoded_rep[(token_type == 0)
                                         & shaped_phr_mask].max(dim=0)[0]
            # max pool representation of second phrase
            phr_reps[i, 1] = encoded_rep[(token_type == 1)
                                         & shaped_phr_mask].max(dim=0)[0]
        final_phrases_rep = torch.cat(
            (phr_reps[:, 0], phr_reps[:, 1],
             torch.abs(phr_reps[:, 0] - phr_reps[:, 1]),
             torch.mul(phr_reps[:, 0], phr_reps[:, 1])),
            dim=1)
        output_logits = self.phrasal_cls(final_phrases_rep)
        return output_logits

    def get_indirect_vg_loss(self, encoded_text, encoded_objects, text_tags,
                             text_mask, box_mask):
        if text_tags.max() <= 0:
            return encoded_text.new_zeros((1)).sum()
        else:
            vg_inputs = []
            vg_labels = []
            indexes = find_phrases(text_tags)
            for i, k, length, tag in indexes:
                phrases_rep = encoded_text[i, k:k + length].max(
                    dim=0)[0]  # max pool encoding of the words in the phrase
                objects_reps = encoded_objects[i][box_mask[i]][1:]
                vg_inputs.append(
                    torch.cat((phrases_rep.unsqueeze(0).repeat(
                        len(objects_reps), 1), objects_reps),
                              dim=1))
                vg_lbl = text_tags.new_zeros((len(objects_reps)))
                vg_lbl[tag - 1] = 1
                vg_labels.append(vg_lbl)
            vg_inputs = torch.cat(vg_inputs, dim=0)
            vg_labels = torch.cat(vg_labels, dim=0)
            vg_logits = self.vg_cls(vg_inputs).view(-1)
            vg_loss = F.binary_cross_entropy_with_logits(
                vg_logits, vg_labels.float())
            return vg_loss
Ejemplo n.º 7
0
class ResNetVLBERT(Module):
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.image_feature_extractor = FastRCNN(
            config,
            average_pool=True,
            final_dim=config.NETWORK.IMAGE_FINAL_DIM,
            enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        self.task1_head = Task1Head(config.NETWORK.VLBERT)
        self.task2_head = Task2Head(config.NETWORK.VLBERT)
        self.task3_head = Task3Head(config.NETWORK.VLBERT)

        # init weights
        self.init_weight()

        self.fix_params()

    def init_weight(self):
        self.image_feature_extractor.init_weight()

        if self.object_linguistic_embeddings is not None:
            self.object_linguistic_embeddings.weight.data.normal_(
                mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)
        for m in self.task1_head.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                torch.nn.init.constant_(m.bias, 0)

        for m in self.task2_head.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                torch.nn.init.constant_(m.bias, 0)

        for m in self.task3_head.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                torch.nn.init.constant_(m.bias, 0)

    def train(self, mode=True):
        super(ResNetVLBERT, self).train(mode)
        # turn some frozen layers to eval mode
        if self.image_feature_bn_eval:
            self.image_feature_extractor.bn_eval()

    def fix_params(self):

        for param in self.image_feature_extractor.parameters():
            param.requires_grad = False

        for param in self.vlbert.parameters():
            param.requires_grad = False

        for param in self.object_linguistic_embeddings.parameters():
            param.requires_grad

    def train_forward(self, image, boxes, im_info, expression, label, pos,
                      target, mask):
        ###########################################

        if self.vlbert.training:
            self.vlbert.eval()

        if self.image_feature_extractor.training:
            self.image_feature_extractor.eval()

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        max_len = int(box_mask.sum(1).max().item())
        origin_len = boxes.shape[1]
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        #our labels for foil are binary and 1 dimension
        #label = label[:, :max_len]

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=boxes,
                                                box_mask=box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None)

        ############################################
        # prepare text
        cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(
            ['[CLS]', '[SEP]'])
        text_input_ids = expression.new_zeros(
            (expression.shape[0], expression.shape[1] + 2))
        text_input_ids[:, 0] = cls_id
        text_input_ids[:, 1:-1] = expression
        _sep_pos = (text_input_ids > 0).sum(1)
        _batch_inds = torch.arange(expression.shape[0],
                                   device=expression.device)
        text_input_ids[_batch_inds, _sep_pos] = sep_id
        text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape)
        text_mask = text_input_ids > 0
        text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze(
            1).repeat((1, text_input_ids.shape[1], 1))

        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT

        sequence_reps, object_reps, _ = self.vlbert(
            text_input_ids,
            text_token_type_ids,
            text_visual_embeddings,
            text_mask,
            object_vl_embeddings,
            box_mask,
            output_text_and_object_separately=True,
            output_all_encoded_layers=False)

        cls_rep = sequence_reps[:, 0, :].squeeze(1)
        sequence_reps = sequence_reps[:, 1:, :]

        cls_log_probs = self.task1_head(cls_rep)

        pos_log_probs = self.task2_head(sequence_reps, text_mask[:, 1:])

        zeros = torch.zeros_like(text_input_ids)
        mask_len = mask.shape[1]
        error_cor_mask = zeros
        error_cor_mask[:, 1:mask_len + 1] = mask
        error_cor_mask = error_cor_mask.bool()
        text_input_ids[error_cor_mask] = self.tokenizer.convert_tokens_to_ids(
            ['[MASK]'])[0]

        sequence_reps, _, _ = self.vlbert(
            text_input_ids,
            text_token_type_ids,
            text_visual_embeddings,
            text_mask,
            object_vl_embeddings,
            box_mask,
            output_text_and_object_separately=True,
            output_all_encoded_layers=False)

        sequence_reps = sequence_reps[:, 1:, :]
        select_index = pos.view(-1, 1).unsqueeze(2).repeat(1, 1, 768)
        select_index[select_index < 0] = 0
        masked_reps = torch.gather(sequence_reps, 1, select_index).squeeze(1)
        cor_log_probs = self.task3_head(masked_reps)

        loss_mask = label.view(-1).float()

        cls_loss = F.binary_cross_entropy(cls_log_probs.view(-1),
                                          label.view(-1).float(),
                                          reduction="none")
        pos_loss = F.nll_loss(
            pos_log_probs, pos, ignore_index=-1,
            reduction="none").view(-1) * loss_mask
        cor_loss = F.nll_loss(cor_log_probs.view(-1, cor_log_probs.shape[-1]),
                              target,
                              ignore_index=0,
                              reduction="none").view(-1) * loss_mask

        loss = cls_loss.mean() + pos_loss.mean() + cor_loss.mean()

        outputs = {
            "cls_logits": cls_log_probs,
            "pos_logits": pos_log_probs,
            "cor_logits": cor_log_probs,
            "cls_label": label,
            "pos_label": pos,
            "cor_label": target
        }

        return outputs, loss

    def inference_forward(self, image, boxes, im_info, expression, label, pos,
                          target):

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        max_len = int(box_mask.sum(1).max().item())
        origin_len = boxes.shape[1]
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        #our labels for foil are binary and 1 dimension
        #label = label[:, :max_len]
        with torch.no_grad():
            obj_reps = self.image_feature_extractor(images=images,
                                                    boxes=boxes,
                                                    box_mask=box_mask,
                                                    im_info=im_info,
                                                    classes=None,
                                                    segms=None)

            ############################################
            # prepare text
            cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(
                ['[CLS]', '[SEP]'])
            text_input_ids = expression.new_zeros(
                (expression.shape[0], expression.shape[1] + 2))
            text_input_ids[:, 0] = cls_id
            text_input_ids[:, 1:-1] = expression
            _sep_pos = (text_input_ids > 0).sum(1)
            _batch_inds = torch.arange(expression.shape[0],
                                       device=expression.device)
            text_input_ids[_batch_inds, _sep_pos] = sep_id
            text_token_type_ids = text_input_ids.new_zeros(
                text_input_ids.shape)
            text_mask = text_input_ids > 0
            text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze(
                1).repeat((1, text_input_ids.shape[1], 1))

            object_linguistic_embeddings = self.object_linguistic_embeddings(
                boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
            object_vl_embeddings = torch.cat(
                (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

            ###########################################

            # Visual Linguistic BERT

            sequence_reps, object_reps, _ = self.vlbert(
                text_input_ids,
                text_token_type_ids,
                text_visual_embeddings,
                text_mask,
                object_vl_embeddings,
                box_mask,
                output_text_and_object_separately=True,
                output_all_encoded_layers=False)

        cls_rep = sequence_reps[:, 0, :].squeeze(1)
        sequence_reps = sequence_reps[:, 1:, :]

        cls_log_probs = self.task1_head(cls_rep)

        pos_log_probs = self.task2_head(sequence_reps, text_mask[:, 1:])

        zeros = torch.zeros_like(text_input_ids)
        mask_len = mask.shape[1]
        error_cor_mask = zeros
        error_cor_mask[:, 1:mask_len + 1] = mask
        error_cor_mask = error_cor_mask.bool()
        text_input_ids[error_cor_mask] = self.tokenizer.convert_tokens_to_ids(
            ['[MASK]'])[0]

        with torch.no_grad():
            sequence_reps, _, _ = self.vlbert(
                text_input_ids,
                text_token_type_ids,
                text_visual_embeddings,
                text_mask,
                object_vl_embeddings,
                box_mask,
                output_text_and_object_separately=True,
                output_all_encoded_layers=False)

        sequence_reps = sequence_reps[:, 1:, :]
        select_index = pos.view(-1, 1).unsqueeze(2).repeat(1, 1, 768)
        select_index[select_index < 0] = 0
        masked_reps = torch.gather(sequence_reps, 1, select_index).squeeze(1)
        cor_log_probs = self.task3_head(masked_reps)

        loss_mask = label.view(-1).float()

        cls_loss = F.binary_cross_entropy(cls_log_probs.view(-1),
                                          label.view(-1).float(),
                                          reduction="none")
        pos_loss = F.nll_loss(
            pos_log_probs, pos, ignore_index=-1,
            reduction="none").view(-1) * loss_mask
        cor_loss = F.nll_loss(cor_log_probs.view(-1, cor_log_probs.shape[-1]),
                              target,
                              ignore_index=0,
                              reduction="none").view(-1) * loss_mask

        loss = cls_loss.mean() + pos_loss.mean() + cor_loss.mean()

        outputs = {
            "cls_logits": cls_log_probs,
            "pos_logits": pos_log_probs,
            "cor_logits": cor_log_probs,
            "cls_label": label,
            "pos_label": pos,
            "cor_label": target
        }

        return outputs, loss
class ResNetVLBERT(Module):
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(config,
                                                    average_pool=True,
                                                    final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                                                    enable_cnn_reg_loss=self.enable_cnn_reg_loss)
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(81, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:
                self.object_linguistic_embeddings = nn.Embedding(1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        self.tokenizer = BertTokenizer.from_pretrained(config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(config.NETWORK.BERT_PRETRAINED,
                                                                      config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print("Warning: no pretrained language model found, training from scratch!!!")

        # Also pass the finetuning strategy
        self.vlbert = VisualLinguisticBert(config.NETWORK.VLBERT,
                                         language_pretrained_model_path=language_pretrained_model_path, finetune_strategy=config.FINETUNE_STRATEGY)

        # self.hm_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)
        # self.hi_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)

        dim = config.NETWORK.VLBERT.hidden_size
        if config.NETWORK.CLASSIFIER_TYPE == "2fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE),
                torch.nn.ReLU(inplace=True),
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, config.DATASET.ANSWER_VOCAB_SIZE),
            )
        elif config.NETWORK.CLASSIFIER_TYPE == "1fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                torch.nn.Linear(dim, config.DATASET.ANSWER_VOCAB_SIZE)
            )
        elif config.NETWORK.CLASSIFIER_TYPE == 'mlm':
            transform = BertPredictionHeadTransform(config.NETWORK.VLBERT)
            linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.DATASET.ANSWER_VOCAB_SIZE)
            self.final_mlp = nn.Sequential(
                transform,
                nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                linear
            )
        else:
            raise ValueError("Not support classifier type: {}!".format(config.NETWORK.CLASSIFIER_TYPE))

        # init weights
        self.init_weight()

        self.fix_params()

    def init_weight(self):
        # self.hm_out.weight.data.normal_(mean=0.0, std=0.02)
        # self.hm_out.bias.data.zero_()
        # self.hi_out.weight.data.normal_(mean=0.0, std=0.02)
        # self.hi_out.bias.data.zero_()
        self.image_feature_extractor.init_weight()
        if self.object_linguistic_embeddings is not None:
            self.object_linguistic_embeddings.weight.data.normal_(mean=0.0, std=0.02)
        for m in self.final_mlp.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                torch.nn.init.constant_(m.bias, 0)
        if self.config.NETWORK.CLASSIFIER_TYPE == 'mlm':
            language_pretrained = torch.load(self.language_pretrained_model_path)
            mlm_transform_state_dict = {}
            pretrain_keys = []
            for k, v in language_pretrained.items():
                if k.startswith('cls.predictions.transform.'):
                    pretrain_keys.append(k)
                    k_ = k[len('cls.predictions.transform.'):]
                    if 'gamma' in k_:
                        k_ = k_.replace('gamma', 'weight')
                    if 'beta' in k_:
                        k_ = k_.replace('beta', 'bias')
                    mlm_transform_state_dict[k_] = v
            print("loading pretrained classifier transform keys: {}.".format(pretrain_keys))
            self.final_mlp[0].load_state_dict(mlm_transform_state_dict)

    def train(self, mode=True):
        super(ResNetVLBERT, self).train(mode)
        # turn some frozen layers to eval mode
        if self.image_feature_bn_eval:
            self.image_feature_extractor.bn_eval()

    def fix_params(self):
        pass

    def _collect_obj_reps(self, span_tags, object_reps):
        """
        Collect span-level object representations
        :param span_tags: [batch_size, ..leading_dims.., L]
        :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim]
        :return:
        """

        span_tags_fixed = torch.clamp(span_tags, min=0)  # In case there were masked values here
        row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape)
        row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None]

        # Add extra diminsions to the row broadcaster so it matches row_id
        leading_dims = len(span_tags.shape) - 2
        for i in range(leading_dims):
            row_id_broadcaster = row_id_broadcaster[..., None]
        row_id += row_id_broadcaster
        return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view(*span_tags_fixed.shape, -1)

    def prepare_text_from_qa(self, question, question_tags, question_mask, answer, answer_tags, answer_mask):
        batch_size, max_q_len = question.shape
        _, max_a_len = answer.shape
        max_len = (question_mask.sum(1) + answer_mask.sum(1)).max() + 3
        cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(['[CLS]', '[SEP]'])
        q_end = 1 + question_mask.sum(1, keepdim=True)
        a_end = q_end + 1 + answer_mask.sum(1, keepdim=True)
        input_ids = torch.zeros((batch_size, max_len), dtype=question.dtype, device=question.device)
        input_mask = torch.ones((batch_size, max_len), dtype=torch.bool, device=question.device)
        input_type_ids = torch.zeros((batch_size, max_len), dtype=question.dtype, device=question.device)
        text_tags = input_type_ids.new_zeros((batch_size, max_len))
        grid_i, grid_j = torch.meshgrid(torch.arange(batch_size, device=question.device),
                                        torch.arange(max_len, device=question.device))

        input_mask[grid_j > a_end] = 0
        input_type_ids[(grid_j > q_end) & (grid_j <= a_end)] = 1
        q_input_mask = (grid_j > 0) & (grid_j < q_end)
        a_input_mask = (grid_j > q_end) & (grid_j < a_end)
        input_ids[:, 0] = cls_id
        input_ids[grid_j == q_end] = sep_id
        input_ids[grid_j == a_end] = sep_id
        input_ids[q_input_mask] = question[question_mask]
        input_ids[a_input_mask] = answer[answer_mask]
        text_tags[q_input_mask] = question_tags[question_mask]
        text_tags[a_input_mask] = answer_tags[answer_mask]

        return input_ids, input_type_ids, text_tags, input_mask, (a_end - 1).squeeze(1)

    def train_forward(self,
                      image,
                      boxes,
                      im_info,
                      question,
                      label,
                      policy=None
                      ):
        ###########################################

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > - 1.5)
        max_len = int(box_mask.sum(1).max().item())
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=boxes,
                                                box_mask=box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None)

        question_ids = question
        question_tags = question.new_zeros(question_ids.shape)
        question_mask = (question > 0.5)

        answer_ids = question_ids.new_zeros((question_ids.shape[0], 1)).fill_(
            self.tokenizer.convert_tokens_to_ids(['[MASK]'])[0])
        answer_mask = question_mask.new_zeros(answer_ids.shape).fill_(1)
        answer_tags = question_tags.new_zeros(answer_ids.shape)

        ############################################

        # prepare text
        text_input_ids, text_token_type_ids, text_tags, text_mask, ans_pos = self.prepare_text_from_qa(question_ids,
                                                                                                       question_tags,
                                                                                                       question_mask,
                                                                                                       answer_ids,
                                                                                                       answer_tags,
                                                                                                       answer_mask)
        if self.config.NETWORK.NO_GROUNDING:
            obj_rep_zeroed = obj_reps['obj_reps'].new_zeros(obj_reps['obj_reps'].shape)
            text_tags.zero_()
            text_visual_embeddings = self._collect_obj_reps(text_tags, obj_rep_zeroed)
        else:
            text_visual_embeddings = self._collect_obj_reps(text_tags, obj_reps['obj_reps'])

        assert self.config.NETWORK.VLBERT.object_word_embed_mode == 2
        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()
        )
        object_vl_embeddings = torch.cat((obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT

        hidden_states, hc = self.vlbert(text_input_ids,
                                      text_token_type_ids,
                                      text_visual_embeddings,
                                      text_mask,
                                      object_vl_embeddings,
                                      box_mask,
                                      output_all_encoded_layers=False,
                                      policy=policy)
        _batch_inds = torch.arange(question.shape[0], device=question.device)

        hm = hidden_states[_batch_inds, ans_pos]
        # hm = F.tanh(self.hm_out(hidden_states[_batch_inds, ans_pos]))
        # hi = F.tanh(self.hi_out(hidden_states[_batch_inds, ans_pos + 2]))

        ###########################################
        outputs = {}

        # classifier
        # logits = self.final_mlp(hc * hm * hi)
        # logits = self.final_mlp(hc)
        logits = self.final_mlp(hm)

        # loss
        ans_loss = F.binary_cross_entropy_with_logits(logits, label) * label.size(1)


        outputs.update({'label_logits': logits,
                        'label': label,
                        'ans_loss': ans_loss})

        loss = ans_loss.mean()

        # check for auxiliary losses
        if policy is not None: 
            if self.config.USE_CONSTRAIN_K_LOSS:
                loss_k = constrain_k_loss(policy, self.config.CONSTRAIN_K_NUM_BLOCKS, self.config.CONSTRAIN_K_SCALE)
                loss += loss_k
                outputs.update({'loss_k': loss_k})

            if self.config.USE_DETERMINISTIC_POLICY_LOSS:
                loss_d = deterministic_policy_loss(policy, self.config.DETERMINISTIC_POLICY_SCALE)
                loss += loss_d
                outputs.update({'loss_d': loss_d})

        return outputs, loss

    def inference_forward(self,
                          image,
                          boxes,
                          im_info,
                          question,
                          policy=None):

        ###########################################

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > - 1.5)
        max_len = int(box_mask.sum(1).max().item())
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=boxes,
                                                box_mask=box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None)

        question_ids = question
        question_tags = question.new_zeros(question_ids.shape)
        question_mask = (question > 0.5)

        answer_ids = question_ids.new_zeros((question_ids.shape[0], 1)).fill_(
            self.tokenizer.convert_tokens_to_ids(['[MASK]'])[0])
        answer_mask = question_mask.new_zeros(answer_ids.shape).fill_(1)
        answer_tags = question_tags.new_zeros(answer_ids.shape)

        ############################################

        # prepare text
        text_input_ids, text_token_type_ids, text_tags, text_mask, ans_pos = self.prepare_text_from_qa(question_ids,
                                                                                                       question_tags,
                                                                                                       question_mask,
                                                                                                       answer_ids,
                                                                                                       answer_tags,
                                                                                                       answer_mask)
        if self.config.NETWORK.NO_GROUNDING:
            obj_rep_zeroed = obj_reps['obj_reps'].new_zeros(obj_reps['obj_reps'].shape)
            text_tags.zero_()
            text_visual_embeddings = self._collect_obj_reps(text_tags, obj_rep_zeroed)
        else:
            text_visual_embeddings = self._collect_obj_reps(text_tags, obj_reps['obj_reps'])

        assert self.config.NETWORK.VLBERT.object_word_embed_mode == 2
        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()
        )
        object_vl_embeddings = torch.cat((obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT

        hidden_states, hc = self.vlbert(text_input_ids,
                                      text_token_type_ids,
                                      text_visual_embeddings,
                                      text_mask,
                                      object_vl_embeddings,
                                      box_mask,
                                      output_all_encoded_layers=False,
                                      policy=policy)
        _batch_inds = torch.arange(question.shape[0], device=question.device)

        hm = hidden_states[_batch_inds, ans_pos]
        # hm = F.tanh(self.hm_out(hidden_states[_batch_inds, ans_pos]))
        # hi = F.tanh(self.hi_out(hidden_states[_batch_inds, ans_pos + 2]))

        ###########################################
        outputs = {}

        # classifier
        # logits = self.final_mlp(hc * hm * hi)
        # logits = self.final_mlp(hc)
        logits = self.final_mlp(hm)

        outputs.update({'label_logits': logits})

        return outputs
Ejemplo n.º 9
0
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        self.cnn_loss_top = config.NETWORK.CNN_LOSS_TOP
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(
                config,
                average_pool=True,
                final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                enable_cnn_reg_loss=(self.enable_cnn_reg_loss
                                     and not self.cnn_loss_top))
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(
                    81, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:
                self.object_linguistic_embeddings = nn.Embedding(
                    1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError
            if self.enable_cnn_reg_loss and self.cnn_loss_top:
                self.cnn_loss_reg = nn.Sequential(
                    VisualLinguisticBertMVRCHeadTransform(
                        config.NETWORK.VLBERT),
                    nn.Dropout(config.NETWORK.CNN_REG_DROPOUT, inplace=False),
                    nn.Linear(config.NETWORK.VLBERT.hidden_size, 81))
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        if 'roberta' in config.NETWORK.BERT_MODEL_NAME:
            self.tokenizer = RobertaTokenizer.from_pretrained(
                config.NETWORK.BERT_MODEL_NAME)
        else:
            self.tokenizer = BertTokenizer.from_pretrained(
                config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path

        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = TimeDistributed(
            VisualLinguisticBert(
                config.NETWORK.VLBERT,
                language_pretrained_model_path=language_pretrained_model_path))

        self.for_pretrain = config.NETWORK.FOR_MASK_VL_MODELING_PRETRAIN
        assert not self.for_pretrain, "Not implement pretrain mode now!"

        if not self.for_pretrain:
            dim = config.NETWORK.VLBERT.hidden_size
            if config.NETWORK.CLASSIFIER_TYPE == "2fc":
                self.final_mlp = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(dim,
                                    config.NETWORK.CLASSIFIER_HIDDEN_SIZE),
                    torch.nn.ReLU(inplace=True),
                    torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, 1),
                )
            elif config.NETWORK.CLASSIFIER_TYPE == "1fc":
                self.final_mlp = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                     inplace=False), torch.nn.Linear(dim, 1))
            else:
                raise ValueError("Not support classifier type: {}!".format(
                    config.NETWORK.CLASSIFIER_TYPE))

        # init weights
        self.init_weight()

        self.fix_params()
class ResNetVLBERTDistanceTranslationWithVision(Module):
    def __init__(self, config):

        super(ResNetVLBERTDistanceTranslationWithVision, self).__init__(config)

        # Constructs/initialises model elements
        self.image_feature_extractor = FastRCNN(
            config,
            average_pool=True,
            final_dim=config.NETWORK.IMAGE_FINAL_DIM,
            enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        if config.NETWORK.IMAGE_FEAT_PRECOMPUTED or (
                not config.NETWORK.MASK_RAW_PIXELS):
            self.object_mask_visual_embedding = nn.Embedding(1, 2048)
        if config.NETWORK.WITH_MVRC_LOSS:
            self.object_mask_word_embedding = nn.Embedding(
                1, config.NETWORK.VLBERT.hidden_size)
        self.aux_text_visual_embedding = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        # Can specify pre-trained model or use the downloaded pretrained model specific in .yaml file
        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            # language_pretrained_model_path = '{}-{:04d}.model'.format(config.NETWORK.BERT_PRETRAINED,
            #                                                           config.NETWORK.BERT_PRETRAINED_EPOCH)
            #FM edit: just use path of pretrained model
            language_pretrained_model_path = config.NETWORK.BERT_PRETRAINED
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path

        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBertForDistance(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=None
            if config.NETWORK.VLBERT.from_scratch else
            language_pretrained_model_path,
            with_rel_head=config.NETWORK.WITH_REL_LOSS,
            with_mlm_head=config.NETWORK.WITH_MLM_LOSS,
            with_mvrc_head=config.NETWORK.WITH_MVRC_LOSS,
        )

        # init weights
        self.init_weight()

        self.fix_params()

    def init_weight(self):
        if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED or (
                not self.config.NETWORK.MASK_RAW_PIXELS):
            self.object_mask_visual_embedding.weight.data.fill_(0.0)
        if self.config.NETWORK.WITH_MVRC_LOSS:
            self.object_mask_word_embedding.weight.data.normal_(
                mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)
        self.aux_text_visual_embedding.weight.data.normal_(
            mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)
        self.image_feature_extractor.init_weight()
        if self.object_linguistic_embeddings is not None:
            self.object_linguistic_embeddings.weight.data.normal_(
                mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)

    def train(self, mode=True):
        super(ResNetVLBERTDistanceTranslationWithVision, self).train(mode)
        # turn some frozen layers to eval mode
        if self.image_feature_bn_eval:
            self.image_feature_extractor.bn_eval()

    def fix_params(self):
        pass

    def _collect_obj_reps(self, span_tags, object_reps):
        """
        Collect span-level object representations
        :param span_tags: [batch_size, ..leading_dims.., L]
        :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim]
        :return:
        """

        span_tags_fixed = torch.clamp(
            span_tags, min=0)  # In case there were masked values here
        row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape)
        row_id_broadcaster = torch.arange(0,
                                          row_id.shape[0],
                                          step=1,
                                          device=row_id.device)[:, None]

        # Add extra diminsions to the row broadcaster so it matches row_id
        leading_dims = len(span_tags.shape) - 2
        for i in range(leading_dims):
            row_id_broadcaster = row_id_broadcaster[..., None]
        row_id += row_id_broadcaster
        return object_reps[row_id.view(-1),
                           span_tags_fixed.view(-1)].view(
                               *span_tags_fixed.shape, -1)

    def forward(self, image, boxes, im_info, text, relationship_label,
                mlm_labels, mvrc_ops, mvrc_labels):

        ###########################################

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        origin_len = boxes.shape[1]
        max_len = int(box_mask.sum(1).max().item())
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        mvrc_ops = mvrc_ops[:, :max_len]
        mvrc_labels = mvrc_labels[:, :max_len]

        if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED:
            box_features = boxes[:, :, 4:]
            box_features[mvrc_ops ==
                         1] = self.object_mask_visual_embedding.weight[0]
            boxes[:, :, 4:] = box_features

        obj_reps = self.image_feature_extractor(
            images=images,
            boxes=boxes,
            box_mask=box_mask,
            im_info=im_info,
            classes=None,
            segms=None,
            mvrc_ops=mvrc_ops,
            mask_visual_embed=self.object_mask_visual_embedding.weight[0] if
            (not self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED) and
            (not self.config.NETWORK.MASK_RAW_PIXELS) else None)

        ############################################

        # prepare text
        text_input_ids = text
        # creates a text_tags tensor of the same shape as text tensor
        text_tags = text.new_zeros(text.shape)
        text_visual_embeddings = self._collect_obj_reps(
            text_tags, obj_reps['obj_reps'])

        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        if self.config.NETWORK.WITH_MVRC_LOSS:
            object_linguistic_embeddings[
                mvrc_ops == 1] = self.object_mask_word_embedding.weight[0]
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        # FM edit: No auxiliary text is used for text only
        # add auxiliary text - Concatenates the batches from the two dataloaders
        # The visual features for the text only corpus is just the embedding of the aux_visual_embedding (only one embedding)
        max_text_len = text_input_ids.shape[1]
        text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape)
        text_mask = (text_input_ids > 0)
        #FM: Edit: i have taken this out, not needed i think since defined above
        # box_mask = box_mask.new_zeros((text_input_ids.shape[0], *box_mask.shape[1:]))

        ###########################################

        # Visual Linguistic BERT

        relationship_logits_multi, mlm_logits_multi, mvrc_logits_multi, pooled_rep, text_out = self.vlbert(
            text_input_ids, text_token_type_ids, text_visual_embeddings,
            text_mask, object_vl_embeddings, box_mask)

        ###########################################
        outputs = {}

        # FM edit: removed other two losses that are not defined
        outputs.update({'cls_output': text_out[:, 0, :]})

        # FM edit: removed addition of other losses which are not defined
        loss = 0

        return outputs, loss
Ejemplo n.º 11
0
class ResNetVLBERTv4(Module):
    def __init__(self, config):

        super(ResNetVLBERTv4, self).__init__(config)

        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(
                config,
                average_pool=True,
                final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                enable_cnn_reg_loss=self.enable_cnn_reg_loss)
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(
                    601, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:
                self.object_linguistic_embeddings = nn.Embedding(
                    1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        # self.hm_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)
        # self.hi_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)

        dim = config.NETWORK.VLBERT.hidden_size
        if config.NETWORK.CLASSIFIER_TYPE == "2fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE),
                torch.nn.ReLU(inplace=True),
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE,
                                config.NETWORK.CLASSIFIER_CLASS),
            )
        elif config.NETWORK.CLASSIFIER_TYPE == "1fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_CLASS))
        elif config.NETWORK.CLASSIFIER_TYPE == 'mlm':
            transform = BertPredictionHeadTransform(config.NETWORK.VLBERT)
            linear = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                               config.NETWORK.CLASSIFIER_CLASS)
            self.final_mlp = nn.Sequential(
                transform,
                nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                linear)
        else:
            raise ValueError("Not support classifier type: {}!".format(
                config.NETWORK.CLASSIFIER_TYPE))

        # init weights
        self.init_weight()

        self.fix_params()

    def init_weight(self):
        # self.hm_out.weight.data.normal_(mean=0.0, std=0.02)
        # self.hm_out.bias.data.zero_()
        # self.hi_out.weight.data.normal_(mean=0.0, std=0.02)
        # self.hi_out.bias.data.zero_()
        self.image_feature_extractor.init_weight()
        if self.object_linguistic_embeddings is not None:
            self.object_linguistic_embeddings.weight.data.normal_(mean=0.0,
                                                                  std=0.02)
        for m in self.final_mlp.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                torch.nn.init.constant_(m.bias, 0)
        if self.config.NETWORK.CLASSIFIER_TYPE == 'mlm':
            language_pretrained = torch.load(
                self.language_pretrained_model_path)
            mlm_transform_state_dict = {}
            pretrain_keys = []
            for k, v in language_pretrained.items():
                if k.startswith('cls.predictions.transform.'):
                    pretrain_keys.append(k)
                    k_ = k[len('cls.predictions.transform.'):]
                    if 'gamma' in k_:
                        k_ = k_.replace('gamma', 'weight')
                    if 'beta' in k_:
                        k_ = k_.replace('beta', 'bias')
                    mlm_transform_state_dict[k_] = v
            print("loading pretrained classifier transform keys: {}.".format(
                pretrain_keys))
            self.final_mlp[0].load_state_dict(mlm_transform_state_dict)

    def train(self, mode=True):
        super(ResNetVLBERTv4, self).train(mode)
        # turn some frozen layers to eval mode
        if self.image_feature_bn_eval:
            self.image_feature_extractor.bn_eval()

    def fix_params(self):
        pass

    def _collect_obj_reps(self, span_tags, object_reps):
        """
        Collect span-level object representations
        :param span_tags: [batch_size, ..leading_dims.., L]
        :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim]
        :return:
        """

        span_tags_fixed = torch.clamp(
            span_tags, min=0)  # In case there were masked values here
        row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape)
        row_id_broadcaster = torch.arange(0,
                                          row_id.shape[0],
                                          step=1,
                                          device=row_id.device)[:, None]

        # Add extra diminsions to the row broadcaster so it matches row_id
        leading_dims = len(span_tags.shape) - 2
        for i in range(leading_dims):
            row_id_broadcaster = row_id_broadcaster[..., None]
        row_id += row_id_broadcaster
        object_select = object_reps[row_id.view(-1), span_tags_fixed.view(-1)]
        return object_select.view(*span_tags_fixed.shape, -1)

    def prepare_text(self, question, question_tags, question_mask):
        batch_size, max_q_len = question.shape
        max_len = question_mask.sum(1).max() + 2
        cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(
            ['[CLS]', '[SEP]'])
        q_end = 1 + question_mask.sum(1, keepdim=True)
        input_ids = torch.zeros((batch_size, max_len),
                                dtype=question.dtype,
                                device=question.device)
        input_mask = torch.ones((batch_size, max_len),
                                dtype=torch.bool,
                                device=question.device)

        input_type_ids = torch.zeros((batch_size, max_len),
                                     dtype=question.dtype,
                                     device=question.device)
        text_tags = input_type_ids.new_zeros((batch_size, max_len))
        grid_i, grid_j = torch.meshgrid(
            torch.arange(batch_size, device=question.device),
            torch.arange(max_len, device=question.device))

        input_mask[grid_j > q_end] = 0
        # input_type_ids[(grid_j > q_end) & (grid_j <= a_end)] = 1
        q_input_mask = (grid_j > 0) & (grid_j < q_end)
        sep_idx = (question == sep_id).nonzero()
        for index in sep_idx:
            input_type_ids[index[0], index[1] +
                           1:] = self.config.NETWORK.VLBERT.visual_tag_type

        input_ids[:, 0] = cls_id
        input_ids[grid_j == q_end] = sep_id

        input_ids[q_input_mask] = question[question_mask]
        text_tags[q_input_mask] = question_tags[question_mask]

        return input_ids, input_type_ids, text_tags, input_mask

    def train_forward(
        self,
        image,
        boxes,
        im_info,
        text,
        img_boxes,
        text_tags,
        label: torch.Tensor,
        *sample_id_and_more,
        loss_fn=F.binary_cross_entropy_with_logits,
    ):
        ###########################################

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)  # NOTE: clip_pad_boxes(pad=-2)
        max_len = int(box_mask.sum(1).max().item())
        box_mask = box_mask[:, :max_len]
        objects = boxes[:, :max_len, 4]
        boxes = boxes[:, :max_len, :4]

        obj_and_img_boxes = torch.cat([img_boxes, boxes], axis=1)
        _box_mask = (obj_and_img_boxes[:, :, 0] > -1.5
                     )  # NOTE: clip_pad_boxes(pad=-2)

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=obj_and_img_boxes,
                                                box_mask=_box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None)
        img_block_reps = obj_reps["obj_reps"][:, :img_boxes.shape[1], :]
        obj_reps["obj_reps"] = obj_reps["obj_reps"][:, img_boxes.shape[1]:, :]

        if self.config.NETWORK.IMAGE_FROZEN_BACKBONE_ALL:
            obj_reps = {k: v.detach() for k, v in obj_reps.items()}

        text_ids = text
        # text_tags = text.new_zeros(text_ids.shape)
        text_mask = (text > 0.5)

        ############################################

        # prepare text
        text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text(
            text_ids, text_tags, text_mask)
        if self.config.NETWORK.NO_GROUNDING:
            obj_rep_zeroed = obj_reps['obj_reps'].new_zeros(
                obj_reps['obj_reps'].shape)
            text_tags.zero_()
            text_visual_embeddings = self._collect_obj_reps(
                text_tags, obj_rep_zeroed)
        else:
            text_visual_embeddings = self._collect_obj_reps(
                text_tags,
                torch.cat([img_block_reps, obj_reps['obj_reps']], dim=1))

        assert self.config.NETWORK.VLBERT.object_word_embed_mode in [1, 2]
        if self.config.NETWORK.VLBERT.object_word_embed_mode == 1:
            object_linguistic_embeddings = self.object_linguistic_embeddings(
                objects.long().clamp(
                    min=0,
                    max=self.object_linguistic_embeddings.weight.data.shape[0]
                    - 1))
        else:
            object_linguistic_embeddings = self.object_linguistic_embeddings(
                boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT

        hidden_states, hc = self.vlbert(text_input_ids,
                                        text_token_type_ids,
                                        text_visual_embeddings,
                                        text_mask,
                                        object_vl_embeddings,
                                        box_mask,
                                        output_all_encoded_layers=False)
        _batch_inds = torch.arange(text.shape[0], device=text.device)

        hm = hidden_states[_batch_inds, 0]
        # hm = F.tanh(self.hm_out(hidden_states[_batch_inds, ans_pos]))
        # hi = F.tanh(self.hi_out(hidden_states[_batch_inds, ans_pos + 2]))

        ###########################################
        outputs = {}

        # classifier
        # logits = self.final_mlp(hc * hm * hi)
        # logits = self.final_mlp(hc)
        logits = self.final_mlp(hm)
        if not self.config.NETWORK.CLASSIFIER_SIGMOID:
            if label.ndim == 2:
                label = label.squeeze(1)
            label = label.long()
        else:
            if label.ndim == 1:
                label = label.unsqueeze(1)
        # loss
        ans_loss = loss_fn(logits, label)

        outputs.update({
            'label_logits': logits,
            'label': label,
            'ans_loss': ans_loss
        })

        loss = ans_loss.mean()

        return outputs, loss

    def _inference_forward(self, image, boxes, im_info, text, *args):

        ###########################################

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        max_len = int(box_mask.sum(1).max().item())
        box_mask = box_mask[:, :max_len]
        objects = boxes[:, :max_len, 4]
        boxes = boxes[:, :max_len, :4]

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=boxes,
                                                box_mask=box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None)

        text_ids = text
        text_tags = text.new_zeros(text_ids.shape)
        text_mask = (text > 0.5)

        ############################################

        # prepare text
        text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text(
            text_ids, text_tags, text_mask)

        if self.config.NETWORK.NO_GROUNDING:
            obj_rep_zeroed = obj_reps['obj_reps'].new_zeros(
                obj_reps['obj_reps'].shape)
            text_tags.zero_()
            text_visual_embeddings = self._collect_obj_reps(
                text_tags, obj_rep_zeroed)
        else:
            text_visual_embeddings = self._collect_obj_reps(
                text_tags, obj_reps['obj_reps'])

        assert self.config.NETWORK.VLBERT.object_word_embed_mode in [1, 2]
        if self.config.NETWORK.VLBERT.object_word_embed_mode == 1:
            object_linguistic_embeddings = self.object_linguistic_embeddings(
                objects.long().clamp(
                    min=0,
                    max=self.object_linguistic_embeddings.weight.data.shape[0]
                    - 1))
        else:
            object_linguistic_embeddings = self.object_linguistic_embeddings(
                boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())

        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT

        hidden_states, hc = self.vlbert(text_input_ids,
                                        text_token_type_ids,
                                        text_visual_embeddings,
                                        text_mask,
                                        object_vl_embeddings,
                                        box_mask,
                                        output_all_encoded_layers=False)
        _batch_inds = torch.arange(text.shape[0], device=text.device)

        hm = hidden_states[_batch_inds, 0]
        # hm = F.tanh(self.hm_out(hidden_states[_batch_inds, ans_pos]))
        # hi = F.tanh(self.hi_out(hidden_states[_batch_inds, ans_pos + 2]))

        ###########################################
        outputs = {}

        # classifier
        # logits = self.final_mlp(hc * hm * hi)
        # logits = self.final_mlp(hc)
        logits = self.final_mlp(hm)

        outputs.update({'label_logits': logits})

        return outputs

    def inference_forward(self, image, boxes, im_info, text, img_boxes,
                          text_tags, *args):

        ###########################################

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        max_len = int(box_mask.sum(1).max().item())
        box_mask = box_mask[:, :max_len]
        objects = boxes[:, :max_len, 4]
        boxes = boxes[:, :max_len, :4]

        obj_and_img_boxes = torch.cat([img_boxes, boxes], axis=1)
        _box_mask = (obj_and_img_boxes[:, :, 0] > -1.5
                     )  # NOTE: clip_pad_boxes(pad=-2)

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=obj_and_img_boxes,
                                                box_mask=_box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None)
        img_block_reps = obj_reps["obj_reps"][:, :img_boxes.shape[1], :]
        obj_reps["obj_reps"] = obj_reps["obj_reps"][:, img_boxes.shape[1]:, :]

        text_ids = text
        # text_tags = text.new_zeros(text_ids.shape)
        text_mask = (text > 0.5)

        ############################################

        # prepare text
        text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text(
            text_ids, text_tags, text_mask)

        if self.config.NETWORK.NO_GROUNDING:
            obj_rep_zeroed = obj_reps['obj_reps'].new_zeros(
                obj_reps['obj_reps'].shape)
            text_tags.zero_()
            text_visual_embeddings = self._collect_obj_reps(
                text_tags, obj_rep_zeroed)
        else:
            text_visual_embeddings = self._collect_obj_reps(
                text_tags,
                torch.cat([img_block_reps, obj_reps['obj_reps']], dim=1))

        assert self.config.NETWORK.VLBERT.object_word_embed_mode in [1, 2]
        if self.config.NETWORK.VLBERT.object_word_embed_mode == 1:
            object_linguistic_embeddings = self.object_linguistic_embeddings(
                objects.long().clamp(
                    min=0,
                    max=self.object_linguistic_embeddings.weight.data.shape[0]
                    - 1))
        else:
            object_linguistic_embeddings = self.object_linguistic_embeddings(
                boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())

        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT

        hidden_states, hc = self.vlbert(text_input_ids,
                                        text_token_type_ids,
                                        text_visual_embeddings,
                                        text_mask,
                                        object_vl_embeddings,
                                        box_mask,
                                        output_all_encoded_layers=False)
        _batch_inds = torch.arange(text.shape[0], device=text.device)

        hm = hidden_states[_batch_inds, 0]
        # hm = F.tanh(self.hm_out(hidden_states[_batch_inds, ans_pos]))
        # hi = F.tanh(self.hi_out(hidden_states[_batch_inds, ans_pos + 2]))

        ###########################################
        outputs = {}

        # classifier
        # logits = self.final_mlp(hc * hm * hi)
        # logits = self.final_mlp(hc)
        logits = self.final_mlp(hm)

        outputs.update({'label_logits': logits})

        return outputs
Ejemplo n.º 12
0
    def __init__(self, config):

        super(ResNetVLBERTv5, self).__init__(config)

        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(
                config,
                average_pool=True,
                final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                enable_cnn_reg_loss=self.enable_cnn_reg_loss)
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(
                    601, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:
                self.object_linguistic_embeddings = nn.Embedding(
                    1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        # self.hm_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)
        # self.hi_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)

        self.hidden_dropout = nn.Dropout(0.2)
        if config.NETWORK.VLBERT.num_hidden_layers == 24:
            self.gating = nn.Parameter(torch.tensor([
                0.0067, 0.0070, 0.0075, 0.0075, 0.0075, 0.0074, 0.0076, 0.0075,
                0.0076, 0.0080, 0.0079, 0.0086, 0.0096, 0.0101, 0.0104, 0.0105,
                0.0111, 0.0120, 0.0126, 0.0115, 0.0108, 0.0105, 0.0104, 0.0117
            ]),
                                       requires_grad=True)
        else:
            self.gating = nn.Parameter(
                torch.ones(config.NETWORK.VLBERT.num_hidden_layers, ) * 1e-2,
                requires_grad=True)
        self.train_steps = 0

        dim = config.NETWORK.VLBERT.hidden_size
        if config.NETWORK.CLASSIFIER_TYPE == "2fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE),
                torch.nn.ReLU(inplace=True),
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE,
                                config.NETWORK.CLASSIFIER_CLASS),
            )
        elif config.NETWORK.CLASSIFIER_TYPE == "1fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_CLASS))
        elif config.NETWORK.CLASSIFIER_TYPE == 'mlm':
            transform = BertPredictionHeadTransform(config.NETWORK.VLBERT)
            linear = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                               config.NETWORK.CLASSIFIER_CLASS)
            self.final_mlp = nn.Sequential(
                transform,
                nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                linear)
        else:
            raise ValueError("Not support classifier type: {}!".format(
                config.NETWORK.CLASSIFIER_TYPE))

        # init weights
        self.init_weight()

        self.fix_params()
Ejemplo n.º 13
0
class ResNetVLBERT(Module):
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.image_feature_extractor = FastRCNN(
            config,
            average_pool=True,
            final_dim=config.NETWORK.IMAGE_FINAL_DIM,
            enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        transform = VisualLinguisticBertMVRCHeadTransform(
            config.NETWORK.VLBERT)
        # self.linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, 768) #331 1000 35 100 12003 lihui
        # self.OIM_loss = OIM_Module(331, 768)  # config.NETWORK.VLBERT.hidden_size)
        self.OIM_loss = OIM_Module(12003, 768)
        self.linear = nn.Sequential(
            # transform,
            nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
            nn.Linear(config.NETWORK.VLBERT.hidden_size,
                      768)  #331 1000 35 100 12003 lihui
        )

        linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, 1)
        self.final_mlp = nn.Sequential(
            transform,
            nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
            linear)

        # init weights
        self.init_weight()

        self.fix_params()

        # self.embeddings_word = torch.nn.Conv1d(in_channels=40, out_channels=1, kernel_size=1)
        # self.embeddings_box = torch.nn.Conv1d(in_channels=6, out_channels=1, kernel_size=1)
        # self.line_cls = nn.utils.weight_norm(nn.Linear(config.NETWORK.VLBERT.hidden_size, 1000), name='weight') #12003

    def init_weight(self):
        self.image_feature_extractor.init_weight()
        if self.object_linguistic_embeddings is not None:
            self.object_linguistic_embeddings.weight.data.normal_(
                mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)
        for m in self.final_mlp.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                torch.nn.init.constant_(m.bias, 0)

    def train(self, mode=True):
        super(ResNetVLBERT, self).train(mode)
        # turn some frozen layers to eval mode
        if self.image_feature_bn_eval:
            self.image_feature_extractor.bn_eval()

    def fix_params(self):
        pass

    def train_forward(
        self,
        image,
        boxes,
        im_info,
        expression,
        label,
    ):
        ###########################################

        # visual feature extraction
        batch_size = image.size(0)
        num_options = image.size(1)
        image = image.view(-1, image.size(2), image.size(3), image.size(4))
        boxes = boxes.view(-1, boxes.size(2), boxes.size(3))
        #boxes = boxes
        im_info = im_info.view(-1, im_info.size(2))
        expression = expression.view(-1, expression.size(2))

        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        # max_len = int(box_mask.sum(1).max().item())
        # origin_len = boxes.shape[1]
        # box_mask = box_mask[:, :max_len]
        # boxes = boxes[:, :max_len]
        # label = label[:, :max_len]

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=boxes,
                                                box_mask=box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None)

        ############################################
        # prepare text
        cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(
            ['[CLS]', '[SEP]'])
        text_input_ids = expression.new_zeros(
            (expression.shape[0], expression.shape[1] + 2))
        text_input_ids[:, 0] = cls_id
        text_input_ids[:, 1:-1] = expression
        _sep_pos = (text_input_ids > 0).sum(1)
        _batch_inds = torch.arange(expression.shape[0],
                                   device=expression.device)
        text_input_ids[_batch_inds, _sep_pos] = sep_id
        text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape)
        text_mask = text_input_ids > 0
        text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze(
            1).repeat((1, text_input_ids.shape[1], 1))

        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT

        _, pooled_output = self.vlbert(text_input_ids,
                                       text_token_type_ids,
                                       text_visual_embeddings,
                                       text_mask,
                                       object_vl_embeddings,
                                       box_mask,
                                       output_all_encoded_layers=False,
                                       output_text_and_object_separately=False)

        ###########################################
        outputs = {}

        # classifier
        #logits = self.final_mlp(pooled_output)
        ''' 
        logits = self.linear(pooled_output)
        # vil_logit = logits.view(batch_size, num_options)
        score_OIM = self.OIM_loss(logits, label.view(-1))
        loss_c = nn.CrossEntropyLoss(ignore_index=-1)
        cmpc_loss = loss_c(F.softmax(score_OIM, dim=1)*10, label.view(-1))# + criterion(text_logits, label_text)
        # cmpc_loss = loss_c(logits, label.view(-1))
        cls_pred = torch.argmax(score_OIM, dim=1)
        cls_precision = torch.mean((cls_pred[label.view(-1) != -1] == label.view(-1)[label.view(-1) != -1]).float())
        return cls_precision, cmpc_loss
	'''

        # loss
        logits = self.final_mlp(pooled_output)
        vil_logit = logits.view(batch_size, num_options)
        loss = nn.CrossEntropyLoss(ignore_index=-1)
        cls_loss = loss(vil_logit, torch.zeros(batch_size).long().cuda())
        _, preds = torch.max(vil_logit, 1)
        batch_score = float((preds == torch.zeros(batch_size).long().cuda()
                             ).sum()) / float(batch_size)

        return batch_score, cls_loss

    def inference_forward(self,
                          image,
                          boxes,
                          im_info,
                          expression,
                          label,
                          feat=None):

        ###########################################

        # visual feature extraction
        batch_size = boxes.size(0)
        num_options = boxes.size(1)

        if feat is None:
            image = image.view(-1, image.size(2), image.size(3), image.size(4))
            boxes = boxes.view(-1, boxes.size(2), boxes.size(3))
            im_info = im_info.view(-1, im_info.size(2))

            images = image
            box_mask = (boxes[:, :, 0] > -1.5)
            # max_len = int(box_mask.sum(1).max().item())
            # origin_len = boxes.shape[1]
            # box_mask = box_mask[:, :max_len]
            # boxes = boxes[:, :max_len]

            obj_reps = self.image_feature_extractor(images=images,
                                                    boxes=boxes,
                                                    box_mask=box_mask,
                                                    im_info=im_info,
                                                    classes=None,
                                                    segms=None)
        else:
            boxes = boxes.view(-1, boxes.size(2), boxes.size(3))
            box_mask = (boxes[:, :, 0] > -1.5)
            # obj_reps = feat
        ############################################
        # prepare text
        expression = expression.view(-1, expression.size(2))
        cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(
            ['[CLS]', '[SEP]'])
        text_input_ids = expression.new_zeros(
            (expression.shape[0], expression.shape[1] + 2))
        text_input_ids[:, 0] = cls_id
        text_input_ids[:, 1:-1] = expression
        _sep_pos = (text_input_ids > 0).sum(1)
        _batch_inds = torch.arange(expression.shape[0],
                                   device=expression.device)
        text_input_ids[_batch_inds, _sep_pos] = sep_id
        text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape)
        text_mask = text_input_ids > 0
        if feat is None:
            text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze(
                1).repeat((1, text_input_ids.shape[1], 1))
            #text_visual_embeddings = feat[:, 0].unsqueeze(1).repeat((1, text_input_ids.shape[1], 1))

            object_linguistic_embeddings = self.object_linguistic_embeddings(
                boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
            object_vl_embeddings = torch.cat(
                (obj_reps['obj_reps'], object_linguistic_embeddings), -1)
            #object_vl_embeddings = torch.cat((feat, object_linguistic_embeddings), -1)
        else:
            # text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze(1).repeat((1, text_input_ids.shape[1], 1))
            text_visual_embeddings = feat[:, 0].unsqueeze(1).repeat(
                (1, text_input_ids.shape[1], 1))

            object_linguistic_embeddings = self.object_linguistic_embeddings(
                boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
            # object_vl_embeddings = torch.cat((obj_reps['obj_reps'], object_linguistic_embeddings), -1)
            object_vl_embeddings = torch.cat(
                (feat, object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT

        encoded_layers, pooled_output, att = self.vlbert(
            text_input_ids,
            text_token_type_ids,
            text_visual_embeddings,
            text_mask,
            object_vl_embeddings,
            box_mask,
            output_all_encoded_layers=False,
            output_text_and_object_separately=False,
            output_attention_probs=True)

        ###########################################
        outputs = {}

        # classifier
        logits = self.final_mlp(pooled_output)  #.squeeze(-1)

        # loss
        vil_logit = logits.view(batch_size, num_options)
        _, preds = torch.max(vil_logit, 1)

        return att, logits

    def compute_cmpc_loss(self, image_embeddings, text_embeddings, labels):
        """
        Cross-Modal Projection Classfication loss(CMPC)
        :param image_embeddings: Tensor with dtype torch.float32
        :param text_embeddings: Tensor with dtype torch.float32
        :param labels: Tensor with dtype torch.int32
        :return:
        """
        criterion = nn.CrossEntropyLoss()
        # labels_onehot = one_hot_coding(labels, self.num_classes).float()
        # image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
        # text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)

        # image_proj_text = torch.sum(image_embeddings * text_norm, dim=1, keepdim=True) * text_norm
        # text_proj_image = torch.sum(text_embeddings * image_norm, dim=1, keepdim=True) * image_norm

        image_logits = image_embeddings  #self.line_cls(image_embeddings)
        text_logits = text_embeddings  #self.line_cls(text_embeddings)

        label_img = labels[:, 1, :].contiguous().view(-1)
        label_text = labels[:, 0, :].contiguous().view(-1)

        cmpc_loss = criterion(
            image_logits, label_img)  # + criterion(text_logits, label_text)
        # cmpc_loss = - (F.log_softmax(image_logits, dim=1) + F.log_softmax(text_logits, dim=1)) * labels_onehot
        # cmpc_loss = torch.mean(torch.sum(cmpc_loss, dim=1))
        # classification accuracy for observation
        image_pred = torch.argmax(image_logits, dim=1)
        text_pred = torch.argmax(text_logits, dim=1)

        image_precision = torch.mean((image_pred == label_img).float())
        text_precision = torch.mean((text_pred == label_text).float())

        return cmpc_loss, image_precision, text_precision
Ejemplo n.º 14
0
class ResNetVLBERT(Module):
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.image_feature_extractor = FastRCNN(
            config,
            average_pool=True,
            final_dim=config.NETWORK.IMAGE_FINAL_DIM,
            enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        transform = VisualLinguisticBertMVRCHeadTransform(
            config.NETWORK.VLBERT)
        linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, 1)
        self.final_mlp = nn.Sequential(
            transform,
            nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
            linear)

        # init weights
        self.init_weight()

        self.fix_params()

    def init_weight(self):
        self.image_feature_extractor.init_weight()
        if self.object_linguistic_embeddings is not None:
            self.object_linguistic_embeddings.weight.data.normal_(
                mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)
        for m in self.final_mlp.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                torch.nn.init.constant_(m.bias, 0)

    def train(self, mode=True):
        super(ResNetVLBERT, self).train(mode)
        # turn some frozen layers to eval mode
        if self.image_feature_bn_eval:
            self.image_feature_extractor.bn_eval()

    def fix_params(self):
        pass

    def train_forward(
        self,
        image,
        boxes,
        im_info,
        expression,
        label,
    ):
        ###########################################

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        max_len = int(box_mask.sum(1).max().item())
        origin_len = boxes.shape[1]
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        label = label[:, :max_len]

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=boxes,
                                                box_mask=box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None)

        ############################################
        # prepare text
        cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(
            ['[CLS]', '[SEP]'])
        text_input_ids = expression.new_zeros(
            (expression.shape[0], expression.shape[1] + 2))
        text_input_ids[:, 0] = cls_id
        text_input_ids[:, 1:-1] = expression
        _sep_pos = (text_input_ids > 0).sum(1)
        _batch_inds = torch.arange(expression.shape[0],
                                   device=expression.device)
        text_input_ids[_batch_inds, _sep_pos] = sep_id
        text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape)
        text_mask = text_input_ids > 0
        text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze(
            1).repeat((1, text_input_ids.shape[1], 1))

        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT

        hidden_states_text, hidden_states_regions, _ = self.vlbert(
            text_input_ids,
            text_token_type_ids,
            text_visual_embeddings,
            text_mask,
            object_vl_embeddings,
            box_mask,
            output_all_encoded_layers=False,
            output_text_and_object_separately=True)

        ###########################################
        outputs = {}

        # classifier
        logits = self.final_mlp(hidden_states_regions).squeeze(-1)

        # loss
        cls_loss = F.binary_cross_entropy_with_logits(logits[box_mask],
                                                      label[box_mask])

        # pad back to origin len for compatibility with DataParallel
        logits_ = logits.new_zeros(
            (logits.shape[0], origin_len)).fill_(-10000.0)
        logits_[:, :logits.shape[1]] = logits
        logits = logits_
        label_ = label.new_zeros((logits.shape[0], origin_len)).fill_(-1)
        label_[:, :label.shape[1]] = label
        label = label_

        outputs.update({
            'label_logits': logits,
            'label': label,
            'cls_loss': cls_loss
        })

        loss = cls_loss.mean()

        return outputs, loss

    def inference_forward(self, image, boxes, im_info, expression):

        ###########################################

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        max_len = int(box_mask.sum(1).max().item())
        origin_len = boxes.shape[1]
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=boxes,
                                                box_mask=box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None)

        ############################################
        # prepare text
        cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(
            ['[CLS]', '[SEP]'])
        text_input_ids = expression.new_zeros(
            (expression.shape[0], expression.shape[1] + 2))
        text_input_ids[:, 0] = cls_id
        text_input_ids[:, 1:-1] = expression
        _sep_pos = (text_input_ids > 0).sum(1)
        _batch_inds = torch.arange(expression.shape[0],
                                   device=expression.device)
        text_input_ids[_batch_inds, _sep_pos] = sep_id
        text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape)
        text_mask = text_input_ids > 0
        text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze(
            1).repeat((1, text_input_ids.shape[1], 1))

        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT

        hidden_states_text, hidden_states_regions, _ = self.vlbert(
            text_input_ids,
            text_token_type_ids,
            text_visual_embeddings,
            text_mask,
            object_vl_embeddings,
            box_mask,
            output_all_encoded_layers=False,
            output_text_and_object_separately=True)

        ###########################################
        outputs = {}

        # classifier
        logits = self.final_mlp(hidden_states_regions).squeeze(-1)

        # pad back to origin len for compatibility with DataParallel
        logits_ = logits.new_zeros(
            (logits.shape[0], origin_len)).fill_(-10000.0)
        logits_[:, :logits.shape[1]] = logits
        logits = logits_

        w_ratio = im_info[:, 2]
        h_ratio = im_info[:, 3]
        pred_boxes = boxes[_batch_inds, logits.argmax(1), :4]
        pred_boxes[:, [0, 2]] /= w_ratio.unsqueeze(1)
        pred_boxes[:, [1, 3]] /= h_ratio.unsqueeze(1)
        outputs.update({'label_logits': logits, 'pred_boxes': pred_boxes})

        return outputs
class ResNetVLBERTForPretrainingGenerate(Module):
    def __init__(self, config):

        super(ResNetVLBERTForPretrainingGenerate, self).__init__(config)

        self.image_feature_extractor = FastRCNN(
            config,
            average_pool=True,
            final_dim=config.NETWORK.IMAGE_FINAL_DIM,
            enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        if config.NETWORK.IMAGE_FEAT_PRECOMPUTED:
            self.object_mask_visual_embedding = nn.Embedding(1, 2048)
        if config.NETWORK.WITH_MVRC_LOSS:
            self.object_mask_word_embedding = nn.Embedding(
                1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)
        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path

        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBertForPretraining(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=None
            if config.NETWORK.VLBERT.from_scratch else
            language_pretrained_model_path,
            with_rel_head=config.NETWORK.WITH_REL_LOSS,
            with_mlm_head=config.NETWORK.WITH_MLM_LOSS,
            with_mvrc_head=config.NETWORK.WITH_MVRC_LOSS,
        )

        # init weights
        self.init_weight()

        self.fix_params()

    def init_weight(self):
        if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED:
            self.object_mask_visual_embedding.weight.data.fill_(0.0)
        if self.config.NETWORK.WITH_MVRC_LOSS:
            self.object_mask_word_embedding.weight.data.normal_(
                mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)
        self.image_feature_extractor.init_weight()
        if self.object_linguistic_embeddings is not None:
            self.object_linguistic_embeddings.weight.data.normal_(
                mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)

    def train(self, mode=True):
        super(ResNetVLBERTForPretrainingGenerate, self).train(mode)
        # turn some frozen layers to eval mode
        if self.image_feature_bn_eval:
            self.image_feature_extractor.bn_eval()

    def fix_params(self):
        pass

    def _collect_obj_reps(self, span_tags, object_reps):
        """
        Collect span-level object representations
        :param span_tags: [batch_size, ..leading_dims.., L]
        :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim]
        :return:
        """

        span_tags_fixed = torch.clamp(
            span_tags, min=0)  # In case there were masked values here
        row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape)
        row_id_broadcaster = torch.arange(0,
                                          row_id.shape[0],
                                          step=1,
                                          device=row_id.device)[:, None]

        # Add extra diminsions to the row broadcaster so it matches row_id
        leading_dims = len(span_tags.shape) - 2
        for i in range(leading_dims):
            row_id_broadcaster = row_id_broadcaster[..., None]
        row_id += row_id_broadcaster
        return object_reps[row_id.view(-1),
                           span_tags_fixed.view(-1)].view(
                               *span_tags_fixed.shape, -1)

    def forward(self, image, boxes, im_info, text, relationship_label,
                mlm_labels, mvrc_ops, mvrc_labels):
        ###########################################

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        origin_len = boxes.shape[1]
        max_len = int(box_mask.sum(1).max().item())
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        mvrc_ops = mvrc_ops[:, :max_len]
        mvrc_labels = mvrc_labels[:, :max_len]

        if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED:
            box_features = boxes[:, :, 4:]
            box_features[mvrc_ops ==
                         1] = self.object_mask_visual_embedding.weight[0]
            boxes[:, :, 4:] = box_features

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=boxes,
                                                box_mask=box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None,
                                                mvrc_ops=mvrc_ops,
                                                mask_visual_embed=None)

        ############################################

        # prepare text
        text_input_ids = text
        text_tags = text.new_zeros(text.shape)
        text_token_type_ids = text.new_zeros(text.shape)
        text_mask = (text_input_ids > 0)
        text_visual_embeddings = self._collect_obj_reps(
            text_tags, obj_reps['obj_reps'])

        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        if self.config.NETWORK.WITH_MVRC_LOSS:
            object_linguistic_embeddings[
                mvrc_ops == 1] = self.object_mask_word_embedding.weight[0]
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################
        # Visual Linguistic BERT
        # #loop here for test mode:
        generated = []
        stop = [False] * text.shape[0]
        curr_len = 0
        max_len = 48
        while not all(stop) and curr_len <= max_len:
            relationship_logits, mlm_logits, mvrc_logits = self.vlbert(
                text_input_ids, text_token_type_ids, text_visual_embeddings,
                text_mask, object_vl_embeddings, box_mask)
            answers = torch.topk(mlm_logits[mlm_labels == 103], k=1, dim=1)

            # Get size of each tensor
            position_tensor = torch.arange(mlm_labels.shape[1])
            position_tensor = position_tensor.repeat(mlm_labels.shape[0]).view(
                mlm_labels.shape[0], -1)
            indeces = position_tensor[mlm_labels == 103]

            # 1. Update mlm_labels:
            mlm_labels_new = mlm_labels.new_zeros(mlm_labels.shape[0],
                                                  mlm_labels.shape[1] + 1)
            mlm_labels_new = mlm_labels_new - 1
            mlm_labels_new[torch.arange(mlm_labels.shape[0]),
                           indeces + 1] = 103
            mlm_labels = mlm_labels_new

            # 2. Update text_input_ids:
            text_input_ids_new = text_input_ids.new_zeros(
                text_input_ids.shape[0], text_input_ids.shape[1] + 1)
            text_input_ids_new[:, :-1] = text_input_ids
            text_input_ids_new[torch.arange(text_input_ids.shape[0]),
                               indeces] = answers[1][:, 0]
            text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces +
                               1] = (self.tokenizer.convert_tokens_to_ids(
                                   ['[MASK]'])[0])
            text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces +
                               2] = (self.tokenizer.convert_tokens_to_ids(
                                   ['[PAD]'])[0])
            text_input_ids_new[torch.arange(text_input_ids.shape[0]), indeces +
                               3] = (self.tokenizer.convert_tokens_to_ids(
                                   ['[SEP]'])[0])
            text_input_ids = text_input_ids_new

            # 3. Update text_token_type_ids:
            text_token_type_ids = text_token_type_ids.new_zeros(
                text_token_type_ids.shape[0], text_token_type_ids.shape[1] + 1)

            # 4. Update text_input_ids:
            text_visual_embeddings_new = text_visual_embeddings.new_zeros(
                text_visual_embeddings.shape[0],
                text_visual_embeddings.shape[1] + 1,
                text_visual_embeddings.shape[2])
            text_visual_embeddings_new = text_visual_embeddings_new.transpose(
                0, 1)
            text_visual_embeddings_new[:] = text_visual_embeddings[:, 0, :]
            text_visual_embeddings = text_visual_embeddings_new.transpose(0, 1)

            # 5. Update text_mask:
            text_mask = (text_input_ids > 0)

            # 6. Append generated words from each sentence in the batch to list - terminate if all [STOP]
            for nid, row in enumerate(answers[1]):
                if curr_len == 0:
                    generated.append([])
                for ele in row:
                    # try:
                    if not stop[nid]:
                        if self.tokenizer.ids_to_tokens[
                                ele.item()] == '[STOP]':
                            stop[nid] = True
                        else:
                            # print('generated: ', ele.item())
                            generated[nid].append(
                                self.tokenizer.ids_to_tokens[ele.item()])
                    # except:
                    #     generated[nid].append(self.tokenizer.ids_to_tokens[100])
            curr_len += 1

        # Join in sentences
        generated_sentences = []
        for sentence in generated:
            new_sentence = ' '.join(sentence)
            generated_sentences.append(new_sentence.replace(' ##', ''))
        # print(generated_sentences)
        # exit()

        ###########################################
        outputs = {}

        # loss
        relationship_loss = im_info.new_zeros(())
        mlm_loss = im_info.new_zeros(())
        mvrc_loss = im_info.new_zeros(())
        if self.config.NETWORK.WITH_REL_LOSS:
            relationship_loss = F.cross_entropy(relationship_logits,
                                                relationship_label)
        if self.config.NETWORK.WITH_MLM_LOSS:
            mlm_logits_padded = mlm_logits.new_zeros(
                (*mlm_labels.shape, mlm_logits.shape[-1])).fill_(-10000.0)
            mlm_logits_padded[:, :mlm_logits.shape[1]] = mlm_logits
            mlm_logits = mlm_logits_padded
            if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST:
                mlm_loss = F.cross_entropy(mlm_logits.transpose(1, 2),
                                           mlm_labels,
                                           ignore_index=-1,
                                           reduction='none')
                num_mlm = (mlm_labels != -1).sum(
                    1, keepdim=True).to(dtype=mlm_loss.dtype)
                num_has_mlm = (num_mlm != 0).sum().to(dtype=mlm_loss.dtype)
                mlm_loss = (mlm_loss /
                            (num_mlm + 1e-4)).sum() / (num_has_mlm + 1e-4)
            else:
                mlm_loss = F.cross_entropy(mlm_logits.view(
                    (-1, mlm_logits.shape[-1])),
                                           mlm_labels.view(-1),
                                           ignore_index=-1)
        # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
        #                             mvrc_labels.contiguous().view(-1),
        #                             ignore_index=-1)
        if self.config.NETWORK.WITH_MVRC_LOSS:
            if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]),
                    reduction='none').view(mvrc_logits.shape[:-1])
                valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1
                mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \
                                .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4)
            else:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]))

            mvrc_logits_padded = mvrc_logits.new_zeros(
                (mvrc_logits.shape[0], origin_len,
                 mvrc_logits.shape[2])).fill_(-10000.0)
            mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits
            mvrc_logits = mvrc_logits_padded
            mvrc_labels_padded = mvrc_labels.new_zeros(
                (mvrc_labels.shape[0], origin_len,
                 mvrc_labels.shape[2])).fill_(0.0)
            mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels
            mvrc_labels = mvrc_labels_padded

        outputs.update({
            'relationship_logits':
            relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None,
            'relationship_label':
            relationship_label if self.config.NETWORK.WITH_REL_LOSS else None,
            'mlm_logits':
            mlm_logits if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_label':
            mlm_labels if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mvrc_logits':
            mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'mvrc_label':
            mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'relationship_loss':
            relationship_loss,
            'mlm_loss':
            mlm_loss,
            'mvrc_loss':
            mvrc_loss,
            'generated_sentences':
            generated_sentences
        })

        loss = relationship_loss.mean() + mlm_loss.mean() + mvrc_loss.mean()

        return outputs, loss
Ejemplo n.º 16
0
class ResNetVLBERTForPretrainingMultitaskNoVision(Module):
    def __init__(self, config):

        super(ResNetVLBERTForPretrainingMultitaskNoVision,
              self).__init__(config)

        # Constructs/initialises model elements
        self.image_feature_extractor = FastRCNN(
            config,
            average_pool=True,
            final_dim=config.NETWORK.IMAGE_FINAL_DIM,
            enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        if config.NETWORK.IMAGE_FEAT_PRECOMPUTED or (
                not config.NETWORK.MASK_RAW_PIXELS):
            self.object_mask_visual_embedding = nn.Embedding(1, 2048)
        if config.NETWORK.WITH_MVRC_LOSS:
            self.object_mask_word_embedding = nn.Embedding(
                1, config.NETWORK.VLBERT.hidden_size)
        self.aux_text_visual_embedding = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        # Can specify pre-trained model or use the downloaded pretrained model specific in .yaml file
        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            # language_pretrained_model_path = '{}-{:04d}.model'.format(config.NETWORK.BERT_PRETRAINED,
            #                                                           config.NETWORK.BERT_PRETRAINED_EPOCH)
            #FM edit: just use path of pretrained model
            language_pretrained_model_path = config.NETWORK.BERT_PRETRAINED
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path

        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBertForPretraining(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=None
            if config.NETWORK.VLBERT.from_scratch else
            language_pretrained_model_path,
            with_rel_head=config.NETWORK.WITH_REL_LOSS,
            with_mlm_head=config.NETWORK.WITH_MLM_LOSS,
            with_mvrc_head=config.NETWORK.WITH_MVRC_LOSS,
            with_MLT_head=config.NETWORK.WITH_MLT_LOSS)

        # init weights
        self.init_weight()

        self.fix_params()

    def init_weight(self):
        if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED or (
                not self.config.NETWORK.MASK_RAW_PIXELS):
            self.object_mask_visual_embedding.weight.data.fill_(0.0)
        if self.config.NETWORK.WITH_MVRC_LOSS:
            self.object_mask_word_embedding.weight.data.normal_(
                mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)
        self.aux_text_visual_embedding.weight.data.normal_(
            mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)
        self.image_feature_extractor.init_weight()
        if self.object_linguistic_embeddings is not None:
            self.object_linguistic_embeddings.weight.data.normal_(
                mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)

    def train(self, mode=True):
        super(ResNetVLBERTForPretrainingMultitaskNoVision, self).train(mode)
        # turn some frozen layers to eval mode
        if self.image_feature_bn_eval:
            self.image_feature_extractor.bn_eval()

    def fix_params(self):
        pass

    def _collect_obj_reps(self, span_tags, object_reps):
        """
        Collect span-level object representations
        :param span_tags: [batch_size, ..leading_dims.., L]
        :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim]
        :return:
        """

        span_tags_fixed = torch.clamp(
            span_tags, min=0)  # In case there were masked values here
        row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape)
        row_id_broadcaster = torch.arange(0,
                                          row_id.shape[0],
                                          step=1,
                                          device=row_id.device)[:, None]

        # Add extra diminsions to the row broadcaster so it matches row_id
        leading_dims = len(span_tags.shape) - 2
        for i in range(leading_dims):
            row_id_broadcaster = row_id_broadcaster[..., None]
        row_id += row_id_broadcaster
        return object_reps[row_id.view(-1),
                           span_tags_fixed.view(-1)].view(
                               *span_tags_fixed.shape, -1)

    def forward(self, image, boxes, im_info, text, relationship_label,
                mlm_labels, mvrc_ops, mvrc_labels, word_de_ids):

        ###########################################

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        origin_len = boxes.shape[1]
        max_len = int(box_mask.sum(1).max().item())
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        mvrc_ops = mvrc_ops[:, :max_len]
        mvrc_labels = mvrc_labels[:, :max_len]

        if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED:
            box_features = boxes[:, :, 4:]
            box_features[mvrc_ops ==
                         1] = self.object_mask_visual_embedding.weight[0]
            boxes[:, :, 4:] = box_features

        obj_reps = self.image_feature_extractor(
            images=images,
            boxes=boxes,
            box_mask=box_mask,
            im_info=im_info,
            classes=None,
            segms=None,
            mvrc_ops=mvrc_ops,
            mask_visual_embed=self.object_mask_visual_embedding.weight[0] if
            (not self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED) and
            (not self.config.NETWORK.MASK_RAW_PIXELS) else None)

        ############################################

        # prepare text
        text_input_ids = text
        # creates a text_tags tensor of the same shape as text tensor
        text_tags = text.new_zeros(text.shape)
        text_visual_embeddings = self._collect_obj_reps(
            text_tags, obj_reps['obj_reps'])
        # ***** FM edit: blank out visual embeddings for translation retrieval task
        text_visual_embeddings[:] = self.aux_text_visual_embedding.weight[0]

        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        if self.config.NETWORK.WITH_MVRC_LOSS:
            object_linguistic_embeddings[
                mvrc_ops == 1] = self.object_mask_word_embedding.weight[0]
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)
        # ****** FM edit: blank out all visual embeddings
        object_vl_embeddings = object_vl_embeddings.new_zeros(
            object_vl_embeddings.shape)

        # FM edit: No auxiliary text is used for text only
        # add auxiliary text - Concatenates the batches from the two dataloaders
        # The visual features for the text only corpus is just the embedding of the aux_visual_embedding (only one embedding)
        max_text_len = text_input_ids.shape[1]
        text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape)
        text_mask = (text_input_ids > 0)
        #FM: Edit: i have taken this out, not needed i think since defined above
        # box_mask = box_mask.new_zeros((text_input_ids.shape[0], *box_mask.shape[1:]))

        ###########################################

        # Visual Linguistic BERT

        relationship_logits_multi, mlm_logits_multi, mvrc_logits_multi, MLT_logits = self.vlbert(
            text_input_ids, text_token_type_ids, text_visual_embeddings,
            text_mask, object_vl_embeddings, box_mask)

        ###########################################
        outputs = {}

        # loss
        relationship_loss = im_info.new_zeros(())
        mlm_loss = im_info.new_zeros(())
        mvrc_loss = im_info.new_zeros(())
        MLT_loss = im_info.new_zeros(())
        if self.config.NETWORK.WITH_REL_LOSS:
            relationship_logits = relationship_logits_multi[:text_input_ids.
                                                            shape[0]]
            # FM edit - change cross_entropy to bce/sigmoid
            relationship_loss = F.binary_cross_entropy(
                torch.sigmoid(relationship_logits),
                relationship_label.unsqueeze(1))
        if self.config.NETWORK.WITH_MLM_LOSS:
            mlm_labels_multi = mlm_labels.new_zeros(
                (text_input_ids.shape[0] + aux_text.shape[0],
                 max_text_len)).fill_(-1)
            mlm_labels_multi[:text_input_ids.shape[0], :mlm_labels.
                             shape[1]] = mlm_labels
            mlm_labels_multi[text_input_ids.shape[0]:, :aux_text_mlm_labels.
                             shape[1]] = aux_text_mlm_labels

            mlm_logits_multi_padded = \
                mlm_logits_multi.new_zeros((*mlm_labels_multi.shape, mlm_logits_multi.shape[-1])).fill_(-10000.0)
            mlm_logits_multi_padded[:, :mlm_logits_multi.
                                    shape[1]] = mlm_logits_multi
            mlm_logits_multi = mlm_logits_multi_padded
            mlm_logits_wvc = mlm_logits_multi_padded[:text_input_ids.shape[0]]
            mlm_labels_wvc = mlm_labels_multi[:text_input_ids.shape[0]]
            mlm_logits_aux = mlm_logits_multi_padded[text_input_ids.shape[0]:]
            mlm_labels_aux = mlm_labels_multi[text_input_ids.shape[0]:]
            if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST:
                mlm_loss_wvc = F.cross_entropy(mlm_logits_wvc.transpose(1, 2),
                                               mlm_labels_wvc,
                                               ignore_index=-1,
                                               reduction='none')
                num_mlm_wvc = (mlm_labels_wvc != -1).sum(
                    1, keepdim=True).to(dtype=mlm_loss_wvc.dtype)
                num_has_mlm_wvc = (num_mlm_wvc != 0).sum().to(
                    dtype=mlm_loss_wvc.dtype)
                mlm_loss_wvc = (mlm_loss_wvc / (num_mlm_wvc + 1e-4)).sum() / (
                    num_has_mlm_wvc + 1e-4)
                mlm_loss_aux = F.cross_entropy(mlm_logits_aux.transpose(1, 2),
                                               mlm_labels_aux,
                                               ignore_index=-1,
                                               reduction='none')
                num_mlm_aux = (mlm_labels_aux != -1).sum(
                    1, keepdim=True).to(dtype=mlm_loss_aux.dtype)
                num_has_mlm_aux = (num_mlm_aux != 0).sum().to(
                    dtype=mlm_loss_aux.dtype)
                mlm_loss_aux = (mlm_loss_aux / (num_mlm_aux + 1e-4)).sum() / (
                    num_has_mlm_aux + 1e-4)
            else:
                # mlm_loss = F.cross_entropy(mlm_logits_multi_padded.view((-1, mlm_logits_multi_padded.shape[-1])),
                #                            mlm_labels_multi.view(-1),
                #                            ignore_index=-1)
                mlm_loss_wvc = F.cross_entropy(mlm_logits_wvc.view(
                    (-1, mlm_logits_multi_padded.shape[-1])),
                                               mlm_labels_wvc.view(-1),
                                               ignore_index=-1)
                mlm_loss_aux = F.cross_entropy(mlm_logits_aux.view(
                    (-1, mlm_logits_multi_padded.shape[-1])),
                                               mlm_labels_aux.view(-1),
                                               ignore_index=-1)

        # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
        #                             mvrc_labels.contiguous().view(-1),
        #                             ignore_index=-1)
        if self.config.NETWORK.WITH_MVRC_LOSS:
            mvrc_logits = mvrc_logits_multi[:mvrc_labels.
                                            shape[0], :mvrc_labels.shape[1]]
            if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]),
                    reduction='none').view(mvrc_logits.shape[:-1])
                valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1
                mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \
                                .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4)
            else:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]))

            mvrc_logits_padded = mvrc_logits.new_zeros(
                (mvrc_logits.shape[0], origin_len,
                 mvrc_logits.shape[2])).fill_(-10000.0)
            mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits
            mvrc_logits = mvrc_logits_padded
            mvrc_labels_padded = mvrc_labels.new_zeros(
                (mvrc_labels.shape[0], origin_len,
                 mvrc_labels.shape[2])).fill_(0.0)
            mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels
            mvrc_labels = mvrc_labels_padded

        # MLT loss applied
        if self.config.NETWORK.WITH_MLT_LOSS:
            MLT_loss = F.cross_entropy(MLT_logits, word_de_ids)

        # FM edit: removed other two losses that are not defined
        outputs.update({
            'relationship_logits':
            relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None,
            'relationship_label':
            relationship_label if self.config.NETWORK.WITH_REL_LOSS else None,
            'mlm_logits_wvc':
            mlm_logits_wvc if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_label_wvc':
            mlm_labels_wvc if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_logits_aux':
            mlm_logits_aux if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_label_aux':
            mlm_labels_aux if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mvrc_logits':
            mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'mvrc_label':
            mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'MLT_logits':
            MLT_logits if self.config.NETWORK.WITH_MLT_LOSS else None,
            'MLT_label':
            word_de_ids if self.config.NETWORK.WITH_MLT_LOSS else None,
            'MLT_loss':
            MLT_loss,
        })

        # FM edit: removed addition of other losses which are not defined
        loss = MLT_loss.mean()

        return outputs, loss
class ResNetVLBERTForPretrainingEncDec(Module):
    def __init__(self, config):

        super(ResNetVLBERTForPretrainingEncDec, self).__init__(config)

        self.image_feature_extractor = FastRCNN(
            config,
            average_pool=True,
            final_dim=config.NETWORK.IMAGE_FINAL_DIM,
            enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        if config.NETWORK.IMAGE_FEAT_PRECOMPUTED:
            self.object_mask_visual_embedding = nn.Embedding(1, 2048)
        if config.NETWORK.WITH_MVRC_LOSS:
            self.object_mask_word_embedding = nn.Embedding(
                1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)
        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path

        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBertEncoder(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=None
            if config.NETWORK.VLBERT.from_scratch else
            language_pretrained_model_path,
            with_rel_head=False,
            with_mlm_head=False,
            with_mvrc_head=False,
        )

        # FM edit: add decoder
        self.decoder = VisualLinguisticBertForPretrainingDecoder(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=None
            if config.NETWORK.VLBERT.from_scratch else
            language_pretrained_model_path,
            with_rel_head=config.NETWORK.WITH_REL_LOSS,
            with_mlm_head=config.NETWORK.WITH_MLM_LOSS,
            with_mvrc_head=config.NETWORK.WITH_MVRC_LOSS,
        )

        # init weights
        self.init_weight()

        self.fix_params()

    def init_weight(self):
        if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED:
            self.object_mask_visual_embedding.weight.data.fill_(0.0)
        if self.config.NETWORK.WITH_MVRC_LOSS:
            self.object_mask_word_embedding.weight.data.normal_(
                mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)
        self.image_feature_extractor.init_weight()
        if self.object_linguistic_embeddings is not None:
            self.object_linguistic_embeddings.weight.data.normal_(
                mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)

    def train(self, mode=True):
        super(ResNetVLBERTForPretrainingEncDec, self).train(mode)
        # turn some frozen layers to eval mode
        if self.image_feature_bn_eval:
            self.image_feature_extractor.bn_eval()

    def fix_params(self):
        pass

    def _collect_obj_reps(self, span_tags, object_reps):
        """
        Collect span-level object representations
        :param span_tags: [batch_size, ..leading_dims.., L]
        :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim]
        :return:
        """

        span_tags_fixed = torch.clamp(
            span_tags, min=0)  # In case there were masked values here
        row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape)
        row_id_broadcaster = torch.arange(0,
                                          row_id.shape[0],
                                          step=1,
                                          device=row_id.device)[:, None]

        # Add extra diminsions to the row broadcaster so it matches row_id
        leading_dims = len(span_tags.shape) - 2
        for i in range(leading_dims):
            row_id_broadcaster = row_id_broadcaster[..., None]
        row_id += row_id_broadcaster
        return object_reps[row_id.view(-1),
                           span_tags_fixed.view(-1)].view(
                               *span_tags_fixed.shape, -1)

    def forward(self, image, boxes, im_info, text_en, text_de,
                relationship_label, mlm_labels_en, mlm_labels_de, mvrc_ops,
                mvrc_labels):
        ###########################################

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        origin_len = boxes.shape[1]
        max_len = int(box_mask.sum(1).max().item())
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        mvrc_ops = mvrc_ops[:, :max_len]
        mvrc_labels = mvrc_labels[:, :max_len]

        if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED:
            box_features = boxes[:, :, 4:]
            box_features[mvrc_ops ==
                         1] = self.object_mask_visual_embedding.weight[0]
            boxes[:, :, 4:] = box_features

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=boxes,
                                                box_mask=box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None,
                                                mvrc_ops=mvrc_ops,
                                                mask_visual_embed=None)

        ############################################

        # prepare text - English
        text_input_ids = text_en
        text_tags = text_en.new_zeros(text_en.shape)
        text_token_type_ids = text_en.new_zeros(text_en.shape)
        text_mask = (text_input_ids > 0)
        text_visual_embeddings = self._collect_obj_reps(
            text_tags, obj_reps['obj_reps'])

        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        if self.config.NETWORK.WITH_MVRC_LOSS:
            object_linguistic_embeddings[
                mvrc_ops == 1] = self.object_mask_word_embedding.weight[0]
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ############################################

        # prepare text - German
        text_input_ids_de = text_de
        text_tags_de = text_de.new_zeros(text_de.shape)
        text_token_type_ids_de = text_de.new_zeros(text_de.shape)
        text_mask_de = (text_input_ids_de > 0)
        text_visual_embeddings_de = self._collect_obj_reps(
            text_tags_de, obj_reps['obj_reps'])

        object_linguistic_embeddings_de = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        if self.config.NETWORK.WITH_MVRC_LOSS:
            object_linguistic_embeddings_de[
                mvrc_ops == 1] = self.object_mask_word_embedding_de.weight[0]
        object_vl_embeddings_de = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings_de), -1)

        ###########################################

        # Visual Linguistic BERT - Encoder
        relationship_logits_en, mlm_logits_en, mvrc_logits_en, encoder_hidden_states = self.vlbert(
            text_input_ids, text_token_type_ids, text_visual_embeddings,
            text_mask, object_vl_embeddings, box_mask)

        ###########################################

        # Visual Linguistic BERT - Decoder
        relationship_logits, mlm_logits, mvrc_logits = self.decoder(
            text_input_ids_de, text_token_type_ids_de,
            text_visual_embeddings_de, text_mask_de, object_vl_embeddings_de,
            box_mask, encoder_hidden_states)

        ###########################################
        outputs = {}

        # loss
        relationship_loss = im_info.new_zeros(())
        mlm_loss = im_info.new_zeros(())
        mvrc_loss = im_info.new_zeros(())
        if self.config.NETWORK.WITH_REL_LOSS:
            relationship_loss = F.cross_entropy(relationship_logits,
                                                relationship_label)
        if self.config.NETWORK.WITH_MLM_LOSS:
            mlm_logits_padded = mlm_logits.new_zeros(
                (*mlm_labels_de.shape, mlm_logits.shape[-1])).fill_(-10000.0)
            mlm_logits_padded[:, :mlm_logits.shape[1]] = mlm_logits
            mlm_logits = mlm_logits_padded
            if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST:
                mlm_loss = F.cross_entropy(mlm_logits.transpose(1, 2),
                                           mlm_labels_de,
                                           ignore_index=-1,
                                           reduction='none')
                num_mlm = (mlm_labels_de != -1).sum(
                    1, keepdim=True).to(dtype=mlm_loss.dtype)
                num_has_mlm = (num_mlm != 0).sum().to(dtype=mlm_loss.dtype)
                mlm_loss = (mlm_loss /
                            (num_mlm + 1e-4)).sum() / (num_has_mlm + 1e-4)
            else:
                mlm_loss = F.cross_entropy(mlm_logits.view(
                    (-1, mlm_logits.shape[-1])),
                                           mlm_labels_de.view(-1),
                                           ignore_index=-1)
        # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
        #                             mvrc_labels.contiguous().view(-1),
        #                             ignore_index=-1)
        if self.config.NETWORK.WITH_MVRC_LOSS:
            if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]),
                    reduction='none').view(mvrc_logits.shape[:-1])
                valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1
                mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \
                                .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4)
            else:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]))

            mvrc_logits_padded = mvrc_logits.new_zeros(
                (mvrc_logits.shape[0], origin_len,
                 mvrc_logits.shape[2])).fill_(-10000.0)
            mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits
            mvrc_logits = mvrc_logits_padded
            mvrc_labels_padded = mvrc_labels.new_zeros(
                (mvrc_labels.shape[0], origin_len,
                 mvrc_labels.shape[2])).fill_(0.0)
            mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels
            mvrc_labels = mvrc_labels_padded

        outputs.update({
            'relationship_logits':
            relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None,
            'relationship_label':
            relationship_label if self.config.NETWORK.WITH_REL_LOSS else None,
            'mlm_logits':
            mlm_logits if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_label':
            mlm_labels_de if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mvrc_logits':
            mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'mvrc_label':
            mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'relationship_loss':
            relationship_loss,
            'mlm_loss':
            mlm_loss,
            'mvrc_loss':
            mvrc_loss,
        })

        loss = relationship_loss.mean() + mlm_loss.mean() + mvrc_loss.mean()

        return outputs, loss
class ResNetVLBERTForPretrainingNoVision(Module):
    def __init__(self, config):

        super(ResNetVLBERTForPretrainingNoVision, self).__init__(config)

        self.image_feature_extractor = FastRCNN(
            config,
            average_pool=True,
            final_dim=config.NETWORK.IMAGE_FINAL_DIM,
            enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        if config.NETWORK.IMAGE_FEAT_PRECOMPUTED:
            self.object_mask_visual_embedding = nn.Embedding(1, 2048)
        if config.NETWORK.WITH_MVRC_LOSS:
            self.object_mask_word_embedding = nn.Embedding(
                1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)
        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path

        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBertForPretraining(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=None
            if config.NETWORK.VLBERT.from_scratch else
            language_pretrained_model_path,
            with_rel_head=config.NETWORK.WITH_REL_LOSS,
            with_mlm_head=config.NETWORK.WITH_MLM_LOSS,
            with_mvrc_head=config.NETWORK.WITH_MVRC_LOSS,
        )

        # init weights
        self.init_weight()

        self.fix_params()

    def init_weight(self):
        if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED:
            self.object_mask_visual_embedding.weight.data.fill_(0.0)
        if self.config.NETWORK.WITH_MVRC_LOSS:
            self.object_mask_word_embedding.weight.data.normal_(
                mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)
        self.image_feature_extractor.init_weight()
        if self.object_linguistic_embeddings is not None:
            self.object_linguistic_embeddings.weight.data.normal_(
                mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)

    def train(self, mode=True):
        super(ResNetVLBERTForPretrainingNoVision, self).train(mode)
        # turn some frozen layers to eval mode
        if self.image_feature_bn_eval:
            self.image_feature_extractor.bn_eval()

    def fix_params(self):
        pass

    def _collect_obj_reps(self, span_tags, object_reps):
        """
        Collect span-level object representations
        :param span_tags: [batch_size, ..leading_dims.., L]
        :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim]
        :return:
        """

        span_tags_fixed = torch.clamp(
            span_tags, min=0)  # In case there were masked values here
        row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape)
        row_id_broadcaster = torch.arange(0,
                                          row_id.shape[0],
                                          step=1,
                                          device=row_id.device)[:, None]

        # Add extra diminsions to the row broadcaster so it matches row_id
        leading_dims = len(span_tags.shape) - 2
        for i in range(leading_dims):
            row_id_broadcaster = row_id_broadcaster[..., None]
        row_id += row_id_broadcaster
        return object_reps[row_id.view(-1),
                           span_tags_fixed.view(-1)].view(
                               *span_tags_fixed.shape, -1)

    def forward(self, text, relationship_label, mlm_labels):
        ###########################################

        # Blank out visual feature extraction

        ############################################

        # prepare text
        text_input_ids = text
        # creates a text_tags tensor of the same shape as text tensor
        text_tags = text.new_zeros(text.shape)
        # ***** FM edit: blank out visual embeddings for translation retrieval task
        text_visual_embeddings = text_input_ids.new_zeros(
            (text_input_ids.shape[0], text_input_ids.shape[1], 768),
            dtype=torch.float)
        # text_visual_embeddings[:] = self.aux_text_visual_embedding.weight[0]

        # ****** FM edit: blank visual embeddings (use known dimensions)
        object_vl_embeddings = text_input_ids.new_zeros(
            (text_input_ids.shape[0], 1, 1536), dtype=torch.float)

        # FM edit: No auxiliary text is used for text only
        # add auxiliary text - Concatenates the batches from the two dataloaders
        # The visual features for the text only corpus is just the embedding of the aux_visual_embedding (only one embedding)
        max_text_len = text_input_ids.shape[1]
        text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape)
        text_mask = (text_input_ids > 0)
        #FM: Edit: set to zero to ignore vision
        box_mask = text_input_ids.new_zeros((text_input_ids.shape[0], 1),
                                            dtype=torch.uint8)

        ###########################################

        # Visual Linguistic BERT
        relationship_logits, mlm_logits, mvrc_logits = self.vlbert(
            text_input_ids, text_token_type_ids, text_visual_embeddings,
            text_mask, object_vl_embeddings, box_mask)

        ###########################################
        outputs = {}

        # losses
        if self.config.NETWORK.WITH_REL_LOSS:
            relationship_loss = F.cross_entropy(relationship_logits,
                                                relationship_label)
        if self.config.NETWORK.WITH_MLM_LOSS:
            mlm_logits_padded = mlm_logits.new_zeros(
                (*mlm_labels.shape, mlm_logits.shape[-1])).fill_(-10000.0)
            mlm_logits_padded[:, :mlm_logits.shape[1]] = mlm_logits
            mlm_logits = mlm_logits_padded
            if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST:
                mlm_loss = F.cross_entropy(mlm_logits.transpose(1, 2),
                                           mlm_labels,
                                           ignore_index=-1,
                                           reduction='none')
                num_mlm = (mlm_labels != -1).sum(
                    1, keepdim=True).to(dtype=mlm_loss.dtype)
                num_has_mlm = (num_mlm != 0).sum().to(dtype=mlm_loss.dtype)
                mlm_loss = (mlm_loss /
                            (num_mlm + 1e-4)).sum() / (num_has_mlm + 1e-4)
            else:
                mlm_loss = F.cross_entropy(mlm_logits.view(
                    (-1, mlm_logits.shape[-1])),
                                           mlm_labels.view(-1),
                                           ignore_index=-1)

        if self.config.NETWORK.WITH_MVRC_LOSS:
            if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]),
                    reduction='none').view(mvrc_logits.shape[:-1])
                valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1
                mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \
                                .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4)
            else:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]))

            mvrc_logits_padded = mvrc_logits.new_zeros(
                (mvrc_logits.shape[0], origin_len,
                 mvrc_logits.shape[2])).fill_(-10000.0)
            mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits
            mvrc_logits = mvrc_logits_padded
            mvrc_labels_padded = mvrc_labels.new_zeros(
                (mvrc_labels.shape[0], origin_len,
                 mvrc_labels.shape[2])).fill_(0.0)
            mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels
            mvrc_labels = mvrc_labels_padded

        outputs.update({
            'relationship_logits':
            relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None,
            'relationship_label':
            relationship_label if self.config.NETWORK.WITH_REL_LOSS else None,
            'mlm_logits':
            mlm_logits if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_label':
            mlm_labels if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mvrc_logits':
            mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'mvrc_label':
            mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'mlm_loss':
            mlm_loss,
        })

        loss = mlm_loss.mean()

        return outputs, loss
Ejemplo n.º 19
0
class ResNetVLBERT(Module):
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.predict_on_cls = config.NETWORK.VLBERT.predict_on_cls  # make prediction on [CLS]?

        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(
                config,
                average_pool=True,
                final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                enable_cnn_reg_loss=self.enable_cnn_reg_loss)
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(
                    81, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:  # default: class-agnostic
                self.object_linguistic_embeddings = nn.Embedding(
                    1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        dim = config.NETWORK.VLBERT.hidden_size
        if config.NETWORK.CLASSIFIER_TYPE == "2fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE),
                torch.nn.ReLU(inplace=True),
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE,
                                config.DATASET.ANSWER_VOCAB_SIZE),
            )
        elif config.NETWORK.CLASSIFIER_TYPE == "1fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, config.DATASET.ANSWER_VOCAB_SIZE))
        elif config.NETWORK.CLASSIFIER_TYPE == 'mlm':
            transform = BertPredictionHeadTransform(config.NETWORK.VLBERT)
            linear = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                               config.DATASET.ANSWER_VOCAB_SIZE)
            self.final_mlp = nn.Sequential(
                transform,
                nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                linear)
        else:
            raise ValueError("Not support classifier type: {}!".format(
                config.NETWORK.CLASSIFIER_TYPE))

        self.use_spatial_model = False
        if config.NETWORK.USE_SPATIAL_MODEL:
            self.use_spatial_model = True
            # self.simple_spatial_model = SimpleSpatialModel(4, config.NETWORK.VLBERT.hidden_size, 9, config)

            self.use_coord_vector = False
            if config.NETWORK.USE_COORD_VECTOR:
                self.use_coord_vector = True
                self.loc_fcs = nn.Sequential(
                    nn.Linear(2 * 5 + 9, config.NETWORK.VLBERT.hidden_size),
                    nn.ReLU(True),
                    nn.Linear(config.NETWORK.VLBERT.hidden_size,
                              config.NETWORK.VLBERT.hidden_size))
            else:
                self.simple_spatial_model = SimpleSpatialModel(
                    4, config.NETWORK.VLBERT.hidden_size, 9)

            self.spa_add = True if config.NETWORK.SPA_ADD else False
            self.spa_concat = True if config.NETWORK.SPA_CONCAT else False

            if self.spa_add:
                self.spa_feat_weight = 0.5
                if config.NETWORK.USE_SPA_WEIGHT:
                    self.spa_feat_weight = config.NETWORK.SPA_FEAT_WEIGHT
                self.spa_fusion_linear = nn.Linear(
                    config.NETWORK.VLBERT.hidden_size,
                    config.NETWORK.VLBERT.hidden_size)
            elif self.spa_concat:
                if self.use_coord_vector:
                    self.spa_fusion_linear = nn.Linear(
                        config.NETWORK.VLBERT.hidden_size +
                        config.NETWORK.VLBERT.hidden_size,
                        config.NETWORK.VLBERT.hidden_size)
                else:
                    self.spa_fusion_linear = nn.Linear(
                        config.NETWORK.VLBERT.hidden_size * 2,
                        config.NETWORK.VLBERT.hidden_size)
            self.spa_linear = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                                        config.NETWORK.VLBERT.hidden_size)
            self.dropout = nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT)

            self.spa_one_more_layer = config.NETWORK.SPA_ONE_MORE_LAYER
            if self.spa_one_more_layer:
                self.spa_linear_hidden = nn.Linear(
                    config.NETWORK.VLBERT.hidden_size,
                    config.NETWORK.VLBERT.hidden_size)

        self.enhanced_img_feature = False
        if config.NETWORK.VLBERT.ENHANCED_IMG_FEATURE:
            self.enhanced_img_feature = True
            self.mask_weight = config.NETWORK.VLBERT.mask_weight
            self.mask_loss_sum = config.NETWORK.VLBERT.mask_loss_sum
            self.mask_loss_mse = config.NETWORK.VLBERT.mask_loss_mse
            self.no_predicate = config.NETWORK.VLBERT.NO_PREDICATE

        self.all_proposals_test = False
        if config.DATASET.ALL_PROPOSALS_TEST:
            self.all_proposals_test = True

        self.use_uvtranse = False
        if config.NETWORK.USE_UVTRANSE:
            self.use_uvtranse = True
            self.union_vec_fc = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                                          config.NETWORK.VLBERT.hidden_size)
            self.uvt_add = True if config.NETWORK.UVT_ADD else False
            self.uvt_concat = True if config.NETWORK.UVT_CONCAT else False
            if not (self.uvt_add ^ self.uvt_concat):
                assert False
            if self.uvt_add:
                self.uvt_feat_weight = config.NETWORK.UVT_FEAT_WEIGHT
                self.uvt_fusion_linear = nn.Linear(
                    config.NETWORK.VLBERT.hidden_size,
                    config.NETWORK.VLBERT.hidden_size)
            elif self.uvt_concat:
                self.uvt_fusion_linear = nn.Linear(
                    config.NETWORK.VLBERT.hidden_size * 2,
                    config.NETWORK.VLBERT.hidden_size)
            self.uvt_linear = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                                        config.NETWORK.VLBERT.hidden_size)
            self.dropout_uvt = nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT)

        # init weights
        self.init_weight()

    def init_weight(self):
        # self.hm_out.weight.data.normal_(mean=0.0, std=0.02)
        # self.hm_out.bias.data.zero_()
        # self.hi_out.weight.data.normal_(mean=0.0, std=0.02)
        # self.hi_out.bias.data.zero_()
        self.image_feature_extractor.init_weight()
        if self.object_linguistic_embeddings is not None:
            self.object_linguistic_embeddings.weight.data.normal_(mean=0.0,
                                                                  std=0.02)
        for m in self.final_mlp.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                torch.nn.init.constant_(m.bias, 0)
        if self.config.NETWORK.CLASSIFIER_TYPE == 'mlm':
            language_pretrained = torch.load(
                self.language_pretrained_model_path)
            mlm_transform_state_dict = {}
            pretrain_keys = []
            for k, v in language_pretrained.items():
                if k.startswith('cls.predictions.transform.'):
                    pretrain_keys.append(k)
                    k_ = k[len('cls.predictions.transform.'):]
                    if 'gamma' in k_:
                        k_ = k_.replace('gamma', 'weight')
                    if 'beta' in k_:
                        k_ = k_.replace('beta', 'bias')
                    mlm_transform_state_dict[k_] = v
            print("loading pretrained classifier transform keys: {}.".format(
                pretrain_keys))
            self.final_mlp[0].load_state_dict(mlm_transform_state_dict)

    def train(self, mode=True):
        super(ResNetVLBERT, self).train(mode)
        # turn some frozen layers to eval mode
        if self.image_feature_bn_eval:
            self.image_feature_extractor.bn_eval()

    def _collect_obj_reps(self, span_tags, object_reps, spo_len):
        """
        Collect span-level object representations
        :param span_tags: [batch_size, ..leading_dims.., L]
        :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim]
        :return:
        """

        span_tags_fixed = torch.clamp(
            span_tags, min=0)  # In case there were masked values here
        row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape)
        row_id_broadcaster = torch.arange(0,
                                          row_id.shape[0],
                                          step=1,
                                          device=row_id.device)[:, None]

        # Add extra diminsions to the row broadcaster so it matches row_id
        leading_dims = len(span_tags.shape) - 2
        for i in range(leading_dims):
            row_id_broadcaster = row_id_broadcaster[..., None]
        row_id += row_id_broadcaster

        if self.enhanced_img_feature:
            # for i in range(span_tags_fixed.shape[0]):
            #     span_tags_fixed[i, 1:1 + spo_len[i, 0]] = 1
            #     span_tags_fixed[i, 1 + spo_len[i, 0]:1 + spo_len[i, 0] + spo_len[i, 1]] = 2
            #     span_tags_fixed[i, 1 + spo_len[i, 0] + spo_len[i, 1]:1 + spo_len[i, 0] + spo_len[i, 1] + spo_len[i, 2]] = 3
            pass

        text_visual_embeddings = object_reps[row_id.view(-1),
                                             span_tags_fixed.view(-1)].view(
                                                 *span_tags_fixed.shape, -1)

        return text_visual_embeddings

    def prepare_text_from_qa(self, question, question_tags, question_mask,
                             answer, answer_tags, answer_mask):
        batch_size, max_q_len = question.shape
        _, max_a_len = answer.shape

        if self.predict_on_cls:
            answer_mask = answer_mask.new_zeros(
                answer_mask.shape)  # remove answer_mask
            max_len = (question_mask.sum(1) +
                       answer_mask.sum(1)).max() + 2  # [CLS] & 1*[SEP]
        else:
            max_len = (question_mask.sum(1) +
                       answer_mask.sum(1)).max() + 3  # [CLS] & 2*[SEP]
        cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(
            ['[CLS]', '[SEP]'])
        q_end = 1 + question_mask.sum(1, keepdim=True)
        a_end = q_end if self.predict_on_cls else q_end + 1 + answer_mask.sum(
            1, keepdim=True)
        input_ids = torch.zeros((batch_size, max_len),
                                dtype=question.dtype,
                                device=question.device)
        input_mask = torch.ones((batch_size, max_len),
                                dtype=torch.uint8,
                                device=question.device)
        input_type_ids = torch.zeros((batch_size, max_len),
                                     dtype=question.dtype,
                                     device=question.device)
        text_tags = input_type_ids.new_zeros((batch_size, max_len))
        grid_i, grid_j = torch.meshgrid(
            torch.arange(batch_size, device=question.device),
            torch.arange(max_len, device=question.device))

        input_mask[grid_j > a_end] = 0
        if not self.predict_on_cls:
            input_type_ids[(grid_j > q_end) & (grid_j <= a_end)] = 1
        q_input_mask = (grid_j > 0) & (grid_j < q_end)
        a_input_mask = (grid_j > q_end) & (grid_j < a_end)
        input_ids[:, 0] = cls_id
        input_ids[grid_j == q_end] = sep_id
        input_ids[grid_j == a_end] = sep_id
        input_ids[q_input_mask] = question[question_mask]
        input_ids[a_input_mask] = answer[answer_mask]
        text_tags[q_input_mask] = question_tags[question_mask]
        text_tags[a_input_mask] = answer_tags[answer_mask]

        ans_pos = a_end.new_zeros(
            a_end.shape).squeeze(1) if self.predict_on_cls else (a_end -
                                                                 1).squeeze(1)

        return input_ids, input_type_ids, text_tags, input_mask, ans_pos

    def train_forward(self, img, im_info, boxes, labels, spo_ids, spo_lens,
                      img_path):
        boxes, labels, spo_ids, spo_lens, im_info = boxes.squeeze(
            0), labels.squeeze(0), spo_ids.squeeze(0), spo_lens.squeeze(
                0), im_info.squeeze(0)

        images = torch.cat([
            img for _ in range(boxes.shape[0])
        ])  # (Pdb) images.shape = torch.Size([4, 3, 895, 899])
        box_mask = (boxes[:, :, 0] > -1.5
                    )  # (Pdb) box_mask.shape = torch.Size([4, 54])
        max_len = int(box_mask.sum(1).max().item())  # max_len = 54
        box_mask = box_mask[:, :max_len]  # doesn't seem to have effect
        boxes = boxes[:, :max_len]  # doesn't seem to have effect

        boxes[boxes < 0] = 0  # rectify those coordinates < 0 to 0

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=boxes,
                                                box_mask=box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None,
                                                copy_images=True)
        # obj_reps['obj_reps'].shape = torch.Size([4, 54, 768])

        question_ids = spo_ids
        question_tags = spo_ids.new_zeros(question_ids.shape)
        question_mask = (spo_ids > 0.5)

        answer_ids = question_ids.new_zeros(
            (question_ids.shape[0],
             1)).fill_(self.tokenizer.convert_tokens_to_ids(['[MASK]'])[0])
        answer_mask = question_mask.new_zeros(answer_ids.shape).fill_(1)
        answer_tags = question_tags.new_zeros(answer_ids.shape)

        ############################################

        # prepare text
        text_input_ids, text_token_type_ids, text_tags, text_mask, ans_pos = self.prepare_text_from_qa(
            question_ids, question_tags, question_mask, answer_ids,
            answer_tags, answer_mask)
        if self.config.NETWORK.NO_GROUNDING:  # always False
            obj_rep_zeroed = obj_reps['obj_reps'].new_zeros(
                obj_reps['obj_reps'].shape)
            text_tags.zero_()
            text_visual_embeddings = self._collect_obj_reps(
                text_tags, obj_rep_zeroed)
        else:
            text_visual_embeddings = self._collect_obj_reps(
                text_tags, obj_reps['obj_reps'], spo_lens)

        assert self.config.NETWORK.VLBERT.object_word_embed_mode == 2
        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())

        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings),
            -1)  # concatenation of obj visual & linguistic

        ###########################################

        # Visual Linguistic BERT

        hidden_states, hc, spo_fused_masks = self.vlbert(
            text_input_ids,
            text_token_type_ids,
            text_visual_embeddings,
            text_mask,
            object_vl_embeddings,
            box_mask,
            object_visual_feat=obj_reps['obj_reps_rawraw'],
            spo_len=spo_lens,
            output_all_encoded_layers=False)
        _batch_inds = torch.arange(spo_ids.shape[0], device=spo_ids.device)

        hm = hidden_states[_batch_inds, ans_pos]

        if self.use_spatial_model:
            if self.use_coord_vector:
                # import pdb; pdb.set_trace()
                spa_feat = torch.zeros((boxes.shape[0], 5 * 2 + 9),
                                       dtype=boxes.dtype,
                                       layout=boxes.layout,
                                       device=boxes.device)
                for i in range(boxes.shape[0]):
                    area_subj_ratio = (boxes[i, 1, 2] - boxes[i, 1, 0]) * (
                        boxes[i, 1, 3] - boxes[i, 1, 1]) / (im_info[i, 0] *
                                                            im_info[i, 1])
                    subj = torch.tensor([
                        boxes[i, 1, 0] / im_info[i, 0],
                        boxes[i, 1, 1] / im_info[i, 1],
                        boxes[i, 1, 2] / im_info[i, 0],
                        boxes[i, 1, 3] / im_info[i, 1], area_subj_ratio
                    ])

                    area_pred_ratio = (boxes[i, 2, 2] - boxes[i, 2, 0]) * (
                        boxes[i, 2, 3] - boxes[i, 2, 1]) / (im_info[i, 0] *
                                                            im_info[i, 1])
                    w_s = (boxes[i, 1, 2] - boxes[i, 1, 0])
                    h_s = (boxes[i, 1, 3] - boxes[i, 1, 1])
                    x_s = (boxes[i, 1, 2] + boxes[i, 1, 0]) / 2
                    y_s = (boxes[i, 1, 3] + boxes[i, 1, 1]) / 2
                    w_o = (boxes[i, 3, 2] - boxes[i, 3, 0])
                    h_o = (boxes[i, 3, 3] - boxes[i, 3, 1])
                    x_o = (boxes[i, 3, 2] + boxes[i, 3, 0]) / 2
                    y_o = (boxes[i, 3, 3] + boxes[i, 3, 1]) / 2
                    pred = torch.tensor([(x_s - x_o) / w_o, (y_s - y_o) / h_o,
                                         torch.log(w_s / w_o),
                                         torch.log(h_s / h_o),
                                         (x_o - x_s) / w_s, (y_o - y_s) / h_s,
                                         torch.log(w_o / w_s),
                                         torch.log(h_o / h_s),
                                         area_pred_ratio])

                    area_obj_ratio = (boxes[i, 3, 2] - boxes[i, 3, 0]) * (
                        boxes[i, 3, 3] - boxes[i, 3, 1]) / (im_info[i, 0] *
                                                            im_info[i, 1])
                    obj = torch.tensor([
                        boxes[i, 3, 0] / im_info[i, 0],
                        boxes[i, 3, 1] / im_info[i, 1],
                        boxes[i, 3, 2] / im_info[i, 0],
                        boxes[i, 3, 3] / im_info[i, 1], area_obj_ratio
                    ])

                    spa_feat[0] = torch.cat((subj, pred, obj)).unsqueeze(0)
                spa_feat = self.loc_fcs(spa_feat)
                # assert self.spa_concat # Currently coord_vec only works with concatenation!
            else:
                for i in range(boxes.shape[0]):
                    boxes[:, :, 0][i] /= im_info[:, 0][i]
                    boxes[:, :, 1][i] /= im_info[:, 1][i]
                    boxes[:, :, 2][i] /= im_info[:, 0][i]
                    boxes[:, :, 3][i] /= im_info[:, 1][i]
                spa_feat = self.simple_spatial_model(boxes[:, 1], boxes[:, 3],
                                                     labels)

            if self.spa_add:
                hm = hm * (
                    1 - self.spa_feat_weight) + spa_feat * self.spa_feat_weight
            elif self.spa_concat:
                hm = torch.cat((hm, spa_feat), dim=1)
            hm = self.spa_fusion_linear(hm)
            hm = F.relu(hm)
            hm = self.dropout(hm)

            if self.spa_one_more_layer:  # if no unfrozen VLBERT add one more layer and lower the dropout rate to 0.2
                hm = self.spa_linear_hidden(hm)
                hm = F.relu(hm)

            hm = self.spa_linear(hm)

        if self.use_uvtranse:
            union_vec = obj_reps['obj_reps'][:, 2] - obj_reps[
                'obj_reps'][:,
                            1] - obj_reps['obj_reps'][:,
                                                      3]  # pred - subj - obj
            union_vec = self.union_vec_fc(union_vec)
            union_vec = F.relu(union_vec)

            if self.uvt_add:
                hm = hm * (1 - self.uvt_feat_weight
                           ) + union_vec * self.uvt_feat_weight
            elif self.uvt_concat:
                hm = torch.cat((hm, union_vec), dim=1)

            hm = self.uvt_fusion_linear(hm)
            hm = F.relu(hm)
            hm = self.dropout_uvt(hm)
            hm = self.uvt_linear(hm)

            # import pdb; pdb.set_trace()

        ###########################################
        outputs = {}

        # classifier
        logits = self.final_mlp(hm)

        # loss
        # import pdb; pdb.set_trace()
        ans_loss = F.cross_entropy(logits, labels.view(-1))  # * label.size(1)

        # Add sigmoid for binary prediction in spasen_metrics.py
        logits = F.softmax(logits, dim=1)

        # mask loss
        if spo_fused_masks is not None:
            nb_of_tokens = 2 if self.no_predicate else 3
            spo_fused_masks = spo_fused_masks.view(-1, nb_of_tokens, 14, 14)
            # spo_fused_masks_norm = spo_fused_masks.new_zeros(size=spo_fused_masks.shape)
            boxes_mask = torch.zeros_like(spo_fused_masks)
            rounded_14x14_boxes = torch.round(boxes * 14).to(torch.int)
            for i in range(boxes.shape[0]):  # for each sample
                for j in range(nb_of_tokens):  # sub, pred, obj
                    # Create a mask
                    boxes_mask[i, j, rounded_14x14_boxes[
                        i, j + 1, 0].item():rounded_14x14_boxes[i, j + 1,
                                                                2].item(),
                               rounded_14x14_boxes[
                                   i, j + 1,
                                   1].item():rounded_14x14_boxes[i, j + 1,
                                                                 3].item()] = 1

            if self.mask_loss_sum:
                mask_loss = F.binary_cross_entropy_with_logits(
                    spo_fused_masks, boxes_mask,
                    reduction='sum') / spo_fused_masks.shape[0]
            elif self.mask_loss_mse:
                mask_loss = F.mse_loss(spo_fused_masks, boxes_mask)
            else:
                mask_loss = F.binary_cross_entropy_with_logits(
                    spo_fused_masks, boxes_mask)

            outputs.update({
                'label_logits': logits,
                'label': labels,
                'ans_loss': ans_loss,
                'mask_loss': mask_loss
            })

            if self.mask_weight < 0:
                loss = (ans_loss + mask_loss).mean()
            else:
                loss = (ans_loss * (1 - self.mask_weight) +
                        mask_loss * self.mask_weight).mean()
        else:
            outputs.update({
                'label_logits': logits,
                'label': labels,
                'ans_loss': ans_loss
            })
            loss = ans_loss.mean()

        return outputs, loss

    def inference_forward(self, img, im_info, boxes, labels, spo_ids, spo_lens,
                          img_path, rels_cand, labels_so_ids,
                          subj_obj_classes):
        boxes, labels, spo_ids, spo_lens, im_info, rels_cand, labels_so_ids, subj_obj_classes = boxes.squeeze(
            0), labels.squeeze(0), spo_ids.squeeze(0), spo_lens.squeeze(
                0), im_info.squeeze(0), rels_cand.squeeze(
                    0), labels_so_ids.squeeze(0), subj_obj_classes.squeeze(0)

        # visual feature extraction
        images = torch.cat([img for _ in range(boxes.shape[0])])
        box_mask = (boxes[:, :, 0] > -1.5)
        max_len = int(box_mask.sum(1).max().item())
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=boxes,
                                                box_mask=box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None,
                                                copy_images=True)

        question_ids = spo_ids
        question_tags = spo_ids.new_zeros(question_ids.shape)
        question_mask = (spo_ids > 0.5)

        answer_ids = question_ids.new_zeros(
            (question_ids.shape[0],
             1)).fill_(self.tokenizer.convert_tokens_to_ids(['[MASK]'])[0])
        answer_mask = question_mask.new_zeros(answer_ids.shape).fill_(1)
        answer_tags = question_tags.new_zeros(answer_ids.shape)

        ############################################

        # prepare text
        text_input_ids, text_token_type_ids, text_tags, text_mask, ans_pos = self.prepare_text_from_qa(
            question_ids, question_tags, question_mask, answer_ids,
            answer_tags, answer_mask)
        if self.config.NETWORK.NO_GROUNDING:
            obj_rep_zeroed = obj_reps['obj_reps'].new_zeros(
                obj_reps['obj_reps'].shape)
            text_tags.zero_()
            text_visual_embeddings = self._collect_obj_reps(
                text_tags, obj_rep_zeroed)
        else:
            text_visual_embeddings = self._collect_obj_reps(
                text_tags, obj_reps['obj_reps'], spo_lens)

        assert self.config.NETWORK.VLBERT.object_word_embed_mode == 2
        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT

        hidden_states, hc, spo_fused_masks = self.vlbert(
            text_input_ids,
            text_token_type_ids,
            text_visual_embeddings,
            text_mask,
            object_vl_embeddings,
            box_mask,
            object_visual_feat=obj_reps['obj_reps_rawraw'],
            spo_len=spo_lens,
            output_all_encoded_layers=False)
        _batch_inds = torch.arange(spo_ids.shape[0], device=spo_ids.device)

        hm = hidden_states[_batch_inds, ans_pos]

        if self.use_spatial_model:
            if self.use_coord_vector:
                # import pdb; pdb.set_trace()
                spa_feat = torch.zeros((boxes.shape[0], 5 * 2 + 9),
                                       dtype=boxes.dtype,
                                       layout=boxes.layout,
                                       device=boxes.device)
                for i in range(boxes.shape[0]):
                    area_subj_ratio = (boxes[i, 1, 2] - boxes[i, 1, 0]) * (
                        boxes[i, 1, 3] - boxes[i, 1, 1]) / (im_info[i, 0] *
                                                            im_info[i, 1])
                    subj = torch.tensor([
                        boxes[i, 1, 0] / im_info[i, 0],
                        boxes[i, 1, 1] / im_info[i, 1],
                        boxes[i, 1, 2] / im_info[i, 0],
                        boxes[i, 1, 3] / im_info[i, 1], area_subj_ratio
                    ])

                    area_pred_ratio = (boxes[i, 2, 2] - boxes[i, 2, 0]) * (
                        boxes[i, 2, 3] - boxes[i, 2, 1]) / (im_info[i, 0] *
                                                            im_info[i, 1])
                    w_s = (boxes[i, 1, 2] - boxes[i, 1, 0])
                    h_s = (boxes[i, 1, 3] - boxes[i, 1, 1])
                    x_s = (boxes[i, 1, 2] + boxes[i, 1, 0]) / 2
                    y_s = (boxes[i, 1, 3] + boxes[i, 1, 1]) / 2
                    w_o = (boxes[i, 3, 2] - boxes[i, 3, 0])
                    h_o = (boxes[i, 3, 3] - boxes[i, 3, 1])
                    x_o = (boxes[i, 3, 2] + boxes[i, 3, 0]) / 2
                    y_o = (boxes[i, 3, 3] + boxes[i, 3, 1]) / 2
                    pred = torch.tensor([(x_s - x_o) / w_o, (y_s - y_o) / h_o,
                                         torch.log(w_s / w_o),
                                         torch.log(h_s / h_o),
                                         (x_o - x_s) / w_s, (y_o - y_s) / h_s,
                                         torch.log(w_o / w_s),
                                         torch.log(h_o / h_s),
                                         area_pred_ratio])

                    area_obj_ratio = (boxes[i, 3, 2] - boxes[i, 3, 0]) * (
                        boxes[i, 3, 3] - boxes[i, 3, 1]) / (im_info[i, 0] *
                                                            im_info[i, 1])
                    obj = torch.tensor([
                        boxes[i, 3, 0] / im_info[i, 0],
                        boxes[i, 3, 1] / im_info[i, 1],
                        boxes[i, 3, 2] / im_info[i, 0],
                        boxes[i, 3, 3] / im_info[i, 1], area_obj_ratio
                    ])

                    spa_feat[0] = torch.cat((subj, pred, obj)).unsqueeze(0)
                spa_feat = self.loc_fcs(spa_feat)
                # assert self.spa_concat # Currently coord_vec only works with concatenation!
            else:
                for i in range(boxes.shape[0]):
                    boxes[:, :, 0][i] /= im_info[:, 0][i]
                    boxes[:, :, 1][i] /= im_info[:, 1][i]
                    boxes[:, :, 2][i] /= im_info[:, 0][i]
                    boxes[:, :, 3][i] /= im_info[:, 1][i]
                spa_feat = self.simple_spatial_model(boxes[:, 1], boxes[:, 3],
                                                     labels)

            if self.spa_add:
                hm = hm * (
                    1 - self.spa_feat_weight) + spa_feat * self.spa_feat_weight
            elif self.spa_concat:
                hm = torch.cat((hm, spa_feat), dim=1)
            hm = self.spa_fusion_linear(hm)
            hm = F.relu(hm)
            hm = self.dropout(hm)

            if self.spa_one_more_layer:  # if no unfrozen VLBERT add one more layer and lower the dropout rate to 0.2
                hm = self.spa_linear_hidden(hm)
                hm = F.relu(hm)

            hm = self.spa_linear(hm)

        if self.use_uvtranse:
            union_vec = obj_reps['obj_reps'][:, 2] - obj_reps[
                'obj_reps'][:,
                            1] - obj_reps['obj_reps'][:,
                                                      3]  # pred - subj - obj
            union_vec = self.union_vec_fc(union_vec)
            union_vec = F.relu(union_vec)

            if self.uvt_add:
                hm = hm * (1 - self.uvt_feat_weight
                           ) + union_vec * self.uvt_feat_weight
            elif self.uvt_concat:
                hm = torch.cat((hm, union_vec), dim=1)

            hm = self.uvt_fusion_linear(hm)
            hm = F.relu(hm)
            hm = self.dropout_uvt(hm)
            hm = self.uvt_linear(hm)

        ###########################################
        outputs = {}

        # classifier
        logits = self.final_mlp(hm)
        logits = F.softmax(logits, dim=1)

        # mask loss
        if spo_fused_masks is not None:
            nb_of_tokens = 2 if self.no_predicate else 3
            spo_fused_masks = spo_fused_masks.view(-1, nb_of_tokens, 14, 14)

            boxes_mask = boxes.new_zeros(size=(boxes.shape[0], nb_of_tokens,
                                               14, 14))
            rounded_14x14_boxes = torch.round(boxes * 14).to(torch.int)
            for i in range(boxes.shape[0]):  # for each sample
                for j in range(nb_of_tokens):  # sub, pred, obj
                    # Create a mask
                    boxes_mask[i, j, rounded_14x14_boxes[
                        i, j + 1, 0].item():rounded_14x14_boxes[i, j + 1,
                                                                2].item(),
                               rounded_14x14_boxes[
                                   i, j + 1,
                                   1].item():rounded_14x14_boxes[i, j + 1,
                                                                 3].item()] = 1

            if self.mask_loss_sum:
                mask_loss = F.binary_cross_entropy_with_logits(
                    spo_fused_masks, boxes_mask,
                    reduction='sum') / spo_fused_masks.shape[0]
            elif self.mask_loss_mse:
                mask_loss = F.mse_loss(spo_fused_masks, boxes_mask)
                # import pdb; pdb.set_trace()
                # self.show_cam_on_image(spo_fused_masks, img_path)
            else:
                mask_loss = F.binary_cross_entropy_with_logits(
                    spo_fused_masks, boxes_mask)

            outputs.update({
                'label_logits': logits,
                'label': labels,
                'labels_so_ids': labels_so_ids,
                'rels_cand': rels_cand,
                'mask_loss': mask_loss,
                'img_path': img_path,
                'spo_fused_masks': spo_fused_masks,
                'subj_obj_classes': subj_obj_classes,
                'prediction': logits.argmax(1)
            })
        else:
            outputs.update({
                'label_logits': logits,
                'label': labels,
                'labels_so_ids': labels_so_ids,
                'rels_cand': rels_cand,
                'prediction': logits.argmax(1)
            })

        return outputs
Ejemplo n.º 20
0
class ResNetVLBERT(Module):
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        self.cnn_loss_top = config.NETWORK.CNN_LOSS_TOP
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(
                config,
                average_pool=True,
                final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                enable_cnn_reg_loss=(self.enable_cnn_reg_loss
                                     and not self.cnn_loss_top))
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(
                    81, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:
                self.object_linguistic_embeddings = nn.Embedding(
                    1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError
            if self.enable_cnn_reg_loss and self.cnn_loss_top:
                self.cnn_loss_reg = nn.Sequential(
                    VisualLinguisticBertMVRCHeadTransform(
                        config.NETWORK.VLBERT),
                    nn.Dropout(config.NETWORK.CNN_REG_DROPOUT, inplace=False),
                    nn.Linear(config.NETWORK.VLBERT.hidden_size, 81))
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        if 'roberta' in config.NETWORK.BERT_MODEL_NAME:
            self.tokenizer = RobertaTokenizer.from_pretrained(
                config.NETWORK.BERT_MODEL_NAME)
        else:
            self.tokenizer = BertTokenizer.from_pretrained(
                config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path

        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = TimeDistributed(
            VisualLinguisticBert(
                config.NETWORK.VLBERT,
                language_pretrained_model_path=language_pretrained_model_path))

        self.for_pretrain = config.NETWORK.FOR_MASK_VL_MODELING_PRETRAIN
        assert not self.for_pretrain, "Not implement pretrain mode now!"

        if not self.for_pretrain:
            dim = config.NETWORK.VLBERT.hidden_size
            if config.NETWORK.CLASSIFIER_TYPE == "2fc":
                self.final_mlp = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(dim,
                                    config.NETWORK.CLASSIFIER_HIDDEN_SIZE),
                    torch.nn.ReLU(inplace=True),
                    torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, 1),
                )
            elif config.NETWORK.CLASSIFIER_TYPE == "1fc":
                self.final_mlp = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                     inplace=False), torch.nn.Linear(dim, 1))
            else:
                raise ValueError("Not support classifier type: {}!".format(
                    config.NETWORK.CLASSIFIER_TYPE))

        # init weights
        self.init_weight()

        self.fix_params()

    def init_weight(self):
        if not self.config.NETWORK.BLIND:
            self.image_feature_extractor.init_weight()
            if self.object_linguistic_embeddings is not None:
                self.object_linguistic_embeddings.weight.data.normal_(mean=0.0,
                                                                      std=0.02)
            if self.enable_cnn_reg_loss and self.cnn_loss_top:
                self.cnn_loss_reg.apply(self.vlbert._module.init_weights)

        if not self.for_pretrain:
            for m in self.final_mlp.modules():
                if isinstance(m, torch.nn.Linear):
                    torch.nn.init.xavier_uniform_(m.weight)
                    torch.nn.init.constant_(m.bias, 0)

    def train(self, mode=True):
        super(ResNetVLBERT, self).train(mode)
        # turn some frozen layers to eval mode
        if (not self.config.NETWORK.BLIND) and self.image_feature_bn_eval:
            self.image_feature_extractor.bn_eval()

    def fix_params(self):
        if self.config.NETWORK.BLIND:
            self.vlbert._module.visual_scale_text.requires_grad = False
            self.vlbert._module.visual_scale_object.requires_grad = False

    def _collect_obj_reps(self, span_tags, object_reps):
        """
        Collect span-level object representations
        :param span_tags: [batch_size, ..leading_dims.., L]
        :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim]
        :return:
        """

        span_tags_fixed = torch.clamp(
            span_tags, min=0)  # In case there were masked values here
        row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape)
        row_id_broadcaster = torch.arange(0,
                                          row_id.shape[0],
                                          step=1,
                                          device=row_id.device)[:, None]

        # Add extra diminsions to the row broadcaster so it matches row_id
        leading_dims = len(span_tags.shape) - 2
        for i in range(leading_dims):
            row_id_broadcaster = row_id_broadcaster[..., None]
        row_id += row_id_broadcaster
        return object_reps[row_id.view(-1),
                           span_tags_fixed.view(-1)].view(
                               *span_tags_fixed.shape, -1)

    def prepare_text_from_qa(self, question, question_tags, question_mask,
                             answers, answers_tags, answers_mask):
        batch_size, max_q_len = question.shape
        _, num_choices, max_a_len = answers.shape
        max_len = (question_mask.sum(1) +
                   answers_mask.sum(2).max(1)[0]).max() + 3
        cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(
            ['[CLS]', '[SEP]'])
        question = question.repeat(1,
                                   num_choices).view(-1, num_choices,
                                                     max_q_len)
        question_mask = question_mask.repeat(1, num_choices).view(
            -1, num_choices, max_q_len)
        q_end = 1 + question_mask.sum(2, keepdim=True)
        a_end = q_end + 1 + answers_mask.sum(2, keepdim=True)
        input_ids = torch.zeros((batch_size, num_choices, max_len),
                                dtype=question.dtype,
                                device=question.device)
        input_mask = torch.ones((batch_size, num_choices, max_len),
                                dtype=torch.uint8,
                                device=question.device)
        input_type_ids = torch.zeros((batch_size, num_choices, max_len),
                                     dtype=question.dtype,
                                     device=question.device)
        text_tags = input_type_ids.new_zeros(
            (batch_size, num_choices, max_len))
        grid_i, grid_j, grid_k = torch.meshgrid(
            torch.arange(batch_size, device=question.device),
            torch.arange(num_choices, device=question.device),
            torch.arange(max_len, device=question.device))

        input_mask[grid_k > a_end] = 0
        input_type_ids[(grid_k > q_end) & (grid_k <= a_end)] = 1
        q_input_mask = (grid_k > 0) & (grid_k < q_end)
        a_input_mask = (grid_k > q_end) & (grid_k < a_end)
        input_ids[:, :, 0] = cls_id
        input_ids[grid_k == q_end] = sep_id
        input_ids[grid_k == a_end] = sep_id
        input_ids[q_input_mask] = question[question_mask]
        input_ids[a_input_mask] = answers[answers_mask]
        text_tags[q_input_mask] = question_tags[question_mask]
        text_tags[a_input_mask] = answers_tags[answers_mask]

        return input_ids, input_type_ids, text_tags, input_mask

    def prepare_text_from_qa_onesent(self, question, question_tags,
                                     question_mask, answers, answers_tags,
                                     answers_mask):
        batch_size, max_q_len = question.shape
        _, num_choices, max_a_len = answers.shape
        max_len = (question_mask.sum(1) +
                   answers_mask.sum(2).max(1)[0]).max() + 2
        cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(
            ['[CLS]', '[SEP]'])
        question = question.repeat(1,
                                   num_choices).view(-1, num_choices,
                                                     max_q_len)
        question_mask = question_mask.repeat(1, num_choices).view(
            -1, num_choices, max_q_len)
        q_end = 1 + question_mask.sum(2, keepdim=True)
        a_end = q_end + answers_mask.sum(2, keepdim=True)
        input_ids = torch.zeros((batch_size, num_choices, max_len),
                                dtype=question.dtype,
                                device=question.device)
        input_mask = torch.ones((batch_size, num_choices, max_len),
                                dtype=torch.uint8,
                                device=question.device)
        input_type_ids = torch.zeros((batch_size, num_choices, max_len),
                                     dtype=question.dtype,
                                     device=question.device)
        text_tags = input_type_ids.new_zeros(
            (batch_size, num_choices, max_len))
        grid_i, grid_j, grid_k = torch.meshgrid(
            torch.arange(batch_size, device=question.device),
            torch.arange(num_choices, device=question.device),
            torch.arange(max_len, device=question.device))

        input_mask[grid_k > a_end] = 0
        q_input_mask = (grid_k > 0) & (grid_k < q_end)
        a_input_mask = (grid_k >= q_end) & (grid_k < a_end)
        input_ids[:, :, 0] = cls_id
        input_ids[grid_k == a_end] = sep_id
        input_ids[q_input_mask] = question[question_mask]
        input_ids[a_input_mask] = answers[answers_mask]
        text_tags[q_input_mask] = question_tags[question_mask]
        text_tags[a_input_mask] = answers_tags[answers_mask]

        return input_ids, input_type_ids, text_tags, input_mask

    def prepare_text_from_aq(self, question, question_tags, question_mask,
                             answers, answers_tags, answers_mask):
        batch_size, max_q_len = question.shape
        _, num_choices, max_a_len = answers.shape
        max_len = (question_mask.sum(1) +
                   answers_mask.sum(2).max(1)[0]).max() + 3
        cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(
            ['[CLS]', '[SEP]'])
        question = question.repeat(1,
                                   num_choices).view(-1, num_choices,
                                                     max_q_len)
        question_mask = question_mask.repeat(1, num_choices).view(
            -1, num_choices, max_q_len)
        a_end = 1 + answers_mask.sum(2, keepdim=True)
        q_end = a_end + 1 + question_mask.sum(2, keepdim=True)
        input_ids = torch.zeros((batch_size, num_choices, max_len),
                                dtype=question.dtype,
                                device=question.device)
        input_mask = torch.ones((batch_size, num_choices, max_len),
                                dtype=torch.uint8,
                                device=question.device)
        input_type_ids = torch.zeros((batch_size, num_choices, max_len),
                                     dtype=question.dtype,
                                     device=question.device)
        text_tags = input_type_ids.new_zeros(
            (batch_size, num_choices, max_len))
        grid_i, grid_j, grid_k = torch.meshgrid(
            torch.arange(batch_size, device=question.device),
            torch.arange(num_choices, device=question.device),
            torch.arange(max_len, device=question.device))

        input_mask[grid_k > q_end] = 0
        input_type_ids[(grid_k > a_end) & (grid_k <= q_end)] = 1
        q_input_mask = (grid_k > a_end) & (grid_k < q_end)
        a_input_mask = (grid_k > 0) & (grid_k < a_end)
        input_ids[:, :, 0] = cls_id
        input_ids[grid_k == a_end] = sep_id
        input_ids[grid_k == q_end] = sep_id
        input_ids[q_input_mask] = question[question_mask]
        input_ids[a_input_mask] = answers[answers_mask]
        text_tags[q_input_mask] = question_tags[question_mask]
        text_tags[a_input_mask] = answers_tags[answers_mask]

        return input_ids, input_type_ids, text_tags, input_mask

    def train_forward(self,
                      image,
                      boxes,
                      masks,
                      question,
                      question_align_matrix,
                      answer_choices,
                      answer_align_matrix,
                      answer_label,
                      im_info,
                      mask_position=None,
                      mask_type=None,
                      mask_label=None):
        ###########################################

        # visual feature extraction
        images = image
        objects = boxes[:, :, -1]
        segms = masks
        boxes = boxes[:, :, :4]
        box_mask = (boxes[:, :, -1] > -0.5)
        max_len = int(box_mask.sum(1).max().item())
        objects = objects[:, :max_len]
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        segms = segms[:, :max_len]

        if self.config.NETWORK.BLIND:
            obj_reps = {
                'obj_reps':
                boxes.new_zeros(
                    (*boxes.shape[:-1], self.config.NETWORK.IMAGE_FINAL_DIM))
            }
        else:
            obj_reps = self.image_feature_extractor(images=images,
                                                    boxes=boxes,
                                                    box_mask=box_mask,
                                                    im_info=im_info,
                                                    classes=objects,
                                                    segms=segms)

        num_choices = answer_choices.shape[1]
        question_ids = question[:, :, 0]
        question_tags = question[:, :, 1]
        question_tags = question_tags.repeat(1, num_choices).view(
            question_tags.shape[0], num_choices, -1)
        question_mask = (question[:, :, 0] > 0.5)
        answer_ids = answer_choices[:, :, :, 0]
        answer_tags = answer_choices[:, :, :, 1]
        answer_mask = (answer_choices[:, :, :, 0] > 0.5)

        ############################################

        # prepare text
        if self.config.NETWORK.ANSWER_FIRST:
            if self.config.NETWORK.QA_ONE_SENT:
                raise NotImplemented
            else:
                text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text_from_aq(
                    question_ids, question_tags, question_mask, answer_ids,
                    answer_tags, answer_mask)
        else:
            if self.config.NETWORK.QA_ONE_SENT:
                text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text_from_qa_onesent(
                    question_ids, question_tags, question_mask, answer_ids,
                    answer_tags, answer_mask)
            else:
                text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text_from_qa(
                    question_ids, question_tags, question_mask, answer_ids,
                    answer_tags, answer_mask)

        if self.config.NETWORK.NO_GROUNDING:
            text_tags.zero_()
        text_visual_embeddings = self._collect_obj_reps(
            text_tags, obj_reps['obj_reps'])
        if self.config.NETWORK.BLIND:
            object_linguistic_embeddings = boxes.new_zeros(
                (*boxes.shape[:-1], self.config.NETWORK.VLBERT.hidden_size))
            object_linguistic_embeddings = object_linguistic_embeddings.unsqueeze(
                1).repeat(1, num_choices, 1, 1)
        else:
            if self.config.NETWORK.VLBERT.object_word_embed_mode in [1, 2]:
                object_linguistic_embeddings = self.object_linguistic_embeddings(
                    objects.long().clamp(min=0,
                                         max=self.object_linguistic_embeddings.
                                         weight.data.shape[0] - 1))
                object_linguistic_embeddings = object_linguistic_embeddings.unsqueeze(
                    1).repeat(1, num_choices, 1, 1)
            elif self.config.NETWORK.VLBERT.object_word_embed_mode == 3:
                cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(
                    ['[CLS]', '[SEP]'])
                global_context_mask = text_mask & (
                    text_input_ids != cls_id) & (text_input_ids != sep_id)
                word_embedding = self.vlbert._module.word_embeddings(
                    text_input_ids)
                word_embedding[global_context_mask == 0] = 0
                object_linguistic_embeddings = word_embedding.sum(
                    dim=2) / global_context_mask.sum(
                        dim=2, keepdim=True).to(dtype=word_embedding.dtype)
                object_linguistic_embeddings = object_linguistic_embeddings.unsqueeze(
                    2).repeat((1, 1, max_len, 1))
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'].unsqueeze(1).repeat(
                1, num_choices, 1, 1), object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT

        if self.config.NETWORK.NO_OBJ_ATTENTION or self.config.NETWORK.BLIND:
            box_mask.zero_()

        hidden_states_text, hidden_states_objects, pooled_rep = self.vlbert(
            text_input_ids,
            text_token_type_ids,
            text_visual_embeddings,
            text_mask,
            object_vl_embeddings,
            box_mask.unsqueeze(1).repeat(1, num_choices, 1),
            output_all_encoded_layers=False,
            output_text_and_object_separately=True)

        ###########################################
        outputs = {}

        # classifier
        logits = self.final_mlp(pooled_rep).squeeze(2)

        # loss
        if self.config.NETWORK.CLASSIFIER_SIGMOID:
            _, choice_ind = torch.meshgrid(
                torch.arange(logits.shape[0], device=logits.device),
                torch.arange(num_choices, device=logits.device))
            label_binary = (choice_ind == answer_label.unsqueeze(1))
            if mask_type is not None and self.config.NETWORK.REPLACE_OBJECT_CHANGE_LABEL:
                label_binary = label_binary * (mask_type != 1).unsqueeze(1)
            weight = logits.new_zeros(logits.shape).fill_(1.0)
            weight[
                label_binary ==
                1] = self.config.NETWORK.CLASSIFIER_SIGMOID_LOSS_POSITIVE_WEIGHT
            rescale = (self.config.NETWORK.CLASSIFIER_SIGMOID_LOSS_POSITIVE_WEIGHT + 1.0) \
                / (2.0 * self.config.NETWORK.CLASSIFIER_SIGMOID_LOSS_POSITIVE_WEIGHT)
            ans_loss = rescale * F.binary_cross_entropy_with_logits(
                logits, label_binary.to(dtype=logits.dtype), weight=weight)
            outputs['positive_fraction'] = label_binary.to(
                dtype=logits.dtype).sum() / label_binary.numel()
        else:
            ans_loss = F.cross_entropy(logits, answer_label.long().view(-1))

        outputs.update({
            'label_logits': logits,
            'label': answer_label.long().view(-1),
            'ans_loss': ans_loss
        })

        loss = ans_loss.mean() * self.config.NETWORK.ANS_LOSS_WEIGHT

        if mask_position is not None:
            assert False, "Todo: align to original position."
            _batch_ind = torch.arange(images.shape[0],
                                      dtype=torch.long,
                                      device=images.device)
            mask_pos_rep = hidden_states[_batch_ind, answer_label,
                                         mask_position]
            mask_pred_logits = (
                obj_reps['obj_reps'] @ mask_pos_rep.unsqueeze(-1)).squeeze(-1)
            mask_pred_logits[1 - box_mask] -= 10000.0
            mask_object_loss = F.cross_entropy(mask_pred_logits,
                                               mask_label,
                                               ignore_index=-1)
            logits_padded = mask_pred_logits.new_zeros(
                (mask_pred_logits.shape[0], origin_len)).fill_(-10000.0)
            logits_padded[:, :mask_pred_logits.shape[1]] = mask_pred_logits
            mask_pred_logits = logits_padded
            outputs.update({
                'mask_object_loss': mask_object_loss,
                'mask_object_logits': mask_pred_logits,
                'mask_object_label': mask_label
            })
            loss = loss + mask_object_loss.mean(
            ) * self.config.NETWORK.MASK_OBJECT_LOSS_WEIGHT

        if self.enable_cnn_reg_loss:
            if not self.cnn_loss_top:
                loss = loss + obj_reps['cnn_regularization_loss'].mean(
                ) * self.config.NETWORK.CNN_LOSS_WEIGHT
                outputs['cnn_regularization_loss'] = obj_reps[
                    'cnn_regularization_loss']
            else:
                objects = objects.unsqueeze(1).repeat(1, num_choices, 1)
                box_mask = box_mask.unsqueeze(1).repeat(1, num_choices, 1)
                cnn_reg_logits = self.cnn_loss_reg(
                    hidden_states_objects[box_mask])
                cnn_reg_loss = F.cross_entropy(cnn_reg_logits,
                                               objects[box_mask].long())
                loss = loss + cnn_reg_loss.mean(
                ) * self.config.NETWORK.CNN_LOSS_WEIGHT
                outputs['cnn_regularization_loss'] = cnn_reg_loss

        return outputs, loss

    def inference_forward(self, image, boxes, masks, question,
                          question_align_matrix, answer_choices,
                          answer_align_matrix, *args):

        if self.for_pretrain:
            answer_label, im_info, mask_position, mask_type = args
        else:
            assert len(args) == 1
            im_info = args[0]

        ###########################################

        # visual feature extraction
        images = image
        objects = boxes[:, :, -1]
        segms = masks
        boxes = boxes[:, :, :4]
        box_mask = (boxes[:, :, -1] > -0.5)
        max_len = int(box_mask.sum(1).max().item())
        objects = objects[:, :max_len]
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        segms = segms[:, :max_len]

        if self.config.NETWORK.BLIND:
            obj_reps = {
                'obj_reps':
                boxes.new_zeros(
                    (*boxes.shape[:-1], self.config.NETWORK.IMAGE_FINAL_DIM))
            }
        else:
            obj_reps = self.image_feature_extractor(images=images,
                                                    boxes=boxes,
                                                    box_mask=box_mask,
                                                    im_info=im_info,
                                                    classes=objects,
                                                    segms=segms)

        num_choices = answer_choices.shape[1]
        question_ids = question[:, :, 0]
        question_tags = question[:, :, 1]
        question_tags = question_tags.repeat(1, num_choices).view(
            question_tags.shape[0], num_choices, -1)
        question_mask = (question[:, :, 0] > 0.5)
        answer_ids = answer_choices[:, :, :, 0]
        answer_tags = answer_choices[:, :, :, 1]
        answer_mask = (answer_choices[:, :, :, 0] > 0.5)

        ############################################

        # prepare text
        if self.config.NETWORK.ANSWER_FIRST:
            if self.config.NETWORK.QA_ONE_SENT:
                raise NotImplemented
            else:
                text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text_from_aq(
                    question_ids, question_tags, question_mask, answer_ids,
                    answer_tags, answer_mask)
        else:
            if self.config.NETWORK.QA_ONE_SENT:
                text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text_from_qa_onesent(
                    question_ids, question_tags, question_mask, answer_ids,
                    answer_tags, answer_mask)
            else:
                text_input_ids, text_token_type_ids, text_tags, text_mask = self.prepare_text_from_qa(
                    question_ids, question_tags, question_mask, answer_ids,
                    answer_tags, answer_mask)

        if self.config.NETWORK.NO_GROUNDING:
            text_tags.zero_()
        text_visual_embeddings = self._collect_obj_reps(
            text_tags, obj_reps['obj_reps'])
        if self.config.NETWORK.BLIND:
            object_linguistic_embeddings = boxes.new_zeros(
                (*boxes.shape[:-1], self.config.NETWORK.VLBERT.hidden_size))
            object_linguistic_embeddings = object_linguistic_embeddings.unsqueeze(
                1).repeat(1, num_choices, 1, 1)
        else:
            if self.config.NETWORK.VLBERT.object_word_embed_mode in [1, 2]:
                object_linguistic_embeddings = self.object_linguistic_embeddings(
                    objects.long().clamp(min=0,
                                         max=self.object_linguistic_embeddings.
                                         weight.data.shape[0] - 1))
                object_linguistic_embeddings = object_linguistic_embeddings.unsqueeze(
                    1).repeat(1, num_choices, 1, 1)
            elif self.config.NETWORK.VLBERT.object_word_embed_mode == 3:
                cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(
                    ['[CLS]', '[SEP]'])
                global_context_mask = text_mask & (
                    text_input_ids != cls_id) & (text_input_ids != sep_id)
                word_embedding = self.vlbert._module.word_embeddings(
                    text_input_ids)
                word_embedding[global_context_mask == 0] = 0
                object_linguistic_embeddings = word_embedding.sum(
                    dim=2) / global_context_mask.sum(
                        dim=2, keepdim=True).to(dtype=word_embedding.dtype)
                object_linguistic_embeddings = object_linguistic_embeddings.unsqueeze(
                    2).repeat((1, 1, max_len, 1))
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'].unsqueeze(1).repeat(
                1, num_choices, 1, 1), object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT

        if self.config.NETWORK.NO_OBJ_ATTENTION or self.config.NETWORK.BLIND:
            box_mask.zero_()

        hidden_states_text, hidden_states_objects, pooled_rep = self.vlbert(
            text_input_ids,
            text_token_type_ids,
            text_visual_embeddings,
            text_mask,
            object_vl_embeddings,
            box_mask.unsqueeze(1).repeat(1, num_choices, 1),
            output_all_encoded_layers=False,
            output_text_and_object_separately=True)

        ###########################################

        # classifier
        logits = self.final_mlp(pooled_rep).squeeze(2)

        outputs = {'label_logits': logits}

        return outputs
Ejemplo n.º 21
0
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.predict_on_cls = config.NETWORK.VLBERT.predict_on_cls  # make prediction on [CLS]?

        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(
                config,
                average_pool=True,
                final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                enable_cnn_reg_loss=self.enable_cnn_reg_loss)
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(
                    81, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:  # default: class-agnostic
                self.object_linguistic_embeddings = nn.Embedding(
                    1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        dim = config.NETWORK.VLBERT.hidden_size
        if config.NETWORK.CLASSIFIER_TYPE == "2fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE),
                torch.nn.ReLU(inplace=True),
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE,
                                config.DATASET.ANSWER_VOCAB_SIZE),
            )
        elif config.NETWORK.CLASSIFIER_TYPE == "1fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, config.DATASET.ANSWER_VOCAB_SIZE))
        elif config.NETWORK.CLASSIFIER_TYPE == 'mlm':
            transform = BertPredictionHeadTransform(config.NETWORK.VLBERT)
            linear = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                               config.DATASET.ANSWER_VOCAB_SIZE)
            self.final_mlp = nn.Sequential(
                transform,
                nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                linear)
        else:
            raise ValueError("Not support classifier type: {}!".format(
                config.NETWORK.CLASSIFIER_TYPE))

        self.use_spatial_model = False
        if config.NETWORK.USE_SPATIAL_MODEL:
            self.use_spatial_model = True
            # self.simple_spatial_model = SimpleSpatialModel(4, config.NETWORK.VLBERT.hidden_size, 9, config)

            self.use_coord_vector = False
            if config.NETWORK.USE_COORD_VECTOR:
                self.use_coord_vector = True
                self.loc_fcs = nn.Sequential(
                    nn.Linear(2 * 5 + 9, config.NETWORK.VLBERT.hidden_size),
                    nn.ReLU(True),
                    nn.Linear(config.NETWORK.VLBERT.hidden_size,
                              config.NETWORK.VLBERT.hidden_size))
            else:
                self.simple_spatial_model = SimpleSpatialModel(
                    4, config.NETWORK.VLBERT.hidden_size, 9)

            self.spa_add = True if config.NETWORK.SPA_ADD else False
            self.spa_concat = True if config.NETWORK.SPA_CONCAT else False

            if self.spa_add:
                self.spa_feat_weight = 0.5
                if config.NETWORK.USE_SPA_WEIGHT:
                    self.spa_feat_weight = config.NETWORK.SPA_FEAT_WEIGHT
                self.spa_fusion_linear = nn.Linear(
                    config.NETWORK.VLBERT.hidden_size,
                    config.NETWORK.VLBERT.hidden_size)
            elif self.spa_concat:
                if self.use_coord_vector:
                    self.spa_fusion_linear = nn.Linear(
                        config.NETWORK.VLBERT.hidden_size +
                        config.NETWORK.VLBERT.hidden_size,
                        config.NETWORK.VLBERT.hidden_size)
                else:
                    self.spa_fusion_linear = nn.Linear(
                        config.NETWORK.VLBERT.hidden_size * 2,
                        config.NETWORK.VLBERT.hidden_size)
            self.spa_linear = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                                        config.NETWORK.VLBERT.hidden_size)
            self.dropout = nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT)

            self.spa_one_more_layer = config.NETWORK.SPA_ONE_MORE_LAYER
            if self.spa_one_more_layer:
                self.spa_linear_hidden = nn.Linear(
                    config.NETWORK.VLBERT.hidden_size,
                    config.NETWORK.VLBERT.hidden_size)

        self.enhanced_img_feature = False
        if config.NETWORK.VLBERT.ENHANCED_IMG_FEATURE:
            self.enhanced_img_feature = True
            self.mask_weight = config.NETWORK.VLBERT.mask_weight
            self.mask_loss_sum = config.NETWORK.VLBERT.mask_loss_sum
            self.mask_loss_mse = config.NETWORK.VLBERT.mask_loss_mse
            self.no_predicate = config.NETWORK.VLBERT.NO_PREDICATE

        self.all_proposals_test = False
        if config.DATASET.ALL_PROPOSALS_TEST:
            self.all_proposals_test = True

        self.use_uvtranse = False
        if config.NETWORK.USE_UVTRANSE:
            self.use_uvtranse = True
            self.union_vec_fc = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                                          config.NETWORK.VLBERT.hidden_size)
            self.uvt_add = True if config.NETWORK.UVT_ADD else False
            self.uvt_concat = True if config.NETWORK.UVT_CONCAT else False
            if not (self.uvt_add ^ self.uvt_concat):
                assert False
            if self.uvt_add:
                self.uvt_feat_weight = config.NETWORK.UVT_FEAT_WEIGHT
                self.uvt_fusion_linear = nn.Linear(
                    config.NETWORK.VLBERT.hidden_size,
                    config.NETWORK.VLBERT.hidden_size)
            elif self.uvt_concat:
                self.uvt_fusion_linear = nn.Linear(
                    config.NETWORK.VLBERT.hidden_size * 2,
                    config.NETWORK.VLBERT.hidden_size)
            self.uvt_linear = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                                        config.NETWORK.VLBERT.hidden_size)
            self.dropout_uvt = nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT)

        # init weights
        self.init_weight()
class ResNetVLBERTForPretrainingMultitask(Module):
    def __init__(self, config):

        super(ResNetVLBERTForPretrainingMultitask, self).__init__(config)

        self.image_feature_extractor = FastRCNN(config,
                                                average_pool=True,
                                                final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                                                enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(1, config.NETWORK.VLBERT.hidden_size)
        if config.NETWORK.IMAGE_FEAT_PRECOMPUTED or (not config.NETWORK.MASK_RAW_PIXELS):
            self.object_mask_visual_embedding = nn.Embedding(1, 2048)
        if config.NETWORK.WITH_MVRC_LOSS:
            self.object_mask_word_embedding = nn.Embedding(1, config.NETWORK.VLBERT.hidden_size)
        self.aux_text_visual_embedding = nn.Embedding(1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(config.NETWORK.BERT_MODEL_NAME)
        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(config.NETWORK.BERT_PRETRAINED,
                                                                      config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path

        if language_pretrained_model_path is None:
            print("Warning: no pretrained language model found, training from scratch!!!")

        self.vlbert = VisualLinguisticBertForPretraining(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=None if config.NETWORK.VLBERT.from_scratch else language_pretrained_model_path,
            with_rel_head=config.NETWORK.WITH_REL_LOSS,
            with_mlm_head=config.NETWORK.WITH_MLM_LOSS,
            with_mvrc_head=config.NETWORK.WITH_MVRC_LOSS,
        )

        # init weights
        self.init_weight()

        self.fix_params()

    def init_weight(self):
        if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED or (not self.config.NETWORK.MASK_RAW_PIXELS):
            self.object_mask_visual_embedding.weight.data.fill_(0.0)
        if self.config.NETWORK.WITH_MVRC_LOSS:
            self.object_mask_word_embedding.weight.data.normal_(mean=0.0,
                                                                std=self.config.NETWORK.VLBERT.initializer_range)
        self.aux_text_visual_embedding.weight.data.normal_(mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)
        self.image_feature_extractor.init_weight()
        if self.object_linguistic_embeddings is not None:
            self.object_linguistic_embeddings.weight.data.normal_(mean=0.0,
                                                                  std=self.config.NETWORK.VLBERT.initializer_range)

    def train(self, mode=True):
        super(ResNetVLBERTForPretrainingMultitask, self).train(mode)
        # turn some frozen layers to eval mode
        if self.image_feature_bn_eval:
            self.image_feature_extractor.bn_eval()

    def fix_params(self):
        pass

    def _collect_obj_reps(self, span_tags, object_reps):
        """
        Collect span-level object representations
        :param span_tags: [batch_size, ..leading_dims.., L]
        :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim]
        :return:
        """

        span_tags_fixed = torch.clamp(span_tags, min=0)  # In case there were masked values here
        row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape)
        row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None]

        # Add extra diminsions to the row broadcaster so it matches row_id
        leading_dims = len(span_tags.shape) - 2
        for i in range(leading_dims):
            row_id_broadcaster = row_id_broadcaster[..., None]
        row_id += row_id_broadcaster
        return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view(*span_tags_fixed.shape, -1)

    def forward(self,
                image,
                boxes,
                im_info,
                text,
                relationship_label,
                mlm_labels,
                mvrc_ops,
                mvrc_labels,
                *aux):

        # concat aux texts from different dataset
        assert len(aux) > 0 and len(aux) % 2 == 0
        aux_text_list = aux[0::2]
        aux_text_mlm_labels_list = aux[1::2]
        num_aux_text = sum([_text.shape[0] for _text in aux_text_list])
        max_aux_text_len = max([_text.shape[1] for _text in aux_text_list])
        aux_text = aux_text_list[0].new_zeros((num_aux_text, max_aux_text_len))
        aux_text_mlm_labels = aux_text_mlm_labels_list[0].new_zeros((num_aux_text, max_aux_text_len)).fill_(-1)
        _cur = 0
        for _text, _mlm_labels in zip(aux_text_list, aux_text_mlm_labels_list):
            _num = _text.shape[0]
            aux_text[_cur:(_cur + _num), :_text.shape[1]] = _text
            aux_text_mlm_labels[_cur:(_cur + _num), :_text.shape[1]] = _mlm_labels
            _cur += _num

        ###########################################

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        origin_len = boxes.shape[1]
        max_len = int(box_mask.sum(1).max().item())
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        mvrc_ops = mvrc_ops[:, :max_len]
        mvrc_labels = mvrc_labels[:, :max_len]

        if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED:
            box_features = boxes[:, :, 4:]
            box_features[mvrc_ops == 1] = self.object_mask_visual_embedding.weight[0]
            boxes[:, :, 4:] = box_features

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=boxes,
                                                box_mask=box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None,
                                                mvrc_ops=mvrc_ops,
                                                mask_visual_embed=self.object_mask_visual_embedding.weight[0]
                                                if (not self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED)
                                                   and (not self.config.NETWORK.MASK_RAW_PIXELS)
                                                else None)

        ############################################

        # prepare text
        text_input_ids = text
        text_tags = text.new_zeros(text.shape)
        text_visual_embeddings = self._collect_obj_reps(text_tags, obj_reps['obj_reps'])

        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()
        )
        if self.config.NETWORK.WITH_MVRC_LOSS:
            object_linguistic_embeddings[mvrc_ops == 1] = self.object_mask_word_embedding.weight[0]
        object_vl_embeddings = torch.cat((obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        # add auxiliary text
        max_text_len = max(text_input_ids.shape[1], aux_text.shape[1])
        text_input_ids_multi = text_input_ids.new_zeros((text_input_ids.shape[0] + aux_text.shape[0], max_text_len))
        text_input_ids_multi[:text_input_ids.shape[0], :text_input_ids.shape[1]] = text_input_ids
        text_input_ids_multi[text_input_ids.shape[0]:, :aux_text.shape[1]] = aux_text
        text_token_type_ids_multi = text_input_ids_multi.new_zeros(text_input_ids_multi.shape)
        text_mask_multi = (text_input_ids_multi > 0)
        text_visual_embeddings_multi = text_visual_embeddings.new_zeros((text_input_ids.shape[0] + aux_text.shape[0],
                                                                         max_text_len,
                                                                         text_visual_embeddings.shape[-1]))
        text_visual_embeddings_multi[:text_visual_embeddings.shape[0], :text_visual_embeddings.shape[1]] \
            = text_visual_embeddings
        text_visual_embeddings_multi[text_visual_embeddings.shape[0]:] = self.aux_text_visual_embedding.weight[0]
        object_vl_embeddings_multi = object_vl_embeddings.new_zeros((text_input_ids.shape[0] + aux_text.shape[0],
                                                                     *object_vl_embeddings.shape[1:]))
        object_vl_embeddings_multi[:object_vl_embeddings.shape[0]] = object_vl_embeddings
        box_mask_multi = box_mask.new_zeros((text_input_ids.shape[0] + aux_text.shape[0], *box_mask.shape[1:]))
        box_mask_multi[:box_mask.shape[0]] = box_mask

        ###########################################

        # Visual Linguistic BERT

        relationship_logits_multi, mlm_logits_multi, mvrc_logits_multi = self.vlbert(text_input_ids_multi,
                                                                                     text_token_type_ids_multi,
                                                                                     text_visual_embeddings_multi,
                                                                                     text_mask_multi,
                                                                                     object_vl_embeddings_multi,
                                                                                     box_mask_multi)

        ###########################################
        outputs = {}

        # loss
        relationship_loss = im_info.new_zeros(())
        mlm_loss = im_info.new_zeros(())
        mvrc_loss = im_info.new_zeros(())
        if self.config.NETWORK.WITH_REL_LOSS:
            relationship_logits = relationship_logits_multi[:text_input_ids.shape[0]]
            relationship_loss = F.cross_entropy(relationship_logits, relationship_label)
        if self.config.NETWORK.WITH_MLM_LOSS:
            mlm_labels_multi = mlm_labels.new_zeros((text_input_ids.shape[0] + aux_text.shape[0], max_text_len)).fill_(
                -1)
            mlm_labels_multi[:text_input_ids.shape[0], :mlm_labels.shape[1]] = mlm_labels
            mlm_labels_multi[text_input_ids.shape[0]:, :aux_text_mlm_labels.shape[1]] = aux_text_mlm_labels

            mlm_logits_multi_padded = \
                mlm_logits_multi.new_zeros((*mlm_labels_multi.shape, mlm_logits_multi.shape[-1])).fill_(-10000.0)
            mlm_logits_multi_padded[:, :mlm_logits_multi.shape[1]] = mlm_logits_multi
            mlm_logits_multi = mlm_logits_multi_padded
            mlm_logits_wvc = mlm_logits_multi_padded[:text_input_ids.shape[0]]
            mlm_labels_wvc = mlm_labels_multi[:text_input_ids.shape[0]]
            mlm_logits_aux = mlm_logits_multi_padded[text_input_ids.shape[0]:]
            mlm_labels_aux = mlm_labels_multi[text_input_ids.shape[0]:]
            if self.config.NETWORK.MLM_LOSS_NORM_IN_BATCH_FIRST:
                mlm_loss_wvc = F.cross_entropy(mlm_logits_wvc.transpose(1, 2),
                                               mlm_labels_wvc,
                                               ignore_index=-1, reduction='none')
                num_mlm_wvc = (mlm_labels_wvc != -1).sum(1, keepdim=True).to(dtype=mlm_loss_wvc.dtype)
                num_has_mlm_wvc = (num_mlm_wvc != 0).sum().to(dtype=mlm_loss_wvc.dtype)
                mlm_loss_wvc = (mlm_loss_wvc / (num_mlm_wvc + 1e-4)).sum() / (num_has_mlm_wvc + 1e-4)
                mlm_loss_aux = F.cross_entropy(mlm_logits_aux.transpose(1, 2),
                                               mlm_labels_aux,
                                               ignore_index=-1, reduction='none')
                num_mlm_aux = (mlm_labels_aux != -1).sum(1, keepdim=True).to(dtype=mlm_loss_aux.dtype)
                num_has_mlm_aux = (num_mlm_aux != 0).sum().to(dtype=mlm_loss_aux.dtype)
                mlm_loss_aux = (mlm_loss_aux / (num_mlm_aux + 1e-4)).sum() / (num_has_mlm_aux + 1e-4)
            else:
                # mlm_loss = F.cross_entropy(mlm_logits_multi_padded.view((-1, mlm_logits_multi_padded.shape[-1])),
                #                            mlm_labels_multi.view(-1),
                #                            ignore_index=-1)
                mlm_loss_wvc = F.cross_entropy(
                    mlm_logits_wvc.view((-1, mlm_logits_multi_padded.shape[-1])),
                    mlm_labels_wvc.view(-1),
                    ignore_index=-1
                )
                mlm_loss_aux = F.cross_entropy(
                    mlm_logits_aux.view((-1, mlm_logits_multi_padded.shape[-1])),
                    mlm_labels_aux.view(-1),
                    ignore_index=-1
                )

        # mvrc_loss = F.cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
        #                             mvrc_labels.contiguous().view(-1),
        #                             ignore_index=-1)
        if self.config.NETWORK.WITH_MVRC_LOSS:
            mvrc_logits = mvrc_logits_multi[:mvrc_labels.shape[0], :mvrc_labels.shape[1]]
            if self.config.NETWORK.MVRC_LOSS_NORM_IN_BATCH_FIRST:
                mvrc_loss = soft_cross_entropy(
                    mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                    mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]),
                    reduction='none').view(mvrc_logits.shape[:-1])
                valid = (mvrc_labels.sum(-1) - 1).abs() < 1.0e-1
                mvrc_loss = (mvrc_loss / (valid.sum(1, keepdim=True).to(dtype=mvrc_loss.dtype) + 1e-4)) \
                                .sum() / ((valid.sum(1) != 0).sum().to(dtype=mvrc_loss.dtype) + 1e-4)
            else:
                mvrc_loss = soft_cross_entropy(mvrc_logits.contiguous().view(-1, mvrc_logits.shape[-1]),
                                               mvrc_labels.contiguous().view(-1, mvrc_logits.shape[-1]))

            mvrc_logits_padded = mvrc_logits.new_zeros((mvrc_logits.shape[0], origin_len, mvrc_logits.shape[2])).fill_(
                -10000.0)
            mvrc_logits_padded[:, :mvrc_logits.shape[1]] = mvrc_logits
            mvrc_logits = mvrc_logits_padded
            mvrc_labels_padded = mvrc_labels.new_zeros((mvrc_labels.shape[0], origin_len, mvrc_labels.shape[2])).fill_(
                0.0)
            mvrc_labels_padded[:, :mvrc_labels.shape[1]] = mvrc_labels
            mvrc_labels = mvrc_labels_padded

        outputs.update({
            'relationship_logits': relationship_logits if self.config.NETWORK.WITH_REL_LOSS else None,
            'relationship_label': relationship_label if self.config.NETWORK.WITH_REL_LOSS else None,
            'mlm_logits_wvc': mlm_logits_wvc if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_label_wvc': mlm_labels_wvc if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_logits_aux': mlm_logits_aux if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mlm_label_aux': mlm_labels_aux if self.config.NETWORK.WITH_MLM_LOSS else None,
            'mvrc_logits': mvrc_logits if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'mvrc_label': mvrc_labels if self.config.NETWORK.WITH_MVRC_LOSS else None,
            'relationship_loss': relationship_loss,
            'mlm_loss_wvc': mlm_loss_wvc,
            'mlm_loss_aux': mlm_loss_aux,
            'mvrc_loss': mvrc_loss,
        })

        loss = relationship_loss.mean() + mlm_loss_wvc.mean() + mlm_loss_aux.mean() + mvrc_loss.mean()

        return outputs, loss
Ejemplo n.º 23
0
class ResNetVLBERTForAttentionVis(Module):
    def __init__(self, config):

        super(ResNetVLBERTForAttentionVis, self).__init__(config)

        self.image_feature_extractor = FastRCNN(
            config,
            average_pool=True,
            final_dim=config.NETWORK.IMAGE_FINAL_DIM,
            enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        if config.NETWORK.IMAGE_FEAT_PRECOMPUTED or (
                not config.NETWORK.MASK_RAW_PIXELS):
            self.object_mask_visual_embedding = nn.Embedding(1, 2048)
        if config.NETWORK.WITH_MVRC_LOSS:
            self.object_mask_word_embedding = nn.Embedding(
                1, config.NETWORK.VLBERT.hidden_size)
        self.aux_text_visual_embedding = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)
        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path

        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=None
            if config.NETWORK.VLBERT.from_scratch else
            language_pretrained_model_path)

        # init weights
        self.init_weight()

        self.fix_params()

    def init_weight(self):
        if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED or (
                not self.config.NETWORK.MASK_RAW_PIXELS):
            self.object_mask_visual_embedding.weight.data.fill_(0.0)
        if self.config.NETWORK.WITH_MVRC_LOSS:
            self.object_mask_word_embedding.weight.data.normal_(
                mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)
        self.aux_text_visual_embedding.weight.data.normal_(
            mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)
        self.image_feature_extractor.init_weight()
        if self.object_linguistic_embeddings is not None:
            self.object_linguistic_embeddings.weight.data.normal_(
                mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)

    def train(self, mode=True):
        super(ResNetVLBERTForAttentionVis, self).train(mode)
        # turn some frozen layers to eval mode
        if self.image_feature_bn_eval:
            self.image_feature_extractor.bn_eval()

    def fix_params(self):
        pass

    def _collect_obj_reps(self, span_tags, object_reps):
        """
        Collect span-level object representations
        :param span_tags: [batch_size, ..leading_dims.., L]
        :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim]
        :return:
        """

        span_tags_fixed = torch.clamp(
            span_tags, min=0)  # In case there were masked values here
        row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape)
        row_id_broadcaster = torch.arange(0,
                                          row_id.shape[0],
                                          step=1,
                                          device=row_id.device)[:, None]

        # Add extra diminsions to the row broadcaster so it matches row_id
        leading_dims = len(span_tags.shape) - 2
        for i in range(leading_dims):
            row_id_broadcaster = row_id_broadcaster[..., None]
        row_id += row_id_broadcaster
        return object_reps[row_id.view(-1),
                           span_tags_fixed.view(-1)].view(
                               *span_tags_fixed.shape, -1)

    def forward(self, image, boxes, im_info, text, relationship_label,
                mlm_labels, mvrc_ops, mvrc_labels, *aux):

        # concat aux texts from different dataset
        # assert len(aux) > 0 and len(aux) % 2 == 0
        aux_text_list = aux[0::2]
        aux_text_mlm_labels_list = aux[1::2]
        num_aux_text = sum([_text.shape[0] for _text in aux_text_list])
        max_aux_text_len = max([_text.shape[1] for _text in aux_text_list
                                ]) if len(aux_text_list) > 0 else 0
        aux_text = text.new_zeros((num_aux_text, max_aux_text_len))
        aux_text_mlm_labels = mlm_labels.new_zeros(
            (num_aux_text, max_aux_text_len)).fill_(-1)
        _cur = 0
        for _text, _mlm_labels in zip(aux_text_list, aux_text_mlm_labels_list):
            _num = _text.shape[0]
            aux_text[_cur:(_cur + _num), :_text.shape[1]] = _text
            aux_text_mlm_labels[_cur:(_cur +
                                      _num), :_text.shape[1]] = _mlm_labels
            _cur += _num

        ###########################################

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        origin_len = boxes.shape[1]
        max_len = int(box_mask.sum(1).max().item())
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        mvrc_ops = mvrc_ops[:, :max_len]
        mvrc_labels = mvrc_labels[:, :max_len]

        if self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED:
            box_features = boxes[:, :, 4:]
            box_features[mvrc_ops ==
                         1] = self.object_mask_visual_embedding.weight[0]
            boxes[:, :, 4:] = box_features

        obj_reps = self.image_feature_extractor(
            images=images,
            boxes=boxes,
            box_mask=box_mask,
            im_info=im_info,
            classes=None,
            segms=None,
            mvrc_ops=mvrc_ops,
            mask_visual_embed=self.object_mask_visual_embedding.weight[0] if
            (not self.config.NETWORK.IMAGE_FEAT_PRECOMPUTED) and
            (not self.config.NETWORK.MASK_RAW_PIXELS) else None)

        ############################################

        # prepare text
        text_input_ids = text
        text_tags = text.new_zeros(text.shape)
        text_visual_embeddings = self._collect_obj_reps(
            text_tags, obj_reps['obj_reps'])

        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        if self.config.NETWORK.WITH_MVRC_LOSS:
            object_linguistic_embeddings[
                mvrc_ops == 1] = self.object_mask_word_embedding.weight[0]
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        # add auxiliary text
        max_text_len = max(text_input_ids.shape[1], aux_text.shape[1])
        text_input_ids_multi = text_input_ids.new_zeros(
            (text_input_ids.shape[0] + aux_text.shape[0], max_text_len))
        text_input_ids_multi[:text_input_ids.shape[0], :text_input_ids.
                             shape[1]] = text_input_ids
        text_input_ids_multi[
            text_input_ids.shape[0]:, :aux_text.shape[1]] = aux_text
        text_token_type_ids_multi = text_input_ids_multi.new_zeros(
            text_input_ids_multi.shape)
        text_mask_multi = (text_input_ids_multi > 0)
        text_visual_embeddings_multi = text_visual_embeddings.new_zeros(
            (text_input_ids.shape[0] + aux_text.shape[0], max_text_len,
             text_visual_embeddings.shape[-1]))
        text_visual_embeddings_multi[:text_visual_embeddings.shape[0], :text_visual_embeddings.shape[1]] \
            = text_visual_embeddings
        text_visual_embeddings_multi[
            text_visual_embeddings.
            shape[0]:] = self.aux_text_visual_embedding.weight[0]
        object_vl_embeddings_multi = object_vl_embeddings.new_zeros(
            (text_input_ids.shape[0] + aux_text.shape[0],
             *object_vl_embeddings.shape[1:]))
        object_vl_embeddings_multi[:object_vl_embeddings.
                                   shape[0]] = object_vl_embeddings
        box_mask_multi = box_mask.new_zeros(
            (text_input_ids.shape[0] + aux_text.shape[0], *box_mask.shape[1:]))
        box_mask_multi[:box_mask.shape[0]] = box_mask

        ###########################################

        # Visual Linguistic BERT

        encoder_layers, _, attention_probs = self.vlbert(
            text_input_ids_multi,
            text_token_type_ids_multi,
            text_visual_embeddings_multi,
            text_mask_multi,
            object_vl_embeddings_multi,
            box_mask_multi,
            output_all_encoded_layers=True,
            output_attention_probs=True)
        hidden_states = torch.stack(encoder_layers,
                                    dim=0).transpose(0, 1).contiguous()
        attention_probs = torch.stack(attention_probs,
                                      dim=0).transpose(0, 1).contiguous()

        return {
            'attention_probs': attention_probs,
            'hidden_states': hidden_states
        }
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(config,
                                                    average_pool=True,
                                                    final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                                                    enable_cnn_reg_loss=self.enable_cnn_reg_loss)
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(81, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:
                self.object_linguistic_embeddings = nn.Embedding(1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        self.tokenizer = BertTokenizer.from_pretrained(config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(config.NETWORK.BERT_PRETRAINED,
                                                                      config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print("Warning: no pretrained language model found, training from scratch!!!")

        # Also pass the finetuning strategy
        self.vlbert = VisualLinguisticBert(config.NETWORK.VLBERT,
                                         language_pretrained_model_path=language_pretrained_model_path, finetune_strategy=config.FINETUNE_STRATEGY)

        # self.hm_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)
        # self.hi_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)

        dim = config.NETWORK.VLBERT.hidden_size
        if config.NETWORK.CLASSIFIER_TYPE == "2fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE),
                torch.nn.ReLU(inplace=True),
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, config.DATASET.ANSWER_VOCAB_SIZE),
            )
        elif config.NETWORK.CLASSIFIER_TYPE == "1fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                torch.nn.Linear(dim, config.DATASET.ANSWER_VOCAB_SIZE)
            )
        elif config.NETWORK.CLASSIFIER_TYPE == 'mlm':
            transform = BertPredictionHeadTransform(config.NETWORK.VLBERT)
            linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.DATASET.ANSWER_VOCAB_SIZE)
            self.final_mlp = nn.Sequential(
                transform,
                nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                linear
            )
        else:
            raise ValueError("Not support classifier type: {}!".format(config.NETWORK.CLASSIFIER_TYPE))

        # init weights
        self.init_weight()

        self.fix_params()
Ejemplo n.º 25
0
class ResNetVLBERT(Module):
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)
        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        self.cnn_loss_top = config.NETWORK.CNN_LOSS_TOP
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(config,
                                                    average_pool=True,
                                                    final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                                                    enable_cnn_reg_loss=(self.enable_cnn_reg_loss and not self.cnn_loss_top))
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(81, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:
                self.object_linguistic_embeddings = nn.Embedding(1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError

        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        if 'roberta' in config.NETWORK.BERT_MODEL_NAME:
            self.tokenizer = RobertaTokenizer.from_pretrained(config.NETWORK.BERT_MODEL_NAME)
        else:
            self.tokenizer = BertTokenizer.from_pretrained(config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(config.NETWORK.BERT_PRETRAINED,
                                                                      config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path

        if language_pretrained_model_path is None:
            print("Warning: no pretrained language model found, training from scratch!!!")

        self.vlbert = VisualLinguisticBert(config.NETWORK.VLBERT,
                                           language_pretrained_model_path=language_pretrained_model_path)
        
        self.for_pretrain = False
        dim = config.NETWORK.VLBERT.hidden_size
        if config.NETWORK.SENTENCE.CLASSIFIER_TYPE == "2fc":
            self.sentence_cls = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT, inplace=False),
                torch.nn.Linear(dim, config.NETWORK.SENTENCE.CLASSIFIER_HIDDEN_SIZE),
                torch.nn.ReLU(inplace=True),
                torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT, inplace=False),
                torch.nn.Linear(config.NETWORK.SENTENCE.CLASSIFIER_HIDDEN_SIZE, 3),
            )
        elif config.NETWORK.SENTENCE.CLASSIFIER_TYPE == "1fc":
            self.sentence_cls = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT, inplace=False),
                torch.nn.Linear(dim, 3)
            )
        else:
            raise ValueError("Classifier type: {} not supported!".format(config.NETWORK.SENTENCE.CLASSIFIER_TYPE))

        # init weights
        self.init_weight()

        self.fix_params()

    def init_weight(self):
        if not self.config.NETWORK.BLIND:
            self.image_feature_extractor.init_weight()
            if self.object_linguistic_embeddings is not None:
                self.object_linguistic_embeddings.weight.data.normal_(mean=0.0, std=0.02)

        if not self.for_pretrain:
            for m in self.sentence_cls.modules():
                if isinstance(m, torch.nn.Linear):
                    torch.nn.init.xavier_uniform_(m.weight)
                    torch.nn.init.constant_(m.bias, 0)

    def train(self, mode=True):
        super(ResNetVLBERT, self).train(mode)
        # turn some frozen layers to eval mode
        if (not self.config.NETWORK.BLIND) and self.image_feature_bn_eval:
            self.image_feature_extractor.bn_eval()

    def fix_params(self):
        if self.config.NETWORK.BLIND:
            self.vlbert._module.visual_scale_text.requires_grad = False
            self.vlbert._module.visual_scale_object.requires_grad = False

    def _collect_obj_reps(self, span_tags, object_reps):
        """
        Collect span-level object representations
        :param span_tags: [batch_size, ..leading_dims.., L]
        :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim]
        :return:
        """

        span_tags_fixed = torch.clamp(span_tags, min=0)  # In case there were masked values here
        row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape)
        row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None]

        # Add extra dimensions to the row broadcaster so it matches row_id
        leading_dims = len(span_tags.shape) - 2
        for i in range(leading_dims):
            row_id_broadcaster = row_id_broadcaster[..., None]
        row_id += row_id_broadcaster
        return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view(*span_tags_fixed.shape, -1)

    def prepare_text(self, sentence, mask):
        batch_size, max_len = sentence.shape
        cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(['[CLS]', '[SEP]'])
        sep_pos = 1 + mask.sum(1, keepdim=True)
        input_ids = torch.zeros((batch_size, max_len + 2), dtype=sentence.dtype, device=sentence.device)
        input_ids[:, 0] = cls_id
        _batch_inds = torch.arange(sentence.shape[0], device=sentence.device)
        input_ids[_batch_inds, sep_pos] = sep_id
        input_ids[:, 1:-1] = sentence
        input_mask = input_ids > 0
        return input_ids, input_mask

    def train_forward(self,
                      images,
                      boxes,
                      hypothesis,
                      im_info,
                      label):
        ###########################################
        # visual feature extraction

        # Don't know what segments are for
        # segms = masks
        
        box_mask = (boxes[:, :, -1] > - 0.5)
        max_len = int(box_mask.sum(1).max().item())

        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len].type(torch.float32)

        # segms = segms[:, :max_len]
        if self.config.NETWORK.BLIND:
            obj_reps = {'obj_reps': boxes.new_zeros((*boxes.shape[:-1], self.config.NETWORK.IMAGE_FINAL_DIM))}
        else:
            obj_reps = self.image_feature_extractor(images=images,
                                                    boxes=boxes,
                                                    box_mask=box_mask,
                                                    im_info=im_info,
                                                    classes=None,
                                                    segms=None)

        # For now no tags
        mask = (hypothesis > 0.5)
        sentence_label = label.view(-1)


        ############################################
        
        # prepare text
        text_input_ids, text_mask = self.prepare_text(hypothesis, mask)
        text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape)

        # Add visual feature to text elements
        text_visual_embeddings = self._collect_obj_reps(text_input_ids.new_zeros(text_input_ids.size()),
                                                        obj_reps['obj_reps'])
        # Add textual feature to image element
        if self.config.NETWORK.BLIND:
            object_linguistic_embeddings = boxes.new_zeros((*boxes.shape[:-1], self.config.NETWORK.VLBERT.hidden_size))
        else:
            object_linguistic_embeddings = self.object_linguistic_embeddings(
                boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        object_vl_embeddings = torch.cat((obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT
        if self.config.NETWORK.NO_OBJ_ATTENTION or self.config.NETWORK.BLIND:
            box_mask.zero_()

        _, pooled_rep = self.vlbert(text_input_ids,
                                    text_token_type_ids,
                                    text_visual_embeddings,
                                    text_mask,
                                    object_vl_embeddings,
                                    box_mask,
                                    output_all_encoded_layers=False,
                                    output_text_and_object_separately=False,
                                    output_attention_probs=False)

        ###########################################
        outputs = {}
        
        # sentence classification
        sentence_logits = self.sentence_cls(pooled_rep).view((-1, 3))
        sentence_cls_loss = F.cross_entropy(sentence_logits, sentence_label)

        outputs.update({'sentence_label_logits': sentence_logits,
                        'sentence_label': sentence_label.long(),
                        'sentence_cls_loss': sentence_cls_loss})

        loss = sentence_cls_loss.mean()

        return outputs, loss

    def inference_forward(self,
                          images,
                          boxes,
                          hypothesis,
                          im_info):
        ###########################################
        # visual feature extraction

        # Don't know what segments are for
        # segms = masks

        box_mask = (boxes[:, :, -1] > - 0.5)
        max_len = int(box_mask.sum(1).max().item())

        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len].type(torch.float32)

        # segms = segms[:, :max_len]
        if self.config.NETWORK.BLIND:
            obj_reps = {'obj_reps': boxes.new_zeros((*boxes.shape[:-1], self.config.NETWORK.IMAGE_FINAL_DIM))}
        else:
            obj_reps = self.image_feature_extractor(images=images,
                                                    boxes=boxes,
                                                    box_mask=box_mask,
                                                    im_info=im_info,
                                                    classes=None,
                                                    segms=None)

        # For now no tags
        mask = (hypothesis > 0.5)

        ############################################

        # prepare text
        text_input_ids, text_mask = self.prepare_text(hypothesis, mask)
        text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape)

        # Add visual feature to text elements
        text_visual_embeddings = self._collect_obj_reps(text_input_ids.new_zeros(text_input_ids.size()),
                                                        obj_reps['obj_reps'])
        # Add textual feature to image element
        if self.config.NETWORK.BLIND:
            object_linguistic_embeddings = boxes.new_zeros((*boxes.shape[:-1], self.config.NETWORK.VLBERT.hidden_size))
        else:
            object_linguistic_embeddings = self.object_linguistic_embeddings(
                boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        object_vl_embeddings = torch.cat((obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT
        if self.config.NETWORK.NO_OBJ_ATTENTION or self.config.NETWORK.BLIND:
            box_mask.zero_()

        _, pooled_rep = self.vlbert(text_input_ids,
                                    text_token_type_ids,
                                    text_visual_embeddings,
                                    text_mask,
                                    object_vl_embeddings,
                                    box_mask,
                                    output_all_encoded_layers=False,
                                    output_text_and_object_separately=False,
                                    output_attention_probs=False)

        ###########################################
        outputs = {}

        # sentence classification
        sentence_logits = self.sentence_cls(pooled_rep).view((-1, 3))

        outputs.update({'sentence_label_logits': sentence_logits})

        return outputs