Exemplo n.º 1
0
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.image_feature_extractor = FastRCNN(
            config,
            average_pool=True,
            final_dim=config.NETWORK.IMAGE_FINAL_DIM,
            enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        transform = VisualLinguisticBertMVRCHeadTransform(
            config.NETWORK.VLBERT)
        # self.linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, 768) #331 1000 35 100 12003 lihui
        # self.OIM_loss = OIM_Module(331, 768)  # config.NETWORK.VLBERT.hidden_size)
        self.OIM_loss = OIM_Module(12003, 768)
        self.linear = nn.Sequential(
            # transform,
            nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
            nn.Linear(config.NETWORK.VLBERT.hidden_size,
                      768)  #331 1000 35 100 12003 lihui
        )

        linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, 1)
        self.final_mlp = nn.Sequential(
            transform,
            nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
            linear)

        # init weights
        self.init_weight()

        self.fix_params()
Exemplo n.º 2
0
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.image_feature_extractor = FastRCNN(
            config,
            average_pool=True,
            final_dim=config.NETWORK.IMAGE_FINAL_DIM,
            enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        self.task1_head = Task1Head(config.NETWORK.VLBERT)
        self.task2_head = Task2Head(config.NETWORK.VLBERT)
        self.task3_head = Task3Head(config.NETWORK.VLBERT)

        # init weights
        self.init_weight()

        self.fix_params()
Exemplo n.º 3
0
    def __init__(self, config):

        super(ResNetVLBERTForAttentionVis, self).__init__(config)

        self.image_feature_extractor = FastRCNN(
            config,
            average_pool=True,
            final_dim=config.NETWORK.IMAGE_FINAL_DIM,
            enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        if config.NETWORK.IMAGE_FEAT_PRECOMPUTED or (
                not config.NETWORK.MASK_RAW_PIXELS):
            self.object_mask_visual_embedding = nn.Embedding(1, 2048)
        if config.NETWORK.WITH_MVRC_LOSS:
            self.object_mask_word_embedding = nn.Embedding(
                1, config.NETWORK.VLBERT.hidden_size)
        self.aux_text_visual_embedding = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)
        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path

        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=None
            if config.NETWORK.VLBERT.from_scratch else
            language_pretrained_model_path)

        # init weights
        self.init_weight()

        self.fix_params()
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(config,
                                                    average_pool=True,
                                                    final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                                                    enable_cnn_reg_loss=self.enable_cnn_reg_loss)
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(81, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:
                self.object_linguistic_embeddings = nn.Embedding(1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        self.tokenizer = BertTokenizer.from_pretrained(config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(config.NETWORK.BERT_PRETRAINED,
                                                                      config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print("Warning: no pretrained language model found, training from scratch!!!")

        # Also pass the finetuning strategy
        self.vlbert = VisualLinguisticBert(config.NETWORK.VLBERT,
                                         language_pretrained_model_path=language_pretrained_model_path, finetune_strategy=config.FINETUNE_STRATEGY)

        # self.hm_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)
        # self.hi_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)

        dim = config.NETWORK.VLBERT.hidden_size
        if config.NETWORK.CLASSIFIER_TYPE == "2fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE),
                torch.nn.ReLU(inplace=True),
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, config.DATASET.ANSWER_VOCAB_SIZE),
            )
        elif config.NETWORK.CLASSIFIER_TYPE == "1fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                torch.nn.Linear(dim, config.DATASET.ANSWER_VOCAB_SIZE)
            )
        elif config.NETWORK.CLASSIFIER_TYPE == 'mlm':
            transform = BertPredictionHeadTransform(config.NETWORK.VLBERT)
            linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.DATASET.ANSWER_VOCAB_SIZE)
            self.final_mlp = nn.Sequential(
                transform,
                nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                linear
            )
        else:
            raise ValueError("Not support classifier type: {}!".format(config.NETWORK.CLASSIFIER_TYPE))

        # init weights
        self.init_weight()

        self.fix_params()
