Esempio n. 1
0
    def build(self):

        # to be further set
        # breakpoint()
        self.image_feature_module = build_image_encoder(
            self.config.image_feature_processor, direct_features=True
        )
        if self.config.concate_trace:
            self.trace_feature_module = build_encoder(self.config.trace_feature_encoder)

        if self.config.base_model_name == "bert-base-uncased":
            self.encoderdecoder = EncoderDecoderModel.from_encoder_decoder_pretrained(
                "bert-base-uncased", "bert-base-uncased"
            )
        elif self.config.base_model_name == "2layer-base":
            config_encoder = BertConfig()
            config_decoder = BertConfig()
            config_encoder.max_position_embeddings = 1090
            config_encoder.num_hidden_layers = 2
            config_decoder.num_hidden_layers = 2
            self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs(
                config_encoder, config_decoder
            )
            self.encoderdecoder = EncoderDecoderModel(config=self.codec_config)
        elif self.config.base_model_name == "3layer-base":
            config_encoder = BertConfig()
            config_decoder = BertConfig()
            config_encoder.num_hidden_layers = 3
            config_decoder.num_hidden_layers = 3
            self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs(
                config_encoder, config_decoder
            )
            self.encoderdecoder = EncoderDecoderModel(config=self.codec_config)
        if self.config.loop_contrastive:
            self.trace_caption_contrastive = TraceCaptionContrastiveModel(
                self.config.tc_contrastive_aggregate_method
            )
        if (
            hasattr(self.config, "pretrans_attention")
            and self.config.pretrans_attention
        ):

            # import ipdb; ipdb.set_trace()
            tempconf = self.encoderdecoder.config.encoder
            num_heads = tempconf.num_attention_heads
            num_layers = tempconf.num_hidden_layers
            self.attention_trans = AttentionTransform(num_layers, num_heads, 100)
        self.BOS_ID = 101
        self.vae = OpenAIDiscreteVAE()
        image_code_dim = 768
        image_fmap_size = self.vae.image_size // (2 ** self.vae.num_layers)
        self.image_seq_len = image_fmap_size ** 2
        self.image_emb = torch.nn.Embedding(self.vae.num_tokens, image_code_dim)
        self.image_pos_emb = AxialPositionalEmbedding(
            image_code_dim, axial_shape=(image_fmap_size, image_fmap_size)
        )
Esempio n. 2
0

def exists(val):
    return val is not None


IMAGE_SIZE = args.image_size
VAE_CLASS = args.vae_class
if VAE_CLASS == 'VQGAN1024':
    vae = VQGanVAE1024()
elif VAE_CLASS == 'VQGAN16384':
    vae = VQGanVAE16384()
elif VAE_CLASS == 'VQGAN_CUSTOM':
    vae = VQGanVAECustom()
elif VAE_CLASS == 'DALLE':
    vae = OpenAIDiscreteVAE()
elif VAE_CLASS == 'DALLE_TRAIN':
    VAE_PATH = args.vae_path
    vae_path = Path(VAE_PATH)
    loaded_obj = torch.load(str(vae_path), map_location='cuda')
    vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']
    vae = DiscreteVAE(**vae_params)
    vae.load_state_dict(weights)
    vae.to('cuda')

filenames = os.listdir(args.target)

for filename in tqdm(filenames):
    TARGET_IMG_PATH = args.target + '/' + filename
    TARGET_SAVE_PATH = args.target + '/output/'
    filename = filename.split('.')[0]
Esempio n. 3
0
import logging
import torch
import numpy as np
from dalle_pytorch import OpenAIDiscreteVAE, DALLE

logging.basicConfig(
    format="%(asctime)s %(levelname)-8s %(message)s",
    level=logging.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
)

vae = OpenAIDiscreteVAE()  # loads pretrained OpenAI VAE


def generate_image_code(dalle, text, mask):
    vae, text_seq_len, image_seq_len, num_text_tokens = (
        dalle.vae,
        dalle.text_seq_len,
        dalle.image_seq_len,
        dalle.num_text_tokens,
    )
    total_len = text_seq_len + image_seq_len
    out = text

    for cur_len in range(text.shape[1], total_len):
        is_image = cur_len >= text_seq_len

        text, image = out[:, :text_seq_len], out[:, text_seq_len:]

        logits = dalle(text, image, mask=mask)[:, -1, :]
        chosen = torch.argmax(logits, dim=1, keepdim=True)
