def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=False) self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path self.language_pretrained_model_path = language_pretrained_model_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path) transform = VisualLinguisticBertMVRCHeadTransform( config.NETWORK.VLBERT) # self.linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, 768) #331 1000 35 100 12003 lihui # self.OIM_loss = OIM_Module(331, 768) # config.NETWORK.VLBERT.hidden_size) self.OIM_loss = OIM_Module(12003, 768) self.linear = nn.Sequential( # transform, nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), nn.Linear(config.NETWORK.VLBERT.hidden_size, 768) #331 1000 35 100 12003 lihui ) linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, 1) self.final_mlp = nn.Sequential( transform, nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), linear) # init weights self.init_weight() self.fix_params()
def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=False) self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path self.language_pretrained_model_path = language_pretrained_model_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path) self.task1_head = Task1Head(config.NETWORK.VLBERT) self.task2_head = Task2Head(config.NETWORK.VLBERT) self.task3_head = Task3Head(config.NETWORK.VLBERT) # init weights self.init_weight() self.fix_params()
def __init__(self, config): super(ResNetVLBERTForAttentionVis, self).__init__(config) self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=False) self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) if config.NETWORK.IMAGE_FEAT_PRECOMPUTED or ( not config.NETWORK.MASK_RAW_PIXELS): self.object_mask_visual_embedding = nn.Embedding(1, 2048) if config.NETWORK.WITH_MVRC_LOSS: self.object_mask_word_embedding = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.aux_text_visual_embedding = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=None if config.NETWORK.VLBERT.from_scratch else language_pretrained_model_path) # init weights self.init_weight() self.fix_params()
def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS if not config.NETWORK.BLIND: self.image_feature_extractor = FastRCNN(config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=self.enable_cnn_reg_loss) if config.NETWORK.VLBERT.object_word_embed_mode == 1: self.object_linguistic_embeddings = nn.Embedding(81, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 2: self.object_linguistic_embeddings = nn.Embedding(1, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 3: self.object_linguistic_embeddings = None else: raise NotImplementedError self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained(config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format(config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path self.language_pretrained_model_path = language_pretrained_model_path if language_pretrained_model_path is None: print("Warning: no pretrained language model found, training from scratch!!!") # Also pass the finetuning strategy self.vlbert = VisualLinguisticBert(config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path, finetune_strategy=config.FINETUNE_STRATEGY) # self.hm_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) # self.hi_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) dim = config.NETWORK.VLBERT.hidden_size if config.NETWORK.CLASSIFIER_TYPE == "2fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, config.DATASET.ANSWER_VOCAB_SIZE), ) elif config.NETWORK.CLASSIFIER_TYPE == "1fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.DATASET.ANSWER_VOCAB_SIZE) ) elif config.NETWORK.CLASSIFIER_TYPE == 'mlm': transform = BertPredictionHeadTransform(config.NETWORK.VLBERT) linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.DATASET.ANSWER_VOCAB_SIZE) self.final_mlp = nn.Sequential( transform, nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), linear ) else: raise ValueError("Not support classifier type: {}!".format(config.NETWORK.CLASSIFIER_TYPE)) # init weights self.init_weight() self.fix_params()
def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS self.cnn_loss_top = config.NETWORK.CNN_LOSS_TOP if not config.NETWORK.BLIND: self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=(self.enable_cnn_reg_loss and not self.cnn_loss_top)) if config.NETWORK.VLBERT.object_word_embed_mode == 1: self.object_linguistic_embeddings = nn.Embedding( 81, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 2: self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 3: self.object_linguistic_embeddings = None else: raise NotImplementedError if self.enable_cnn_reg_loss and self.cnn_loss_top: self.cnn_loss_reg = nn.Sequential( VisualLinguisticBertMVRCHeadTransform( config.NETWORK.VLBERT), nn.Dropout(config.NETWORK.CNN_REG_DROPOUT, inplace=False), nn.Linear(config.NETWORK.VLBERT.hidden_size, 81)) self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN if 'roberta' in config.NETWORK.BERT_MODEL_NAME: self.tokenizer = RobertaTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) else: self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = TimeDistributed( VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path)) self.for_pretrain = config.NETWORK.FOR_MASK_VL_MODELING_PRETRAIN assert not self.for_pretrain, "Not implement pretrain mode now!" if not self.for_pretrain: dim = config.NETWORK.VLBERT.hidden_size if config.NETWORK.CLASSIFIER_TYPE == "2fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, 1), ) elif config.NETWORK.CLASSIFIER_TYPE == "1fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, 1)) else: raise ValueError("Not support classifier type: {}!".format( config.NETWORK.CLASSIFIER_TYPE)) # init weights self.init_weight() self.fix_params()
def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS self.cnn_loss_top = config.NETWORK.CNN_LOSS_TOP self.align_caption_img = config.DATASET.ALIGN_CAPTION_IMG self.use_phrasal_paraphrases = config.DATASET.PHRASE_CLS self.supervise_attention = config.NETWORK.SUPERVISE_ATTENTION self.normalization = config.NETWORK.ATTENTION_NORM_METHOD self.ewc_reg = config.NETWORK.EWC_REG self.importance_hparam = 0. if config.NETWORK.EWC_REG: self.fisher = pickle.load(open(config.NETWORK.FISHER_PATH, "rb")) self.pretrain_param = torch.load(config.NETWORK.PARAM_PRETRAIN) self.importance_hparam = config.NETWORK.EWC_IMPORTANCE if not config.NETWORK.BLIND: self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=(self.enable_cnn_reg_loss and not self.cnn_loss_top)) if config.NETWORK.VLBERT.object_word_embed_mode == 1: self.object_linguistic_embeddings = nn.Embedding( 81, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 2: self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 3: self.object_linguistic_embeddings = None else: raise NotImplementedError self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN if 'roberta' in config.NETWORK.BERT_MODEL_NAME: self.tokenizer = RobertaTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) else: self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path) self.for_pretrain = False dim = config.NETWORK.VLBERT.hidden_size if self.align_caption_img: sentence_logits_shape = 3 else: sentence_logits_shape = 1 if config.NETWORK.SENTENCE.CLASSIFIER_TYPE == "2fc": self.sentence_cls = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear( dim, config.NETWORK.SENTENCE.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(config.NETWORK.SENTENCE.CLASSIFIER_HIDDEN_SIZE, sentence_logits_shape), ) elif config.NETWORK.SENTENCE.CLASSIFIER_TYPE == "1fc": self.sentence_cls = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.SENTENCE.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, sentence_logits_shape)) else: raise ValueError("Classifier type: {} not supported!".format( config.NETWORK.SENTENCE.CLASSIFIER_TYPE)) if self.use_phrasal_paraphrases: if config.NETWORK.PHRASE.CLASSIFIER_TYPE == "2fc": self.phrasal_cls = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.PHRASE.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear( 4 * dim, config.NETWORK.PHRASE.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.PHRASE.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear( config.NETWORK.PHRASE.CLASSIFIER_HIDDEN_SIZE, 5), ) elif config.NETWORK.PHRASE.CLASSIFIER_TYPE == "1fc": self.phrasal_cls = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.PHRASE.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(4 * dim, 5)) else: raise ValueError("Classifier type: {} not supported!".format( config.NETWORK.PHRASE.CLASSIFIER_TYPE)) if self.supervise_attention == "indirect": if config.NETWORK.VG.CLASSIFIER_TYPE == "2fc": self.vg_cls = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.VG.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(2 * dim, config.NETWORK.VG.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.VG.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(config.NETWORK.VG.CLASSIFIER_HIDDEN_SIZE, 1), ) elif config.NETWORK.VG.CLASSIFIER_TYPE == "1fc": self.vg_cls = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.VG.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(2 * dim, 1)) else: raise ValueError("Classifier type: {} not supported!".format( config.NETWORK.PHRASE.CLASSIFIER_TYPE)) # init weights self.init_weight() self.fix_params()
def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.predict_on_cls = config.NETWORK.VLBERT.predict_on_cls # make prediction on [CLS]? self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS if not config.NETWORK.BLIND: self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=self.enable_cnn_reg_loss) if config.NETWORK.VLBERT.object_word_embed_mode == 1: self.object_linguistic_embeddings = nn.Embedding( 81, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 2: # default: class-agnostic self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 3: self.object_linguistic_embeddings = None else: raise NotImplementedError self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path self.language_pretrained_model_path = language_pretrained_model_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path) dim = config.NETWORK.VLBERT.hidden_size if config.NETWORK.CLASSIFIER_TYPE == "2fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, config.DATASET.ANSWER_VOCAB_SIZE), ) elif config.NETWORK.CLASSIFIER_TYPE == "1fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.DATASET.ANSWER_VOCAB_SIZE)) elif config.NETWORK.CLASSIFIER_TYPE == 'mlm': transform = BertPredictionHeadTransform(config.NETWORK.VLBERT) linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.DATASET.ANSWER_VOCAB_SIZE) self.final_mlp = nn.Sequential( transform, nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), linear) else: raise ValueError("Not support classifier type: {}!".format( config.NETWORK.CLASSIFIER_TYPE)) self.use_spatial_model = False if config.NETWORK.USE_SPATIAL_MODEL: self.use_spatial_model = True # self.simple_spatial_model = SimpleSpatialModel(4, config.NETWORK.VLBERT.hidden_size, 9, config) self.use_coord_vector = False if config.NETWORK.USE_COORD_VECTOR: self.use_coord_vector = True self.loc_fcs = nn.Sequential( nn.Linear(2 * 5 + 9, config.NETWORK.VLBERT.hidden_size), nn.ReLU(True), nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)) else: self.simple_spatial_model = SimpleSpatialModel( 4, config.NETWORK.VLBERT.hidden_size, 9) self.spa_add = True if config.NETWORK.SPA_ADD else False self.spa_concat = True if config.NETWORK.SPA_CONCAT else False if self.spa_add: self.spa_feat_weight = 0.5 if config.NETWORK.USE_SPA_WEIGHT: self.spa_feat_weight = config.NETWORK.SPA_FEAT_WEIGHT self.spa_fusion_linear = nn.Linear( config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) elif self.spa_concat: if self.use_coord_vector: self.spa_fusion_linear = nn.Linear( config.NETWORK.VLBERT.hidden_size + config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) else: self.spa_fusion_linear = nn.Linear( config.NETWORK.VLBERT.hidden_size * 2, config.NETWORK.VLBERT.hidden_size) self.spa_linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) self.dropout = nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT) self.spa_one_more_layer = config.NETWORK.SPA_ONE_MORE_LAYER if self.spa_one_more_layer: self.spa_linear_hidden = nn.Linear( config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) self.enhanced_img_feature = False if config.NETWORK.VLBERT.ENHANCED_IMG_FEATURE: self.enhanced_img_feature = True self.mask_weight = config.NETWORK.VLBERT.mask_weight self.mask_loss_sum = config.NETWORK.VLBERT.mask_loss_sum self.mask_loss_mse = config.NETWORK.VLBERT.mask_loss_mse self.no_predicate = config.NETWORK.VLBERT.NO_PREDICATE self.all_proposals_test = False if config.DATASET.ALL_PROPOSALS_TEST: self.all_proposals_test = True self.use_uvtranse = False if config.NETWORK.USE_UVTRANSE: self.use_uvtranse = True self.union_vec_fc = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) self.uvt_add = True if config.NETWORK.UVT_ADD else False self.uvt_concat = True if config.NETWORK.UVT_CONCAT else False if not (self.uvt_add ^ self.uvt_concat): assert False if self.uvt_add: self.uvt_feat_weight = config.NETWORK.UVT_FEAT_WEIGHT self.uvt_fusion_linear = nn.Linear( config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) elif self.uvt_concat: self.uvt_fusion_linear = nn.Linear( config.NETWORK.VLBERT.hidden_size * 2, config.NETWORK.VLBERT.hidden_size) self.uvt_linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) self.dropout_uvt = nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT) # init weights self.init_weight()
def __init__(self, config): super(ResNetVLBERTv5, self).__init__(config) self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS if not config.NETWORK.BLIND: self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=self.enable_cnn_reg_loss) if config.NETWORK.VLBERT.object_word_embed_mode == 1: self.object_linguistic_embeddings = nn.Embedding( 601, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 2: self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 3: self.object_linguistic_embeddings = None else: raise NotImplementedError self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path self.language_pretrained_model_path = language_pretrained_model_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path) # self.hm_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) # self.hi_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) self.hidden_dropout = nn.Dropout(0.2) if config.NETWORK.VLBERT.num_hidden_layers == 24: self.gating = nn.Parameter(torch.tensor([ 0.0067, 0.0070, 0.0075, 0.0075, 0.0075, 0.0074, 0.0076, 0.0075, 0.0076, 0.0080, 0.0079, 0.0086, 0.0096, 0.0101, 0.0104, 0.0105, 0.0111, 0.0120, 0.0126, 0.0115, 0.0108, 0.0105, 0.0104, 0.0117 ]), requires_grad=True) else: self.gating = nn.Parameter( torch.ones(config.NETWORK.VLBERT.num_hidden_layers, ) * 1e-2, requires_grad=True) self.train_steps = 0 dim = config.NETWORK.VLBERT.hidden_size if config.NETWORK.CLASSIFIER_TYPE == "2fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, config.NETWORK.CLASSIFIER_CLASS), ) elif config.NETWORK.CLASSIFIER_TYPE == "1fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_CLASS)) elif config.NETWORK.CLASSIFIER_TYPE == 'mlm': transform = BertPredictionHeadTransform(config.NETWORK.VLBERT) linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.CLASSIFIER_CLASS) self.final_mlp = nn.Sequential( transform, nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), linear) else: raise ValueError("Not support classifier type: {}!".format( config.NETWORK.CLASSIFIER_TYPE)) # init weights self.init_weight() self.fix_params()
class ResNetVLBERT(Module): def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=False) self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path self.language_pretrained_model_path = language_pretrained_model_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path) self.task1_head = Task1Head(config.NETWORK.VLBERT) self.task2_head = Task2Head(config.NETWORK.VLBERT) self.task3_head = Task3Head(config.NETWORK.VLBERT) # init weights self.init_weight() self.fix_params() def init_weight(self): self.image_feature_extractor.init_weight() if self.object_linguistic_embeddings is not None: self.object_linguistic_embeddings.weight.data.normal_( mean=0.0, std=self.config.NETWORK.VLBERT.initializer_range) for m in self.task1_head.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) for m in self.task2_head.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) for m in self.task3_head.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) torch.nn.init.constant_(m.bias, 0) def train(self, mode=True): super(ResNetVLBERT, self).train(mode) # turn some frozen layers to eval mode if self.image_feature_bn_eval: self.image_feature_extractor.bn_eval() def fix_params(self): for param in self.image_feature_extractor.parameters(): param.requires_grad = False for param in self.vlbert.parameters(): param.requires_grad = False for param in self.object_linguistic_embeddings.parameters(): param.requires_grad def train_forward(self, image, boxes, im_info, expression, label, pos, target, mask): ########################################### if self.vlbert.training: self.vlbert.eval() if self.image_feature_extractor.training: self.image_feature_extractor.eval() # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) max_len = int(box_mask.sum(1).max().item()) origin_len = boxes.shape[1] box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] #our labels for foil are binary and 1 dimension #label = label[:, :max_len] obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) ############################################ # prepare text cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) text_input_ids = expression.new_zeros( (expression.shape[0], expression.shape[1] + 2)) text_input_ids[:, 0] = cls_id text_input_ids[:, 1:-1] = expression _sep_pos = (text_input_ids > 0).sum(1) _batch_inds = torch.arange(expression.shape[0], device=expression.device) text_input_ids[_batch_inds, _sep_pos] = sep_id text_token_type_ids = text_input_ids.new_zeros(text_input_ids.shape) text_mask = text_input_ids > 0 text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze( 1).repeat((1, text_input_ids.shape[1], 1)) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT sequence_reps, object_reps, _ = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_text_and_object_separately=True, output_all_encoded_layers=False) cls_rep = sequence_reps[:, 0, :].squeeze(1) sequence_reps = sequence_reps[:, 1:, :] cls_log_probs = self.task1_head(cls_rep) pos_log_probs = self.task2_head(sequence_reps, text_mask[:, 1:]) zeros = torch.zeros_like(text_input_ids) mask_len = mask.shape[1] error_cor_mask = zeros error_cor_mask[:, 1:mask_len + 1] = mask error_cor_mask = error_cor_mask.bool() text_input_ids[error_cor_mask] = self.tokenizer.convert_tokens_to_ids( ['[MASK]'])[0] sequence_reps, _, _ = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_text_and_object_separately=True, output_all_encoded_layers=False) sequence_reps = sequence_reps[:, 1:, :] select_index = pos.view(-1, 1).unsqueeze(2).repeat(1, 1, 768) select_index[select_index < 0] = 0 masked_reps = torch.gather(sequence_reps, 1, select_index).squeeze(1) cor_log_probs = self.task3_head(masked_reps) loss_mask = label.view(-1).float() cls_loss = F.binary_cross_entropy(cls_log_probs.view(-1), label.view(-1).float(), reduction="none") pos_loss = F.nll_loss( pos_log_probs, pos, ignore_index=-1, reduction="none").view(-1) * loss_mask cor_loss = F.nll_loss(cor_log_probs.view(-1, cor_log_probs.shape[-1]), target, ignore_index=0, reduction="none").view(-1) * loss_mask loss = cls_loss.mean() + pos_loss.mean() + cor_loss.mean() outputs = { "cls_logits": cls_log_probs, "pos_logits": pos_log_probs, "cor_logits": cor_log_probs, "cls_label": label, "pos_label": pos, "cor_label": target } return outputs, loss def inference_forward(self, image, boxes, im_info, expression, label, pos, target): # visual feature extraction images = image box_mask = (boxes[:, :, 0] > -1.5) max_len = int(box_mask.sum(1).max().item()) origin_len = boxes.shape[1] box_mask = box_mask[:, :max_len] boxes = boxes[:, :max_len] #our labels for foil are binary and 1 dimension #label = label[:, :max_len] with torch.no_grad(): obj_reps = self.image_feature_extractor(images=images, boxes=boxes, box_mask=box_mask, im_info=im_info, classes=None, segms=None) ############################################ # prepare text cls_id, sep_id = self.tokenizer.convert_tokens_to_ids( ['[CLS]', '[SEP]']) text_input_ids = expression.new_zeros( (expression.shape[0], expression.shape[1] + 2)) text_input_ids[:, 0] = cls_id text_input_ids[:, 1:-1] = expression _sep_pos = (text_input_ids > 0).sum(1) _batch_inds = torch.arange(expression.shape[0], device=expression.device) text_input_ids[_batch_inds, _sep_pos] = sep_id text_token_type_ids = text_input_ids.new_zeros( text_input_ids.shape) text_mask = text_input_ids > 0 text_visual_embeddings = obj_reps['obj_reps'][:, 0].unsqueeze( 1).repeat((1, text_input_ids.shape[1], 1)) object_linguistic_embeddings = self.object_linguistic_embeddings( boxes.new_zeros((boxes.shape[0], boxes.shape[1])).long()) object_vl_embeddings = torch.cat( (obj_reps['obj_reps'], object_linguistic_embeddings), -1) ########################################### # Visual Linguistic BERT sequence_reps, object_reps, _ = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_text_and_object_separately=True, output_all_encoded_layers=False) cls_rep = sequence_reps[:, 0, :].squeeze(1) sequence_reps = sequence_reps[:, 1:, :] cls_log_probs = self.task1_head(cls_rep) pos_log_probs = self.task2_head(sequence_reps, text_mask[:, 1:]) zeros = torch.zeros_like(text_input_ids) mask_len = mask.shape[1] error_cor_mask = zeros error_cor_mask[:, 1:mask_len + 1] = mask error_cor_mask = error_cor_mask.bool() text_input_ids[error_cor_mask] = self.tokenizer.convert_tokens_to_ids( ['[MASK]'])[0] with torch.no_grad(): sequence_reps, _, _ = self.vlbert( text_input_ids, text_token_type_ids, text_visual_embeddings, text_mask, object_vl_embeddings, box_mask, output_text_and_object_separately=True, output_all_encoded_layers=False) sequence_reps = sequence_reps[:, 1:, :] select_index = pos.view(-1, 1).unsqueeze(2).repeat(1, 1, 768) select_index[select_index < 0] = 0 masked_reps = torch.gather(sequence_reps, 1, select_index).squeeze(1) cor_log_probs = self.task3_head(masked_reps) loss_mask = label.view(-1).float() cls_loss = F.binary_cross_entropy(cls_log_probs.view(-1), label.view(-1).float(), reduction="none") pos_loss = F.nll_loss( pos_log_probs, pos, ignore_index=-1, reduction="none").view(-1) * loss_mask cor_loss = F.nll_loss(cor_log_probs.view(-1, cor_log_probs.shape[-1]), target, ignore_index=0, reduction="none").view(-1) * loss_mask loss = cls_loss.mean() + pos_loss.mean() + cor_loss.mean() outputs = { "cls_logits": cls_log_probs, "pos_logits": pos_log_probs, "cor_logits": cor_log_probs, "cls_label": label, "pos_label": pos, "cor_label": target } return outputs, loss