Exemplo n.º 5
0
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        self.cnn_loss_top = config.NETWORK.CNN_LOSS_TOP
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(
                config,
                average_pool=True,
                final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                enable_cnn_reg_loss=(self.enable_cnn_reg_loss
                                     and not self.cnn_loss_top))
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(
                    81, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:
                self.object_linguistic_embeddings = nn.Embedding(
                    1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError
            if self.enable_cnn_reg_loss and self.cnn_loss_top:
                self.cnn_loss_reg = nn.Sequential(
                    VisualLinguisticBertMVRCHeadTransform(
                        config.NETWORK.VLBERT),
                    nn.Dropout(config.NETWORK.CNN_REG_DROPOUT, inplace=False),
                    nn.Linear(config.NETWORK.VLBERT.hidden_size, 81))
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        if 'roberta' in config.NETWORK.BERT_MODEL_NAME:
            self.tokenizer = RobertaTokenizer.from_pretrained(
                config.NETWORK.BERT_MODEL_NAME)
        else:
            self.tokenizer = BertTokenizer.from_pretrained(
                config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path

        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = TimeDistributed(
            VisualLinguisticBert(
                config.NETWORK.VLBERT,
                language_pretrained_model_path=language_pretrained_model_path))

        self.for_pretrain = config.NETWORK.FOR_MASK_VL_MODELING_PRETRAIN
        assert not self.for_pretrain, "Not implement pretrain mode now!"

        if not self.for_pretrain:
            dim = config.NETWORK.VLBERT.hidden_size
            if config.NETWORK.CLASSIFIER_TYPE == "2fc":
                self.final_mlp = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(dim,
                                    config.NETWORK.CLASSIFIER_HIDDEN_SIZE),
                    torch.nn.ReLU(inplace=True),
                    torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, 1),
                )
            elif config.NETWORK.CLASSIFIER_TYPE == "1fc":
                self.final_mlp = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                     inplace=False), torch.nn.Linear(dim, 1))
            else:
                raise ValueError("Not support classifier type: {}!".format(
                    config.NETWORK.CLASSIFIER_TYPE))

        # init weights
        self.init_weight()

        self.fix_params()
