예제 #1
0
    def _init_classifier(self, combined_embedding_dim):
        # TODO: Later support multihead
        num_choices = registry.get(self._datasets[0] + "_num_final_outputs")

        self.classifier = ClassifierLayer(
            self.config["classifier"]["type"],
            in_dim=combined_embedding_dim,
            out_dim=num_choices,
            **self.config["classifier"]["params"])
예제 #2
0
    def _build_output(self):
        # dynamic OCR-copying scores with pointer network
        self.ocr_ptr_net = OcrPtrNet(**self.config.classifier.ocr_ptr_net)

        # fixed answer vocabulary scores
        num_choices = registry.get(self._datasets[0] + "_num_final_outputs")
        # remove the OCR copying dimensions in LoRRA's classifier output
        # (OCR copying will be handled separately)
        num_choices -= self.config.classifier.ocr_max_num
        self.classifier = ClassifierLayer(
            self.config["classifier"]["type"],
            in_dim=self.mmt_config.hidden_size,
            out_dim=num_choices,
            **self.config["classifier"]["params"])

        self.answer_processor = registry.get(self._datasets[0] +
                                             "_answer_processor")
예제 #3
0
파일: butd.py 프로젝트: tajalagawani/PyTaj
 def _init_classifier(self):
     self.classifier = ClassifierLayer(
         self.config["classifier"]["type"],
         in_dim=self.config["classifier"]["params"]["feature_dim"],
         out_dim=self.vocab_size,
         **self.config["classifier"]["params"])
예제 #4
0
파일: butd.py 프로젝트: tajalagawani/PyTaj
class BUTD(Pythia):
    def __init__(self, config):
        super().__init__(config)

    def build(self):
        self._build_word_embedding()
        self._init_feature_encoders("image")
        self._init_feature_embeddings("image")
        self._init_classifier()
        self._init_extras()

    def _build_word_embedding(self):
        self.text_processor = registry.get(self._datasets[0] +
                                           "_text_processor")
        self.vocab = self.text_processor.vocab
        self.vocab_size = self.vocab.get_size()
        self.word_embedding = self.vocab.get_embedding(
            torch.nn.Embedding, embedding_dim=self.config["embedding_dim"])
        setattr(self, "text_embeddings_out_dim", self.config["embedding_dim"])

    def _init_classifier(self):
        self.classifier = ClassifierLayer(
            self.config["classifier"]["type"],
            in_dim=self.config["classifier"]["params"]["feature_dim"],
            out_dim=self.vocab_size,
            **self.config["classifier"]["params"])

    def get_optimizer_parameters(self, config):
        params = [
            {
                "params": self.word_embedding.parameters()
            },
            {
                "params": self.image_feature_embeddings_list.parameters()
            },
            {
                "params": self.classifier.parameters()
            },
            {
                "params": self.image_feature_encoders.parameters(),
                "lr": (config["optimizer_attributes"]["params"]["lr"] * 0.1),
            },
        ]
        return params

    def prepare_data(self, sample_list, batch_size):
        setattr(self, "teacher_forcing", hasattr(sample_list, "text"))
        data = {}
        if self.teacher_forcing:
            caption_lengths, sort_ind = sample_list.caption_len.sort(
                dim=0, descending=True)
            data["decode_lengths"] = (caption_lengths - 1).tolist()
            sample_list.text = sample_list.text[sort_ind]
            sample_list.answers = sample_list.answers[sort_ind]
            sample_list.image_feature_0 = sample_list.image_feature_0[sort_ind]
            data["texts"] = sample_list.text
            timesteps = max(data["decode_lengths"])
            sample_list.add_field("targets", sample_list.text[:, 1:])
        else:
            data["texts"] = sample_list.answers.new_full((batch_size, 1),
                                                         self.vocab.SOS_INDEX,
                                                         dtype=torch.long)
            timesteps = self.text_processor.max_length
            sample_list.add_field("targets", sample_list.answers[:, 0, 1:])
        return data, sample_list, timesteps

    def init_hidden_state(self, features):
        h = features.new_zeros(
            (features.size(0),
             self.config["classifier"]["params"]["hidden_dim"]),
            dtype=torch.float,
        )
        c = features.new_zeros(
            (features.size(0),
             self.config["classifier"]["params"]["hidden_dim"]),
            dtype=torch.float,
        )
        return h, c

    def get_data_t(self, t, data, batch_size_t, prev_output):
        if self.teacher_forcing:
            # Modify batch_size for timestep t
            batch_size_t = sum([l > t for l in data["decode_lengths"]])
        elif prev_output is not None and self.config["inference"][
                "type"] == "greedy":
            # Adding t-1 output words to data["text"] for greedy decoding
            output_softmax = torch.log_softmax(prev_output, dim=1)
            _, indices = torch.max(output_softmax, dim=1, keepdim=True)
            data["texts"] = torch.cat(
                (data["texts"], indices.view(batch_size_t, 1)), dim=1)

        # Slice data based on batch_size at timestep t
        data["texts"] = data["texts"][:batch_size_t]
        if "state" in data:
            h1 = data["state"]["td_hidden"][0][:batch_size_t]
            c1 = data["state"]["td_hidden"][1][:batch_size_t]
            h2 = data["state"]["lm_hidden"][0][:batch_size_t]
            c2 = data["state"]["lm_hidden"][1][:batch_size_t]
        else:
            h1, c1 = self.init_hidden_state(data["texts"])
            h2, c2 = self.init_hidden_state(data["texts"])
        data["state"] = {"td_hidden": (h1, c1), "lm_hidden": (h2, c2)}
        registry.register("{}_lstm_state".format(h1.device), data["state"])

        return data, batch_size_t

    def forward(self, sample_list):
        # Stores the output probabilites. Not used if beam_search inference
        scores = sample_list.answers.new_ones(
            (
                sample_list.answers.size(0),
                self.text_processor.max_length,
                self.vocab_size,
            ),
            dtype=torch.float,
        )

        # For beam search inference. Currently beam seach for BUTD works only
        # with batch_size = 1 and should be used with run_type inference only.
        # TODO : Implement batch beam search
        if self.config["inference"]["type"] == "beam_search":
            beam_search = BeamSearch(
                self.vocab, self.config["inference"]["params"]["beam_length"])
            sample_list = beam_search.init_batch(sample_list)

        batch_size = sample_list.image_feature_0.size(0)
        data, sample_list, timesteps = self.prepare_data(
            sample_list, batch_size)
        output = None
        batch_size_t = batch_size
        for t in range(timesteps):
            data, batch_size_t = self.get_data_t(t, data, batch_size_t, output)
            if self.config["inference"]["type"] == "beam_search":
                pi_t = data["texts"]
            else:
                pi_t = data["texts"][:, t].unsqueeze(-1)
            embedding = self.word_embedding(pi_t)
            attention_feature, _ = self.process_feature_embedding(
                "image",
                sample_list,
                embedding[:, 0, :],
                batch_size_t=batch_size_t)
            output = self.classifier(attention_feature)

            # Compute Beam Search decoding
            if self.config["inference"]["type"] == "beam_search":
                finish, data, batch_size_t = beam_search.search(
                    t, data, output)
                if finish:
                    break
            else:
                scores[:batch_size_t, t] = output

        model_output = {"scores": scores}
        if self.config["inference"]["type"] == "beam_search":
            model_output["captions"] = beam_search.best_score()

        return model_output
