def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.predict_on_cls = config.NETWORK.VLBERT.predict_on_cls  # make prediction on [CLS]?

        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(
                config,
                average_pool=True,
                final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                enable_cnn_reg_loss=self.enable_cnn_reg_loss)
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(
                    81, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:  # default: class-agnostic
                self.object_linguistic_embeddings = nn.Embedding(
                    1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        dim = config.NETWORK.VLBERT.hidden_size
        if config.NETWORK.CLASSIFIER_TYPE == "2fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE),
                torch.nn.ReLU(inplace=True),
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE,
                                config.DATASET.ANSWER_VOCAB_SIZE),
            )
        elif config.NETWORK.CLASSIFIER_TYPE == "1fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, config.DATASET.ANSWER_VOCAB_SIZE))
        elif config.NETWORK.CLASSIFIER_TYPE == 'mlm':
            transform = BertPredictionHeadTransform(config.NETWORK.VLBERT)
            linear = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                               config.DATASET.ANSWER_VOCAB_SIZE)
            self.final_mlp = nn.Sequential(
                transform,
                nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                linear)
        else:
            raise ValueError("Not support classifier type: {}!".format(
                config.NETWORK.CLASSIFIER_TYPE))

        self.use_spatial_model = False
        if config.NETWORK.USE_SPATIAL_MODEL:
            self.use_spatial_model = True
            # self.simple_spatial_model = SimpleSpatialModel(4, config.NETWORK.VLBERT.hidden_size, 9, config)

            self.use_coord_vector = False
            if config.NETWORK.USE_COORD_VECTOR:
                self.use_coord_vector = True
                self.loc_fcs = nn.Sequential(
                    nn.Linear(2 * 5 + 9, config.NETWORK.VLBERT.hidden_size),
                    nn.ReLU(True),
                    nn.Linear(config.NETWORK.VLBERT.hidden_size,
                              config.NETWORK.VLBERT.hidden_size))
            else:
                self.simple_spatial_model = SimpleSpatialModel(
                    4, config.NETWORK.VLBERT.hidden_size, 9)

            self.spa_add = True if config.NETWORK.SPA_ADD else False
            self.spa_concat = True if config.NETWORK.SPA_CONCAT else False

            if self.spa_add:
                self.spa_feat_weight = 0.5
                if config.NETWORK.USE_SPA_WEIGHT:
                    self.spa_feat_weight = config.NETWORK.SPA_FEAT_WEIGHT
                self.spa_fusion_linear = nn.Linear(
                    config.NETWORK.VLBERT.hidden_size,
                    config.NETWORK.VLBERT.hidden_size)
            elif self.spa_concat:
                if self.use_coord_vector:
                    self.spa_fusion_linear = nn.Linear(
                        config.NETWORK.VLBERT.hidden_size +
                        config.NETWORK.VLBERT.hidden_size,
                        config.NETWORK.VLBERT.hidden_size)
                else:
                    self.spa_fusion_linear = nn.Linear(
                        config.NETWORK.VLBERT.hidden_size * 2,
                        config.NETWORK.VLBERT.hidden_size)
            self.spa_linear = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                                        config.NETWORK.VLBERT.hidden_size)
            self.dropout = nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT)

            self.spa_one_more_layer = config.NETWORK.SPA_ONE_MORE_LAYER
            if self.spa_one_more_layer:
                self.spa_linear_hidden = nn.Linear(
                    config.NETWORK.VLBERT.hidden_size,
                    config.NETWORK.VLBERT.hidden_size)

        self.enhanced_img_feature = False
        if config.NETWORK.VLBERT.ENHANCED_IMG_FEATURE:
            self.enhanced_img_feature = True
            self.mask_weight = config.NETWORK.VLBERT.mask_weight
            self.mask_loss_sum = config.NETWORK.VLBERT.mask_loss_sum
            self.mask_loss_mse = config.NETWORK.VLBERT.mask_loss_mse
            self.no_predicate = config.NETWORK.VLBERT.NO_PREDICATE

        self.all_proposals_test = False
        if config.DATASET.ALL_PROPOSALS_TEST:
            self.all_proposals_test = True

        self.use_uvtranse = False
        if config.NETWORK.USE_UVTRANSE:
            self.use_uvtranse = True
            self.union_vec_fc = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                                          config.NETWORK.VLBERT.hidden_size)
            self.uvt_add = True if config.NETWORK.UVT_ADD else False
            self.uvt_concat = True if config.NETWORK.UVT_CONCAT else False
            if not (self.uvt_add ^ self.uvt_concat):
                assert False
            if self.uvt_add:
                self.uvt_feat_weight = config.NETWORK.UVT_FEAT_WEIGHT
                self.uvt_fusion_linear = nn.Linear(
                    config.NETWORK.VLBERT.hidden_size,
                    config.NETWORK.VLBERT.hidden_size)
            elif self.uvt_concat:
                self.uvt_fusion_linear = nn.Linear(
                    config.NETWORK.VLBERT.hidden_size * 2,
                    config.NETWORK.VLBERT.hidden_size)
            self.uvt_linear = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                                        config.NETWORK.VLBERT.hidden_size)
            self.dropout_uvt = nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT)

        # init weights
        self.init_weight()
    def __init__(self, config):

        super(ResNetVLBERT, self).__init__(config)

        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(config,
                                                    average_pool=True,
                                                    final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                                                    enable_cnn_reg_loss=self.enable_cnn_reg_loss)
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(81, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:
                self.object_linguistic_embeddings = nn.Embedding(1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        self.tokenizer = BertTokenizer.from_pretrained(config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(config.NETWORK.BERT_PRETRAINED,
                                                                      config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print("Warning: no pretrained language model found, training from scratch!!!")

        # Also pass the finetuning strategy
        self.vlbert = VisualLinguisticBert(config.NETWORK.VLBERT,
                                         language_pretrained_model_path=language_pretrained_model_path, finetune_strategy=config.FINETUNE_STRATEGY)

        # self.hm_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)
        # self.hi_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)

        dim = config.NETWORK.VLBERT.hidden_size
        if config.NETWORK.CLASSIFIER_TYPE == "2fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE),
                torch.nn.ReLU(inplace=True),
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, config.DATASET.ANSWER_VOCAB_SIZE),
            )
        elif config.NETWORK.CLASSIFIER_TYPE == "1fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                torch.nn.Linear(dim, config.DATASET.ANSWER_VOCAB_SIZE)
            )
        elif config.NETWORK.CLASSIFIER_TYPE == 'mlm':
            transform = BertPredictionHeadTransform(config.NETWORK.VLBERT)
            linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.DATASET.ANSWER_VOCAB_SIZE)
            self.final_mlp = nn.Sequential(
                transform,
                nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                linear
            )
        else:
            raise ValueError("Not support classifier type: {}!".format(config.NETWORK.CLASSIFIER_TYPE))

        # init weights
        self.init_weight()

        self.fix_params()
