示例#1
0
    def __init__(self, opt):
        super(JointMatching, self).__init__()
        num_layers = opt['rnn_num_layers']
        hidden_size = opt['rnn_hidden_size']
        num_dirs = 2 if opt['bidirectional'] > 0 else 1
        jemb_dim = opt['jemb_dim']
        self.unk_token = opt['unk_token']

        # language rnn encoder
        word_emb_path = os.path.join(os.getcwd(), 'glove_emb',
                                     opt['dataset'] + '.npy')
        dict_emb = np.load(word_emb_path)

        self.rnn_encoder = RNNEncoderImgGlove(
            dict_emb,
            vocab_size=opt['vocab_size'],
            word_embedding_size=opt['word_embedding_size'],
            word_vec_size=opt['word_vec_size'],
            hidden_size=opt['rnn_hidden_size'],
            bidirectional=opt['bidirectional'] > 0,
            input_dropout_p=opt['word_drop_out'],
            dropout_p=opt['rnn_drop_out'],
            n_layers=opt['rnn_num_layers'],
            rnn_type=opt['rnn_type'],
            variable_lengths=opt['variable_lengths'] > 0)

        # [vis; loc] weighter
        self.weight_fc = nn.Linear(num_layers * num_dirs * hidden_size, 3)

        # phrase attender
        self.sub_attn = PhraseAttention(hidden_size * num_dirs)
        self.loc_attn = PhraseAttention(hidden_size * num_dirs)
        self.rel_attn = PhraseAttention(hidden_size * num_dirs)

        # visual matching
        self.sub_encoder = SubjectEncoder(opt)
        self.sub_matching = Matching(opt['fc7_dim'] + opt['jemb_dim'],
                                     opt['word_vec_size'], opt['jemb_dim'],
                                     opt['jemb_drop_out'])

        # location matching
        self.loc_encoder = LocationEncoderAtt(opt)
        self.loc_matching = Matching(opt['jemb_dim'], opt['word_vec_size'],
                                     opt['jemb_dim'], opt['jemb_drop_out'])

        # relation matching
        self.rel_encoder = RelationEncoderAtt(opt)
        self.rel_matching = Matching(opt['jemb_dim'], opt['word_vec_size'],
                                     opt['jemb_dim'], opt['jemb_drop_out'])

        self.fc7_dim = opt['fc7_dim']
        self.process_pool5 = nn.Sequential(nn.Conv2d(opt['pool5_dim'], 256, 1),
                                           nn.BatchNorm2d(256), nn.ReLU(),
                                           nn.Conv2d(256, 256, 3, padding=1),
                                           nn.BatchNorm2d(256), nn.ReLU(),
                                           nn.Conv2d(256, opt['pool5_dim'], 1),
                                           nn.BatchNorm2d(opt['pool5_dim']))
        self.proj_img = nn.Linear(opt['pool5_dim'], opt['jemb_dim'])
示例#2
0
    def __init__(self, opt):
        super(JointMatching, self).__init__()
        num_layers = opt['rnn_num_layers']
        hidden_size = opt['rnn_hidden_size']
        num_dirs = 2 if opt['bidirectional'] > 0 else 1
        jemb_dim = opt['jemb_dim']

        # language rnn encoder
        self.rnn_encoder = RNNEncoder(
            vocab_size=opt['vocab_size'],
            word_embedding_size=opt['word_embedding_size'],
            word_vec_size=opt['word_vec_size'],
            hidden_size=opt['rnn_hidden_size'],
            bidirectional=opt['bidirectional'] > 0,
            input_dropout_p=opt['word_drop_out'],
            dropout_p=opt['rnn_drop_out'],
            n_layers=opt['rnn_num_layers'],
            rnn_type=opt['rnn_type'],
            variable_lengths=opt['variable_lengths'] > 0)

        # [vis; loc] weighter
        self.weight_fc = nn.Linear(num_layers * num_dirs * hidden_size, 3)

        # phrase attender
        self.sub_attn = PhraseAttention(hidden_size * num_dirs)
        self.loc_attn = PhraseAttention(hidden_size * num_dirs)
        self.rel_attn = PhraseAttention(hidden_size * num_dirs)

        # visual matching
        self.sub_encoder = SubjectEncoder(opt)
        self.sub_matching = Matching(opt['fc7_dim'] + opt['jemb_dim'],
                                     opt['word_vec_size'], opt['jemb_dim'],
                                     opt['jemb_drop_out'])

        # location matching
        self.loc_encoder = LocationEncoder(opt)
        self.loc_matching = Matching(opt['jemb_dim'], opt['word_vec_size'],
                                     opt['jemb_dim'], opt['jemb_drop_out'])

        # relation matching
        self.rel_encoder = RelationEncoder(opt)
        self.rel_matching = RelationMatching(opt['jemb_dim'],
                                             opt['word_vec_size'],
                                             opt['jemb_dim'],
                                             opt['jemb_drop_out'])