예제 #5
0
class Pythia(BaseModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self._global_config = registry.get("config")
        self._datasets = self._global_config.datasets.split(",")

    def build(self):
        self._build_word_embedding()
        self._init_text_embeddings("text")
        self._init_feature_encoders("image")
        self._init_feature_embeddings("image")
        self._init_combine_layer("image", "text")
        self._init_classifier(self._get_classifier_input_dim())
        self._init_extras()

    def _build_word_embedding(self):
        assert len(self._datasets) > 0
        text_processor = registry.get(self._datasets[0] + "_text_processor")
        vocab = text_processor.vocab
        self.word_embedding = vocab.get_embedding(torch.nn.Embedding,
                                                  embedding_dim=300)

    def _init_text_embeddings(self, attr="text"):
        if "embeddings" not in attr:
            attr += "_embeddings"

        text_embeddings = []
        text_embeddings_list_config = self.config[attr]

        embeddings_out_dim = 0

        for text_embedding in text_embeddings_list_config:
            embedding_type = text_embedding.type
            embedding_kwargs = ConfigNode(text_embedding.params)

            self._update_text_embedding_args(embedding_kwargs)

            embedding = TextEmbedding(embedding_type, **embedding_kwargs)

            text_embeddings.append(embedding)
            embeddings_out_dim += embedding.text_out_dim

        setattr(self, attr + "_out_dim", embeddings_out_dim)
        setattr(self, attr, nn.ModuleList(text_embeddings))

    def _update_text_embedding_args(self, args):
        # Add model_data_dir to kwargs
        args["model_data_dir"] = self.config["model_data_dir"]

    def _init_feature_encoders(self, attr):
        feat_encoders = []
        feat_encoders_list_config = self.config[attr + "_feature_encodings"]
        feature_dim = self.config[attr + "_feature_dim"]
        setattr(self, attr + "_feature_dim", feature_dim)

        for feat_encoder in feat_encoders_list_config:
            encoder_type = feat_encoder["type"]
            encoder_kwargs = feat_encoder["params"]
            encoder_kwargs["model_data_dir"] = self.config["model_data_dir"]

            feat_model = ImageEncoder(encoder_type, feature_dim,
                                      **encoder_kwargs)

            feat_encoders.append(feat_model)
            setattr(self, attr + "_feature_dim", feat_model.out_dim)

        setattr(self, attr + "_feature_encoders", nn.ModuleList(feat_encoders))

    def _init_feature_embeddings(self, attr):
        feature_embeddings_list = []
        num_feature_feat = len(
            getattr(self.config, "{}_feature_encodings".format(attr)))

        self.feature_embeddings_out_dim = 0

        for _ in range(num_feature_feat):
            feature_embeddings = []
            feature_attn_model_list = self.config[attr + "_feature_embeddings"]

            for feature_attn_model_params in feature_attn_model_list:
                feature_embedding = ImageEmbedding(
                    getattr(self, attr + "_feature_dim"),
                    self.text_embeddings_out_dim, **feature_attn_model_params)
                feature_embeddings.append(feature_embedding)
                self.feature_embeddings_out_dim += feature_embedding.out_dim

            feature_embeddings = nn.ModuleList(feature_embeddings)
            feature_embeddings_list.append(feature_embeddings)

        self.feature_embeddings_out_dim *= getattr(self, attr + "_feature_dim")

        setattr(self, attr + "_feature_embeddings_out_dim",
                self.feature_embeddings_out_dim)
        del self.feature_embeddings_out_dim
        setattr(
            self,
            attr + "_feature_embeddings_list",
            nn.ModuleList(feature_embeddings_list),
        )

    def _get_embeddings_attr(self, attr):
        embedding_attr1 = attr
        if hasattr(self, attr + "_embeddings_out_dim"):
            embedding_attr1 = attr + "_embeddings_out_dim"
        else:
            embedding_attr1 = attr + "_feature_embeddings_out_dim"

        return embedding_attr1

    def _init_combine_layer(self, attr1, attr2):
        config_attr = attr1 + "_" + attr2 + "_modal_combine"

        multi_modal_combine_layer = ModalCombineLayer(
            self.config[config_attr]["type"],
            getattr(self, self._get_embeddings_attr(attr1)),
            getattr(self, self._get_embeddings_attr(attr2)),
            **self.config[config_attr]["params"])

        setattr(
            self,
            attr1 + "_" + attr2 + "_multi_modal_combine_layer",
            multi_modal_combine_layer,
        )

    def _init_classifier(self, combined_embedding_dim):
        # TODO: Later support multihead
        num_choices = registry.get(self._datasets[0] + "_num_final_outputs")

        self.classifier = ClassifierLayer(
            self.config["classifier"]["type"],
            in_dim=combined_embedding_dim,
            out_dim=num_choices,
            **self.config["classifier"]["params"])

    def _init_extras(self):
        self.inter_model = None

    def get_optimizer_parameters(self, config):
        combine_layer = self.image_text_multi_modal_combine_layer
        params = [
            {
                "params": self.word_embedding.parameters()
            },
            {
                "params": self.image_feature_embeddings_list.parameters()
            },
            {
                "params": self.text_embeddings.parameters()
            },
            {
                "params": combine_layer.parameters()
            },
            {
                "params": self.classifier.parameters()
            },
            {
                "params": self.image_feature_encoders.parameters(),
                "lr": (config["optimizer_attributes"]["params"]["lr"] * 0.1),
            },
        ]

        return params

    def _get_classifier_input_dim(self):
        return self.image_text_multi_modal_combine_layer.out_dim

    def process_text_embedding(self,
                               sample_list,
                               embedding_attr="text_embeddings",
                               info=None):
        text_embeddings = []

        # Get "text" attribute in case of "text_embeddings" case
        # and "context" attribute in case of "context_embeddings"
        texts = getattr(sample_list, embedding_attr.split("_")[0])

        # Get embedding models
        text_embedding_models = getattr(self, embedding_attr)

        for text_embedding_model in text_embedding_models:
            # TODO: Move this logic inside
            if isinstance(text_embedding_model, PreExtractedEmbedding):
                embedding = text_embedding_model(sample_list.question_id)
            else:
                embedding = text_embedding_model(texts)
            text_embeddings.append(embedding)

        text_embeddding_total = torch.cat(text_embeddings, dim=1)

        return text_embeddding_total

    def process_feature_embedding(self,
                                  attr,
                                  sample_list,
                                  text_embedding_total,
                                  extra=[],
                                  batch_size_t=None):
        feature_embeddings = []
        feature_attentions = []
        features = []
        batch_size_t = (sample_list.get_batch_size()
                        if batch_size_t is None else batch_size_t)

        # Convert list of keys to the actual values
        extra = sample_list.get_fields(extra)

        feature_idx = 0

        # Get all of the features, which are in the form, "image_feature_0"
        # "image_feature_1" ...
        while True:
            feature = getattr(sample_list,
                              "{}_feature_{:d}".format(attr,
                                                       feature_idx), None)
            if feature is None:
                break
            feature_idx += 1
            feature = feature[:batch_size_t]
            features.append(feature)

        feature_encoders = getattr(self, attr + "_feature_encoders")
        # Each feature should have a separate image feature encoders
        assert len(features) == len(feature_encoders), (
            "Number of feature encoders, {} are not equal "
            "to number of features, {}.".format(len(feature_encoders),
                                                len(features)))

        # Now, iterate to get final attended image features
        for i, feature in enumerate(features):
            # Get info related to the current feature. info is generally
            # in key of format "image_info_0" for 0th feature
            feature_info = getattr(sample_list, "{}_info_{:d}".format(attr, i),
                                   {})
            # For Pythia, we need max_features to mask attention
            feature_dim = getattr(feature_info, "max_features", None)
            if feature_dim is not None:
                feature_dim = feature_dim[:batch_size_t]

            # Attribute in which encoders are saved, for "image" it
            # will be "image_feature_encoders", other example is
            # "context_feature_encoders"
            encoders_attr = attr + "_feature_encoders"
            feature_encoder = getattr(self, encoders_attr)[i]

            # Encode the features
            encoded_feature = feature_encoder(feature)

            # Get all of the feature embeddings
            list_attr = attr + "_feature_embeddings_list"
            feature_embedding_models = getattr(self, list_attr)[i]

            # Forward through these embeddings one by one
            for feature_embedding_model in feature_embedding_models:
                inp = (encoded_feature, text_embedding_total, feature_dim,
                       extra)

                embedding, attention = feature_embedding_model(*inp)
                feature_embeddings.append(embedding)
                feature_attentions.append(attention.squeeze(-1))

        # Concatenate all features embeddings and return along with attention
        feature_embedding_total = torch.cat(feature_embeddings, dim=1)
        return feature_embedding_total, feature_attentions

    def combine_embeddings(self, *args):
        feature_names = args[0]
        feature_embeddings = args[1]

        layer = "_".join(feature_names) + "_multi_modal_combine_layer"
        return getattr(self, layer)(*feature_embeddings)

    def calculate_logits(self, joint_embedding, **kwargs):
        return self.classifier(joint_embedding)

    def forward(self, sample_list):
        sample_list.text = self.word_embedding(sample_list.text)
        text_embedding_total = self.process_text_embedding(sample_list)

        image_embedding_total, _ = self.process_feature_embedding(
            "image", sample_list, text_embedding_total)

        if self.inter_model is not None:
            image_embedding_total = self.inter_model(image_embedding_total)

        joint_embedding = self.combine_embeddings(
            ["image", "text"], [image_embedding_total, text_embedding_total])

        model_output = {"scores": self.calculate_logits(joint_embedding)}

        return model_output
예제 #6
0
class Pythia(BaseModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self._global_config = registry.get("config")
        self._datasets = self._global_config.datasets.split(",")

    def build(self):
        self._build_word_embedding()
        self._init_text_embeddings("text")
        self._init_feature_encoders("image")
        self._init_feature_embeddings("image")
        self._init_combine_layer("image", "text")
        self._init_classifier(self._get_classifier_input_dim())
        self._init_extras()

    def _build_word_embedding(self):
        assert len(self._datasets) > 0
        text_processor = registry.get(self._datasets[0] + "_text_processor")
        vocab = text_processor.vocab
        self.word_embedding = vocab.get_embedding(torch.nn.Embedding,
                                                  embedding_dim=300)

    def _init_text_embeddings(self, attr="text"):
        if "embeddings" not in attr:
            attr += "_embeddings"

        text_embeddings = []
        text_embeddings_list_config = self.config[attr]

        embeddings_out_dim = 0

        for text_embedding in text_embeddings_list_config:
            embedding_type = text_embedding.type
            embedding_kwargs = ConfigNode(text_embedding.params)

            self._update_text_embedding_args(embedding_kwargs)

            embedding = TextEmbedding(embedding_type, **embedding_kwargs)

            text_embeddings.append(embedding)
            embeddings_out_dim += embedding.text_out_dim

        setattr(self, attr + "_out_dim", embeddings_out_dim)
        setattr(self, attr, nn.ModuleList(text_embeddings))

    def _update_text_embedding_args(self, args):
        # Add model_data_dir to kwargs
        args["model_data_dir"] = self.config["model_data_dir"]

    def _init_feature_encoders(self, attr):
        feat_encoders = []
        feat_encoders_list_config = self.config[attr + "_feature_encodings"]
        feature_dim = self.config[attr + "_feature_dim"]
        setattr(self, attr + "_feature_dim", feature_dim)

        for feat_encoder in feat_encoders_list_config:
            encoder_type = feat_encoder["type"]
            encoder_kwargs = feat_encoder["params"]
            encoder_kwargs["model_data_dir"] = self.config["model_data_dir"]

            feat_model = ImageEncoder(encoder_type, feature_dim,
                                      **encoder_kwargs)

            feat_encoders.append(feat_model)
            setattr(self, attr + "_feature_dim", feat_model.out_dim)

        setattr(self, attr + "_feature_encoders", nn.ModuleList(feat_encoders))

    def _init_feature_embeddings(self, attr):
        feature_embeddings_list = []
        num_feature_feat = len(
            getattr(self.config, "{}_feature_encodings".format(attr)))

        self.feature_embeddings_out_dim = 0

        for _ in range(num_feature_feat):
            feature_embeddings = []
            feature_attn_model_list = self.config[attr + "_feature_embeddings"]

            for feature_attn_model_params in feature_attn_model_list:
                feature_embedding = ImageEmbedding(
                    getattr(self, attr + "_feature_dim"),
                    self.text_embeddings_out_dim, **feature_attn_model_params)
                feature_embeddings.append(feature_embedding)
                self.feature_embeddings_out_dim += feature_embedding.out_dim

            feature_embeddings = nn.ModuleList(feature_embeddings)
            feature_embeddings_list.append(feature_embeddings)

        self.feature_embeddings_out_dim *= getattr(self, attr + "_feature_dim")

        setattr(self, attr + "_feature_embeddings_out_dim",
                self.feature_embeddings_out_dim)
        del self.feature_embeddings_out_dim
        setattr(
            self,
            attr + "_feature_embeddings_list",
            nn.ModuleList(feature_embeddings_list),
        )

    def _get_embeddings_attr(self, attr):
        embedding_attr1 = attr
        if hasattr(self, attr + "_embeddings_out_dim"):
            embedding_attr1 = attr + "_embeddings_out_dim"
        else:
            embedding_attr1 = attr + "_feature_embeddings_out_dim"

        return embedding_attr1

    def _init_combine_layer(self, attr1, attr2):
        config_attr = attr1 + "_" + attr2 + "_modal_combine"

        multi_modal_combine_layer = ModalCombineLayer(
            self.config[config_attr]["type"],
            getattr(self, self._get_embeddings_attr(attr1)),
            getattr(self, self._get_embeddings_attr(attr2)),
            **self.config[config_attr]["params"])

        setattr(
            self,
            attr1 + "_" + attr2 + "_multi_modal_combine_layer",
            multi_modal_combine_layer,
        )

    def _init_classifier(self, combined_embedding_dim):
        # TODO: Later support multihead
        num_choices = registry.get(self._datasets[0] + "_num_final_outputs")

        self.classifier = ClassifierLayer(
            self.config["classifier"]["type"],
            in_dim=combined_embedding_dim,
            out_dim=num_choices,
            **self.config["classifier"]["params"])

    def _init_extras(self):
        self.inter_model = None

    def get_optimizer_parameters(self, config):
        combine_layer = self.image_text_multi_modal_combine_layer
        params = [
            {
                "params": self.word_embedding.parameters()
            },
            {
                "params": self.image_feature_embeddings_list.parameters()
            },
            {
                "params": self.text_embeddings.parameters()
            },
            {
                "params": combine_layer.parameters()
            },
            {
                "params": self.classifier.parameters()
            },
            {
                "params": self.image_feature_encoders.parameters(),
                "lr": (config["optimizer_attributes"]["params"]["lr"] * 0.1),
            },
        ]

        return params

    def _get_classifier_input_dim(self):
        return self.image_text_multi_modal_combine_layer.out_dim

    def process_text_embedding(self,
                               sample_list,
                               embedding_attr="text_embeddings",
                               info=None):
        text_embeddings = []
        #pdb.set_trace()

        # Get "text" attribute in case of "text_embeddings" case
        # and "context" attribute in case of "context_embeddings"
        if not info:
            texts = getattr(sample_list, embedding_attr.split("_")[0])
        elif info == "sub_question":
            texts = getattr(sample_list, embedding_attr.split("_")[0] + '_sq')
        elif info == "other_question":
            texts = getattr(sample_list, embedding_attr.split("_")[0] + '_oq')

        # Get embedding models
        text_embedding_models = getattr(self, embedding_attr)

        for text_embedding_model in text_embedding_models:
            # TODO: Move this logic inside
            if isinstance(text_embedding_model, PreExtractedEmbedding):
                embedding = text_embedding_model(sample_list.question_id)
            else:
                embedding = text_embedding_model(texts)
            text_embeddings.append(embedding)

        text_embeddding_total = torch.cat(text_embeddings, dim=1)

        return text_embeddding_total

    def process_feature_embedding(self,
                                  attr,
                                  sample_list,
                                  text_embedding_total,
                                  extra=[],
                                  batch_size_t=None):
        feature_embeddings = []
        feature_attentions = []
        features = []
        batch_size_t = (sample_list.get_batch_size()
                        if batch_size_t is None else batch_size_t)

        # Convert list of keys to the actual values
        extra = sample_list.get_fields(extra)

        feature_idx = 0

        # Get all of the features, which are in the form, "image_feature_0"
        # "image_feature_1" ...
        while True:
            feature = getattr(sample_list,
                              "{}_feature_{:d}".format(attr,
                                                       feature_idx), None)
            if feature is None:
                break
            feature_idx += 1
            feature = feature[:batch_size_t]
            features.append(feature)

        feature_encoders = getattr(self, attr + "_feature_encoders")
        # Each feature should have a separate image feature encoders
        assert len(features) == len(feature_encoders), (
            "Number of feature encoders, {} are not equal "
            "to number of features, {}.".format(len(feature_encoders),
                                                len(features)))

        # Now, iterate to get final attended image features
        for i, feature in enumerate(features):
            # Get info related to the current feature. info is generally
            # in key of format "image_info_0" for 0th feature
            feature_info = getattr(sample_list, "{}_info_{:d}".format(attr, i),
                                   {})
            # For Pythia, we need max_features to mask attention
            feature_dim = getattr(feature_info, "max_features", None)
            if feature_dim is not None:
                feature_dim = feature_dim[:batch_size_t]

            # Attribute in which encoders are saved, for "image" it
            # will be "image_feature_encoders", other example is
            # "context_feature_encoders"
            encoders_attr = attr + "_feature_encoders"
            feature_encoder = getattr(self, encoders_attr)[i]

            # Encode the features
            encoded_feature = feature_encoder(feature)

            # Get all of the feature embeddings
            list_attr = attr + "_feature_embeddings_list"
            feature_embedding_models = getattr(self, list_attr)[i]

            # Forward through these embeddings one by one
            for feature_embedding_model in feature_embedding_models:
                inp = (encoded_feature, text_embedding_total, feature_dim,
                       extra)

                embedding, attention = feature_embedding_model(*inp)
                feature_embeddings.append(embedding)
                feature_attentions.append(attention.squeeze(-1))

        # Concatenate all features embeddings and return along with attention
        feature_embedding_total = torch.cat(feature_embeddings, dim=1)
        return feature_embedding_total, feature_attentions

    def combine_embeddings(self, *args):
        feature_names = args[0]
        feature_embeddings = args[1]

        layer = "_".join(feature_names) + "_multi_modal_combine_layer"
        layer_model = getattr(self, layer)
        joint_embeddings = layer_model(*feature_embeddings)
        if args[2] == "main":
            self.question_embedding = layer_model.question_embedding
        elif args[2] == "sub_question":
            self.question_embedding_sq = layer_model.question_embedding
        elif args[2] == "other_question":
            self.question_embedding_oq = layer_model.question_embedding
        #pdb.set_trace()
        #self.combine_layer = self.layer
        #joint_embedding = self.combine_layer(feature_embeddings)
        #pdb.set_trace()
        return joint_embeddings
        #return getattr(self, layer)(*feature_embeddings)

    def calculate_logits(self, joint_embedding, **kwargs):
        return self.classifier(joint_embedding)

    def compute_grad_cam(self, sample_list, model_output, question=None):
        #pdb.set_trace()
        #pdb.set_trace()
        if question == "main":
            #self.importance_vectors_reas = []
            scores = model_output['scores']
            classes = sample_list['gt_answer_index']
            classes_one_hot = torch.zeros_like(scores)
            classes_one_hot[range(classes_one_hot.shape[0]), classes] = 1
            #grads = torch.autograd.grad(outputs = scores, inputs = self.joint_embedding, grad_outputs = classes_one_hot, create_graph=True)[0].to(self.device)
            grads = torch.autograd.grad(outputs=scores,
                                        inputs=self.joint_embedding,
                                        grad_outputs=classes_one_hot,
                                        create_graph=True)[0]
            importance_vectors_cam = grads * self.joint_embedding
            #self.importance_vectors_reas.append(self.question_embedding)
            #pdb.set_trace()
            self.importance_vectors_reas = importance_vectors_cam
            #self.importance_vectors_reas.append(torch.cat((importance_vectors_cam, self.question_embedding), 1))
        elif question == "sq":
            #self.importance_vectors_sq = []
            scores = model_output['scores_sq']
            classes = sample_list['gt_answer_index_sq']
            classes_one_hot = torch.zeros_like(scores)
            classes_one_hot[range(classes_one_hot.shape[0]), classes] = 1
            #grads = torch.autograd.grad(outputs = scores, inputs = self.joint_embedding_sq, grad_outputs = classes_one_hot, create_graph=True)[0].to(self.device)
            grads = torch.autograd.grad(outputs=scores,
                                        inputs=self.joint_embedding_sq,
                                        grad_outputs=classes_one_hot,
                                        create_graph=True)[0]
            importance_vectors_cam = grads * self.joint_embedding_sq
            #self.importance_vectors_sq.append(self.question_embedding_sq)
            self.importance_vectors_sq = importance_vectors_cam
            #self.importance_vectors_sq.append(torch.cat((importance_vectors_cam, self.question_embedding_sq), 1))
        elif question == "oq":
            #self.importance_vectors_oq = []
            scores = model_output['scores_oq']
            classes = sample_list['gt_answer_index_oq']
            classes_one_hot = torch.zeros_like(scores)
            classes_one_hot[range(classes_one_hot.shape[0]), classes] = 1
            #grads = torch.autograd.grad(outputs = scores, inputs = self.joint_embedding_oq, grad_outputs = classes_one_hot, create_graph=True)[0].to(self.device)
            grads = torch.autograd.grad(outputs=scores,
                                        inputs=self.joint_embedding_oq,
                                        grad_outputs=classes_one_hot,
                                        create_graph=True)[0]
            importance_vectors_cam = grads * self.joint_embedding_oq
            #self.importance_vectors_oq.append(self.question_embedding_oq)
            self.importance_vectors_oq = importance_vectors_cam
            #self.importance_vectors_oq.append(torch.cat((importance_vectors_cam, self.question_embedding_oq), 1))

    def cosine_distance(self, vec_1, vec_2):
        batched_distance_vector = []
        cos_similarity = nn.CosineSimilarity(dim=1, eps=1e-6)
        for i in range(vec_1.shape[0]):
            norm_vec_1 = vec_1[i] / torch.max(vec_1[i])
            norm_vec_2 = vec_2[i] / torch.max(vec_2[i])
            distance = 1 - cos_similarity(norm_vec_1.unsqueeze(0),
                                          norm_vec_2.unsqueeze(0))
            batched_distance_vector.append(distance)
        return torch.cat(batched_distance_vector)

    def compute_distances(self, sample_list, model_output):
        model_output['distance_reas_sub'] = self.cosine_distance(
            self.importance_vectors_reas, self.importance_vectors_sq)
        model_output['distance_reas_other'] = self.cosine_distance(
            self.importance_vectors_reas, self.importance_vectors_oq)

    def forward(self, sample_list):
        # Compute the scores for the reasoning question
        sample_list.text = self.word_embedding(sample_list.text)
        text_embedding_total = self.process_text_embedding(sample_list)

        image_embedding_total, _ = self.process_feature_embedding(
            "image", sample_list, text_embedding_total)

        if self.inter_model is not None:
            image_embedding_total = self.inter_model(image_embedding_total)

        joint_embedding = self.combine_embeddings(
            ["image", "text"], [image_embedding_total, text_embedding_total],
            "main")
        #pdb.set_trace()

        self.joint_embedding = joint_embedding

        model_output = {"scores": self.calculate_logits(joint_embedding)}

        # Compute the scores for the sub-question

        sample_list.text_sq = self.word_embedding(sample_list.text_sq)
        text_embedding_total = self.process_text_embedding(sample_list,
                                                           info="sub_question")
        image_embedding_total, _ = self.process_feature_embedding(
            "image", sample_list, text_embedding_total)
        joint_embedding_sq = self.combine_embeddings(
            ["image", "text"], [image_embedding_total, text_embedding_total],
            "sub_question")
        self.joint_embedding_sq = joint_embedding_sq
        model_output["scores_sq"] = self.calculate_logits(joint_embedding_sq)

        sample_list.text_oq = self.word_embedding(sample_list.text_oq)
        text_embedding_total = self.process_text_embedding(
            sample_list, info="other_question")
        image_embedding_total, _ = self.process_feature_embedding(
            "image", sample_list, text_embedding_total)
        joint_embedding_oq = self.combine_embeddings(
            ["image", "text"], [image_embedding_total, text_embedding_total],
            "other_question")
        self.joint_embedding_oq = joint_embedding_oq
        model_output["scores_oq"] = self.calculate_logits(joint_embedding_oq)
        self.compute_grad_cam(sample_list, model_output, question="main")
        self.compute_grad_cam(sample_list, model_output, question="sq")
        self.compute_grad_cam(sample_list, model_output, question="oq")

        self.compute_distances(sample_list, model_output)
        #self.compute_grad_cam()
        #pdb.set_trace()

        #image_embedding_total, _ = self.process_feature_embedding(
        #    "image", sample_list, text_embedding_total
        #)

        #if self.inter_model is not None:
        #    image_embedding_total = self.inter_model(image_embedding_total)

        #joint_embedding = self.combine_embeddings(
        #    ["image", "text"], [image_embedding_total, text_embedding_total]
        #)

        #self.joint_embedding = joint_embedding

        #model_output = {"scores": self.calculate_logits(joint_embedding)}

        return model_output
예제 #7
0
class Pythia(BaseModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self._global_config = registry.get("config")
        self._datasets = self._global_config.datasets.split(",")

    def build(self):
        self._build_word_embedding()
        self._init_text_embeddings("text")
        self._init_feature_encoders("image")
        self._init_feature_embeddings("image")
        self._init_combine_layer("image", "text")
        self._init_classifier(self._get_classifier_input_dim())
        self._init_extras()

    def _build_word_embedding(self):
        assert len(self._datasets) > 0
        text_processor = registry.get(self._datasets[0] + "_text_processor")
        vocab = text_processor.vocab
        self.word_embedding = vocab.get_embedding(torch.nn.Embedding,
                                                  embedding_dim=300)

    def _init_text_embeddings(self, attr="text"):
        if "embeddings" not in attr:
            attr += "_embeddings"

        text_embeddings = []
        text_embeddings_list_config = self.config[attr]

        embeddings_out_dim = 0

        for text_embedding in text_embeddings_list_config:
            embedding_type = text_embedding.type
            embedding_kwargs = ConfigNode(text_embedding.params)

            self._update_text_embedding_args(embedding_kwargs)

            embedding = TextEmbedding(embedding_type, **embedding_kwargs)

            text_embeddings.append(embedding)
            embeddings_out_dim += embedding.text_out_dim

        setattr(self, attr + "_out_dim", embeddings_out_dim)
        setattr(self, attr, nn.ModuleList(text_embeddings))

    def _update_text_embedding_args(self, args):
        # Add model_data_dir to kwargs
        args["model_data_dir"] = self.config["model_data_dir"]

    def _init_feature_encoders(self, attr):
        feat_encoders = []
        feat_encoders_list_config = self.config[attr + "_feature_encodings"]
        feature_dim = self.config[attr + "_feature_dim"]
        setattr(self, attr + "_feature_dim", feature_dim)

        for feat_encoder in feat_encoders_list_config:
            encoder_type = feat_encoder["type"]
            encoder_kwargs = feat_encoder["params"]
            encoder_kwargs["model_data_dir"] = self.config["model_data_dir"]

            feat_model = ImageEncoder(encoder_type, feature_dim,
                                      **encoder_kwargs)

            feat_encoders.append(feat_model)
            setattr(self, attr + "_feature_dim", feat_model.out_dim)

        setattr(self, attr + "_feature_encoders", nn.ModuleList(feat_encoders))

    def _init_feature_embeddings(self, attr):
        feature_embeddings_list = []
        num_feature_feat = len(
            getattr(self.config, "{}_feature_encodings".format(attr)))

        self.feature_embeddings_out_dim = 0

        for _ in range(num_feature_feat):
            feature_embeddings = []
            feature_attn_model_list = self.config[attr + "_feature_embeddings"]

            for feature_attn_model_params in feature_attn_model_list:
                feature_embedding = ImageEmbedding(
                    getattr(self, attr + "_feature_dim"),
                    self.text_embeddings_out_dim, **feature_attn_model_params)
                feature_embeddings.append(feature_embedding)
                self.feature_embeddings_out_dim += feature_embedding.out_dim

            feature_embeddings = nn.ModuleList(feature_embeddings)
            feature_embeddings_list.append(feature_embeddings)

        self.feature_embeddings_out_dim *= getattr(self, attr + "_feature_dim")

        setattr(self, attr + "_feature_embeddings_out_dim",
                self.feature_embeddings_out_dim)
        del self.feature_embeddings_out_dim
        setattr(
            self,
            attr + "_feature_embeddings_list",
            nn.ModuleList(feature_embeddings_list),
        )

    def _get_embeddings_attr(self, attr):
        embedding_attr1 = attr
        if hasattr(self, attr + "_embeddings_out_dim"):
            embedding_attr1 = attr + "_embeddings_out_dim"
        else:
            embedding_attr1 = attr + "_feature_embeddings_out_dim"

        return embedding_attr1

    def _init_combine_layer(self, attr1, attr2):
        config_attr = attr1 + "_" + attr2 + "_modal_combine"

        multi_modal_combine_layer = ModalCombineLayer(
            self.config[config_attr]["type"],
            getattr(self, self._get_embeddings_attr(attr1)),
            getattr(self, self._get_embeddings_attr(attr2)),
            **self.config[config_attr]["params"])

        setattr(
            self,
            attr1 + "_" + attr2 + "_multi_modal_combine_layer",
            multi_modal_combine_layer,
        )

    def _init_classifier(self, combined_embedding_dim):
        # TODO: Later support multihead
        num_choices = registry.get(self._datasets[0] + "_num_final_outputs")

        self.classifier = ClassifierLayer(
            self.config["classifier"]["type"],
            in_dim=combined_embedding_dim,
            out_dim=num_choices,
            **self.config["classifier"]["params"])

    def _init_extras(self):
        self.inter_model = None

    def get_optimizer_parameters(self, config):
        combine_layer = self.image_text_multi_modal_combine_layer
        params = [
            {
                "params": self.word_embedding.parameters()
            },
            {
                "params": self.image_feature_embeddings_list.parameters()
            },
            {
                "params": self.text_embeddings.parameters()
            },
            {
                "params": combine_layer.parameters()
            },
            {
                "params": self.classifier.parameters()
            },
            {
                "params": self.image_feature_encoders.parameters(),
                "lr": (config["optimizer_attributes"]["params"]["lr"] * 0.1),
            },
        ]

        return params

    def _get_classifier_input_dim(self):
        return self.image_text_multi_modal_combine_layer.out_dim

    def process_text_embedding(self,
                               sample_list,
                               embedding_attr="text_embeddings",
                               info=None):
        text_embeddings = []

        # Get "text" attribute in case of "text_embeddings" case
        # and "context" attribute in case of "context_embeddings"
        texts = getattr(sample_list, embedding_attr.split("_")[0])

        # Get embedding models
        text_embedding_models = getattr(self, embedding_attr)

        for text_embedding_model in text_embedding_models:
            # TODO: Move this logic inside
            if isinstance(text_embedding_model, PreExtractedEmbedding):
                embedding = text_embedding_model(sample_list.question_id)
            else:
                embedding = text_embedding_model(texts)
            text_embeddings.append(embedding)

        # # visualize decomposed question attention
        # image_id = getattr(sample_list, "image_id")
        # question_id = getattr(sample_list, "question_id").cpu()
        # question_id = question_id.numpy()
        # batch_size_t, _, _ = text_embeddings[0][7].shape
        # for cnt in range(0, batch_size_t):
        #     # image_path_org = './save/temp_check/'+question_id[cnt]+'image_id.pdh'
        #     # torch.save(image_id[cnt], image_path_org)
        #     attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_o.pdh'
        #     torch.save(text_embeddings[0][7][cnt], attn_path_org)
        #     attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_oo.pdh'
        #     torch.save(text_embeddings[0][8][cnt], attn_path_org)
        #     attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_ot.pdh'
        #     torch.save(text_embeddings[0][9][cnt], attn_path_org)
        #     attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_t.pdh'
        #     torch.save(text_embeddings[0][10][cnt], attn_path_org)
        #     attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_tt.pdh'
        #     torch.save(text_embeddings[0][11][cnt], attn_path_org)
        #     attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_to.pdh'
        #     torch.save(text_embeddings[0][12][cnt], attn_path_org)
        return text_embeddings[0][0], text_embeddings[0][1], text_embeddings[
            0][2], text_embeddings[0][3], text_embeddings[0][
                4], text_embeddings[0][5], text_embeddings[0][6]

    def process_feature_embedding(self,
                                  attr,
                                  sample_list,
                                  s_central,
                                  s_homo=None,
                                  s_hetero=None,
                                  pre_ques_embed=None,
                                  obj_feats=None,
                                  ocr_feats=None):
        """
        parameters:

        input: 
        attr: "image" or "context"
        sample_list: just sample_list
        s_central: question features for guiding purpose, torch.Size([128, 2048])
                   s_o/s_t
        s_homo: s_oo/s_tt
        s_hetero: s_ot/s_to

        output:
        """
        # add obj bbox feats and image size
        batch, bbox_num, obj_feat_dim = obj_feats.shape
        _, _, ocr_feat_dim = ocr_feats.shape
        knn_k = 5
        loc_dim = 5
        # expand obj_feats
        temp_expand_obj_feat = obj_feats[0][0]
        temp_expand_obj_feat = temp_expand_obj_feat.expand(
            batch, 1, obj_feat_dim) * 0
        temp_expand_obj_feat = torch.cat((obj_feats, temp_expand_obj_feat), 1)

        # expand ocr_feats
        temp_expand_ocr_feat = ocr_feats[0][0]
        temp_expand_ocr_feat = temp_expand_ocr_feat.expand(
            batch, 1, ocr_feat_dim) * 0
        temp_expand_ocr_feat = torch.cat((ocr_feats, temp_expand_ocr_feat), 1)

        if attr == 'image':
            batch_size_t = (sample_list.get_batch_size())
            # Get "image_feature_0"
            feature = getattr(sample_list, "{}_feature_{:d}".format(attr, 0),
                              None)
            feature = feature[:batch_size_t]
            # Get info related to the current feature. info is generally
            # in key of format "image_info_0" for 0th feature
            feature_info = getattr(sample_list, "{}_info_{:d}".format(attr, 0),
                                   {})
            # For Pythia, we need max_features to mask attention
            feature_dim = getattr(feature_info, "max_features", None)
            if feature_dim is not None:
                feature_dim = feature_dim[:batch_size_t]
            # Get feature embedding
            feature_embedding_model = getattr(self,
                                              attr + "_feature_embedding")
            encoded_feature = obj_feats
            batch, bbox_num, obj_feat_dim = encoded_feature.shape

            # obj_obj_edge_feature = None
            # oo edge generation
            obj_obj_edge_feature = torch.zeros(
                (batch, bbox_num, knn_k, obj_feat_dim + loc_dim)).float()
            obj_obj_edge_feature = obj_obj_edge_feature.cuda()
            oo_edge = getattr(getattr(sample_list, "ocr_bbox"), "edge_oo")
            oo_edgefeats = getattr(getattr(sample_list, "ocr_bbox"),
                                   "edge_oofeats")
            for i in range(batch):
                obj_obj_edge_feature[i] = torch.cat(
                    (oo_edgefeats[i], temp_expand_obj_feat[i][oo_edge[i]]), 2)

            # obj_text_edge_feature = None
            # ot edge generation
            obj_text_edge_feature = torch.zeros(
                (batch, bbox_num, knn_k, ocr_feat_dim + loc_dim)).float()
            obj_text_edge_feature = obj_text_edge_feature.cuda()
            ot_edge = getattr(getattr(sample_list, "ocr_bbox"), "edge_ot")
            ot_edgefeats = getattr(getattr(sample_list, "ocr_bbox"),
                                   "edge_otfeats")
            for i in range(batch):
                obj_text_edge_feature[i] = torch.cat(
                    (ot_edgefeats[i], temp_expand_ocr_feat[i][ot_edge[i]]), 2)

            oo_edge_feature = obj_obj_edge_feature
            ot_edge_feature = obj_text_edge_feature

            s_o, s_oo, s_ot = s_central, s_homo, s_hetero
            # for ablation study purpose,
            # o feature + oo relation + ot relation
            if (s_oo is not None) and (oo_edge_feature is not None) and (
                    s_ot is not None) and (ot_edge_feature
                                           is not None) and (pre_ques_embed
                                                             is not None):
                inp = (attr, encoded_feature, s_o, feature_dim, s_oo,
                       oo_edge_feature, s_ot, ot_edge_feature, pre_ques_embed)
            # o feature + oo relation
            elif (s_oo is not None) and (oo_edge_feature
                                         is not None) and (pre_ques_embed
                                                           is not None):
                inp = (attr, encoded_feature, s_o, feature_dim, s_oo,
                       oo_edge_feature, pre_ques_embed)
            # o feature + ot relation
            elif (s_ot is not None) and (ot_edge_feature
                                         is not None) and (pre_ques_embed
                                                           is not None):
                inp = (attr, encoded_feature, s_o, feature_dim, s_ot,
                       ot_edge_feature, pre_ques_embed)
            # o feature only
            else:
                inp = (attr, encoded_feature, s_o, feature_dim)

            g_o = feature_embedding_model(*inp)
            return g_o

        elif attr == 'context':
            batch_size_t = (sample_list.get_batch_size())
            # Get "context_feature_0"
            feature = getattr(sample_list, "{}_feature_{:d}".format(attr, 0),
                              None)
            feature = feature[:batch_size_t]
            # Get info related to the current feature. info is generally
            # in key of format "image_info_0" for 0th feature
            feature_info = getattr(sample_list, "{}_info_{:d}".format(attr, 0),
                                   {})
            # For Pythia, we need max_features to mask attention
            feature_dim = getattr(feature_info, "max_features", None)
            if feature_dim is not None:
                feature_dim = feature_dim[:batch_size_t]
            # Get feature embedding
            feature_embedding_model = getattr(self,
                                              "context_feature_embedding")
            encoded_feature = ocr_feats
            batch, bbox_num, _ = encoded_feature.shape

            # text_text_edge_feature = None
            # tt edge generation
            text_text_edge_feature = torch.zeros(
                (batch, bbox_num, knn_k, ocr_feat_dim + loc_dim)).float()
            text_text_edge_feature = text_text_edge_feature.cuda()
            tt_edge = getattr(getattr(sample_list, "ocr_bbox"), "edge_tt")
            tt_edgefeats = getattr(getattr(sample_list, "ocr_bbox"),
                                   "edge_ttfeats")
            for i in range(batch):
                text_text_edge_feature[i] = torch.cat(
                    (tt_edgefeats[i], temp_expand_ocr_feat[i][tt_edge[i]]), 2)

            # text_obj_edge_feature = None
            # to edge generation
            text_obj_edge_feature = torch.zeros(
                (batch, bbox_num, knn_k, obj_feat_dim + loc_dim)).float()
            text_obj_edge_feature = text_obj_edge_feature.cuda()
            to_edge = getattr(getattr(sample_list, "ocr_bbox"), "edge_to")
            to_edgefeats = getattr(getattr(sample_list, "ocr_bbox"),
                                   "edge_tofeats")
            for i in range(batch):
                text_obj_edge_feature[i] = torch.cat(
                    (to_edgefeats[i], temp_expand_obj_feat[i][to_edge[i]]), 2)

            tt_edge_feature = text_text_edge_feature
            to_edge_feature = text_obj_edge_feature

            s_t, s_tt, s_to = s_central, s_homo, s_hetero
            # for ablation study purpose
            # t feature + tt relation + to relation
            if (s_tt is not None) and (tt_edge_feature is not None) and (
                    s_to is not None) and (to_edge_feature
                                           is not None) and (pre_ques_embed
                                                             is not None):
                inp = (attr, encoded_feature, s_t, feature_dim, s_tt,
                       tt_edge_feature, s_to, to_edge_feature, pre_ques_embed)
            # t feature + tt relation
            elif (s_tt is not None) and (tt_edge_feature
                                         is not None) and (pre_ques_embed
                                                           is not None):
                inp = (attr, encoded_feature, s_t, feature_dim, s_tt,
                       tt_edge_feature, pre_ques_embed)
            # t feature + to relation
            elif (s_to is not None) and (to_edge_feature
                                         is not None) and (pre_ques_embed
                                                           is not None):
                inp = (attr, encoded_feature, s_t, feature_dim, s_to,
                       to_edge_feature, pre_ques_embed)
            # t feature only
            else:
                inp = (attr, encoded_feature, s_t, feature_dim)

            g_t, updated_ocr = feature_embedding_model(*inp)
            return g_t, updated_ocr

    def combine_embeddings(self, *args):
        feature_names = args[0]
        feature_embeddings = args[1]

        layer = "_".join(feature_names) + "_multi_modal_combine_layer"
        return getattr(self, layer)(*feature_embeddings)

    def calculate_logits(self, joint_embedding, **kwargs):
        return self.classifier(joint_embedding)

    def forward(self, sample_list):
        sample_list.text = self.word_embedding(sample_list.text)
        text_embedding_total = self.process_text_embedding(sample_list)

        image_embedding_total, _ = self.process_feature_embedding(
            "image", sample_list, text_embedding_total)

        if self.inter_model is not None:
            image_embedding_total = self.inter_model(image_embedding_total)

        joint_embedding = self.combine_embeddings(
            ["image", "text"], [image_embedding_total, text_embedding_total])

        model_output = {"scores": self.calculate_logits(joint_embedding)}

        return model_output
예제 #8
0
파일: sma.py 프로젝트: SelinaFelton/SMA
 def build(self):
     self.mmt_config = BertConfig(**self.config.mmt)
     self.mmt = MMT(self.mmt_config)
     self.so_to_mmt_in = nn.Linear(3 * 1536, self.mmt_config.hidden_size)
     self.st_to_mmt_in = nn.Linear(3 * 1536, self.mmt_config.hidden_size)
     self.so_layer_norm = BertLayerNorm(self.mmt_config.hidden_size)
     self.st_layer_norm = BertLayerNorm(self.mmt_config.hidden_size)
     self.so_drop = nn.Dropout(0.1)
     self.st_drop = nn.Dropout(0.1)
     self.linear_go_to_mmt_in = nn.Linear(2048, self.mmt_config.hidden_size)
     self.linear_gt_to_mmt_in = nn.Linear(300, self.mmt_config.hidden_size)
     self.go_layer_norm = BertLayerNorm(self.mmt_config.hidden_size)
     self.gt_layer_norm = BertLayerNorm(self.mmt_config.hidden_size)
     self.go_drop = nn.Dropout(0.1)
     self.gt_drop = nn.Dropout(0.1)
     self.linear_updated_ocr_to_mmt_in = nn.Linear(
         300, self.mmt_config.hidden_size)
     self.updated_ocr_layer_norm = BertLayerNorm(
         self.mmt_config.hidden_size)
     self.updated_ocr_drop = nn.Dropout(self.config.ocr.dropout_prob)
     self.linear_joint = nn.Linear(1536, 768)
     self.answer_processor = registry.get(self._datasets[0] +
                                          "_answer_processor")
     self.ocr_ptr_net = OcrPtrNet(**self.config.classifier.ocr_ptr_net)
     # modules requiring custom learning rates (usually for finetuning)
     self.finetune_modules = []
     self._build_txt_encoding()
     self._build_obj_encoding()
     self._build_ocr_encoding()
     self._init_text_embeddings("text")
     # init feature embedding for "image"
     setattr(self, "image_feature_dim", self.config["image_feature_dim"])
     self.feature_embeddings_out_dim = 0
     feature_attn_model_params = self.config["image_feature_embeddings"][0]
     feature_embedding = ImageEmbedding(getattr(self, "image_feature_dim"),
                                        self.text_embeddings_out_dim,
                                        **feature_attn_model_params)
     self.feature_embeddings_out_dim += feature_embedding.out_dim
     self.feature_embeddings_out_dim *= getattr(self, "image_feature_dim")
     setattr(self, "image_feature_embeddings_out_dim",
             self.feature_embeddings_out_dim)
     del self.feature_embeddings_out_dim
     setattr(self, "image_feature_embedding", feature_embedding)
     # init feature embedding for "context"
     setattr(self, "context_feature_dim",
             self.config["context_feature_dim"])
     self.feature_embeddings_out_dim = 0
     feature_attn_model_params = self.config["context_feature_embeddings"][
         0]
     feature_embedding = ImageEmbedding(
         getattr(self, "context_feature_dim"), self.text_embeddings_out_dim,
         **feature_attn_model_params)
     self.feature_embeddings_out_dim += feature_embedding.out_dim
     self.feature_embeddings_out_dim *= getattr(self, "context_feature_dim")
     setattr(self, "context_feature_embeddings_out_dim",
             self.feature_embeddings_out_dim)
     del self.feature_embeddings_out_dim
     setattr(self, "context_feature_embedding", feature_embedding)
     self._init_combine_layer("image", "text")
     num_choices = registry.get(self._datasets[0] + "_num_final_outputs")
     self.classifier = ClassifierLayer(
         self.config["classifier"]["type"],
         in_dim=768,
         out_dim=num_choices - 50,
         **self.config["classifier"]["params"])