示例#3
0
    def __init__(self, config):

        super(ResNetVLBERTv5, self).__init__(config)

        self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS
        if not config.NETWORK.BLIND:
            self.image_feature_extractor = FastRCNN(
                config,
                average_pool=True,
                final_dim=config.NETWORK.IMAGE_FINAL_DIM,
                enable_cnn_reg_loss=self.enable_cnn_reg_loss)
            if config.NETWORK.VLBERT.object_word_embed_mode == 1:
                self.object_linguistic_embeddings = nn.Embedding(
                    601, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 2:
                self.object_linguistic_embeddings = nn.Embedding(
                    1, config.NETWORK.VLBERT.hidden_size)
            elif config.NETWORK.VLBERT.object_word_embed_mode == 3:
                self.object_linguistic_embeddings = None
            else:
                raise NotImplementedError
        self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN

        self.tokenizer = BertTokenizer.from_pretrained(
            config.NETWORK.BERT_MODEL_NAME)

        language_pretrained_model_path = None
        if config.NETWORK.BERT_PRETRAINED != '':
            language_pretrained_model_path = '{}-{:04d}.model'.format(
                config.NETWORK.BERT_PRETRAINED,
                config.NETWORK.BERT_PRETRAINED_EPOCH)
        elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME):
            weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME,
                                       BERT_WEIGHTS_NAME)
            if os.path.isfile(weight_path):
                language_pretrained_model_path = weight_path
        self.language_pretrained_model_path = language_pretrained_model_path
        if language_pretrained_model_path is None:
            print(
                "Warning: no pretrained language model found, training from scratch!!!"
            )

        self.vlbert = VisualLinguisticBert(
            config.NETWORK.VLBERT,
            language_pretrained_model_path=language_pretrained_model_path)

        # self.hm_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)
        # self.hi_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)

        self.hidden_dropout = nn.Dropout(0.2)
        if config.NETWORK.VLBERT.num_hidden_layers == 24:
            self.gating = nn.Parameter(torch.tensor([
                0.0067, 0.0070, 0.0075, 0.0075, 0.0075, 0.0074, 0.0076, 0.0075,
                0.0076, 0.0080, 0.0079, 0.0086, 0.0096, 0.0101, 0.0104, 0.0105,
                0.0111, 0.0120, 0.0126, 0.0115, 0.0108, 0.0105, 0.0104, 0.0117
            ]),
                                       requires_grad=True)
        else:
            self.gating = nn.Parameter(
                torch.ones(config.NETWORK.VLBERT.num_hidden_layers, ) * 1e-2,
                requires_grad=True)
        self.train_steps = 0

        dim = config.NETWORK.VLBERT.hidden_size
        if config.NETWORK.CLASSIFIER_TYPE == "2fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE),
                torch.nn.ReLU(inplace=True),
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE,
                                config.NETWORK.CLASSIFIER_CLASS),
            )
        elif config.NETWORK.CLASSIFIER_TYPE == "1fc":
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT,
                                 inplace=False),
                torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_CLASS))
        elif config.NETWORK.CLASSIFIER_TYPE == 'mlm':
            transform = BertPredictionHeadTransform(config.NETWORK.VLBERT)
            linear = nn.Linear(config.NETWORK.VLBERT.hidden_size,
                               config.NETWORK.CLASSIFIER_CLASS)
            self.final_mlp = nn.Sequential(
                transform,
                nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False),
                linear)
        else:
            raise ValueError("Not support classifier type: {}!".format(
                config.NETWORK.CLASSIFIER_TYPE))

        # init weights
        self.init_weight()

        self.fix_params()