示例#3
0
文件: match.py 项目: amanishr/ARN
    def __init__(self, opt):
        super(AdaptiveReconstruct, self).__init__()
        num_layers = opt['rnn_num_layers']
        hidden_size = opt['rnn_hidden_size']
        num_dirs = 2 if opt['bidirectional'] > 0 else 1
        self.word_vec_size = opt['word_vec_size']
        self.pool5_dim, self.fc7_dim = opt['pool5_dim'], opt['fc7_dim']

        self.lang_res_weight = opt['lang_res_weight']
        self.vis_res_weight = opt['vis_res_weight']
        self.att_res_weight = opt['att_res_weight']
        self.loss_combined = opt['loss_combined']
        self.loss_divided = opt['loss_divided']

        # language rnn encoder
        self.rnn_encoder = RNNEncoder(
            vocab_size=opt['vocab_size'],
            word_embedding_size=opt['word_embedding_size'],
            word_vec_size=opt['word_vec_size'],
            hidden_size=opt['rnn_hidden_size'],
            bidirectional=opt['bidirectional'] > 0,
            input_dropout_p=opt['word_drop_out'],
            dropout_p=opt['rnn_drop_out'],
            n_layers=opt['rnn_num_layers'],
            rnn_type=opt['rnn_type'],
            variable_lengths=opt['variable_lengths'] > 0)

        self.weight_fc = nn.Linear(num_layers * num_dirs * hidden_size, 3)

        self.sub_attn = PhraseAttention(hidden_size * num_dirs)
        self.loc_attn = PhraseAttention(hidden_size * num_dirs)
        self.rel_attn = PhraseAttention(hidden_size * num_dirs)

        self.sub_encoder = SubjectEncoder(opt)
        self.loc_encoder = LocationEncoder(opt)
        self.rel_encoder = RelationEncoder(opt)

        self.sub_score = Score(self.pool5_dim + self.fc7_dim,
                               opt['word_vec_size'], opt['jemb_dim'])
        self.loc_score = Score(25 + 5, opt['word_vec_size'], opt['jemb_dim'])
        self.rel_score = RelationScore(self.fc7_dim + 5, opt['word_vec_size'],
                                       opt['jemb_dim'])

        self.sub_decoder = SubjectDecoder(opt)
        self.loc_decoder = LocationDecoder(opt)
        self.rel_decoder = RelationDecoder(opt)

        self.att_res_loss = AttributeReconstructLoss(opt)
        self.vis_res_loss = AdapVisualReconstructLoss(opt)
        self.lang_res_loss = AdapLangReconstructLoss(opt)
        self.rec_loss = LangReconstructionLoss(opt)

        #         self.sub_mlp = nn.Sequential(nn.Linear(opt['jemb_dim'], self.pool5_dim+self.fc7_dim))
        #         self.loc_mlp = nn.Sequential(nn.Linear(opt['jemb_dim'], 25+5))
        #         self.rel_mlp = nn.Sequential(nn.Linear(opt['jemb_dim'], self.fc7_dim+5))

        self.feat_fuse = nn.Sequential(
            nn.Linear(
                self.fc7_dim + self.pool5_dim + 25 + 5 + self.fc7_dim + 5,
                opt['jemb_dim']), nn.ReLU())