Esempio n. 4
0
class CrossVLGenerator(BaseModel):
    def __init__(self, config):
        super().__init__(config)
        self.build()

    @classmethod
    def config_path(cls):
        return "configs/models/cvlg/defaults.yaml"

    def build(self):

        # to be further set
        # breakpoint()
        self.image_feature_module = build_image_encoder(
            self.config.image_feature_processor, direct_features=True
        )
        if self.config.concate_trace:
            self.trace_feature_module = build_encoder(self.config.trace_feature_encoder)

        if self.config.base_model_name == "bert-base-uncased":
            self.encoderdecoder = EncoderDecoderModel.from_encoder_decoder_pretrained(
                "bert-base-uncased", "bert-base-uncased"
            )
        elif self.config.base_model_name == "2layer-base":
            config_encoder = BertConfig()
            config_decoder = BertConfig()
            config_encoder.max_position_embeddings = 1090
            config_encoder.num_hidden_layers = 2
            config_decoder.num_hidden_layers = 2
            self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs(
                config_encoder, config_decoder
            )
            self.encoderdecoder = EncoderDecoderModel(config=self.codec_config)
        elif self.config.base_model_name == "3layer-base":
            config_encoder = BertConfig()
            config_decoder = BertConfig()
            config_encoder.num_hidden_layers = 3
            config_decoder.num_hidden_layers = 3
            self.codec_config = EncoderDecoderConfig.from_encoder_decoder_configs(
                config_encoder, config_decoder
            )
            self.encoderdecoder = EncoderDecoderModel(config=self.codec_config)
        if self.config.loop_contrastive:
            self.trace_caption_contrastive = TraceCaptionContrastiveModel(
                self.config.tc_contrastive_aggregate_method
            )
        if (
            hasattr(self.config, "pretrans_attention")
            and self.config.pretrans_attention
        ):

            # import ipdb; ipdb.set_trace()
            tempconf = self.encoderdecoder.config.encoder
            num_heads = tempconf.num_attention_heads
            num_layers = tempconf.num_hidden_layers
            self.attention_trans = AttentionTransform(num_layers, num_heads, 100)
        self.BOS_ID = 101
        self.vae = OpenAIDiscreteVAE()
        image_code_dim = 768
        image_fmap_size = self.vae.image_size // (2 ** self.vae.num_layers)
        self.image_seq_len = image_fmap_size ** 2
        self.image_emb = torch.nn.Embedding(self.vae.num_tokens, image_code_dim)
        self.image_pos_emb = AxialPositionalEmbedding(
            image_code_dim, axial_shape=(image_fmap_size, image_fmap_size)
        )

    def forward(self, sample_list, *args, **kwargs):

        # breakpoint()
        # import ipdb; ipdb.set_trace()
        visual_code = self.vae.get_codebook_indices(sample_list["image"])
        visual_emb = self.image_emb(visual_code)
        visual_emb += self.image_pos_emb(visual_emb)

        decoder_input_ids = sample_list["input_ids"][:, :-1]
        # using default mask
        # target_mask = sample_list["input_mask"]
        # segment_ids = sample_list["segment_ids"]
        # token_attends = sample_list["token_attends"]
        other_kwargs = {}
        # if self.config.image_feature_processor.type == "spatial":
        #     bbox_feature = sample_list["image_feature_0"]
        #     spatial_feature = sample_list["image_info_0"]["bbox"]
        #     inputs_embeds = self.image_feature_module(bbox_feature, spatial_feature)
        # else:
        #     bbox_feature = sample_list["image_feature_0"]
        #     inputs_embeds = self.image_feature_module(bbox_feature)
        # if hasattr(self.config, "no_vision") and self.config.no_vision:
        #     inputs_embeds = inputs_embeds * 0
        inputs_embeds = visual_emb
        batch_size = inputs_embeds.shape[0]
        if self.config.concate_trace:
            trace_boxes = sample_list["trace_boxes"]
            trace_boxes_mask = sample_list["trace_boxes_mask"]
            trace_feature = self.trace_feature_module(trace_boxes)
            trace_seg_id = sample_list["trace_boxes_seg_id"]
            inputs_embeds = torch.cat((inputs_embeds, trace_feature), dim=1)
            image_feats_mask = trace_boxes_mask.new_ones(
                (batch_size, visual_code.shape[1])
            )
            image_feats_seg_id = trace_seg_id.new_zeros(
                (batch_size, visual_code.shape[1])
            )
            attention_mask = torch.cat((image_feats_mask, trace_boxes_mask), dim=1)
            token_type_ids = torch.cat((image_feats_seg_id, trace_seg_id), dim=1)
            position_ids = trace_seg_id.new_zeros((batch_size, attention_mask.shape[1]))
            other_kwargs.update(
                {
                    "attention_mask": attention_mask,
                    "token_type_ids": token_type_ids,
                    "position_ids": position_ids,
                }
            )

        if self.training:
            decoder_output = self.encoderdecoder(
                decoder_input_ids=decoder_input_ids,
                inputs_embeds=inputs_embeds,
                output_attentions=True,
                output_hidden_states=True,
                return_dict=True,
                **other_kwargs
            )

            logits = decoder_output["logits"]
            cross_attentions = []
            # import ipdb; ipdb.set_trace()
            for cross_attention in decoder_output["cross_attentions"]:
                if self.config.concate_trace:
                    cross_attention = cross_attention[:, :, :, :100]
                # cross_attentions.append(cross_attention.mean(dim=1))
                cross_attentions.append(cross_attention)
            # breakpoint()
            if (
                hasattr(self.config, "pretrans_attention")
                and self.config.pretrans_attention
            ):
                cross_attentions = self.attention_trans(cross_attentions)
            else:
                cross_attentions = [crs.mean(dim=1) for crs in cross_attentions]
            model_output = {}
            model_output["captions"] = torch.max(logits, dim=-1)[1]
            model_output["scores"] = logits
            model_output["cross_attentions"] = cross_attentions
            sample_list["targets"] = sample_list["input_ids"][:, 1:]

            if self.config.loop_contrastive:
                cap_feat, vision_trace_feat = self.trace_caption_contrastive(
                    decoder_output["encoder_hidden_states"][-1],
                    sample_list["trace_boxes_loop_contrastive_seg_id"],
                    decoder_output["decoder_hidden_states"][-1],
                    sample_list["segment_ids"],
                )
                model_output["contrastive_a"] = cap_feat
                model_output["contrastive_b"] = vision_trace_feat
        else:
            if self.config.inference.type == "beam_search":
                generate_output = self.encoderdecoder.generate(
                    input_ids=None,
                    input_embeds=inputs_embeds,
                    bos_token_id=self.BOS_ID,
                    decoder_start_token_id=self.BOS_ID,
                    **self.config.inference.args,
                    **other_kwargs
                )
            elif self.config.inference.type == "greedy":
                generate_output = self.encoderdecoder.generate(
                    input_ids=None,
                    input_embeds=inputs_embeds,
                    max_length=self.config.max_gen_length,
                    bos_token_id=self.BOS_ID,
                    decoder_start_token_id=self.BOS_ID,
                    **other_kwargs
                )
            elif self.config.inference.type == "nucleus_sampling":
                generate_output = self.encoderdecoder.generate(
                    input_ids=None,
                    input_embeds=inputs_embeds,
                    bos_token_id=self.BOS_ID,
                    decoder_start_token_id=self.BOS_ID,
                    **self.config.inference.args,
                    **other_kwargs
                )
            model_output = {}
            # breakpoint()
            if (
                "return_attention" in self.config.inference
                and self.config.inference.return_attention
            ):
                with torch.no_grad():
                    attention_temp_output = self.encoderdecoder(
                        decoder_input_ids=generate_output,
                        inputs_embeds=inputs_embeds,
                        output_attentions=True,
                        return_dict=True,
                    )
                    cross_attentions = []
                    for cross_attention in attention_temp_output["cross_attentions"]:
                        if self.config.concate_trace:
                            cross_attention = cross_attention[:, :, :, :100]
                        cross_attentions.append(cross_attention.mean(dim=1))
                    # breakpoint()
                    cross_attentions = (
                        torch.stack(cross_attentions).max(dim=0)[0].max(dim=-1)[1]
                    )
                    model_output["cross_attention"] = cross_attentions
                # breakpoint()

            model_output["captions"] = generate_output
            model_output["losses"] = {}
            loss_key = "{}/{}".format(
                sample_list.dataset_name, sample_list.dataset_type
            )
            # Add a dummy loss so that loss calculation is not required
            model_output["losses"][loss_key + "/dummy_loss"] = torch.zeros(
                1, device=sample_list.image_feature_0.device
            )
            # breakpoint()

        return model_output