Exemplo n.º 6
0
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)
        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        self.cnn_loss_top = config.NETWORK.CNN_LOSS_TOP
        self.align_caption_img = config.DATASET.ALIGN_CAPTION_IMG
        self.use_phrasal_paraphrases = config.DATASET.PHRASE_CLS
        self.supervise_attention = config.NETWORK.SUPERVISE_ATTENTION
        self.normalization = config.NETWORK.ATTENTION_NORM_METHOD
        self.ewc_reg = config.NETWORK.EWC_REG
        self.importance_hparam = 0.
        if config.NETWORK.EWC_REG:
            self.fisher = pickle.load(open(config.NETWORK.FISHER_PATH, "rb"))
            self.pretrain_param = torch.load(config.NETWORK.PARAM_PRETRAIN)
            self.importance_hparam = config.NETWORK.EWC_IMPORTANCE
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(
                config,
                average_pool=True,
                final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                enable_cnn_reg_loss=(self.enable_cnn_reg_loss
                                     and not self.cnn_loss_top))
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(
                    81, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:
                self.object_linguistic_embeddings = nn.Embedding(
                    1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError

        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        if 'roberta' in config.NETWORK.BERT_MODEL_NAME:
            self.tokenizer = RobertaTokenizer.from_pretrained(
                config.NETWORK.BERT_MODEL_NAME)
        else:
            self.tokenizer = BertTokenizer.from_pretrained(
                config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path

        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        self.for_pretrain = False
        dim = config.NETWORK.VLBERT.hidden_size
        if self.align_caption_img:
            sentence_logits_shape = 3
        else:
            sentence_logits_shape = 1
        if config.NETWORK.SENTENCE.CLASSIFIER_TYPE == "2fc":
            self.sentence_cls = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(
                    dim, config.NETWORK.SENTENCE.CLASSIFIER_HIDDEN_SIZE),
                torch.nn.ReLU(inplace=True),
                torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(config.NETWORK.SENTENCE.CLASSIFIER_HIDDEN_SIZE,
                                sentence_logits_shape),
            )
        elif config.NETWORK.SENTENCE.CLASSIFIER_TYPE == "1fc":
            self.sentence_cls = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, sentence_logits_shape))
        else:
            raise ValueError("Classifier type: {} not supported!".format(
                config.NETWORK.SENTENCE.CLASSIFIER_TYPE))

        if self.use_phrasal_paraphrases:
            if config.NETWORK.PHRASE.CLASSIFIER_TYPE == "2fc":
                self.phrasal_cls = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.PHRASE.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(
                        4 * dim, config.NETWORK.PHRASE.CLASSIFIER_HIDDEN_SIZE),
                    torch.nn.ReLU(inplace=True),
                    torch.nn.Dropout(config.NETWORK.PHRASE.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(
                        config.NETWORK.PHRASE.CLASSIFIER_HIDDEN_SIZE, 5),
                )
            elif config.NETWORK.PHRASE.CLASSIFIER_TYPE == "1fc":
                self.phrasal_cls = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.PHRASE.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(4 * dim, 5))
            else:
                raise ValueError("Classifier type: {} not supported!".format(
                    config.NETWORK.PHRASE.CLASSIFIER_TYPE))

        if self.supervise_attention == "indirect":
            if config.NETWORK.VG.CLASSIFIER_TYPE == "2fc":
                self.vg_cls = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.VG.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(2 * dim,
                                    config.NETWORK.VG.CLASSIFIER_HIDDEN_SIZE),
                    torch.nn.ReLU(inplace=True),
                    torch.nn.Dropout(config.NETWORK.VG.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(config.NETWORK.VG.CLASSIFIER_HIDDEN_SIZE,
                                    1),
                )
            elif config.NETWORK.VG.CLASSIFIER_TYPE == "1fc":
                self.vg_cls = torch.nn.Sequential(
                    torch.nn.Dropout(config.NETWORK.VG.CLASSIFIER_DROPOUT,
                                     inplace=False),
                    torch.nn.Linear(2 * dim, 1))
            else:
                raise ValueError("Classifier type: {} not supported!".format(
                    config.NETWORK.PHRASE.CLASSIFIER_TYPE))

        # init weights
        self.init_weight()

        self.fix_params()
