Exemplo n.º 1
0
class Pythia(BaseModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self._global_config = registry.get("config")
        self._datasets = self._global_config.datasets.split(",")

    @classmethod
    def config_path(cls):
        return "configs/models/pythia/defaults.yaml"

    @classmethod
    def format_state_key(cls, key):
        return key.replace("fa_history", "fa_context")

    def build(self):
        self._build_word_embedding()
        self._init_text_embeddings("text")
        self._init_feature_encoders("image")
        self._init_feature_embeddings("image")
        self._init_combine_layer("image", "text")
        self._init_classifier(self._get_classifier_input_dim())
        self._init_extras()

    def _build_word_embedding(self):
        assert len(self._datasets) > 0
        text_processor = registry.get(self._datasets[0] + "_text_processor")
        vocab = text_processor.vocab
        self.word_embedding = vocab.get_embedding(torch.nn.Embedding,
                                                  embedding_dim=300)

    def _init_text_embeddings(self, attr="text"):
        if "embeddings" not in attr:
            attr += "_embeddings"

        text_embeddings = []
        text_embeddings_list_config = self.config[attr]

        embeddings_out_dim = 0

        for text_embedding in text_embeddings_list_config:
            embedding_type = text_embedding.type
            embedding_kwargs = copy.deepcopy(text_embedding.params)

            self._update_text_embedding_args(embedding_kwargs)

            embedding = TextEmbedding(embedding_type, **embedding_kwargs)

            text_embeddings.append(embedding)
            embeddings_out_dim += embedding.text_out_dim

        setattr(self, attr + "_out_dim", embeddings_out_dim)
        setattr(self, attr, nn.ModuleList(text_embeddings))

    def _update_text_embedding_args(self, args):
        # Add model_data_dir to kwargs
        args.model_data_dir = self.config.model_data_dir

    def _init_feature_encoders(self, attr):
        feat_encoders = []
        feat_encoders_list_config = self.config[attr + "_feature_encodings"]
        feature_dim = self.config[attr + "_feature_dim"]
        setattr(self, attr + "_feature_dim", feature_dim)

        for feat_encoder in feat_encoders_list_config:
            encoder_type = feat_encoder.type
            encoder_kwargs = copy.deepcopy(feat_encoder.params)
            encoder_kwargs.model_data_dir = self.config.model_data_dir

            feat_model = ImageFeatureEncoder(encoder_type, feature_dim,
                                             **encoder_kwargs)

            feat_encoders.append(feat_model)
            setattr(self, attr + "_feature_dim", feat_model.out_dim)

        setattr(self, attr + "_feature_encoders", nn.ModuleList(feat_encoders))

    def _init_feature_embeddings(self, attr):
        feature_embeddings_list = []
        num_feature_feat = len(
            getattr(self.config, "{}_feature_encodings".format(attr)))

        self.feature_embeddings_out_dim = 0

        for _ in range(num_feature_feat):
            feature_embeddings = []
            feature_attn_model_list = self.config[attr + "_feature_embeddings"]

            for feature_attn_model_params in feature_attn_model_list:
                feature_embedding = ImageFeatureEmbedding(
                    getattr(self, attr + "_feature_dim"),
                    self.text_embeddings_out_dim, **feature_attn_model_params)
                feature_embeddings.append(feature_embedding)
                self.feature_embeddings_out_dim += feature_embedding.out_dim

            feature_embeddings = nn.ModuleList(feature_embeddings)
            feature_embeddings_list.append(feature_embeddings)

        self.feature_embeddings_out_dim *= getattr(self, attr + "_feature_dim")

        setattr(self, attr + "_feature_embeddings_out_dim",
                self.feature_embeddings_out_dim)
        del self.feature_embeddings_out_dim
        setattr(
            self,
            attr + "_feature_embeddings_list",
            nn.ModuleList(feature_embeddings_list),
        )

    def _get_embeddings_attr(self, attr):
        embedding_attr1 = attr
        if hasattr(self, attr + "_embeddings_out_dim"):
            embedding_attr1 = attr + "_embeddings_out_dim"
        else:
            embedding_attr1 = attr + "_feature_embeddings_out_dim"

        return embedding_attr1

    def _init_combine_layer(self, attr1, attr2):
        config_attr = attr1 + "_" + attr2 + "_modal_combine"

        multi_modal_combine_layer = ModalCombineLayer(
            self.config[config_attr].type,
            getattr(self, self._get_embeddings_attr(attr1)),
            getattr(self, self._get_embeddings_attr(attr2)),
            **self.config[config_attr].params)

        setattr(
            self,
            attr1 + "_" + attr2 + "_multi_modal_combine_layer",
            multi_modal_combine_layer,
        )

    def _init_classifier(self, combined_embedding_dim):
        # TODO: Later support multihead
        num_choices = registry.get(self._datasets[0] + "_num_final_outputs")

        self.classifier = ClassifierLayer(self.config.classifier.type,
                                          in_dim=combined_embedding_dim,
                                          out_dim=num_choices,
                                          **self.config.classifier.params)

    def _init_extras(self):
        self.inter_model = None

    def get_optimizer_parameters(self, config):
        combine_layer = self.image_text_multi_modal_combine_layer
        params = [
            {
                "params": self.word_embedding.parameters()
            },
            {
                "params": self.image_feature_embeddings_list.parameters()
            },
            {
                "params": self.text_embeddings.parameters()
            },
            {
                "params": combine_layer.parameters()
            },
            {
                "params": self.classifier.parameters()
            },
            {
                "params": self.image_feature_encoders.parameters(),
                "lr": (config.optimizer.params.lr * 0.1),
            },
        ]

        return params

    def _get_classifier_input_dim(self):
        return self.image_text_multi_modal_combine_layer.out_dim

    def process_text_embedding(self,
                               sample_list,
                               embedding_attr="text_embeddings",
                               info=None):
        text_embeddings = []

        # Get "text" attribute in case of "text_embeddings" case
        # and "context" attribute in case of "context_embeddings"
        texts = getattr(sample_list, embedding_attr.split("_")[0])

        # Get embedding models
        text_embedding_models = getattr(self, embedding_attr)

        for text_embedding_model in text_embedding_models:
            # TODO: Move this logic inside
            if isinstance(text_embedding_model, PreExtractedEmbedding):
                embedding = text_embedding_model(sample_list.question_id)
            else:
                embedding = text_embedding_model(texts)
            text_embeddings.append(embedding)

        text_embeddding_total = torch.cat(text_embeddings, dim=1)

        return text_embeddding_total

    def process_feature_embedding(self,
                                  attr,
                                  sample_list,
                                  text_embedding_total,
                                  extra=None,
                                  batch_size_t=None):
        if extra is None:
            extra = []
        feature_embeddings = []
        feature_attentions = []
        features = []
        batch_size_t = (sample_list.get_batch_size()
                        if batch_size_t is None else batch_size_t)

        # Convert list of keys to the actual values
        extra = sample_list.get_fields(extra)

        feature_idx = 0

        # Get all of the features, which are in the form, "image_feature_0"
        # "image_feature_1" ...
        while True:
            feature = getattr(sample_list,
                              "{}_feature_{:d}".format(attr,
                                                       feature_idx), None)
            if feature is None:
                break
            feature_idx += 1
            feature = feature[:batch_size_t]
            features.append(feature)

        feature_encoders = getattr(self, attr + "_feature_encoders")
        # Each feature should have a separate image feature encoders
        assert len(features) == len(feature_encoders), (
            "Number of feature encoders, {} are not equal "
            "to number of features, {}.".format(len(feature_encoders),
                                                len(features)))

        # Now, iterate to get final attended image features
        for i, feature in enumerate(features):
            # Get info related to the current feature. info is generally
            # in key of format "image_info_0" for 0th feature
            feature_info = getattr(sample_list, "{}_info_{:d}".format(attr, i),
                                   {})
            # For Pythia, we need max_features to mask attention
            feature_dim = getattr(feature_info, "max_features", None)
            if feature_dim is not None:
                feature_dim = feature_dim[:batch_size_t]

            # Attribute in which encoders are saved, for "image" it
            # will be "image_feature_encoders", other example is
            # "context_feature_encoders"
            encoders_attr = attr + "_feature_encoders"
            feature_encoder = getattr(self, encoders_attr)[i]

            # Encode the features
            encoded_feature = feature_encoder(feature)

            # Get all of the feature embeddings
            list_attr = attr + "_feature_embeddings_list"
            feature_embedding_models = getattr(self, list_attr)[i]

            # Forward through these embeddings one by one
            for feature_embedding_model in feature_embedding_models:
                inp = (encoded_feature, text_embedding_total, feature_dim,
                       extra)

                embedding, attention = feature_embedding_model(*inp)
                feature_embeddings.append(embedding)
                feature_attentions.append(attention.squeeze(-1))

        # Concatenate all features embeddings and return along with attention
        feature_embedding_total = torch.cat(feature_embeddings, dim=1)
        return feature_embedding_total, feature_attentions

    def combine_embeddings(self, *args):
        feature_names = args[0]
        feature_embeddings = args[1]

        layer = "_".join(feature_names) + "_multi_modal_combine_layer"
        return getattr(self, layer)(*feature_embeddings)

    def calculate_logits(self, joint_embedding, **kwargs):
        return self.classifier(joint_embedding)

    def forward(self, sample_list):
        sample_list.text = self.word_embedding(sample_list.text)
        text_embedding_total = self.process_text_embedding(sample_list)

        image_embedding_total, _ = self.process_feature_embedding(
            "image", sample_list, text_embedding_total)

        if self.inter_model is not None:
            image_embedding_total = self.inter_model(image_embedding_total)

        joint_embedding = self.combine_embeddings(
            ["image", "text"], [image_embedding_total, text_embedding_total])

        model_output = {"scores": self.calculate_logits(joint_embedding)}

        return model_output
Exemplo n.º 2
0
class BUTD(Pythia):
    def __init__(self, config):
        super().__init__(config)

    @classmethod
    def config_path(cls):
        return "configs/models/butd/defaults.yaml"

    def build(self):
        self._build_word_embedding()
        self._init_feature_encoders("image")
        self._init_feature_embeddings("image")
        self._init_classifier()
        self._init_extras()

    def _build_word_embedding(self):
        self.text_processor = registry.get(self._datasets[0] + "_text_processor")
        self.vocab = self.text_processor.vocab
        self.vocab_size = self.vocab.get_size()
        self.word_embedding = self.vocab.get_embedding(
            torch.nn.Embedding, embedding_dim=self.config.embedding_dim
        )
        self.text_embeddings_out_dim = self.config.embedding_dim

    def _init_classifier(self):
        self.classifier = ClassifierLayer(
            self.config.classifier.type,
            in_dim=self.config.classifier.params.feature_dim,
            out_dim=self.vocab_size,
            **self.config.classifier.params,
        )

    def get_optimizer_parameters(self, config):
        params = [
            {"params": self.word_embedding.parameters()},
            {"params": self.image_feature_embeddings_list.parameters()},
            {"params": self.classifier.parameters()},
            {
                "params": self.image_feature_encoders.parameters(),
                "lr": (config.optimizer.params.lr * 0.1),
            },
        ]
        return params

    def prepare_data(self, sample_list, batch_size):
        # turn off teacher forcing during beam search
        # (otherwise one cannot run beam search on val set)
        self.teacher_forcing = self.config.inference.type != "beam_search" and hasattr(
            sample_list, "text"
        )
        data = {}
        if self.teacher_forcing:
            caption_lengths, sort_ind = sample_list.caption_len.sort(
                dim=0, descending=True
            )
            data["decode_lengths"] = (caption_lengths - 1).tolist()
            sample_list.text = sample_list.text[sort_ind]
            sample_list.answers = sample_list.answers[sort_ind]
            sample_list.image_feature_0 = sample_list.image_feature_0[sort_ind]
            data["texts"] = sample_list.text
            timesteps = max(data["decode_lengths"])
            sample_list.add_field("targets", sample_list.text[:, 1:])
        else:
            data["texts"] = sample_list.answers.new_full(
                (batch_size, 1), self.vocab.SOS_INDEX, dtype=torch.long
            )
            timesteps = self.text_processor.max_length
            sample_list.add_field("targets", sample_list.answers[:, 0, 1:])
        return data, sample_list, timesteps

    def init_hidden_state(self, features):
        h = features.new_zeros(
            (features.size(0), self.config.classifier.params.hidden_dim),
            dtype=torch.float,
        )
        c = features.new_zeros(
            (features.size(0), self.config.classifier.params.hidden_dim),
            dtype=torch.float,
        )
        return h, c

    def get_data_t(self, t, data, batch_size_t, prev_output):
        if self.teacher_forcing:
            # Modify batch_size for timestep t
            batch_size_t = sum([l > t for l in data["decode_lengths"]])
        elif prev_output is not None and self.config.inference.type == "greedy":
            # Adding t-1 output words to data["text"] for greedy decoding
            output_softmax = torch.log_softmax(prev_output, dim=1)
            _, indices = torch.max(output_softmax, dim=1, keepdim=True)
            data["texts"] = torch.cat(
                (data["texts"], indices.view(batch_size_t, 1)), dim=1
            )

        # Slice data based on batch_size at timestep t
        data["texts"] = data["texts"][:batch_size_t]
        if "state" in data:
            h1 = data["state"]["td_hidden"][0][:batch_size_t]
            c1 = data["state"]["td_hidden"][1][:batch_size_t]
            h2 = data["state"]["lm_hidden"][0][:batch_size_t]
            c2 = data["state"]["lm_hidden"][1][:batch_size_t]
        else:
            h1, c1 = self.init_hidden_state(data["texts"])
            h2, c2 = self.init_hidden_state(data["texts"])
        data["state"] = {"td_hidden": (h1, c1), "lm_hidden": (h2, c2)}
        registry.register(f"{h1.device}_lstm_state", data["state"])

        return data, batch_size_t

    def forward(self, sample_list):
        # Stores the output probabilites.
        scores = sample_list.answers.new_ones(
            (
                sample_list.answers.size(0),
                self.text_processor.max_length,
                self.vocab_size,
            ),
            dtype=torch.float,
        )

        if self.config["inference"]["type"] in ["beam_search", "nucleus_sampling"]:
            decoder = registry.get_decoder_class(self.config["inference"]["type"])(
                self.vocab, self.config
            )
            sample_list = decoder.init_batch(sample_list)

        batch_size = sample_list.image_feature_0.size(0)
        data, sample_list, timesteps = self.prepare_data(sample_list, batch_size)
        output = None
        batch_size_t = batch_size
        for t in range(timesteps):
            data, batch_size_t = self.get_data_t(t, data, batch_size_t, output)
            if self.config.inference.type in ["beam_search", "nucleus_sampling"]:
                pi_t = data["texts"]
            else:
                pi_t = data["texts"][:, t].unsqueeze(-1)
            embedding = self.word_embedding(pi_t)
            attention_feature, _ = self.process_feature_embedding(
                "image", sample_list, embedding[:, 0, :], batch_size_t=batch_size_t
            )
            output = self.classifier(attention_feature)
            # Compute decoding
            if self.config.inference.type in ["beam_search", "nucleus_sampling"]:
                finish, data, batch_size_t = decoder.decode(t, data, output)
                if finish:
                    break
            else:
                scores[:batch_size_t, t] = output

        model_output = {"scores": scores}
        if self.config.inference.type in ["beam_search", "nucleus_sampling"]:
            model_output["captions"] = decoder.get_result()

        return model_output
Exemplo n.º 3
0
class MoVieMcan(BaseModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self._global_config = registry.get("config")
        self._datasets = self._global_config.datasets.split(",")

    @classmethod
    def config_path(cls):
        return "configs/models/movie_mcan/defaults.yaml"

    def build(self):
        self.image_feature_dim = 2048
        self._build_word_embedding()
        self._init_text_embeddings("text")
        self._init_feature_encoders("image")
        self._init_feature_embeddings("image")
        self._init_combine_layer("image", "text")
        self._init_classifier(self._get_classifier_input_dim())
        self._init_extras()

    def _build_word_embedding(self):
        assert len(self._datasets) > 0
        text_processor = registry.get(self._datasets[0] + "_text_processor")
        vocab = text_processor.vocab
        self.word_embedding = vocab.get_embedding(torch.nn.Embedding,
                                                  embedding_dim=300)

    def _init_text_embeddings(self, attr: str = "text"):
        if "embeddings" not in attr:
            attr += "_embeddings"

        module_config = self.config[attr]
        embedding_type = module_config.type
        embedding_kwargs = copy.deepcopy(module_config.params)
        self._update_text_embedding_args(embedding_kwargs)
        embedding = TextEmbedding(embedding_type, **embedding_kwargs)
        embeddings_out_dim = embedding.text_out_dim

        setattr(self, attr + "_out_dim", embeddings_out_dim)
        setattr(self, attr, embedding)

    def _update_text_embedding_args(self, args):
        # Add model_data_dir to kwargs
        args.model_data_dir = self.config.model_data_dir

    def _init_feature_encoders(self, attr: str):
        feat_encoder = self.config[attr + "_feature_encodings"]
        feature_dim = self.config[attr + "_feature_dim"]
        setattr(self, attr + "_feature_dim", feature_dim)

        encoder_type = feat_encoder.type
        encoder_kwargs = copy.deepcopy(feat_encoder.params)
        encoder_kwargs.model_data_dir = self.config.model_data_dir
        encoder_kwargs.cond_features = self.text_embeddings_out_dim

        feat_model = ImageFeatureEncoder(encoder_type, feature_dim,
                                         **encoder_kwargs)

        setattr(self, attr + "_feature_dim", feat_model.out_dim)
        setattr(self, attr + "_feature_encoders", feat_model)

    def _init_feature_embeddings(self, attr: str):
        embedding_kwargs = self.config[attr + "_feature_embeddings"]["params"]
        setattr(self, attr + "_feature_embeddings_out_dim",
                embedding_kwargs["hidden_dim"])
        assert (getattr(self, attr + "_feature_embeddings_out_dim") ==
                self.text_embeddings_out_dim), "dim1: {}, dim2: {}".format(
                    getattr(self, attr + "_feature_embeddings_out_dim"),
                    self.text_embeddings_out_dim,
                )

        feature_embedding = TwoBranchEmbedding(
            getattr(self, attr + "_feature_dim"), **embedding_kwargs)
        setattr(
            self,
            attr + "_feature_embeddings_list",
            feature_embedding,
        )

    def _get_embeddings_attr(self, attr: str):
        embedding_attr1 = attr
        if hasattr(self, attr + "_embeddings_out_dim"):
            embedding_attr1 = attr + "_embeddings_out_dim"
        else:
            embedding_attr1 = attr + "_feature_embeddings_out_dim"

        return embedding_attr1

    def _init_combine_layer(self, attr1: str, attr2: str):
        multi_modal_combine_layer = BranchCombineLayer(
            getattr(self, self._get_embeddings_attr(attr1)),
            getattr(self, self._get_embeddings_attr(attr2)),
        )

        setattr(
            self,
            attr1 + "_" + attr2 + "_multi_modal_combine_layer",
            multi_modal_combine_layer,
        )

    def _init_classifier(self, combined_embedding_dim: int):
        # TODO: Later support multihead
        num_choices = registry.get(self._datasets[0] + "_num_final_outputs")
        params = self.config["classifier"].get("params")
        if params is None:
            params = {}

        self.classifier = ClassifierLayer(self.config.classifier.type,
                                          in_dim=combined_embedding_dim,
                                          out_dim=num_choices,
                                          **params)

    def _init_extras(self):
        self.inter_model = None

    def get_optimizer_parameters(self,
                                 config: DictConfig) -> List[Dict[str, Any]]:
        combine_layer = self.image_text_multi_modal_combine_layer
        params = [
            {
                "params": filter_grads(self.word_embedding.parameters())
            },
            {
                "params":
                filter_grads(
                    self.image_feature_embeddings_list.sga.parameters())
            },
            {
                "params":
                filter_grads(
                    self.image_feature_embeddings_list.sga_pool.parameters())
            },
            {
                "params":
                filter_grads(
                    self.image_feature_embeddings_list.cbn.parameters()),
                "lr": (config.optimizer.params.lr *
                       config.training.encoder_lr_multiply),
            },
            {
                "params": filter_grads(self.text_embeddings.parameters())
            },
            {
                "params": filter_grads(combine_layer.parameters())
            },
            {
                "params": filter_grads(self.classifier.parameters())
            },
            {
                "params":
                filter_grads(self.image_feature_encoders.parameters())
            },
        ]

        return params

    def get_mapping(self):
        mapping = [
            "word_embedding",
            "image_feature_embeddings_list_sga",
            "image_feature_embeddings_list_sga_pool",
            "image_feature_embeddings_list_cbn",
            "text_embeddings",
            "combine_layer",
            "classifier",
            "image_feature_encoders",
        ]
        return mapping

    def _get_classifier_input_dim(self):
        return self.image_text_multi_modal_combine_layer.out_dim

    def process_text_embedding(
        self,
        sample_list: Dict[str, Any],
        embedding_attr: str = "text_embeddings"
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        # Get "text" attribute in case of "text_embeddings" case
        # and "context" attribute in case of "context_embeddings"
        texts = getattr(sample_list, embedding_attr.split("_")[0])

        # Get embedding models
        text_embedding_model = getattr(self, embedding_attr)

        # TODO: Move this logic inside
        if isinstance(text_embedding_model, PreExtractedEmbedding):
            text_embedding_total = text_embedding_model(
                sample_list.question_id)
        else:
            text_embedding_total, text_embedding_vec = text_embedding_model(
                texts, sample_list.text_mask)

        return text_embedding_total, text_embedding_vec

    def process_feature_embedding(
        self,
        attr: str,
        sample_list: Dict[str, Any],
        text_embedding_total: torch.Tensor,
        text_embedding_vec: torch.Tensor,
        extra: list = [],
        batch_size_t: Optional[int] = None,
    ):
        batch_size_t = (sample_list.get_batch_size()
                        if batch_size_t is None else batch_size_t)

        # Convert list of keys to the actual values
        if hasattr(sample_list, "image"):
            feature = sample_list.image

            feature_encoder = getattr(self, attr + "_feature_encoders")
            encoded_feature = feature_encoder(feature, text_embedding_vec)
        else:
            feature = sample_list.image_feature_0

            feature_encoder = getattr(self, attr + "_feature_encoders")
            encoded_feature = feature_encoder(feature)

        feature_embedding = getattr(self, attr + "_feature_embeddings_list")
        feature_sga, feature_cbn = feature_embedding(
            encoded_feature,
            text_embedding_total,
            text_embedding_vec,
            None,
            sample_list.text_mask,
        )

        return feature_sga, feature_cbn

    def combine_embeddings(self, *args):
        feature_names = args[0]
        v1, v2, q = args[1]

        layer = "_".join(feature_names) + "_multi_modal_combine_layer"
        return getattr(self, layer)(v1, v2, q)

    def calculate_logits(self, joint_embedding: torch.Tensor, **kwargs):
        return self.classifier(joint_embedding)

    def forward(self, sample_list: Dict[str, Any]) -> Dict[str, torch.Tensor]:
        sample_list.text_mask = sample_list.text.eq(0)
        sample_list.text = self.word_embedding(sample_list.text)
        text_embedding_total, text_embedding_vec = self.process_text_embedding(
            sample_list)

        feature_sga, feature_cbn = self.process_feature_embedding(
            "image", sample_list, text_embedding_total, text_embedding_vec[:,
                                                                           0])

        joint_embedding = self.combine_embeddings(
            ["image", "text"],
            [feature_sga, feature_cbn, text_embedding_vec[:, 1]])

        model_output = {"scores": self.calculate_logits(joint_embedding)}

        return model_output
Exemplo n.º 4
0
class Pythia(BaseModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self._global_config = registry.get("config")
        self._datasets = self._global_config.datasets.split(",")

    @classmethod
    def config_path(cls):
        return "configs/models/pythia/defaults.yaml"

    @classmethod
    def format_state_key(cls, key):
        return key.replace("fa_history", "fa_context")

    def build(self):
        self._build_word_embedding()
        self._init_text_embeddings("text")
        self._init_feature_encoders("image")
        self._init_feature_embeddings("image")
        self._init_combine_layer("image", "text")
        self._init_classifier(self._get_classifier_input_dim())
        self._init_extras()

    def _build_word_embedding(self):
        assert len(self._datasets) > 0
        text_processor = registry.get(self._datasets[0] + "_text_processor")
        vocab = text_processor.vocab
        self.word_embedding = vocab.get_embedding(torch.nn.Embedding,
                                                  embedding_dim=300)

    def _init_text_embeddings(self, attr="text"):
        if "embeddings" not in attr:
            attr += "_embeddings"

        text_embeddings = []
        text_embeddings_list_config = self.config[attr]

        embeddings_out_dim = 0

        for text_embedding in text_embeddings_list_config:
            embedding_type = text_embedding.type
            embedding_kwargs = copy.deepcopy(text_embedding.params)

            self._update_text_embedding_args(embedding_kwargs)

            embedding = TextEmbedding(embedding_type, **embedding_kwargs)

            text_embeddings.append(embedding)
            embeddings_out_dim += embedding.text_out_dim

        setattr(self, attr + "_out_dim", embeddings_out_dim)
        setattr(self, attr, nn.ModuleList(text_embeddings))

    def _update_text_embedding_args(self, args):
        # Add model_data_dir to kwargs
        args.model_data_dir = self.config.model_data_dir

    def _init_feature_encoders(self, attr):
        feat_encoders = []
        feat_encoders_list_config = self.config[attr + "_feature_encodings"]
        feature_dim = self.config[attr + "_feature_dim"]
        setattr(self, attr + "_feature_dim", feature_dim)

        # print("feat_encoders_list_config", feat_encoders_list_config)

        for feat_encoder in feat_encoders_list_config:
            encoder_type = feat_encoder.type
            encoder_kwargs = copy.deepcopy(feat_encoder.params)
            encoder_kwargs.model_data_dir = self.config.model_data_dir

            feat_model = ImageFeatureEncoder(encoder_type, feature_dim,
                                             **encoder_kwargs)

            feat_encoders.append(feat_model)
            setattr(self, attr + "_feature_dim", feat_model.out_dim)

        setattr(self, attr + "_feature_encoders", nn.ModuleList(feat_encoders))

    def _init_feature_embeddings(self, attr):
        feature_embeddings_list = []
        num_feature_feat = len(
            getattr(self.config, f"{attr}_feature_encodings"))

        # print("num_feature_feat", num_feature_feat)
        # print(getattr(self.config, f"{attr}_feature_encodings"))
        # [{'type': 'finetune_faster_rcnn_fpn_fc7', 'params': {'bias_file': 'models/detectron.defaults/fc7_b.pkl', 'weights_file': 'models/detectron.defaults/fc7_w.pkl', 'model_data_dir': '/media/ubuntu/MyDisk/data_mmf/vg'}}, {'type': 'default', 'params': {'model_data_dir': '/media/ubuntu/MyDisk/data_mmf/vg'}}]

        self.feature_embeddings_out_dim = 0

        for _ in range(num_feature_feat):  #2
            feature_embeddings = []
            feature_attn_model_list = self.config[attr + "_feature_embeddings"]

            # print ("feature_attn_model_list", feature_attn_model_list)
            # # [{'modal_combine': {'type': 'non_linear_element_multiply', 'params': {'dropout': 0, 'hidden_dim': 5000}}, 'normalization': 'softmax', 'transform': {'type': 'linear', 'params': {'out_dim': 1}}}]
            # print("attr_feat_dim", getattr(self, attr + "_feature_dim")) #2048
            # print("text_embeddings_out_dim", self.text_embeddings_out_dim) # 2048

            for feature_attn_model_params in feature_attn_model_list:
                feature_embedding = ImageFeatureEmbedding(
                    getattr(self, attr + "_feature_dim"),  #2048
                    self.text_embeddings_out_dim,
                    **feature_attn_model_params,
                )
                # print ("feature_embedding", feature_embedding) #a embedding model
                feature_embeddings.append(feature_embedding)
                self.feature_embeddings_out_dim += feature_embedding.out_dim

            feature_embeddings = nn.ModuleList(feature_embeddings)
            feature_embeddings_list.append(feature_embeddings)

        self.feature_embeddings_out_dim *= getattr(self, attr + "_feature_dim")

        setattr(self, attr + "_feature_embeddings_out_dim",
                self.feature_embeddings_out_dim)
        del self.feature_embeddings_out_dim
        setattr(
            self,
            attr + "_feature_embeddings_list",
            nn.ModuleList(feature_embeddings_list),
        )

    def _get_embeddings_attr(self, attr):
        embedding_attr1 = attr
        if hasattr(self, attr + "_embeddings_out_dim"):
            embedding_attr1 = attr + "_embeddings_out_dim"
        else:
            embedding_attr1 = attr + "_feature_embeddings_out_dim"

        return embedding_attr1

    def _init_combine_layer(self, attr1, attr2):
        config_attr = attr1 + "_" + attr2 + "_modal_combine"

        multi_modal_combine_layer = ModalCombineLayer(
            self.config[config_attr].type,
            getattr(self, self._get_embeddings_attr(attr1)),
            getattr(self, self._get_embeddings_attr(attr2)),
            **self.config[config_attr].params,
        )

        setattr(
            self,
            attr1 + "_" + attr2 + "_multi_modal_combine_layer",
            multi_modal_combine_layer,
        )

    def _init_classifier(self, combined_embedding_dim):
        # TODO: Later support multihead
        num_choices = registry.get(self._datasets[0] + "_num_final_outputs")
        # print (num_choices)
        self.classifier = ClassifierLayer(
            self.config.classifier.type,
            in_dim=combined_embedding_dim,
            out_dim=num_choices,
            **self.config.classifier.params,
        )

    def _init_extras(self):
        self.inter_model = None

    def get_optimizer_parameters(self, config):
        combine_layer = self.image_text_multi_modal_combine_layer
        params = [
            {
                "params": self.word_embedding.parameters()
            },
            {
                "params": self.image_feature_embeddings_list.parameters()
            },
            {
                "params": self.text_embeddings.parameters()
            },
            {
                "params": combine_layer.parameters()
            },
            {
                "params": self.classifier.parameters()
            },
            {
                "params": self.image_feature_encoders.parameters(),
                "lr": (config.optimizer.params.lr * 0.1),
            },
        ]

        return params

    def _get_classifier_input_dim(self):
        return self.image_text_multi_modal_combine_layer.out_dim

    def process_text_embedding(self,
                               sample_list,
                               embedding_attr="text_embeddings",
                               info=None):

        # print("=====text embedding=====")

        text_embeddings = []

        # Get "text" attribute in case of "text_embeddings" case
        # and "context" attribute in case of "context_embeddings"
        texts = getattr(sample_list, embedding_attr.split("_")[0])

        # print("text", texts.size())  # bs*20*300

        # Get embedding models
        text_embedding_models = getattr(self, embedding_attr)

        for text_embedding_model in text_embedding_models:
            # print("text_model", text_embedding_model)
            '''
            text_model TextEmbedding(
                (module): AttentionTextEmbedding(
                    (recurrent_unit): LSTM(300, 1024, batch_first=True)
                    (dropout): Dropout(p=0, inplace=False)
                    (conv1): Conv1d(1024, 512, kernel_size=(1,), stride=(1,))
                    (conv2): Conv1d(512, 2, kernel_size=(1,), stride=(1,))
                    (relu): ReLU()
                )
            )
            '''
            # TODO: Move this logic inside
            if isinstance(text_embedding_model, PreExtractedEmbedding):
                embedding = text_embedding_model(sample_list.question_id)
            else:
                embedding = text_embedding_model(texts)

            # print("text_embedding: ", embedding.size())
            # torch.Size([4(bs), 2048])
            text_embeddings.append(embedding)

        text_embeddding_total = torch.cat(text_embeddings, dim=1)
        # print("text_embedding_tot: ", text_embeddding_total.size()) # torch.Size([4(bs), 2048])

        return text_embeddding_total

    def process_feature_embedding(self,
                                  attr,
                                  sample_list,
                                  text_embedding_total,
                                  extra=None,
                                  batch_size_t=None):
        if extra is None:
            extra = []
        feature_embeddings = []
        feature_attentions = []
        features = []
        batch_size_t = (sample_list.get_batch_size()
                        if batch_size_t is None else batch_size_t)

        # Convert list of keys to the actual values
        extra = sample_list.get_fields(extra)

        feature_idx = 0
        # print("=====feature encoder=====")

        # Get all of the features, which are in the form, "image_feature_0"
        # "image_feature_1" ...
        while True:
            feature = getattr(sample_list, f"{attr}_feature_{feature_idx:d}",
                              None)
            if feature is None:
                break
            feature_idx += 1
            feature = feature[:batch_size_t]
            features.append(feature)

        feature_encoders = getattr(self, attr + "_feature_encoders")
        # Each feature should have a separate image feature encoders
        assert len(features) == len(feature_encoders), (
            "Number of feature encoders, {} are not equal "
            "to number of features, {}.".format(len(feature_encoders),
                                                len(features)))

        # Now, iterate to get final attended image features
        for i, feature in enumerate(features):
            # Get info related to the current feature. info is generally
            # in key of format "image_info_0" for 0th feature
            feature_info = getattr(sample_list, f"{attr}_info_{i:d}", {})
            # print("feature_i: ", i, feature.size())
            # print("feature_info", feature_info)

            # For Pythia, we need max_features to mask attention
            feature_dim = getattr(feature_info, "max_features", None)
            if feature_dim is not None:
                feature_dim = feature_dim[:batch_size_t]

            # print("feat_dim", feature_dim)  # none

            # Attribute in which encoders are saved, for "image" it
            # will be "image_feature_encoders", other example is
            # "context_feature_encoders"
            encoders_attr = attr + "_feature_encoders"
            feature_encoder = getattr(self,
                                      encoders_attr)[i]  #repeat of line 271
            # print("feature_encoder", feature_encoder)
            '''
            feature_i:  0 torch.Size([64, 100, 2048])
            feature_info {}
            feature_encoder ImageFeatureEncoder(
            (module): FinetuneFasterRcnnFpnFc7(
                (lc): Linear(in_features=2048, out_features=2048, bias=True)
            )
            )
            feature_i:  1 torch.Size([64, 196, 2048])
            feature_info {}
            feature_encoder ImageFeatureEncoder(
            (module): Identity()
            )
            '''

            # Encode the features
            encoded_feature = feature_encoder(feature)

            # print("encoded_feat:", encoded_feature.size()) # torch.Size([64, 100, 2048])
            # print("=====feat_embedding", i, "===== ")

            # Get all of the feature embeddings
            list_attr = attr + "_feature_embeddings_list"
            feature_embedding_models = getattr(
                self, list_attr)[i]  # image_feature_embeddings_list

            # Forward through these embeddings one by one
            for feature_embedding_model in feature_embedding_models:
                inp = (encoded_feature, text_embedding_total, feature_dim,
                       extra)
                # torch.Size([64, 100, 2048]), [64,2048], none, samplelist()
                # print(feature_embedding_model)
                # print(encoded_feature.size())
                # print(text_embedding_total.size())

                embedding, attention = feature_embedding_model(*inp)
                feature_embeddings.append(embedding)
                feature_attentions.append(attention.squeeze(-1))

                print("feature_embeddings_&_attns")
                print(embedding.size())  # torch.Size([64, 2048])
                print(attention.size())  # torch.Size([64, 196, 1])

        # Concatenate all features embeddings and return along with attention
        feature_embedding_total = torch.cat(feature_embeddings, dim=1)

        # print("feature_embeddings_tot")
        # print(feature_embedding_total.size())

        return feature_embedding_total, feature_attentions

    def combine_embeddings(self, *args):
        feature_names = args[0]
        feature_embeddings = args[1]

        layer = "_".join(feature_names) + "_multi_modal_combine_layer"
        return getattr(self, layer)(*feature_embeddings)

    def calculate_logits(self, joint_embedding, **kwargs):
        return self.classifier(joint_embedding)

    def forward(self, sample_list):
        print("=====sample_list=====")
        print(sample_list.fields())
        for key in sample_list.keys():
            print(key + ":")
            # print(type(sample_list[key]))
            if isinstance(sample_list[key], str):
                print("str:", sample_list[key])
            elif isinstance(sample_list[key],
                            dict):  # region description: dict
                for key2 in sample_list[key].keys():
                    print("    " + key2 + ":")
                    # if type(sample_list[key][key2]) is np.ndarray:
                    #     print(sample_list[key][key2].shape)
                    # else:
                    #     print(sample_list[key][key2])
                    print(sample_list[key][key2])
            elif isinstance(sample_list[key], list):
                for i in sample_list[key]:  # image info 1: [none, ]
                    if i != None:
                        print(i.keys(), i.values)
                    else:
                        print(i)
            else:
                print(sample_list[key].size())

        sample_list.text = self.word_embedding(sample_list.text)
        text_embedding_total = self.process_text_embedding(sample_list)

        image_embedding_total, _ = self.process_feature_embedding(
            "image", sample_list, text_embedding_total)

        if self.inter_model is not None:
            image_embedding_total = self.inter_model(image_embedding_total)

        # print("image_embedding", image_embedding_total.size())  # [batch*4096]
        # print("text_embedding:" , text_embedding_total.size())  # [batch*?]
        # print("=====combine layer=====")

        joint_embedding = self.combine_embeddings(
            ["image", "text"], [image_embedding_total, text_embedding_total])

        # print("joint_embedding:", joint_embedding.size()) # [batch*5000]

        model_output = {"scores": self.calculate_logits(joint_embedding)}

        # print("model_output:", model_output['scores'].size()) # [64, 3129]
        return model_output