Ejemplo n.º 1
0
    def forward(self, frcn_feat, grid_feat, bbox_feat, ques_ix):

        # Pre-process Language Feature
        if self.__C.USE_BERT:
            lang_feat_mask = make_mask(ques_ix[:, 1:-1].unsqueeze(2))
        else:
            lang_feat_mask = make_mask(ques_ix.unsqueeze(2))

        if self.__C.BERT_ENCODER:
            outputs = self.encoder(ques_ix)
            last_hidden_state = outputs[0]
            lang_feat = last_hidden_state[:, 1:-1, :] # remove CLS and SEP, making this to max_token=14
        elif not self.__C.BERT_ENCODER and self.__C.USE_BERT:
            outputs = self.bert_layer(ques_ix)
            # Uncomment this to use last layer
            # last_hidden_state = outputs[0][:, 1:-1, :] # remove CLS and SEP, making this to max_token=14
            # lang_feat, _ = self.lstm(last_hidden_state)
            # Concatenation of the last four layers
            hidden_states = outputs[2]
            concat_layers = torch.cat([hidden_states[i] for i in [-1, -2, -3, -4]], dim=-1)
            concat_layers = concat_layers[:, 1:-1, :]
            lang_feat, _ = self.lstm(concat_layers)
        elif self.__C.USE_GLOVE:
            lang_feat = self.embedding(ques_ix)
            lang_feat, _ = self.lstm(lang_feat)

        img_feat, img_feat_mask = self.adapter(frcn_feat, grid_feat, bbox_feat)

        # Backbone Framework
        lang_feat, img_feat = self.backbone(
            lang_feat,
            img_feat,
            lang_feat_mask,
            img_feat_mask
        )

        # Flatten to vector
        lang_feat, _ = self.attflat_lang(
            lang_feat,
            lang_feat_mask
        )

        img_feat, img_att = self.attflat_img(
            img_feat,
            img_feat_mask
        )

        # Classification layers
        if self.__C.FUSION == "sum":
            proj_feat = lang_feat + img_feat
        elif self.__C.FUSION == "product":
            proj_feat = lang_feat * img_feat
        proj_feat = self.proj_norm(proj_feat)
        proj_feat = self.proj(proj_feat)

        return proj_feat, img_att
Ejemplo n.º 2
0
    def gqa_forward(self, feat_dict):
        frcn_feat = feat_dict['FRCN_FEAT']
        bbox_feat = feat_dict['BBOX_FEAT']
        grid_feat = feat_dict['GRID_FEAT']

        img_feat_mask = torch.cat((make_mask(frcn_feat), make_mask(grid_feat)), dim=-1)
        bbox_feat = self.bbox_linear(bbox_feat)
        frcn_feat = torch.cat((frcn_feat, bbox_feat), dim=-1)
        frcn_feat = self.frcn_linear(frcn_feat)
        grid_feat = self.grid_linear(grid_feat)
        img_feat = torch.cat((frcn_feat, grid_feat), dim=1)

        return img_feat, img_feat_mask
Ejemplo n.º 3
0
    def forward(self, frcn_feat, grid_feat, bbox_feat, ques_ix):
        batch_size = ques_ix.shape[0]
        device = ques_ix.device

        # Pre-process Language Feature
        text_feat_mask = make_mask(ques_ix.unsqueeze(2))
        text_feat = self.embedding(ques_ix)
        text_feat, _ = self.lstm(text_feat)
        text_feat = self.lstm_proj(text_feat)

        img_feat, img_feat_mask = self.adapter(frcn_feat, grid_feat, bbox_feat)

        cls_token = torch.tensor(self.token_to_ix['CLS'],
                                 device=device).repeat(batch_size, 1)
        cls_token = self.embedding(cls_token)
        cls_token = self.token_proj(cls_token)

        img_token = torch.tensor(self.token_to_ix['IMG'],
                                 device=device).repeat(batch_size, 1)
        img_token = self.embedding(img_token)
        img_token = self.token_proj(img_token)

        text_feat = torch.cat([cls_token, text_feat], dim=1)
        img_feat = torch.cat([img_token, img_feat], dim=1)

        img_mask = make_mask(img_token)
        cls_mask = make_mask(cls_token)
        text_feat_mask = torch.cat([cls_mask, text_feat_mask], dim=-1)
        img_feat_mask = torch.cat([img_mask, img_feat_mask], dim=-1)

        # Backbone Framework
        lang_feat, img_feat, text_attention_map, img_attention_map = self.backbone(
            text_feat, img_feat, text_feat_mask, img_feat_mask)

        img_attention_map = img_attention_map[:, :, :, 0, :]
        txt_attention_map = text_attention_map[:, :, :, 0, :]

        text_pool = self.text_pooler(lang_feat)
        img_pool = self.img_pooler(img_feat)

        # Classification layers
        pooled_output = self.dropout(text_pool * img_pool)
        pooled_output = self.dense(pooled_output)
        pooled_output = self.activation(pooled_output)
        pooled_output = self.layer_norm(pooled_output)
        output = self.cls(pooled_output)

        return output, img_attention_map, txt_attention_map