Exemplo n.º 7
0
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.predict_on_cls = config.NETWORK.VLBERT.predict_on_cls  # make prediction on [CLS]?

        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(
                config,
                average_pool=True,
                final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                enable_cnn_reg_loss=self.enable_cnn_reg_loss)
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(
                    81, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:  # default: class-agnostic
                self.object_linguistic_embeddings = nn.Embedding(
                    1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        dim = config.NETWORK.VLBERT.hidden_size
        if config.NETWORK.CLASSIFIER_TYPE == "2fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE),
                torch.nn.ReLU(inplace=True),
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE,
                                config.DATASET.ANSWER_VOCAB_SIZE),
            )
        elif config.NETWORK.CLASSIFIER_TYPE == "1fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, config.DATASET.ANSWER_VOCAB_SIZE))
        elif config.NETWORK.CLASSIFIER_TYPE == 'mlm':
            transform = BertPredictionHeadTransform(config.NETWORK.VLBERT)
            linear = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                               config.DATASET.ANSWER_VOCAB_SIZE)
            self.final_mlp = nn.Sequential(
                transform,
                nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                linear)
        else:
            raise ValueError("Not support classifier type: {}!".format(
                config.NETWORK.CLASSIFIER_TYPE))

        self.use_spatial_model = False
        if config.NETWORK.USE_SPATIAL_MODEL:
            self.use_spatial_model = True
            # self.simple_spatial_model = SimpleSpatialModel(4, config.NETWORK.VLBERT.hidden_size, 9, config)

            self.use_coord_vector = False
            if config.NETWORK.USE_COORD_VECTOR:
                self.use_coord_vector = True
                self.loc_fcs = nn.Sequential(
                    nn.Linear(2 * 5 + 9, config.NETWORK.VLBERT.hidden_size),
                    nn.ReLU(True),
                    nn.Linear(config.NETWORK.VLBERT.hidden_size,
                              config.NETWORK.VLBERT.hidden_size))
            else:
                self.simple_spatial_model = SimpleSpatialModel(
                    4, config.NETWORK.VLBERT.hidden_size, 9)

            self.spa_add = True if config.NETWORK.SPA_ADD else False
            self.spa_concat = True if config.NETWORK.SPA_CONCAT else False

            if self.spa_add:
                self.spa_feat_weight = 0.5
                if config.NETWORK.USE_SPA_WEIGHT:
                    self.spa_feat_weight = config.NETWORK.SPA_FEAT_WEIGHT
                self.spa_fusion_linear = nn.Linear(
                    config.NETWORK.VLBERT.hidden_size,
                    config.NETWORK.VLBERT.hidden_size)
            elif self.spa_concat:
                if self.use_coord_vector:
                    self.spa_fusion_linear = nn.Linear(
                        config.NETWORK.VLBERT.hidden_size +
                        config.NETWORK.VLBERT.hidden_size,
                        config.NETWORK.VLBERT.hidden_size)
                else:
                    self.spa_fusion_linear = nn.Linear(
                        config.NETWORK.VLBERT.hidden_size * 2,
                        config.NETWORK.VLBERT.hidden_size)
            self.spa_linear = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                                        config.NETWORK.VLBERT.hidden_size)
            self.dropout = nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT)

            self.spa_one_more_layer = config.NETWORK.SPA_ONE_MORE_LAYER
            if self.spa_one_more_layer:
                self.spa_linear_hidden = nn.Linear(
                    config.NETWORK.VLBERT.hidden_size,
                    config.NETWORK.VLBERT.hidden_size)

        self.enhanced_img_feature = False
        if config.NETWORK.VLBERT.ENHANCED_IMG_FEATURE:
            self.enhanced_img_feature = True
            self.mask_weight = config.NETWORK.VLBERT.mask_weight
            self.mask_loss_sum = config.NETWORK.VLBERT.mask_loss_sum
            self.mask_loss_mse = config.NETWORK.VLBERT.mask_loss_mse
            self.no_predicate = config.NETWORK.VLBERT.NO_PREDICATE

        self.all_proposals_test = False
        if config.DATASET.ALL_PROPOSALS_TEST:
            self.all_proposals_test = True

        self.use_uvtranse = False
        if config.NETWORK.USE_UVTRANSE:
            self.use_uvtranse = True
            self.union_vec_fc = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                                          config.NETWORK.VLBERT.hidden_size)
            self.uvt_add = True if config.NETWORK.UVT_ADD else False
            self.uvt_concat = True if config.NETWORK.UVT_CONCAT else False
            if not (self.uvt_add ^ self.uvt_concat):
                assert False
            if self.uvt_add:
                self.uvt_feat_weight = config.NETWORK.UVT_FEAT_WEIGHT
                self.uvt_fusion_linear = nn.Linear(
                    config.NETWORK.VLBERT.hidden_size,
                    config.NETWORK.VLBERT.hidden_size)
            elif self.uvt_concat:
                self.uvt_fusion_linear = nn.Linear(
                    config.NETWORK.VLBERT.hidden_size * 2,
                    config.NETWORK.VLBERT.hidden_size)
            self.uvt_linear = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                                        config.NETWORK.VLBERT.hidden_size)
            self.dropout_uvt = nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT)

        # init weights
        self.init_weight()
