def _init_classifier(self, combined_embedding_dim): # TODO: Later support multihead num_choices = registry.get(self._datasets[0] + "_num_final_outputs") self.classifier = ClassifierLayer( self.config["classifier"]["type"], in_dim=combined_embedding_dim, out_dim=num_choices, **self.config["classifier"]["params"])
def _build_output(self): # dynamic OCR-copying scores with pointer network self.ocr_ptr_net = OcrPtrNet(**self.config.classifier.ocr_ptr_net) # fixed answer vocabulary scores num_choices = registry.get(self._datasets[0] + "_num_final_outputs") # remove the OCR copying dimensions in LoRRA's classifier output # (OCR copying will be handled separately) num_choices -= self.config.classifier.ocr_max_num self.classifier = ClassifierLayer( self.config["classifier"]["type"], in_dim=self.mmt_config.hidden_size, out_dim=num_choices, **self.config["classifier"]["params"]) self.answer_processor = registry.get(self._datasets[0] + "_answer_processor")
def _init_classifier(self): self.classifier = ClassifierLayer( self.config["classifier"]["type"], in_dim=self.config["classifier"]["params"]["feature_dim"], out_dim=self.vocab_size, **self.config["classifier"]["params"])
class BUTD(Pythia): def __init__(self, config): super().__init__(config) def build(self): self._build_word_embedding() self._init_feature_encoders("image") self._init_feature_embeddings("image") self._init_classifier() self._init_extras() def _build_word_embedding(self): self.text_processor = registry.get(self._datasets[0] + "_text_processor") self.vocab = self.text_processor.vocab self.vocab_size = self.vocab.get_size() self.word_embedding = self.vocab.get_embedding( torch.nn.Embedding, embedding_dim=self.config["embedding_dim"]) setattr(self, "text_embeddings_out_dim", self.config["embedding_dim"]) def _init_classifier(self): self.classifier = ClassifierLayer( self.config["classifier"]["type"], in_dim=self.config["classifier"]["params"]["feature_dim"], out_dim=self.vocab_size, **self.config["classifier"]["params"]) def get_optimizer_parameters(self, config): params = [ { "params": self.word_embedding.parameters() }, { "params": self.image_feature_embeddings_list.parameters() }, { "params": self.classifier.parameters() }, { "params": self.image_feature_encoders.parameters(), "lr": (config["optimizer_attributes"]["params"]["lr"] * 0.1), }, ] return params def prepare_data(self, sample_list, batch_size): setattr(self, "teacher_forcing", hasattr(sample_list, "text")) data = {} if self.teacher_forcing: caption_lengths, sort_ind = sample_list.caption_len.sort( dim=0, descending=True) data["decode_lengths"] = (caption_lengths - 1).tolist() sample_list.text = sample_list.text[sort_ind] sample_list.answers = sample_list.answers[sort_ind] sample_list.image_feature_0 = sample_list.image_feature_0[sort_ind] data["texts"] = sample_list.text timesteps = max(data["decode_lengths"]) sample_list.add_field("targets", sample_list.text[:, 1:]) else: data["texts"] = sample_list.answers.new_full((batch_size, 1), self.vocab.SOS_INDEX, dtype=torch.long) timesteps = self.text_processor.max_length sample_list.add_field("targets", sample_list.answers[:, 0, 1:]) return data, sample_list, timesteps def init_hidden_state(self, features): h = features.new_zeros( (features.size(0), self.config["classifier"]["params"]["hidden_dim"]), dtype=torch.float, ) c = features.new_zeros( (features.size(0), self.config["classifier"]["params"]["hidden_dim"]), dtype=torch.float, ) return h, c def get_data_t(self, t, data, batch_size_t, prev_output): if self.teacher_forcing: # Modify batch_size for timestep t batch_size_t = sum([l > t for l in data["decode_lengths"]]) elif prev_output is not None and self.config["inference"][ "type"] == "greedy": # Adding t-1 output words to data["text"] for greedy decoding output_softmax = torch.log_softmax(prev_output, dim=1) _, indices = torch.max(output_softmax, dim=1, keepdim=True) data["texts"] = torch.cat( (data["texts"], indices.view(batch_size_t, 1)), dim=1) # Slice data based on batch_size at timestep t data["texts"] = data["texts"][:batch_size_t] if "state" in data: h1 = data["state"]["td_hidden"][0][:batch_size_t] c1 = data["state"]["td_hidden"][1][:batch_size_t] h2 = data["state"]["lm_hidden"][0][:batch_size_t] c2 = data["state"]["lm_hidden"][1][:batch_size_t] else: h1, c1 = self.init_hidden_state(data["texts"]) h2, c2 = self.init_hidden_state(data["texts"]) data["state"] = {"td_hidden": (h1, c1), "lm_hidden": (h2, c2)} registry.register("{}_lstm_state".format(h1.device), data["state"]) return data, batch_size_t def forward(self, sample_list): # Stores the output probabilites. Not used if beam_search inference scores = sample_list.answers.new_ones( ( sample_list.answers.size(0), self.text_processor.max_length, self.vocab_size, ), dtype=torch.float, ) # For beam search inference. Currently beam seach for BUTD works only # with batch_size = 1 and should be used with run_type inference only. # TODO : Implement batch beam search if self.config["inference"]["type"] == "beam_search": beam_search = BeamSearch( self.vocab, self.config["inference"]["params"]["beam_length"]) sample_list = beam_search.init_batch(sample_list) batch_size = sample_list.image_feature_0.size(0) data, sample_list, timesteps = self.prepare_data( sample_list, batch_size) output = None batch_size_t = batch_size for t in range(timesteps): data, batch_size_t = self.get_data_t(t, data, batch_size_t, output) if self.config["inference"]["type"] == "beam_search": pi_t = data["texts"] else: pi_t = data["texts"][:, t].unsqueeze(-1) embedding = self.word_embedding(pi_t) attention_feature, _ = self.process_feature_embedding( "image", sample_list, embedding[:, 0, :], batch_size_t=batch_size_t) output = self.classifier(attention_feature) # Compute Beam Search decoding if self.config["inference"]["type"] == "beam_search": finish, data, batch_size_t = beam_search.search( t, data, output) if finish: break else: scores[:batch_size_t, t] = output model_output = {"scores": scores} if self.config["inference"]["type"] == "beam_search": model_output["captions"] = beam_search.best_score() return model_output
class Pythia(BaseModel): def __init__(self, config): super().__init__(config) self.config = config self._global_config = registry.get("config") self._datasets = self._global_config.datasets.split(",") def build(self): self._build_word_embedding() self._init_text_embeddings("text") self._init_feature_encoders("image") self._init_feature_embeddings("image") self._init_combine_layer("image", "text") self._init_classifier(self._get_classifier_input_dim()) self._init_extras() def _build_word_embedding(self): assert len(self._datasets) > 0 text_processor = registry.get(self._datasets[0] + "_text_processor") vocab = text_processor.vocab self.word_embedding = vocab.get_embedding(torch.nn.Embedding, embedding_dim=300) def _init_text_embeddings(self, attr="text"): if "embeddings" not in attr: attr += "_embeddings" text_embeddings = [] text_embeddings_list_config = self.config[attr] embeddings_out_dim = 0 for text_embedding in text_embeddings_list_config: embedding_type = text_embedding.type embedding_kwargs = ConfigNode(text_embedding.params) self._update_text_embedding_args(embedding_kwargs) embedding = TextEmbedding(embedding_type, **embedding_kwargs) text_embeddings.append(embedding) embeddings_out_dim += embedding.text_out_dim setattr(self, attr + "_out_dim", embeddings_out_dim) setattr(self, attr, nn.ModuleList(text_embeddings)) def _update_text_embedding_args(self, args): # Add model_data_dir to kwargs args["model_data_dir"] = self.config["model_data_dir"] def _init_feature_encoders(self, attr): feat_encoders = [] feat_encoders_list_config = self.config[attr + "_feature_encodings"] feature_dim = self.config[attr + "_feature_dim"] setattr(self, attr + "_feature_dim", feature_dim) for feat_encoder in feat_encoders_list_config: encoder_type = feat_encoder["type"] encoder_kwargs = feat_encoder["params"] encoder_kwargs["model_data_dir"] = self.config["model_data_dir"] feat_model = ImageEncoder(encoder_type, feature_dim, **encoder_kwargs) feat_encoders.append(feat_model) setattr(self, attr + "_feature_dim", feat_model.out_dim) setattr(self, attr + "_feature_encoders", nn.ModuleList(feat_encoders)) def _init_feature_embeddings(self, attr): feature_embeddings_list = [] num_feature_feat = len( getattr(self.config, "{}_feature_encodings".format(attr))) self.feature_embeddings_out_dim = 0 for _ in range(num_feature_feat): feature_embeddings = [] feature_attn_model_list = self.config[attr + "_feature_embeddings"] for feature_attn_model_params in feature_attn_model_list: feature_embedding = ImageEmbedding( getattr(self, attr + "_feature_dim"), self.text_embeddings_out_dim, **feature_attn_model_params) feature_embeddings.append(feature_embedding) self.feature_embeddings_out_dim += feature_embedding.out_dim feature_embeddings = nn.ModuleList(feature_embeddings) feature_embeddings_list.append(feature_embeddings) self.feature_embeddings_out_dim *= getattr(self, attr + "_feature_dim") setattr(self, attr + "_feature_embeddings_out_dim", self.feature_embeddings_out_dim) del self.feature_embeddings_out_dim setattr( self, attr + "_feature_embeddings_list", nn.ModuleList(feature_embeddings_list), ) def _get_embeddings_attr(self, attr): embedding_attr1 = attr if hasattr(self, attr + "_embeddings_out_dim"): embedding_attr1 = attr + "_embeddings_out_dim" else: embedding_attr1 = attr + "_feature_embeddings_out_dim" return embedding_attr1 def _init_combine_layer(self, attr1, attr2): config_attr = attr1 + "_" + attr2 + "_modal_combine" multi_modal_combine_layer = ModalCombineLayer( self.config[config_attr]["type"], getattr(self, self._get_embeddings_attr(attr1)), getattr(self, self._get_embeddings_attr(attr2)), **self.config[config_attr]["params"]) setattr( self, attr1 + "_" + attr2 + "_multi_modal_combine_layer", multi_modal_combine_layer, ) def _init_classifier(self, combined_embedding_dim): # TODO: Later support multihead num_choices = registry.get(self._datasets[0] + "_num_final_outputs") self.classifier = ClassifierLayer( self.config["classifier"]["type"], in_dim=combined_embedding_dim, out_dim=num_choices, **self.config["classifier"]["params"]) def _init_extras(self): self.inter_model = None def get_optimizer_parameters(self, config): combine_layer = self.image_text_multi_modal_combine_layer params = [ { "params": self.word_embedding.parameters() }, { "params": self.image_feature_embeddings_list.parameters() }, { "params": self.text_embeddings.parameters() }, { "params": combine_layer.parameters() }, { "params": self.classifier.parameters() }, { "params": self.image_feature_encoders.parameters(), "lr": (config["optimizer_attributes"]["params"]["lr"] * 0.1), }, ] return params def _get_classifier_input_dim(self): return self.image_text_multi_modal_combine_layer.out_dim def process_text_embedding(self, sample_list, embedding_attr="text_embeddings", info=None): text_embeddings = [] # Get "text" attribute in case of "text_embeddings" case # and "context" attribute in case of "context_embeddings" texts = getattr(sample_list, embedding_attr.split("_")[0]) # Get embedding models text_embedding_models = getattr(self, embedding_attr) for text_embedding_model in text_embedding_models: # TODO: Move this logic inside if isinstance(text_embedding_model, PreExtractedEmbedding): embedding = text_embedding_model(sample_list.question_id) else: embedding = text_embedding_model(texts) text_embeddings.append(embedding) text_embeddding_total = torch.cat(text_embeddings, dim=1) return text_embeddding_total def process_feature_embedding(self, attr, sample_list, text_embedding_total, extra=[], batch_size_t=None): feature_embeddings = [] feature_attentions = [] features = [] batch_size_t = (sample_list.get_batch_size() if batch_size_t is None else batch_size_t) # Convert list of keys to the actual values extra = sample_list.get_fields(extra) feature_idx = 0 # Get all of the features, which are in the form, "image_feature_0" # "image_feature_1" ... while True: feature = getattr(sample_list, "{}_feature_{:d}".format(attr, feature_idx), None) if feature is None: break feature_idx += 1 feature = feature[:batch_size_t] features.append(feature) feature_encoders = getattr(self, attr + "_feature_encoders") # Each feature should have a separate image feature encoders assert len(features) == len(feature_encoders), ( "Number of feature encoders, {} are not equal " "to number of features, {}.".format(len(feature_encoders), len(features))) # Now, iterate to get final attended image features for i, feature in enumerate(features): # Get info related to the current feature. info is generally # in key of format "image_info_0" for 0th feature feature_info = getattr(sample_list, "{}_info_{:d}".format(attr, i), {}) # For Pythia, we need max_features to mask attention feature_dim = getattr(feature_info, "max_features", None) if feature_dim is not None: feature_dim = feature_dim[:batch_size_t] # Attribute in which encoders are saved, for "image" it # will be "image_feature_encoders", other example is # "context_feature_encoders" encoders_attr = attr + "_feature_encoders" feature_encoder = getattr(self, encoders_attr)[i] # Encode the features encoded_feature = feature_encoder(feature) # Get all of the feature embeddings list_attr = attr + "_feature_embeddings_list" feature_embedding_models = getattr(self, list_attr)[i] # Forward through these embeddings one by one for feature_embedding_model in feature_embedding_models: inp = (encoded_feature, text_embedding_total, feature_dim, extra) embedding, attention = feature_embedding_model(*inp) feature_embeddings.append(embedding) feature_attentions.append(attention.squeeze(-1)) # Concatenate all features embeddings and return along with attention feature_embedding_total = torch.cat(feature_embeddings, dim=1) return feature_embedding_total, feature_attentions def combine_embeddings(self, *args): feature_names = args[0] feature_embeddings = args[1] layer = "_".join(feature_names) + "_multi_modal_combine_layer" return getattr(self, layer)(*feature_embeddings) def calculate_logits(self, joint_embedding, **kwargs): return self.classifier(joint_embedding) def forward(self, sample_list): sample_list.text = self.word_embedding(sample_list.text) text_embedding_total = self.process_text_embedding(sample_list) image_embedding_total, _ = self.process_feature_embedding( "image", sample_list, text_embedding_total) if self.inter_model is not None: image_embedding_total = self.inter_model(image_embedding_total) joint_embedding = self.combine_embeddings( ["image", "text"], [image_embedding_total, text_embedding_total]) model_output = {"scores": self.calculate_logits(joint_embedding)} return model_output
class Pythia(BaseModel): def __init__(self, config): super().__init__(config) self.config = config self._global_config = registry.get("config") self._datasets = self._global_config.datasets.split(",") def build(self): self._build_word_embedding() self._init_text_embeddings("text") self._init_feature_encoders("image") self._init_feature_embeddings("image") self._init_combine_layer("image", "text") self._init_classifier(self._get_classifier_input_dim()) self._init_extras() def _build_word_embedding(self): assert len(self._datasets) > 0 text_processor = registry.get(self._datasets[0] + "_text_processor") vocab = text_processor.vocab self.word_embedding = vocab.get_embedding(torch.nn.Embedding, embedding_dim=300) def _init_text_embeddings(self, attr="text"): if "embeddings" not in attr: attr += "_embeddings" text_embeddings = [] text_embeddings_list_config = self.config[attr] embeddings_out_dim = 0 for text_embedding in text_embeddings_list_config: embedding_type = text_embedding.type embedding_kwargs = ConfigNode(text_embedding.params) self._update_text_embedding_args(embedding_kwargs) embedding = TextEmbedding(embedding_type, **embedding_kwargs) text_embeddings.append(embedding) embeddings_out_dim += embedding.text_out_dim setattr(self, attr + "_out_dim", embeddings_out_dim) setattr(self, attr, nn.ModuleList(text_embeddings)) def _update_text_embedding_args(self, args): # Add model_data_dir to kwargs args["model_data_dir"] = self.config["model_data_dir"] def _init_feature_encoders(self, attr): feat_encoders = [] feat_encoders_list_config = self.config[attr + "_feature_encodings"] feature_dim = self.config[attr + "_feature_dim"] setattr(self, attr + "_feature_dim", feature_dim) for feat_encoder in feat_encoders_list_config: encoder_type = feat_encoder["type"] encoder_kwargs = feat_encoder["params"] encoder_kwargs["model_data_dir"] = self.config["model_data_dir"] feat_model = ImageEncoder(encoder_type, feature_dim, **encoder_kwargs) feat_encoders.append(feat_model) setattr(self, attr + "_feature_dim", feat_model.out_dim) setattr(self, attr + "_feature_encoders", nn.ModuleList(feat_encoders)) def _init_feature_embeddings(self, attr): feature_embeddings_list = [] num_feature_feat = len( getattr(self.config, "{}_feature_encodings".format(attr))) self.feature_embeddings_out_dim = 0 for _ in range(num_feature_feat): feature_embeddings = [] feature_attn_model_list = self.config[attr + "_feature_embeddings"] for feature_attn_model_params in feature_attn_model_list: feature_embedding = ImageEmbedding( getattr(self, attr + "_feature_dim"), self.text_embeddings_out_dim, **feature_attn_model_params) feature_embeddings.append(feature_embedding) self.feature_embeddings_out_dim += feature_embedding.out_dim feature_embeddings = nn.ModuleList(feature_embeddings) feature_embeddings_list.append(feature_embeddings) self.feature_embeddings_out_dim *= getattr(self, attr + "_feature_dim") setattr(self, attr + "_feature_embeddings_out_dim", self.feature_embeddings_out_dim) del self.feature_embeddings_out_dim setattr( self, attr + "_feature_embeddings_list", nn.ModuleList(feature_embeddings_list), ) def _get_embeddings_attr(self, attr): embedding_attr1 = attr if hasattr(self, attr + "_embeddings_out_dim"): embedding_attr1 = attr + "_embeddings_out_dim" else: embedding_attr1 = attr + "_feature_embeddings_out_dim" return embedding_attr1 def _init_combine_layer(self, attr1, attr2): config_attr = attr1 + "_" + attr2 + "_modal_combine" multi_modal_combine_layer = ModalCombineLayer( self.config[config_attr]["type"], getattr(self, self._get_embeddings_attr(attr1)), getattr(self, self._get_embeddings_attr(attr2)), **self.config[config_attr]["params"]) setattr( self, attr1 + "_" + attr2 + "_multi_modal_combine_layer", multi_modal_combine_layer, ) def _init_classifier(self, combined_embedding_dim): # TODO: Later support multihead num_choices = registry.get(self._datasets[0] + "_num_final_outputs") self.classifier = ClassifierLayer( self.config["classifier"]["type"], in_dim=combined_embedding_dim, out_dim=num_choices, **self.config["classifier"]["params"]) def _init_extras(self): self.inter_model = None def get_optimizer_parameters(self, config): combine_layer = self.image_text_multi_modal_combine_layer params = [ { "params": self.word_embedding.parameters() }, { "params": self.image_feature_embeddings_list.parameters() }, { "params": self.text_embeddings.parameters() }, { "params": combine_layer.parameters() }, { "params": self.classifier.parameters() }, { "params": self.image_feature_encoders.parameters(), "lr": (config["optimizer_attributes"]["params"]["lr"] * 0.1), }, ] return params def _get_classifier_input_dim(self): return self.image_text_multi_modal_combine_layer.out_dim def process_text_embedding(self, sample_list, embedding_attr="text_embeddings", info=None): text_embeddings = [] #pdb.set_trace() # Get "text" attribute in case of "text_embeddings" case # and "context" attribute in case of "context_embeddings" if not info: texts = getattr(sample_list, embedding_attr.split("_")[0]) elif info == "sub_question": texts = getattr(sample_list, embedding_attr.split("_")[0] + '_sq') elif info == "other_question": texts = getattr(sample_list, embedding_attr.split("_")[0] + '_oq') # Get embedding models text_embedding_models = getattr(self, embedding_attr) for text_embedding_model in text_embedding_models: # TODO: Move this logic inside if isinstance(text_embedding_model, PreExtractedEmbedding): embedding = text_embedding_model(sample_list.question_id) else: embedding = text_embedding_model(texts) text_embeddings.append(embedding) text_embeddding_total = torch.cat(text_embeddings, dim=1) return text_embeddding_total def process_feature_embedding(self, attr, sample_list, text_embedding_total, extra=[], batch_size_t=None): feature_embeddings = [] feature_attentions = [] features = [] batch_size_t = (sample_list.get_batch_size() if batch_size_t is None else batch_size_t) # Convert list of keys to the actual values extra = sample_list.get_fields(extra) feature_idx = 0 # Get all of the features, which are in the form, "image_feature_0" # "image_feature_1" ... while True: feature = getattr(sample_list, "{}_feature_{:d}".format(attr, feature_idx), None) if feature is None: break feature_idx += 1 feature = feature[:batch_size_t] features.append(feature) feature_encoders = getattr(self, attr + "_feature_encoders") # Each feature should have a separate image feature encoders assert len(features) == len(feature_encoders), ( "Number of feature encoders, {} are not equal " "to number of features, {}.".format(len(feature_encoders), len(features))) # Now, iterate to get final attended image features for i, feature in enumerate(features): # Get info related to the current feature. info is generally # in key of format "image_info_0" for 0th feature feature_info = getattr(sample_list, "{}_info_{:d}".format(attr, i), {}) # For Pythia, we need max_features to mask attention feature_dim = getattr(feature_info, "max_features", None) if feature_dim is not None: feature_dim = feature_dim[:batch_size_t] # Attribute in which encoders are saved, for "image" it # will be "image_feature_encoders", other example is # "context_feature_encoders" encoders_attr = attr + "_feature_encoders" feature_encoder = getattr(self, encoders_attr)[i] # Encode the features encoded_feature = feature_encoder(feature) # Get all of the feature embeddings list_attr = attr + "_feature_embeddings_list" feature_embedding_models = getattr(self, list_attr)[i] # Forward through these embeddings one by one for feature_embedding_model in feature_embedding_models: inp = (encoded_feature, text_embedding_total, feature_dim, extra) embedding, attention = feature_embedding_model(*inp) feature_embeddings.append(embedding) feature_attentions.append(attention.squeeze(-1)) # Concatenate all features embeddings and return along with attention feature_embedding_total = torch.cat(feature_embeddings, dim=1) return feature_embedding_total, feature_attentions def combine_embeddings(self, *args): feature_names = args[0] feature_embeddings = args[1] layer = "_".join(feature_names) + "_multi_modal_combine_layer" layer_model = getattr(self, layer) joint_embeddings = layer_model(*feature_embeddings) if args[2] == "main": self.question_embedding = layer_model.question_embedding elif args[2] == "sub_question": self.question_embedding_sq = layer_model.question_embedding elif args[2] == "other_question": self.question_embedding_oq = layer_model.question_embedding #pdb.set_trace() #self.combine_layer = self.layer #joint_embedding = self.combine_layer(feature_embeddings) #pdb.set_trace() return joint_embeddings #return getattr(self, layer)(*feature_embeddings) def calculate_logits(self, joint_embedding, **kwargs): return self.classifier(joint_embedding) def compute_grad_cam(self, sample_list, model_output, question=None): #pdb.set_trace() #pdb.set_trace() if question == "main": #self.importance_vectors_reas = [] scores = model_output['scores'] classes = sample_list['gt_answer_index'] classes_one_hot = torch.zeros_like(scores) classes_one_hot[range(classes_one_hot.shape[0]), classes] = 1 #grads = torch.autograd.grad(outputs = scores, inputs = self.joint_embedding, grad_outputs = classes_one_hot, create_graph=True)[0].to(self.device) grads = torch.autograd.grad(outputs=scores, inputs=self.joint_embedding, grad_outputs=classes_one_hot, create_graph=True)[0] importance_vectors_cam = grads * self.joint_embedding #self.importance_vectors_reas.append(self.question_embedding) #pdb.set_trace() self.importance_vectors_reas = importance_vectors_cam #self.importance_vectors_reas.append(torch.cat((importance_vectors_cam, self.question_embedding), 1)) elif question == "sq": #self.importance_vectors_sq = [] scores = model_output['scores_sq'] classes = sample_list['gt_answer_index_sq'] classes_one_hot = torch.zeros_like(scores) classes_one_hot[range(classes_one_hot.shape[0]), classes] = 1 #grads = torch.autograd.grad(outputs = scores, inputs = self.joint_embedding_sq, grad_outputs = classes_one_hot, create_graph=True)[0].to(self.device) grads = torch.autograd.grad(outputs=scores, inputs=self.joint_embedding_sq, grad_outputs=classes_one_hot, create_graph=True)[0] importance_vectors_cam = grads * self.joint_embedding_sq #self.importance_vectors_sq.append(self.question_embedding_sq) self.importance_vectors_sq = importance_vectors_cam #self.importance_vectors_sq.append(torch.cat((importance_vectors_cam, self.question_embedding_sq), 1)) elif question == "oq": #self.importance_vectors_oq = [] scores = model_output['scores_oq'] classes = sample_list['gt_answer_index_oq'] classes_one_hot = torch.zeros_like(scores) classes_one_hot[range(classes_one_hot.shape[0]), classes] = 1 #grads = torch.autograd.grad(outputs = scores, inputs = self.joint_embedding_oq, grad_outputs = classes_one_hot, create_graph=True)[0].to(self.device) grads = torch.autograd.grad(outputs=scores, inputs=self.joint_embedding_oq, grad_outputs=classes_one_hot, create_graph=True)[0] importance_vectors_cam = grads * self.joint_embedding_oq #self.importance_vectors_oq.append(self.question_embedding_oq) self.importance_vectors_oq = importance_vectors_cam #self.importance_vectors_oq.append(torch.cat((importance_vectors_cam, self.question_embedding_oq), 1)) def cosine_distance(self, vec_1, vec_2): batched_distance_vector = [] cos_similarity = nn.CosineSimilarity(dim=1, eps=1e-6) for i in range(vec_1.shape[0]): norm_vec_1 = vec_1[i] / torch.max(vec_1[i]) norm_vec_2 = vec_2[i] / torch.max(vec_2[i]) distance = 1 - cos_similarity(norm_vec_1.unsqueeze(0), norm_vec_2.unsqueeze(0)) batched_distance_vector.append(distance) return torch.cat(batched_distance_vector) def compute_distances(self, sample_list, model_output): model_output['distance_reas_sub'] = self.cosine_distance( self.importance_vectors_reas, self.importance_vectors_sq) model_output['distance_reas_other'] = self.cosine_distance( self.importance_vectors_reas, self.importance_vectors_oq) def forward(self, sample_list): # Compute the scores for the reasoning question sample_list.text = self.word_embedding(sample_list.text) text_embedding_total = self.process_text_embedding(sample_list) image_embedding_total, _ = self.process_feature_embedding( "image", sample_list, text_embedding_total) if self.inter_model is not None: image_embedding_total = self.inter_model(image_embedding_total) joint_embedding = self.combine_embeddings( ["image", "text"], [image_embedding_total, text_embedding_total], "main") #pdb.set_trace() self.joint_embedding = joint_embedding model_output = {"scores": self.calculate_logits(joint_embedding)} # Compute the scores for the sub-question sample_list.text_sq = self.word_embedding(sample_list.text_sq) text_embedding_total = self.process_text_embedding(sample_list, info="sub_question") image_embedding_total, _ = self.process_feature_embedding( "image", sample_list, text_embedding_total) joint_embedding_sq = self.combine_embeddings( ["image", "text"], [image_embedding_total, text_embedding_total], "sub_question") self.joint_embedding_sq = joint_embedding_sq model_output["scores_sq"] = self.calculate_logits(joint_embedding_sq) sample_list.text_oq = self.word_embedding(sample_list.text_oq) text_embedding_total = self.process_text_embedding( sample_list, info="other_question") image_embedding_total, _ = self.process_feature_embedding( "image", sample_list, text_embedding_total) joint_embedding_oq = self.combine_embeddings( ["image", "text"], [image_embedding_total, text_embedding_total], "other_question") self.joint_embedding_oq = joint_embedding_oq model_output["scores_oq"] = self.calculate_logits(joint_embedding_oq) self.compute_grad_cam(sample_list, model_output, question="main") self.compute_grad_cam(sample_list, model_output, question="sq") self.compute_grad_cam(sample_list, model_output, question="oq") self.compute_distances(sample_list, model_output) #self.compute_grad_cam() #pdb.set_trace() #image_embedding_total, _ = self.process_feature_embedding( # "image", sample_list, text_embedding_total #) #if self.inter_model is not None: # image_embedding_total = self.inter_model(image_embedding_total) #joint_embedding = self.combine_embeddings( # ["image", "text"], [image_embedding_total, text_embedding_total] #) #self.joint_embedding = joint_embedding #model_output = {"scores": self.calculate_logits(joint_embedding)} return model_output
class Pythia(BaseModel): def __init__(self, config): super().__init__(config) self.config = config self._global_config = registry.get("config") self._datasets = self._global_config.datasets.split(",") def build(self): self._build_word_embedding() self._init_text_embeddings("text") self._init_feature_encoders("image") self._init_feature_embeddings("image") self._init_combine_layer("image", "text") self._init_classifier(self._get_classifier_input_dim()) self._init_extras() def _build_word_embedding(self): assert len(self._datasets) > 0 text_processor = registry.get(self._datasets[0] + "_text_processor") vocab = text_processor.vocab self.word_embedding = vocab.get_embedding(torch.nn.Embedding, embedding_dim=300) def _init_text_embeddings(self, attr="text"): if "embeddings" not in attr: attr += "_embeddings" text_embeddings = [] text_embeddings_list_config = self.config[attr] embeddings_out_dim = 0 for text_embedding in text_embeddings_list_config: embedding_type = text_embedding.type embedding_kwargs = ConfigNode(text_embedding.params) self._update_text_embedding_args(embedding_kwargs) embedding = TextEmbedding(embedding_type, **embedding_kwargs) text_embeddings.append(embedding) embeddings_out_dim += embedding.text_out_dim setattr(self, attr + "_out_dim", embeddings_out_dim) setattr(self, attr, nn.ModuleList(text_embeddings)) def _update_text_embedding_args(self, args): # Add model_data_dir to kwargs args["model_data_dir"] = self.config["model_data_dir"] def _init_feature_encoders(self, attr): feat_encoders = [] feat_encoders_list_config = self.config[attr + "_feature_encodings"] feature_dim = self.config[attr + "_feature_dim"] setattr(self, attr + "_feature_dim", feature_dim) for feat_encoder in feat_encoders_list_config: encoder_type = feat_encoder["type"] encoder_kwargs = feat_encoder["params"] encoder_kwargs["model_data_dir"] = self.config["model_data_dir"] feat_model = ImageEncoder(encoder_type, feature_dim, **encoder_kwargs) feat_encoders.append(feat_model) setattr(self, attr + "_feature_dim", feat_model.out_dim) setattr(self, attr + "_feature_encoders", nn.ModuleList(feat_encoders)) def _init_feature_embeddings(self, attr): feature_embeddings_list = [] num_feature_feat = len( getattr(self.config, "{}_feature_encodings".format(attr))) self.feature_embeddings_out_dim = 0 for _ in range(num_feature_feat): feature_embeddings = [] feature_attn_model_list = self.config[attr + "_feature_embeddings"] for feature_attn_model_params in feature_attn_model_list: feature_embedding = ImageEmbedding( getattr(self, attr + "_feature_dim"), self.text_embeddings_out_dim, **feature_attn_model_params) feature_embeddings.append(feature_embedding) self.feature_embeddings_out_dim += feature_embedding.out_dim feature_embeddings = nn.ModuleList(feature_embeddings) feature_embeddings_list.append(feature_embeddings) self.feature_embeddings_out_dim *= getattr(self, attr + "_feature_dim") setattr(self, attr + "_feature_embeddings_out_dim", self.feature_embeddings_out_dim) del self.feature_embeddings_out_dim setattr( self, attr + "_feature_embeddings_list", nn.ModuleList(feature_embeddings_list), ) def _get_embeddings_attr(self, attr): embedding_attr1 = attr if hasattr(self, attr + "_embeddings_out_dim"): embedding_attr1 = attr + "_embeddings_out_dim" else: embedding_attr1 = attr + "_feature_embeddings_out_dim" return embedding_attr1 def _init_combine_layer(self, attr1, attr2): config_attr = attr1 + "_" + attr2 + "_modal_combine" multi_modal_combine_layer = ModalCombineLayer( self.config[config_attr]["type"], getattr(self, self._get_embeddings_attr(attr1)), getattr(self, self._get_embeddings_attr(attr2)), **self.config[config_attr]["params"]) setattr( self, attr1 + "_" + attr2 + "_multi_modal_combine_layer", multi_modal_combine_layer, ) def _init_classifier(self, combined_embedding_dim): # TODO: Later support multihead num_choices = registry.get(self._datasets[0] + "_num_final_outputs") self.classifier = ClassifierLayer( self.config["classifier"]["type"], in_dim=combined_embedding_dim, out_dim=num_choices, **self.config["classifier"]["params"]) def _init_extras(self): self.inter_model = None def get_optimizer_parameters(self, config): combine_layer = self.image_text_multi_modal_combine_layer params = [ { "params": self.word_embedding.parameters() }, { "params": self.image_feature_embeddings_list.parameters() }, { "params": self.text_embeddings.parameters() }, { "params": combine_layer.parameters() }, { "params": self.classifier.parameters() }, { "params": self.image_feature_encoders.parameters(), "lr": (config["optimizer_attributes"]["params"]["lr"] * 0.1), }, ] return params def _get_classifier_input_dim(self): return self.image_text_multi_modal_combine_layer.out_dim def process_text_embedding(self, sample_list, embedding_attr="text_embeddings", info=None): text_embeddings = [] # Get "text" attribute in case of "text_embeddings" case # and "context" attribute in case of "context_embeddings" texts = getattr(sample_list, embedding_attr.split("_")[0]) # Get embedding models text_embedding_models = getattr(self, embedding_attr) for text_embedding_model in text_embedding_models: # TODO: Move this logic inside if isinstance(text_embedding_model, PreExtractedEmbedding): embedding = text_embedding_model(sample_list.question_id) else: embedding = text_embedding_model(texts) text_embeddings.append(embedding) # # visualize decomposed question attention # image_id = getattr(sample_list, "image_id") # question_id = getattr(sample_list, "question_id").cpu() # question_id = question_id.numpy() # batch_size_t, _, _ = text_embeddings[0][7].shape # for cnt in range(0, batch_size_t): # # image_path_org = './save/temp_check/'+question_id[cnt]+'image_id.pdh' # # torch.save(image_id[cnt], image_path_org) # attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_o.pdh' # torch.save(text_embeddings[0][7][cnt], attn_path_org) # attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_oo.pdh' # torch.save(text_embeddings[0][8][cnt], attn_path_org) # attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_ot.pdh' # torch.save(text_embeddings[0][9][cnt], attn_path_org) # attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_t.pdh' # torch.save(text_embeddings[0][10][cnt], attn_path_org) # attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_tt.pdh' # torch.save(text_embeddings[0][11][cnt], attn_path_org) # attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_to.pdh' # torch.save(text_embeddings[0][12][cnt], attn_path_org) return text_embeddings[0][0], text_embeddings[0][1], text_embeddings[ 0][2], text_embeddings[0][3], text_embeddings[0][ 4], text_embeddings[0][5], text_embeddings[0][6] def process_feature_embedding(self, attr, sample_list, s_central, s_homo=None, s_hetero=None, pre_ques_embed=None, obj_feats=None, ocr_feats=None): """ parameters: input: attr: "image" or "context" sample_list: just sample_list s_central: question features for guiding purpose, torch.Size([128, 2048]) s_o/s_t s_homo: s_oo/s_tt s_hetero: s_ot/s_to output: """ # add obj bbox feats and image size batch, bbox_num, obj_feat_dim = obj_feats.shape _, _, ocr_feat_dim = ocr_feats.shape knn_k = 5 loc_dim = 5 # expand obj_feats temp_expand_obj_feat = obj_feats[0][0] temp_expand_obj_feat = temp_expand_obj_feat.expand( batch, 1, obj_feat_dim) * 0 temp_expand_obj_feat = torch.cat((obj_feats, temp_expand_obj_feat), 1) # expand ocr_feats temp_expand_ocr_feat = ocr_feats[0][0] temp_expand_ocr_feat = temp_expand_ocr_feat.expand( batch, 1, ocr_feat_dim) * 0 temp_expand_ocr_feat = torch.cat((ocr_feats, temp_expand_ocr_feat), 1) if attr == 'image': batch_size_t = (sample_list.get_batch_size()) # Get "image_feature_0" feature = getattr(sample_list, "{}_feature_{:d}".format(attr, 0), None) feature = feature[:batch_size_t] # Get info related to the current feature. info is generally # in key of format "image_info_0" for 0th feature feature_info = getattr(sample_list, "{}_info_{:d}".format(attr, 0), {}) # For Pythia, we need max_features to mask attention feature_dim = getattr(feature_info, "max_features", None) if feature_dim is not None: feature_dim = feature_dim[:batch_size_t] # Get feature embedding feature_embedding_model = getattr(self, attr + "_feature_embedding") encoded_feature = obj_feats batch, bbox_num, obj_feat_dim = encoded_feature.shape # obj_obj_edge_feature = None # oo edge generation obj_obj_edge_feature = torch.zeros( (batch, bbox_num, knn_k, obj_feat_dim + loc_dim)).float() obj_obj_edge_feature = obj_obj_edge_feature.cuda() oo_edge = getattr(getattr(sample_list, "ocr_bbox"), "edge_oo") oo_edgefeats = getattr(getattr(sample_list, "ocr_bbox"), "edge_oofeats") for i in range(batch): obj_obj_edge_feature[i] = torch.cat( (oo_edgefeats[i], temp_expand_obj_feat[i][oo_edge[i]]), 2) # obj_text_edge_feature = None # ot edge generation obj_text_edge_feature = torch.zeros( (batch, bbox_num, knn_k, ocr_feat_dim + loc_dim)).float() obj_text_edge_feature = obj_text_edge_feature.cuda() ot_edge = getattr(getattr(sample_list, "ocr_bbox"), "edge_ot") ot_edgefeats = getattr(getattr(sample_list, "ocr_bbox"), "edge_otfeats") for i in range(batch): obj_text_edge_feature[i] = torch.cat( (ot_edgefeats[i], temp_expand_ocr_feat[i][ot_edge[i]]), 2) oo_edge_feature = obj_obj_edge_feature ot_edge_feature = obj_text_edge_feature s_o, s_oo, s_ot = s_central, s_homo, s_hetero # for ablation study purpose, # o feature + oo relation + ot relation if (s_oo is not None) and (oo_edge_feature is not None) and ( s_ot is not None) and (ot_edge_feature is not None) and (pre_ques_embed is not None): inp = (attr, encoded_feature, s_o, feature_dim, s_oo, oo_edge_feature, s_ot, ot_edge_feature, pre_ques_embed) # o feature + oo relation elif (s_oo is not None) and (oo_edge_feature is not None) and (pre_ques_embed is not None): inp = (attr, encoded_feature, s_o, feature_dim, s_oo, oo_edge_feature, pre_ques_embed) # o feature + ot relation elif (s_ot is not None) and (ot_edge_feature is not None) and (pre_ques_embed is not None): inp = (attr, encoded_feature, s_o, feature_dim, s_ot, ot_edge_feature, pre_ques_embed) # o feature only else: inp = (attr, encoded_feature, s_o, feature_dim) g_o = feature_embedding_model(*inp) return g_o elif attr == 'context': batch_size_t = (sample_list.get_batch_size()) # Get "context_feature_0" feature = getattr(sample_list, "{}_feature_{:d}".format(attr, 0), None) feature = feature[:batch_size_t] # Get info related to the current feature. info is generally # in key of format "image_info_0" for 0th feature feature_info = getattr(sample_list, "{}_info_{:d}".format(attr, 0), {}) # For Pythia, we need max_features to mask attention feature_dim = getattr(feature_info, "max_features", None) if feature_dim is not None: feature_dim = feature_dim[:batch_size_t] # Get feature embedding feature_embedding_model = getattr(self, "context_feature_embedding") encoded_feature = ocr_feats batch, bbox_num, _ = encoded_feature.shape # text_text_edge_feature = None # tt edge generation text_text_edge_feature = torch.zeros( (batch, bbox_num, knn_k, ocr_feat_dim + loc_dim)).float() text_text_edge_feature = text_text_edge_feature.cuda() tt_edge = getattr(getattr(sample_list, "ocr_bbox"), "edge_tt") tt_edgefeats = getattr(getattr(sample_list, "ocr_bbox"), "edge_ttfeats") for i in range(batch): text_text_edge_feature[i] = torch.cat( (tt_edgefeats[i], temp_expand_ocr_feat[i][tt_edge[i]]), 2) # text_obj_edge_feature = None # to edge generation text_obj_edge_feature = torch.zeros( (batch, bbox_num, knn_k, obj_feat_dim + loc_dim)).float() text_obj_edge_feature = text_obj_edge_feature.cuda() to_edge = getattr(getattr(sample_list, "ocr_bbox"), "edge_to") to_edgefeats = getattr(getattr(sample_list, "ocr_bbox"), "edge_tofeats") for i in range(batch): text_obj_edge_feature[i] = torch.cat( (to_edgefeats[i], temp_expand_obj_feat[i][to_edge[i]]), 2) tt_edge_feature = text_text_edge_feature to_edge_feature = text_obj_edge_feature s_t, s_tt, s_to = s_central, s_homo, s_hetero # for ablation study purpose # t feature + tt relation + to relation if (s_tt is not None) and (tt_edge_feature is not None) and ( s_to is not None) and (to_edge_feature is not None) and (pre_ques_embed is not None): inp = (attr, encoded_feature, s_t, feature_dim, s_tt, tt_edge_feature, s_to, to_edge_feature, pre_ques_embed) # t feature + tt relation elif (s_tt is not None) and (tt_edge_feature is not None) and (pre_ques_embed is not None): inp = (attr, encoded_feature, s_t, feature_dim, s_tt, tt_edge_feature, pre_ques_embed) # t feature + to relation elif (s_to is not None) and (to_edge_feature is not None) and (pre_ques_embed is not None): inp = (attr, encoded_feature, s_t, feature_dim, s_to, to_edge_feature, pre_ques_embed) # t feature only else: inp = (attr, encoded_feature, s_t, feature_dim) g_t, updated_ocr = feature_embedding_model(*inp) return g_t, updated_ocr def combine_embeddings(self, *args): feature_names = args[0] feature_embeddings = args[1] layer = "_".join(feature_names) + "_multi_modal_combine_layer" return getattr(self, layer)(*feature_embeddings) def calculate_logits(self, joint_embedding, **kwargs): return self.classifier(joint_embedding) def forward(self, sample_list): sample_list.text = self.word_embedding(sample_list.text) text_embedding_total = self.process_text_embedding(sample_list) image_embedding_total, _ = self.process_feature_embedding( "image", sample_list, text_embedding_total) if self.inter_model is not None: image_embedding_total = self.inter_model(image_embedding_total) joint_embedding = self.combine_embeddings( ["image", "text"], [image_embedding_total, text_embedding_total]) model_output = {"scores": self.calculate_logits(joint_embedding)} return model_output
def build(self): self.mmt_config = BertConfig(**self.config.mmt) self.mmt = MMT(self.mmt_config) self.so_to_mmt_in = nn.Linear(3 * 1536, self.mmt_config.hidden_size) self.st_to_mmt_in = nn.Linear(3 * 1536, self.mmt_config.hidden_size) self.so_layer_norm = BertLayerNorm(self.mmt_config.hidden_size) self.st_layer_norm = BertLayerNorm(self.mmt_config.hidden_size) self.so_drop = nn.Dropout(0.1) self.st_drop = nn.Dropout(0.1) self.linear_go_to_mmt_in = nn.Linear(2048, self.mmt_config.hidden_size) self.linear_gt_to_mmt_in = nn.Linear(300, self.mmt_config.hidden_size) self.go_layer_norm = BertLayerNorm(self.mmt_config.hidden_size) self.gt_layer_norm = BertLayerNorm(self.mmt_config.hidden_size) self.go_drop = nn.Dropout(0.1) self.gt_drop = nn.Dropout(0.1) self.linear_updated_ocr_to_mmt_in = nn.Linear( 300, self.mmt_config.hidden_size) self.updated_ocr_layer_norm = BertLayerNorm( self.mmt_config.hidden_size) self.updated_ocr_drop = nn.Dropout(self.config.ocr.dropout_prob) self.linear_joint = nn.Linear(1536, 768) self.answer_processor = registry.get(self._datasets[0] + "_answer_processor") self.ocr_ptr_net = OcrPtrNet(**self.config.classifier.ocr_ptr_net) # modules requiring custom learning rates (usually for finetuning) self.finetune_modules = [] self._build_txt_encoding() self._build_obj_encoding() self._build_ocr_encoding() self._init_text_embeddings("text") # init feature embedding for "image" setattr(self, "image_feature_dim", self.config["image_feature_dim"]) self.feature_embeddings_out_dim = 0 feature_attn_model_params = self.config["image_feature_embeddings"][0] feature_embedding = ImageEmbedding(getattr(self, "image_feature_dim"), self.text_embeddings_out_dim, **feature_attn_model_params) self.feature_embeddings_out_dim += feature_embedding.out_dim self.feature_embeddings_out_dim *= getattr(self, "image_feature_dim") setattr(self, "image_feature_embeddings_out_dim", self.feature_embeddings_out_dim) del self.feature_embeddings_out_dim setattr(self, "image_feature_embedding", feature_embedding) # init feature embedding for "context" setattr(self, "context_feature_dim", self.config["context_feature_dim"]) self.feature_embeddings_out_dim = 0 feature_attn_model_params = self.config["context_feature_embeddings"][ 0] feature_embedding = ImageEmbedding( getattr(self, "context_feature_dim"), self.text_embeddings_out_dim, **feature_attn_model_params) self.feature_embeddings_out_dim += feature_embedding.out_dim self.feature_embeddings_out_dim *= getattr(self, "context_feature_dim") setattr(self, "context_feature_embeddings_out_dim", self.feature_embeddings_out_dim) del self.feature_embeddings_out_dim setattr(self, "context_feature_embedding", feature_embedding) self._init_combine_layer("image", "text") num_choices = registry.get(self._datasets[0] + "_num_final_outputs") self.classifier = ClassifierLayer( self.config["classifier"]["type"], in_dim=768, out_dim=num_choices - 50, **self.config["classifier"]["params"])