Ejemplo n.º 4
0
    def clevr_forward(self, feat_dict):
        grid_feat = feat_dict['GRID_FEAT']

        img_feat_mask = make_mask(grid_feat)
        img_feat = self.grid_linear(grid_feat)

        return img_feat, img_feat_mask
Ejemplo n.º 5
0
    def forward(self, frcn_feat, grid_feat, bbox_feat, ques_ix):
        cls_tensor = torch.full((ques_ix.shape[0], 1), 2,
                                dtype=torch.long).cuda()
        ques_ix = torch.cat((cls_tensor, ques_ix), dim=1)

        # Pre-process Language Feature
        lang_feat_mask = make_mask(ques_ix.unsqueeze(2))
        lang_feat = self.embedding(ques_ix)
        lang_feat, _ = self.lstm(lang_feat)

        img_feat, img_feat_mask = self.adapter(frcn_feat, grid_feat, bbox_feat)

        lang_feat = self.norm1(lang_feat)
        img_feat = self.norm2(img_feat)

        fuse_feat = torch.cat((lang_feat, img_feat), dim=1)
        fuse_feat_mask = torch.cat((lang_feat_mask, img_feat_mask), dim=-1)

        # Backbone Framework
        fuse_feat = self.backbone(fuse_feat, fuse_feat_mask)

        # Flatten to vector
        fuse_flat = self.flat(fuse_feat, fuse_feat_mask)

        # Classification layers
        # proj_feat = lang_feat + img_feat
        # proj_feat = self.proj_norm(proj_feat)
        proj_feat = self.proj(fuse_flat)

        return proj_feat
    def forward(self, frcn_feat, grid_feat, bbox_feat, ques_ix):

        # Pre-process Language Feature
        lang_feat_mask = make_mask(ques_ix.unsqueeze(2))
        lang_feat = self.embedding(ques_ix)
        lang_feat, _ = self.lstm(lang_feat)

        img_feat, rel_embed, img_feat_mask = self.adapter(
            frcn_feat, grid_feat, bbox_feat)
        rela = self.relu(self.linear_rel(rel_embed))

        # Backbone Framework
        lang_feat, img_feat = self.backbone(lang_feat, img_feat,
                                            lang_feat_mask, img_feat_mask,
                                            rela)

        # Flatten to vector
        lang_feat = self.attflat_lang(lang_feat, lang_feat_mask)

        img_feat = self.attflat_img(img_feat, img_feat_mask)

        # Classification layers
        proj_feat = lang_feat + img_feat
        proj_feat = self.proj_norm(proj_feat)
        proj_feat = self.proj(proj_feat)

        return proj_feat
Ejemplo n.º 7
0
    def vqa_forward(self, feat_dict):
        frcn_feat = feat_dict['FRCN_FEAT']
        bbox_feat = feat_dict['BBOX_FEAT']

        img_feat_mask = make_mask(frcn_feat)
        # img_feat = self.frcn_linear(frcn_feat)

        return frcn_feat, img_feat_mask
Ejemplo n.º 8
0
    def vqa_forward(self, feat_dict):
        frcn_feat = feat_dict['FRCN_FEAT']
        bbox_feat = feat_dict['BBOX_FEAT']

        img_feat_mask = make_mask(frcn_feat)
        img_feat = frcn_feat
        #[N, C, W] = img_feat.shape
        #img_feat = F.normalize(img_feat.view(N, -1)).view(N, C, W)
        return img_feat, img_feat_mask
    def clevr_forward(self, feat_dict):
        grid_feat = feat_dict['GRID_FEAT']

        img_feat_mask = make_mask(grid_feat)
        img_feat = self.grid_linear(grid_feat)

        rel_embed = self.relation_embedding(bbox_feat)

        return img_feat, rel_embed, img_feat_mask
Ejemplo n.º 10
0
    def gqa_forward(self, feat_dict):
        frcn_feat = feat_dict['FRCN_FEAT']
        bbox_feat = feat_dict['BBOX_FEAT']
        grid_feat = feat_dict['GRID_FEAT']

        img_feat_mask = make_mask(frcn_feat)

        if self.__C.USE_BBOX_FEAT:
            bbox_feat = self.bbox_linear(bbox_feat)
            frcn_feat = torch.cat((frcn_feat, bbox_feat), dim=-1)
        img_feat = self.frcn_linear(frcn_feat)

        if self.__C.USE_AUX_FEAT:
            grid_feat_mask = make_mask(grid_feat)
            img_feat_mask = torch.cat((img_feat_mask, grid_feat_mask), dim=-1)
            grid_feat = self.grid_linear(grid_feat)
            img_feat = torch.cat((img_feat, grid_feat), dim=1)

        return img_feat, img_feat_mask
