def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.predict_on_cls = config.NETWORK.VLBERT.predict_on_cls # make prediction on [CLS]? self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS if not config.NETWORK.BLIND: self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=self.enable_cnn_reg_loss) if config.NETWORK.VLBERT.object_word_embed_mode == 1: self.object_linguistic_embeddings = nn.Embedding( 81, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 2: # default: class-agnostic self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 3: self.object_linguistic_embeddings = None else: raise NotImplementedError self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path self.language_pretrained_model_path = language_pretrained_model_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path) dim = config.NETWORK.VLBERT.hidden_size if config.NETWORK.CLASSIFIER_TYPE == "2fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, config.DATASET.ANSWER_VOCAB_SIZE), ) elif config.NETWORK.CLASSIFIER_TYPE == "1fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.DATASET.ANSWER_VOCAB_SIZE)) elif config.NETWORK.CLASSIFIER_TYPE == 'mlm': transform = BertPredictionHeadTransform(config.NETWORK.VLBERT) linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.DATASET.ANSWER_VOCAB_SIZE) self.final_mlp = nn.Sequential( transform, nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), linear) else: raise ValueError("Not support classifier type: {}!".format( config.NETWORK.CLASSIFIER_TYPE)) self.use_spatial_model = False if config.NETWORK.USE_SPATIAL_MODEL: self.use_spatial_model = True # self.simple_spatial_model = SimpleSpatialModel(4, config.NETWORK.VLBERT.hidden_size, 9, config) self.use_coord_vector = False if config.NETWORK.USE_COORD_VECTOR: self.use_coord_vector = True self.loc_fcs = nn.Sequential( nn.Linear(2 * 5 + 9, config.NETWORK.VLBERT.hidden_size), nn.ReLU(True), nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size)) else: self.simple_spatial_model = SimpleSpatialModel( 4, config.NETWORK.VLBERT.hidden_size, 9) self.spa_add = True if config.NETWORK.SPA_ADD else False self.spa_concat = True if config.NETWORK.SPA_CONCAT else False if self.spa_add: self.spa_feat_weight = 0.5 if config.NETWORK.USE_SPA_WEIGHT: self.spa_feat_weight = config.NETWORK.SPA_FEAT_WEIGHT self.spa_fusion_linear = nn.Linear( config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) elif self.spa_concat: if self.use_coord_vector: self.spa_fusion_linear = nn.Linear( config.NETWORK.VLBERT.hidden_size + config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) else: self.spa_fusion_linear = nn.Linear( config.NETWORK.VLBERT.hidden_size * 2, config.NETWORK.VLBERT.hidden_size) self.spa_linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) self.dropout = nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT) self.spa_one_more_layer = config.NETWORK.SPA_ONE_MORE_LAYER if self.spa_one_more_layer: self.spa_linear_hidden = nn.Linear( config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) self.enhanced_img_feature = False if config.NETWORK.VLBERT.ENHANCED_IMG_FEATURE: self.enhanced_img_feature = True self.mask_weight = config.NETWORK.VLBERT.mask_weight self.mask_loss_sum = config.NETWORK.VLBERT.mask_loss_sum self.mask_loss_mse = config.NETWORK.VLBERT.mask_loss_mse self.no_predicate = config.NETWORK.VLBERT.NO_PREDICATE self.all_proposals_test = False if config.DATASET.ALL_PROPOSALS_TEST: self.all_proposals_test = True self.use_uvtranse = False if config.NETWORK.USE_UVTRANSE: self.use_uvtranse = True self.union_vec_fc = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) self.uvt_add = True if config.NETWORK.UVT_ADD else False self.uvt_concat = True if config.NETWORK.UVT_CONCAT else False if not (self.uvt_add ^ self.uvt_concat): assert False if self.uvt_add: self.uvt_feat_weight = config.NETWORK.UVT_FEAT_WEIGHT self.uvt_fusion_linear = nn.Linear( config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) elif self.uvt_concat: self.uvt_fusion_linear = nn.Linear( config.NETWORK.VLBERT.hidden_size * 2, config.NETWORK.VLBERT.hidden_size) self.uvt_linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) self.dropout_uvt = nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT) # init weights self.init_weight()
def __init__(self, config): super(ResNetVLBERT, self).__init__(config) self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS if not config.NETWORK.BLIND: self.image_feature_extractor = FastRCNN(config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=self.enable_cnn_reg_loss) if config.NETWORK.VLBERT.object_word_embed_mode == 1: self.object_linguistic_embeddings = nn.Embedding(81, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 2: self.object_linguistic_embeddings = nn.Embedding(1, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 3: self.object_linguistic_embeddings = None else: raise NotImplementedError self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained(config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format(config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path self.language_pretrained_model_path = language_pretrained_model_path if language_pretrained_model_path is None: print("Warning: no pretrained language model found, training from scratch!!!") # Also pass the finetuning strategy self.vlbert = VisualLinguisticBert(config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path, finetune_strategy=config.FINETUNE_STRATEGY) # self.hm_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) # self.hi_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) dim = config.NETWORK.VLBERT.hidden_size if config.NETWORK.CLASSIFIER_TYPE == "2fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, config.DATASET.ANSWER_VOCAB_SIZE), ) elif config.NETWORK.CLASSIFIER_TYPE == "1fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.DATASET.ANSWER_VOCAB_SIZE) ) elif config.NETWORK.CLASSIFIER_TYPE == 'mlm': transform = BertPredictionHeadTransform(config.NETWORK.VLBERT) linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.DATASET.ANSWER_VOCAB_SIZE) self.final_mlp = nn.Sequential( transform, nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), linear ) else: raise ValueError("Not support classifier type: {}!".format(config.NETWORK.CLASSIFIER_TYPE)) # init weights self.init_weight() self.fix_params()
def __init__(self, config): super(ResNetVLBERTv5, self).__init__(config) self.enable_cnn_reg_loss = config.NETWORK.ENABLE_CNN_REG_LOSS if not config.NETWORK.BLIND: self.image_feature_extractor = FastRCNN( config, average_pool=True, final_dim=config.NETWORK.IMAGE_FINAL_DIM, enable_cnn_reg_loss=self.enable_cnn_reg_loss) if config.NETWORK.VLBERT.object_word_embed_mode == 1: self.object_linguistic_embeddings = nn.Embedding( 601, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 2: self.object_linguistic_embeddings = nn.Embedding( 1, config.NETWORK.VLBERT.hidden_size) elif config.NETWORK.VLBERT.object_word_embed_mode == 3: self.object_linguistic_embeddings = None else: raise NotImplementedError self.image_feature_bn_eval = config.NETWORK.IMAGE_FROZEN_BN self.tokenizer = BertTokenizer.from_pretrained( config.NETWORK.BERT_MODEL_NAME) language_pretrained_model_path = None if config.NETWORK.BERT_PRETRAINED != '': language_pretrained_model_path = '{}-{:04d}.model'.format( config.NETWORK.BERT_PRETRAINED, config.NETWORK.BERT_PRETRAINED_EPOCH) elif os.path.isdir(config.NETWORK.BERT_MODEL_NAME): weight_path = os.path.join(config.NETWORK.BERT_MODEL_NAME, BERT_WEIGHTS_NAME) if os.path.isfile(weight_path): language_pretrained_model_path = weight_path self.language_pretrained_model_path = language_pretrained_model_path if language_pretrained_model_path is None: print( "Warning: no pretrained language model found, training from scratch!!!" ) self.vlbert = VisualLinguisticBert( config.NETWORK.VLBERT, language_pretrained_model_path=language_pretrained_model_path) # self.hm_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) # self.hi_out = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.VLBERT.hidden_size) self.hidden_dropout = nn.Dropout(0.2) if config.NETWORK.VLBERT.num_hidden_layers == 24: self.gating = nn.Parameter(torch.tensor([ 0.0067, 0.0070, 0.0075, 0.0075, 0.0075, 0.0074, 0.0076, 0.0075, 0.0076, 0.0080, 0.0079, 0.0086, 0.0096, 0.0101, 0.0104, 0.0105, 0.0111, 0.0120, 0.0126, 0.0115, 0.0108, 0.0105, 0.0104, 0.0117 ]), requires_grad=True) else: self.gating = nn.Parameter( torch.ones(config.NETWORK.VLBERT.num_hidden_layers, ) * 1e-2, requires_grad=True) self.train_steps = 0 dim = config.NETWORK.VLBERT.hidden_size if config.NETWORK.CLASSIFIER_TYPE == "2fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_HIDDEN_SIZE), torch.nn.ReLU(inplace=True), torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(config.NETWORK.CLASSIFIER_HIDDEN_SIZE, config.NETWORK.CLASSIFIER_CLASS), ) elif config.NETWORK.CLASSIFIER_TYPE == "1fc": self.final_mlp = torch.nn.Sequential( torch.nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), torch.nn.Linear(dim, config.NETWORK.CLASSIFIER_CLASS)) elif config.NETWORK.CLASSIFIER_TYPE == 'mlm': transform = BertPredictionHeadTransform(config.NETWORK.VLBERT) linear = nn.Linear(config.NETWORK.VLBERT.hidden_size, config.NETWORK.CLASSIFIER_CLASS) self.final_mlp = nn.Sequential( transform, nn.Dropout(config.NETWORK.CLASSIFIER_DROPOUT, inplace=False), linear) else: raise ValueError("Not support classifier type: {}!".format( config.NETWORK.CLASSIFIER_TYPE)) # init weights self.init_weight() self.fix_params()