Exemplo n.º 8
0
    def __init__(self, config):

        super(ResNetVLBERTv5, self).__init__(config)

        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(
                config,
                average_pool=True,
                final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                enable_cnn_reg_loss=self.enable_cnn_reg_loss)
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(
                    601, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:
                self.object_linguistic_embeddings = nn.Embedding(
                    1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        # self.hm_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)
        # self.hi_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)

        self.hidden_dropout = nn.Dropout(0.2)
        if config.NETWORK.VLBERT.num_hidden_layers == 24:
            self.gating = nn.Parameter(torch.tensor([
                0.0067, 0.0070, 0.0075, 0.0075, 0.0075, 0.0074, 0.0076, 0.0075,
                0.0076, 0.0080, 0.0079, 0.0086, 0.0096, 0.0101, 0.0104, 0.0105,
                0.0111, 0.0120, 0.0126, 0.0115, 0.0108, 0.0105, 0.0104, 0.0117
            ]),
                                       requires_grad=True)
        else:
            self.gating = nn.Parameter(
                torch.ones(config.NETWORK.VLBERT.num_hidden_layers, ) * 1e-2,
                requires_grad=True)
        self.train_steps = 0

        dim = config.NETWORK.VLBERT.hidden_size
        if config.NETWORK.CLASSIFIER_TYPE == "2fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE),
                torch.nn.ReLU(inplace=True),
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE,
                                config.NETWORK.CLASSIFIER_CLASS),
            )
        elif config.NETWORK.CLASSIFIER_TYPE == "1fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_CLASS))
        elif config.NETWORK.CLASSIFIER_TYPE == 'mlm':
            transform = BertPredictionHeadTransform(config.NETWORK.VLBERT)
            linear = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                               config.NETWORK.CLASSIFIER_CLASS)
            self.final_mlp = nn.Sequential(
                transform,
                nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                linear)
        else:
            raise ValueError("Not support classifier type: {}!".format(
                config.NETWORK.CLASSIFIER_TYPE))

        # init weights
        self.init_weight()

        self.fix_params()