Ejemplo n.º 11
0
    def clevr_forward(self, feat_dict):
        grid_feat = feat_dict['GRID_FEAT']

        img_feat_mask = make_mask(grid_feat)

        img_feat = grid_feat.permute(0, 2, 1)
        img_feat = img_feat.view(-1, 1024, 14, 14)
        img_feat = self.conv(img_feat)

        return img_feat, img_feat_mask
Ejemplo n.º 12
0
    def vqa_forward(self, feat_dict):
        frcn_feat = feat_dict['FRCN_FEAT']
        bbox_feat = feat_dict['BBOX_FEAT']

        img_feat_mask = make_mask(frcn_feat)

        if self.__C.USE_BBOX_FEAT:
            bbox_feat = self.bbox_linear(bbox_feat)
            frcn_feat = torch.cat((frcn_feat, bbox_feat), dim=-1)
        img_feat = self.frcn_linear(frcn_feat)

        return img_feat, img_feat_mask
Ejemplo n.º 13
0
    def forward(self, frcn_feat, grid_feat, bbox_feat, ques_ix):

        # Pre-process Language Feature
        lang_feat_mask = make_mask(ques_ix.unsqueeze(2))
        lang_feat = self.embedding(ques_ix)
        lang_feat, _ = self.gru(lang_feat)

        img_feat, _ = self.adapter(frcn_feat, grid_feat, bbox_feat)

        # Backbone Framework

        lang_feat = self.backbone(lang_feat, img_feat)

        # Classification layers
        proj_feat = self.classifer(lang_feat.sum(1))

        return proj_feat
Ejemplo n.º 14
0
 def forward(self, feat, bbox):
     mask = make_mask(feat)
     b_feat = self.bbox_linear(bbox)
     feat = torch.cat((feat, b_feat), dim=-1)
     feat = self.frcn_linear(feat)
     return feat, mask
Ejemplo n.º 15
0
 def forward(self, feat):
     mask = make_mask(feat)
     feat = self.frcn_linear(feat)
     return feat, mask
Ejemplo n.º 16
0
    def forward(self, frcn_feat, grid_feat, bbox_feat, ques_ix):
        batch_size = ques_ix.shape[0]
        device = ques_ix.device

        # create text feature
        text_feat_mask = make_mask(ques_ix.unsqueeze(2))
        text_feat = self.word_embedding(ques_ix)
        text_feat, _ = self.lstm(text_feat)
        text_feat = self.lstm_proj(text_feat)

        # seq_length = text_feat.size()[1]
        # text_position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
        # text_position_ids = text_position_ids.unsqueeze(0).expand(ques_ix.size())
        # text_position_embeddings = self.text_position_embeddings(text_position_ids)

        # create text segment embedding
        text_seg_ids = torch.zeros(text_feat.size()[:-1],
                                   dtype=torch.long,
                                   device=device)
        text_seg_embedding = self.segment_embedding(text_seg_ids)

        # text embedding
        text_feat = text_feat + text_seg_embedding  # + text_position_embeddings

        # image features and mask
        img_feat, img_feat_mask = self.img_encoder(frcn_feat, grid_feat,
                                                   bbox_feat)

        # create image segment embedding
        img_seg_ids = torch.ones(img_feat.size()[:-1],
                                 dtype=torch.long,
                                 device=device)
        img_seg_embedding = self.segment_embedding(img_seg_ids)

        # image position embeddign
        width = 14
        height = 14
        img_pos = torch.meshgrid([
            torch.arange(width, dtype=torch.float, device=device),
            torch.arange(height, dtype=torch.float, device=device)
        ])
        img_pos = torch.stack([img_pos[1], img_pos[0]], dim=-1).view(
            width * height, 2).unsqueeze(0).expand(batch_size, width * height,
                                                   2)
        img_pos_emb = self.img_pos_emb(img_pos)

        # image embedding
        img_feat = img_feat + img_seg_embedding + img_pos_emb

        # CLS embedding
        cls_token = torch.tensor(self.token_to_ix['CLS'],
                                 device=device).repeat(batch_size, 1)
        cls_token = self.word_embedding(cls_token)
        cls_token = self.cls_project(cls_token)

        # prepare input embedding for transformer
        embeddings = torch.cat([cls_token, text_feat, img_feat], dim=1)
        embeddings = self.layer_norm1(embeddings)
        embeddings = self.embbeding_dropout(embeddings)

        # prepare mask for self attention
        cls_mask = make_mask(cls_token)
        attention_mask = torch.cat([cls_mask, text_feat_mask, img_feat_mask],
                                   dim=-1)

        # Backbone Framework
        feat = self.transformer(embeddings, attention_mask)

        # Classification layers
        pooled_output = self.pooler(feat)
        pooled_output = self.cls_dropout(pooled_output)
        pooled_output = self.dense(pooled_output)
        pooled_output = self.activation(pooled_output)
        pooled_output = self.layer_norm2(pooled_output)
        output = self.classifier(pooled_output)
        return output