Esempio n. 1
0
    def build_encoders(self):
        self.encoders = nn.ModuleDict()

        for modality in self.config.modalities:
            if "encoder" not in modality:
                # Support "image_encoder" attribute in config if directly provided
                if modality.type == "image" and "image_encoder" in self.config:
                    encoder_config = self.config.image_encoder
                else:
                    # 100 is a random number added to satisfy identity encoder
                    # Set encoder to identity
                    encoder_config = OmegaConf.create({
                        "type": "identity",
                        "params": {
                            "in_dim": 100
                        }
                    })
            else:
                encoder_config = modality.encoder

            encoder = build_encoder(encoder_config)
            self.encoders[modality.key] = encoder

            if modality.type == "image" and getattr(
                    self.config, "freeze_image_encoder", False):
                for param in encoder.parameters():
                    param.requires_grad = False
Esempio n. 2
0
    def build(self):

        # to be further set
        # breakpoint()
        self.image_feature_module = build_image_encoder(
            self.config.image_feature_processor, direct_features=True
        )
        if self.config.concate_trace:
            self.trace_feature_module = build_encoder(self.config.trace_feature_encoder)

        if self.config.base_model_name == "bert-base-uncased":
            self.encoderdecoder = EncoderDecoderModel.from_encoder_decoder_pretrained(
                "bert-base-uncased", "bert-base-uncased"
            )
        elif self.config.base_model_name == "2layer-base":
            config_encoder = BertConfig()
            config_decoder = BertConfig()
            config_encoder.max_position_embeddings = 1090
            config_encoder.num_hidden_layers = 2
            config_decoder.num_hidden_layers = 2
            self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs(
                config_encoder, config_decoder
            )
            self.encoderdecoder = EncoderDecoderModel(config=self.codec_config)
        elif self.config.base_model_name == "3layer-base":
            config_encoder = BertConfig()
            config_decoder = BertConfig()
            config_encoder.num_hidden_layers = 3
            config_decoder.num_hidden_layers = 3
            self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs(
                config_encoder, config_decoder
            )
            self.encoderdecoder = EncoderDecoderModel(config=self.codec_config)
        if self.config.loop_contrastive:
            self.trace_caption_contrastive = TraceCaptionContrastiveModel(
                self.config.tc_contrastive_aggregate_method
            )
        if (
            hasattr(self.config, "pretrans_attention")
            and self.config.pretrans_attention
        ):

            # import ipdb; ipdb.set_trace()
            tempconf = self.encoderdecoder.config.encoder
            num_heads = tempconf.num_attention_heads
            num_layers = tempconf.num_hidden_layers
            self.attention_trans = AttentionTransform(num_layers, num_heads, 100)
        self.BOS_ID = 101
        self.vae = OpenAIDiscreteVAE()
        image_code_dim = 768
        image_fmap_size = self.vae.image_size // (2 ** self.vae.num_layers)
        self.image_seq_len = image_fmap_size ** 2
        self.image_emb = torch.nn.Embedding(self.vae.num_tokens, image_code_dim)
        self.image_pos_emb = AxialPositionalEmbedding(
            image_code_dim, axial_shape=(image_fmap_size, image_fmap_size)
        )
Esempio n. 3
0
    def _build_model(self):
        self.model_items = load_pretrained_model(self.checkpoint)
        self.config = OmegaConf.create(self.model_items["full_config"])
        dataset_name = list(self.config.dataset_config.keys())[0]
        processor = build_processors(
            self.config.dataset_config[dataset_name].processors)
        feature_extractor = build_encoder(
            self.model_items["config"].image_feature_encodings)
        ckpt = self.model_items["checkpoint"]
        model = build_model(self.model_items["config"])
        model.load_state_dict(ckpt)

        return processor, feature_extractor, model
Esempio n. 4
0
    def build(self):
        self.text_embeddings = ViLTTextEmbedding(**self.config.text_embeddings)
        self.image_embeddings = ViLTImageEmbedding(
            **self.config.image_encoder.params)
        self.encoder = build_encoder(self.config.image_encoder)

        head_configs = self.config.get("heads", {})
        self.tasks = self.config.get("tasks", head_configs.keys())
        if isinstance(self.tasks, str):
            self.tasks = self.tasks.split(",")

        self.losses = nn.ModuleDict()
        self.heads_dict = build_heads_dict(head_configs, self.tasks,
                                           self.losses)
        self.modality_keys = self.modality_type = ["text", "image"]
Esempio n. 5
0
    def build(self):

        # to be further set
        # breakpoint()
        self.image_feature_module = build_image_encoder(
            self.config.image_feature_processor, direct_features=True)
        if self.config.concate_trace:
            self.trace_feature_module = build_encoder(
                self.config.trace_feature_encoder)

        if self.config.base_model_name == "bert-base-uncased":
            self.encoderdecoder = EncoderDecoderModel.from_encoder_decoder_pretrained(
                "bert-base-uncased", "bert-base-uncased")
        elif self.config.base_model_name == "2layer-base":
            config_encoder = BertConfig()
            config_decoder = BertConfig()
            config_encoder.num_hidden_layers = 2
            config_decoder.num_hidden_layers = 2
            self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs(
                config_encoder, config_decoder)
            self.encoderdecoder = EncoderDecoderModel(config=self.codec_config)
        elif self.config.base_model_name == "3layer-base":
            config_encoder = BertConfig()
            config_decoder = BertConfig()
            config_encoder.num_hidden_layers = 3
            config_decoder.num_hidden_layers = 3
            self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs(
                config_encoder, config_decoder)
            self.encoderdecoder = EncoderDecoderModel(config=self.codec_config)
        if self.config.loop_contrastive:
            self.trace_caption_contrastive = TraceCaptionContrastiveModel(
                self.config.tc_contrastive_aggregate_method)
        if (hasattr(self.config, "pretrans_attention")
                and self.config.pretrans_attention):

            # import ipdb; ipdb.set_trace()
            tempconf = self.encoderdecoder.config.encoder
            num_heads = tempconf.num_attention_heads
            num_layers = tempconf.num_hidden_layers
            self.attention_trans = AttentionTransform(num_layers, num_heads,
                                                      100)
        self.BOS_ID = 101