Exemplo n.º 9
0
class ResNetVLBERT(Module):
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.image_feature_extractor = FastRCNN(
            config,
            average_pool=True,
            final_dim=config.NETWORK.IMAGE_FINAL_DIM,
            enable_cnn_reg_loss=False)
        self.object_linguistic_embeddings = nn.Embedding(
            1, config.NETWORK.VLBERT.hidden_size)
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN
        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        self.task1_head = Task1Head(config.NETWORK.VLBERT)
        self.task2_head = Task2Head(config.NETWORK.VLBERT)
        self.task3_head = Task3Head(config.NETWORK.VLBERT)

        # init weights
        self.init_weight()

        self.fix_params()

    def init_weight(self):
        self.image_feature_extractor.init_weight()

        if self.object_linguistic_embeddings is not None:
            self.object_linguistic_embeddings.weight.data.normal_(
                mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range)
        for m in self.task1_head.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                torch.nn.init.constant_(m.bias, 0)

        for m in self.task2_head.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                torch.nn.init.constant_(m.bias, 0)

        for m in self.task3_head.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                torch.nn.init.constant_(m.bias, 0)

    def train(self, mode=True):
        super(ResNetVLBERT, self).train(mode)
        # turn some frozen layers to eval mode
        if self.image_feature_bn_eval:
            self.image_feature_extractor.bn_eval()

    def fix_params(self):

        for param in self.image_feature_extractor.parameters():
            param.requires_grad = False

        for param in self.vlbert.parameters():
            param.requires_grad = False

        for param in self.object_linguistic_embeddings.parameters():
            param.requires_grad

    def train_forward(self, image, boxes, im_info, expression, label, pos,
                      target, mask):
        ###########################################

        if self.vlbert.training:
            self.vlbert.eval()

        if self.image_feature_extractor.training:
            self.image_feature_extractor.eval()

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        max_len = int(box_mask.sum(1).max().item())
        origin_len = boxes.shape[1]
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        #our labels for foil are binary and 1 dimension
        #label = label[:, :max_len]

        obj_reps = self.image_feature_extractor(images=images,
                                                boxes=boxes,
                                                box_mask=box_mask,
                                                im_info=im_info,
                                                classes=None,
                                                segms=None)

        ############################################
        # prepare text
        cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(
            ['[CLS]', '[SEP]'])
        text_input_ids = expression.new_zeros(
            (expression.shape[0], expression.shape[1] + 2))
        text_input_ids[:, 0] = cls_id
        text_input_ids[:, 1:-1] = expression
        _sep_pos = (text_input_ids > 0).sum(1)
        _batch_inds = torch.arange(expression.shape[0],
                                   device=expression.device)
        text_input_ids[_batch_inds, _sep_pos] = sep_id
        text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape)
        text_mask = text_input_ids > 0
        text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze(
            1).repeat((1, text_input_ids.shape[1], 1))

        object_linguistic_embeddings = self.object_linguistic_embeddings(
            boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
        object_vl_embeddings = torch.cat(
            (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

        ###########################################

        # Visual Linguistic BERT

        sequence_reps, object_reps, _ = self.vlbert(
            text_input_ids,
            text_token_type_ids,
            text_visual_embeddings,
            text_mask,
            object_vl_embeddings,
            box_mask,
            output_text_and_object_separately=True,
            output_all_encoded_layers=False)

        cls_rep = sequence_reps[:, 0, :].squeeze(1)
        sequence_reps = sequence_reps[:, 1:, :]

        cls_log_probs = self.task1_head(cls_rep)

        pos_log_probs = self.task2_head(sequence_reps, text_mask[:, 1:])

        zeros = torch.zeros_like(text_input_ids)
        mask_len = mask.shape[1]
        error_cor_mask = zeros
        error_cor_mask[:, 1:mask_len + 1] = mask
        error_cor_mask = error_cor_mask.bool()
        text_input_ids[error_cor_mask] = self.tokenizer.convert_tokens_to_ids(
            ['[MASK]'])[0]

        sequence_reps, _, _ = self.vlbert(
            text_input_ids,
            text_token_type_ids,
            text_visual_embeddings,
            text_mask,
            object_vl_embeddings,
            box_mask,
            output_text_and_object_separately=True,
            output_all_encoded_layers=False)

        sequence_reps = sequence_reps[:, 1:, :]
        select_index = pos.view(-1, 1).unsqueeze(2).repeat(1, 1, 768)
        select_index[select_index < 0] = 0
        masked_reps = torch.gather(sequence_reps, 1, select_index).squeeze(1)
        cor_log_probs = self.task3_head(masked_reps)

        loss_mask = label.view(-1).float()

        cls_loss = F.binary_cross_entropy(cls_log_probs.view(-1),
                                          label.view(-1).float(),
                                          reduction="none")
        pos_loss = F.nll_loss(
            pos_log_probs, pos, ignore_index=-1,
            reduction="none").view(-1) * loss_mask
        cor_loss = F.nll_loss(cor_log_probs.view(-1, cor_log_probs.shape[-1]),
                              target,
                              ignore_index=0,
                              reduction="none").view(-1) * loss_mask

        loss = cls_loss.mean() + pos_loss.mean() + cor_loss.mean()

        outputs = {
            "cls_logits": cls_log_probs,
            "pos_logits": pos_log_probs,
            "cor_logits": cor_log_probs,
            "cls_label": label,
            "pos_label": pos,
            "cor_label": target
        }

        return outputs, loss

    def inference_forward(self, image, boxes, im_info, expression, label, pos,
                          target):

        # visual feature extraction
        images = image
        box_mask = (boxes[:, :, 0] > -1.5)
        max_len = int(box_mask.sum(1).max().item())
        origin_len = boxes.shape[1]
        box_mask = box_mask[:, :max_len]
        boxes = boxes[:, :max_len]
        #our labels for foil are binary and 1 dimension
        #label = label[:, :max_len]
        with torch.no_grad():
            obj_reps = self.image_feature_extractor(images=images,
                                                    boxes=boxes,
                                                    box_mask=box_mask,
                                                    im_info=im_info,
                                                    classes=None,
                                                    segms=None)

            ############################################
            # prepare text
            cls_id, sep_id = self.tokenizer.convert_tokens_to_ids(
                ['[CLS]', '[SEP]'])
            text_input_ids = expression.new_zeros(
                (expression.shape[0], expression.shape[1] + 2))
            text_input_ids[:, 0] = cls_id
            text_input_ids[:, 1:-1] = expression
            _sep_pos = (text_input_ids > 0).sum(1)
            _batch_inds = torch.arange(expression.shape[0],
                                       device=expression.device)
            text_input_ids[_batch_inds, _sep_pos] = sep_id
            text_token_type_ids = text_input_ids.new_zeros(
                text_input_ids.shape)
            text_mask = text_input_ids > 0
            text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze(
                1).repeat((1, text_input_ids.shape[1], 1))

            object_linguistic_embeddings = self.object_linguistic_embeddings(
                boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long())
            object_vl_embeddings = torch.cat(
                (obj_reps['obj_reps'], object_linguistic_embeddings), -1)

            ###########################################

            # Visual Linguistic BERT

            sequence_reps, object_reps, _ = self.vlbert(
                text_input_ids,
                text_token_type_ids,
                text_visual_embeddings,
                text_mask,
                object_vl_embeddings,
                box_mask,
                output_text_and_object_separately=True,
                output_all_encoded_layers=False)

        cls_rep = sequence_reps[:, 0, :].squeeze(1)
        sequence_reps = sequence_reps[:, 1:, :]

        cls_log_probs = self.task1_head(cls_rep)

        pos_log_probs = self.task2_head(sequence_reps, text_mask[:, 1:])

        zeros = torch.zeros_like(text_input_ids)
        mask_len = mask.shape[1]
        error_cor_mask = zeros
        error_cor_mask[:, 1:mask_len + 1] = mask
        error_cor_mask = error_cor_mask.bool()
        text_input_ids[error_cor_mask] = self.tokenizer.convert_tokens_to_ids(
            ['[MASK]'])[0]

        with torch.no_grad():
            sequence_reps, _, _ = self.vlbert(
                text_input_ids,
                text_token_type_ids,
                text_visual_embeddings,
                text_mask,
                object_vl_embeddings,
                box_mask,
                output_text_and_object_separately=True,
                output_all_encoded_layers=False)

        sequence_reps = sequence_reps[:, 1:, :]
        select_index = pos.view(-1, 1).unsqueeze(2).repeat(1, 1, 768)
        select_index[select_index < 0] = 0
        masked_reps = torch.gather(sequence_reps, 1, select_index).squeeze(1)
        cor_log_probs = self.task3_head(masked_reps)

        loss_mask = label.view(-1).float()

        cls_loss = F.binary_cross_entropy(cls_log_probs.view(-1),
                                          label.view(-1).float(),
                                          reduction="none")
        pos_loss = F.nll_loss(
            pos_log_probs, pos, ignore_index=-1,
            reduction="none").view(-1) * loss_mask
        cor_loss = F.nll_loss(cor_log_probs.view(-1, cor_log_probs.shape[-1]),
                              target,
                              ignore_index=0,
                              reduction="none").view(-1) * loss_mask

        loss = cls_loss.mean() + pos_loss.mean() + cor_loss.mean()

        outputs = {
            "cls_logits": cls_log_probs,
            "pos_logits": pos_log_probs,
            "cor_logits": cor_log_probs,
            "cls_label": label,
            "pos_label": pos,
            "cor_label": target
        }

        return outputs, loss