示例#4
0
    def __init__(self, opt):
        super(JointMatching, self).__init__()
        num_layers = opt['rnn_num_layers']
        hidden_size = opt['rnn_hidden_size']
        num_dirs = 2 if opt['bidirectional'] > 0 else 1
        jemb_dim = opt['jemb_dim']

        # language rnn encoder
        self.rnn_encoder = RNNEncoder(
            vocab_size=opt['vocab_size'],
            word_embedding_size=opt['word_embedding_size'],
            word_vec_size=opt['word_vec_size'],
            hidden_size=opt['rnn_hidden_size'],
            bidirectional=opt['bidirectional'] > 0,
            input_dropout_p=opt['word_drop_out'],
            dropout_p=opt['rnn_drop_out'],
            n_layers=opt['rnn_num_layers'],
            rnn_type=opt['rnn_type'],
            variable_lengths=opt['variable_lengths'] > 0)

        # dynamic filter generator
        self.dynamic_fc = nn.Linear(num_layers * num_dirs * hidden_size,
                                    opt['C4_feat_dim'])

        # dynamic filter convolution

        # [vis; loc] weighter
        #self.weight_fc = nn.Linear(num_layers * num_dirs * hidden_size, 3)  #### remove

        # phrase attender
        self.sub_attn = PhraseAttention(hidden_size * num_dirs)
        #self.loc_attn = PhraseAttention(hidden_size * num_dirs)
        #self.rel_attn = PhraseAttention(hidden_size * num_dirs)

        # visual matching
        self.sub_encoder = VisualEncoder(opt)
        self.sub_matching = Matching(opt['fc7_dim'] + opt['jemb_dim'],
                                     opt['word_vec_size'], opt['jemb_dim'],
                                     opt['jemb_drop_out'])
示例#5
0
文件: match.py 项目: amanishr/ARN
class AdaptiveReconstruct(nn.Module):
    def __init__(self, opt):
        super(AdaptiveReconstruct, self).__init__()
        num_layers = opt['rnn_num_layers']
        hidden_size = opt['rnn_hidden_size']
        num_dirs = 2 if opt['bidirectional'] > 0 else 1
        self.word_vec_size = opt['word_vec_size']
        self.pool5_dim, self.fc7_dim = opt['pool5_dim'], opt['fc7_dim']

        self.lang_res_weight = opt['lang_res_weight']
        self.vis_res_weight = opt['vis_res_weight']
        self.att_res_weight = opt['att_res_weight']
        self.loss_combined = opt['loss_combined']
        self.loss_divided = opt['loss_divided']

        # language rnn encoder
        self.rnn_encoder = RNNEncoder(
            vocab_size=opt['vocab_size'],
            word_embedding_size=opt['word_embedding_size'],
            word_vec_size=opt['word_vec_size'],
            hidden_size=opt['rnn_hidden_size'],
            bidirectional=opt['bidirectional'] > 0,
            input_dropout_p=opt['word_drop_out'],
            dropout_p=opt['rnn_drop_out'],
            n_layers=opt['rnn_num_layers'],
            rnn_type=opt['rnn_type'],
            variable_lengths=opt['variable_lengths'] > 0)

        self.weight_fc = nn.Linear(num_layers * num_dirs * hidden_size, 3)

        self.sub_attn = PhraseAttention(hidden_size * num_dirs)
        self.loc_attn = PhraseAttention(hidden_size * num_dirs)
        self.rel_attn = PhraseAttention(hidden_size * num_dirs)

        self.sub_encoder = SubjectEncoder(opt)
        self.loc_encoder = LocationEncoder(opt)
        self.rel_encoder = RelationEncoder(opt)

        self.sub_score = Score(self.pool5_dim + self.fc7_dim,
                               opt['word_vec_size'], opt['jemb_dim'])
        self.loc_score = Score(25 + 5, opt['word_vec_size'], opt['jemb_dim'])
        self.rel_score = RelationScore(self.fc7_dim + 5, opt['word_vec_size'],
                                       opt['jemb_dim'])

        self.sub_decoder = SubjectDecoder(opt)
        self.loc_decoder = LocationDecoder(opt)
        self.rel_decoder = RelationDecoder(opt)

        self.att_res_loss = AttributeReconstructLoss(opt)
        self.vis_res_loss = AdapVisualReconstructLoss(opt)
        self.lang_res_loss = AdapLangReconstructLoss(opt)
        self.rec_loss = LangReconstructionLoss(opt)

        #         self.sub_mlp = nn.Sequential(nn.Linear(opt['jemb_dim'], self.pool5_dim+self.fc7_dim))
        #         self.loc_mlp = nn.Sequential(nn.Linear(opt['jemb_dim'], 25+5))
        #         self.rel_mlp = nn.Sequential(nn.Linear(opt['jemb_dim'], self.fc7_dim+5))

        self.feat_fuse = nn.Sequential(
            nn.Linear(
                self.fc7_dim + self.pool5_dim + 25 + 5 + self.fc7_dim + 5,
                opt['jemb_dim']), nn.ReLU())

    def forward(self, pool5, fc7, lfeats, dif_lfeats, cxt_fc7, cxt_lfeats,
                labels, enc_labels, dec_labels, att_labels, select_ixs,
                att_weights):

        context, hidden, embedded = self.rnn_encoder(labels)

        weights = F.softmax(self.weight_fc(hidden))
        sub_attn, sub_phrase_emb = self.sub_attn(context, embedded, labels)
        loc_attn, loc_phrase_emb = self.loc_attn(context, embedded, labels)
        rel_attn, rel_phrase_emb = self.rel_attn(context, embedded, labels)

        sent_num = pool5.size(0)
        ann_num = pool5.size(1)

        # subject matching score
        sub_feats = self.sub_encoder(pool5, fc7, sub_phrase_emb)
        sub_ann_attn = self.sub_score(sub_feats, sub_phrase_emb)

        # location matching score
        loc_feats = self.loc_encoder(lfeats, dif_lfeats)
        loc_ann_attn = self.loc_score(loc_feats, loc_phrase_emb)

        # relation matching score
        rel_feats, masks = self.rel_encoder(cxt_fc7, cxt_lfeats)
        rel_ann_attn, rel_ixs = self.rel_score(rel_feats, rel_phrase_emb,
                                               masks)

        weights_expand = weights.unsqueeze(1).expand(sent_num, ann_num, 3)
        total_ann_score = (
            weights_expand *
            torch.cat([sub_ann_attn, loc_ann_attn, rel_ann_attn], 2)).sum(2)

        loss = 0
        att_res_loss = 0
        lang_res_loss = 0
        vis_res_loss = 0

        # divided_loss
        sub_phrase_recons = self.sub_decoder(sub_feats, total_ann_score)
        loc_phrase_recons = self.loc_decoder(loc_feats, total_ann_score)
        rel_phrase_recons = self.rel_decoder(rel_feats, total_ann_score,
                                             rel_ixs)

        if self.vis_res_weight > 0:
            vis_res_loss = self.vis_res_loss(sub_phrase_emb, sub_phrase_recons,
                                             loc_phrase_emb, loc_phrase_recons,
                                             rel_phrase_emb, rel_phrase_recons,
                                             weights)
            loss = self.vis_res_weight * vis_res_loss

        if self.lang_res_weight > 0:
            lang_res_loss = self.lang_res_loss(sub_phrase_emb, loc_phrase_emb,
                                               rel_phrase_emb, enc_labels,
                                               dec_labels)

            loss += self.lang_res_weight * lang_res_loss

        # combined_loss
        loss = self.loss_divided * loss
        ann_score = total_ann_score.unsqueeze(1)

        ixs = rel_ixs.view(sent_num, ann_num,
                           1).unsqueeze(3).expand(sent_num, ann_num, 1,
                                                  self.fc7_dim + 5)
        rel_feats_max = torch.gather(rel_feats, 2, ixs)
        rel_feats_max = rel_feats_max.squeeze(2)

        fuse_feats = torch.cat([sub_feats, loc_feats, rel_feats_max], 2)
        fuse_feats = torch.bmm(ann_score, fuse_feats)
        fuse_feats = fuse_feats.squeeze(1)
        fuse_feats = self.feat_fuse(fuse_feats)
        rec_loss = self.rec_loss(fuse_feats, enc_labels, dec_labels)
        loss += self.loss_combined * rec_loss

        if self.att_res_weight > 0:
            att_scores, att_res_loss = self.att_res_loss(
                sub_feats, total_ann_score, att_labels, select_ixs,
                att_weights)
            loss += self.att_res_weight * att_res_loss

        # Non-construction loss
        total_ann_score_nocon = (weights_expand * torch.cat(
            [1 - sub_ann_attn, 1 - loc_ann_attn, 1 - rel_ann_attn], 2)).sum(2)

        loss_nocon = 0
        att_res_loss_nocon = 0
        lang_res_loss_nocon = 0
        vis_res_loss_nocon = 0

        # divided_loss
        sub_phrase_recons_nocon = self.sub_decoder(sub_feats,
                                                   total_ann_score_nocon)
        loc_phrase_recons_nocon = self.loc_decoder(loc_feats,
                                                   total_ann_score_nocon)
        rel_phrase_recons_nocon = self.rel_decoder(rel_feats,
                                                   total_ann_score_nocon,
                                                   rel_ixs)

        if self.vis_res_weight > 0:
            vis_res_loss_nocon = self.vis_res_loss(
                sub_phrase_emb, sub_phrase_recons_nocon, loc_phrase_emb,
                loc_phrase_recons_nocon, rel_phrase_emb,
                rel_phrase_recons_nocon, weights)
            loss_nocon = self.vis_res_weight * vis_res_loss_nocon

        # combined_loss
        loss_nocon = self.loss_divided * loss_nocon

        ann_score_nocon = total_ann_score_nocon.unsqueeze(1)

        fuse_feats = torch.cat([sub_feats, loc_feats, rel_feats_max], 2)
        fuse_feats_nocon = torch.bmm(ann_score_nocon, fuse_feats)
        fuse_feats_nocon = fuse_feats_nocon.squeeze(1)
        fuse_feats_nocon = self.feat_fuse(fuse_feats_nocon)
        rec_loss_nocon = self.rec_loss(fuse_feats_nocon, enc_labels,
                                       dec_labels)
        loss_nocon += self.loss_combined * rec_loss_nocon

        if self.att_res_weight > 0:
            att_scores_nocon, att_res_loss_nocon = self.att_res_loss(
                sub_feats, total_ann_score_nocon, att_labels, select_ixs,
                att_weights)
            loss_nocon += self.att_res_weight * att_res_loss_nocon

        losses = {}
        losses['loss'] = loss
        losses['vis_res_loss'] = vis_res_loss
        losses['att_res_loss'] = att_res_loss
        losses['lang_res_loss'] = lang_res_loss
        losses['rec_loss'] = rec_loss
        losses['loss_nocon'] = loss_nocon
        losses['vis_res_loss_nocon'] = vis_res_loss_nocon
        losses['att_res_loss_nocon'] = att_res_loss_nocon
        losses['lang_res_loss_nocon'] = lang_res_loss_nocon
        losses['rec_loss_nocon'] = rec_loss_nocon
        return total_ann_score, losses, rel_ixs, sub_attn, loc_attn, rel_attn, weights

    def recon_zero_grad(self):
        self.rnn_encoder.zero_grad()
        self.weight_fc.zero_grad()
        self.sub_attn.zero_grad()
        self.loc_attn.zero_grad()
        self.rel_attn.zero_grad()
        self.sub_encoder.zero_grad()
        self.loc_encoder.zero_grad()
        self.rel_encoder.zero_grad()
        self.sub_decoder.zero_grad()
        self.loc_decoder.zero_grad()
        self.rel_decoder.zero_grad()
        self.vis_res_loss.zero_grad()
        self.feat_fuse.zero_grad()
        self.rec_loss.zero_grad()
        self.att_res_loss.zero_grad()