Пример #1
0
    def __init__(self,
                 text_dim,
                 use_ce,
                 l2renorm,
                 vlad_clusters,
                 disable_nan_checks,
                 expert_dims,
                 keep_missing_modalities,
                 test_caption_mode,
                 randomise_feats,
                 freeze_weights=False,
                 verbose=False,
                 mimic_ce_dims=False,
                 concat_experts=False,
                 concat_mix_experts=False):
        super().__init__()

        self.expert_dims = expert_dims
        self.l2renorm = l2renorm
        self.disable_nan_checks = disable_nan_checks
        self.text_pooling = NetVLAD(
            feature_size=text_dim,
            cluster_size=vlad_clusters["text"],
        )
        if randomise_feats:
            self.random_feats = set(
                [x for x in args.randomise_feats.split(",")])
        else:
            self.random_feats = set()

        # sanity checks
        expected_feat_sizes = {"audio": 128, "speech": 300, "ocr": 300}
        self.pooling = nn.ModuleDict()
        for mod, expected in expected_feat_sizes.items():
            if mod in expert_dims.keys():
                feature_size = expert_dims[mod][0] // vlad_clusters[mod]
                msg = f"expected {expected} for {mod} features atm"
                assert feature_size == expected, msg
                self.pooling[mod] = NetVLAD(
                    feature_size=feature_size,
                    cluster_size=vlad_clusters[mod],
                )

        self.ce = CEModule(
            use_ce=use_ce,
            verbose=verbose,
            l2renorm=l2renorm,
            random_feats=self.random_feats,
            freeze_weights=freeze_weights,
            text_dim=self.text_pooling.out_dim,
            test_caption_mode=test_caption_mode,
            concat_experts=concat_experts,
            concat_mix_experts=concat_mix_experts,
            expert_dims=expert_dims,
            disable_nan_checks=disable_nan_checks,
            keep_missing_modalities=keep_missing_modalities,
            mimic_ce_dims=mimic_ce_dims,
        )
Пример #2
0
def get_aggregation(agg, feature_size):
    if agg['type'] == 'net_vlad':
        cluster_size = agg['cluster_size']
        ghost_clusters = agg['ghost_clusters']
        return NetVLAD(cluster_size, feature_size, ghost_clusters)
    elif agg['type'] == 'mean':
        return MeanToken(1)
    elif agg['type'] == 'max':
        return MaxToken(1)
    else:
        raise NotImplementedError
Пример #3
0
    def __init__(self,
                 l2renorm,
                 expert_dims,
                 tokenizer,
                 keep_missing_modalities,
                 test_caption_mode,
                 freeze_weights=False,
                 mimic_ce_dims=False,
                 concat_experts=False,
                 concat_mix_experts=False,
                 use_experts='origfeat',
                 txt_inp=None,
                 txt_agg=None,
                 txt_pro=None,
                 txt_wgh=None,
                 vid_inp=None,
                 vid_cont=None,
                 vid_wgh=None,
                 pos_enc=None,
                 out_tok=None,
                 use_mask='nomask',
                 same_dim=512,
                 vid_bert_params=None,
                 txt_bert_params=None,
                 agg_dims=None,
                 normalize_experts=True):
        super().__init__()

        self.sanity_checks = False
        modalities = list(expert_dims.keys())
        self.expert_dims = expert_dims
        self.modalities = modalities
        logger.debug(self.modalities)
        self.mimic_ce_dims = mimic_ce_dims
        self.concat_experts = concat_experts
        self.concat_mix_experts = concat_mix_experts
        self.test_caption_mode = test_caption_mode
        self.freeze_weights = freeze_weights
        self.use_experts = use_experts
        self.use_mask = use_mask
        self.keep_missing_modalities = keep_missing_modalities
        self.l2renorm = l2renorm
        self.same_dim = same_dim
        self.txt_inp = txt_inp
        self.txt_agg = txt_agg
        self.txt_pro = txt_pro
        self.txt_wgh = txt_wgh
        self.vid_inp = vid_inp
        self.vid_cont = vid_cont
        self.vid_wgh = vid_wgh
        self.pos_enc = pos_enc
        self.out_tok = out_tok
        self.vid_bert_params = vid_bert_params
        self.normalize_experts = normalize_experts

        self.video_dim_reduce = nn.ModuleDict()
        for mod in self.modalities:
            in_dim = expert_dims[mod]['dim']
            if self.vid_inp in ['agg', 'both', 'all', 'temp']:
                self.video_dim_reduce[mod] = ReduceDim(in_dim, same_dim)

        if self.vid_cont == 'coll':
            self.g_reason_1 = nn.Linear(same_dim * 2, same_dim)
            dout_prob = vid_bert_params['hidden_dropout_prob']
            self.coll_g_dropout = nn.Dropout(dout_prob)
            self.g_reason_2 = nn.Linear(same_dim, same_dim)

            self.f_reason_1 = nn.Linear(same_dim, same_dim)
            self.coll_f_dropout = nn.Dropout(dout_prob)
            self.f_reason_2 = nn.Linear(same_dim, same_dim)
            self.f_reason_3 = nn.Linear(same_dim, same_dim)
            self.batch_norm_g1 = nn.BatchNorm1d(same_dim)
            self.batch_norm_g2 = nn.BatchNorm1d(same_dim)

            self.batch_norm_f1 = nn.BatchNorm1d(same_dim)
            self.batch_norm_f2 = nn.BatchNorm1d(same_dim)

            self.video_GU = nn.ModuleDict()
            for mod in self.modalities:
                self.video_GU[mod] = GatedEmbeddingUnitReasoning(same_dim)

        # If Bert architecture is employed for video
        elif self.vid_cont == 'bert':
            vid_bert_config = types.SimpleNamespace(**vid_bert_params)
            self.vid_bert = BertModel(vid_bert_config)

        elif self.vid_cont == 'none':
            pass

        if self.txt_agg[:4] in ['bert']:
            z = re.match(r'bert([a-z]{3})(\d*)(\D*)', txt_agg)
            assert z
            state = z.groups()[0]
            freeze_until = z.groups()[1]

            # Post aggregation: Use [CLS] token ("cls") or aggregate all tokens
            # (mxp, mnp)
            if z.groups()[2] and z.groups()[2] != 'cls':
                self.post_agg = z.groups()[2]
            else:
                self.post_agg = 'cls'

            if state in ['ftn', 'frz']:
                # State is finetune or frozen, we use a pretrained bert model
                txt_bert_config = 'bert-base-cased'

                # Overwrite config
                if txt_bert_params is None:
                    dout_prob = vid_bert_params['hidden_dropout_prob']
                    txt_bert_params = {
                        'hidden_dropout_prob': dout_prob,
                        'attention_probs_dropout_prob': dout_prob,
                    }
                self.txt_bert = TxtBertModel.from_pretrained(
                    txt_bert_config, **txt_bert_params)

                if state == 'frz':
                    if freeze_until:
                        # Freeze only certain layers
                        freeze_until = int(freeze_until)
                        logger.debug(
                            'Freezing text bert until layer %d excluded',
                            freeze_until)
                        # Freeze net until given layer
                        for name, param in self.txt_bert.named_parameters():
                            module = name.split('.')[0]
                            if name.split('.')[2].isdigit():
                                layer_nb = int(name.split('.')[2])
                            else:
                                continue
                            if module == 'encoder' and layer_nb in range(
                                    freeze_until):
                                param.requires_grad = False
                                logger.debug(name)
                    else:
                        # Freeze the whole model
                        for name, param in self.txt_bert.named_parameters():
                            module = name.split('.')[0]
                            if module == 'encoder':
                                param.requires_grad = False
                else:
                    assert not freeze_until

            if self.txt_inp == 'bertfrz':
                # Freeze model
                for param in self.txt_bert.embeddings.parameters():
                    param.requires_grad = False
            elif self.txt_inp not in ['bertftn']:
                logger.error('Wrong parameter for the text encoder')
            text_dim = self.txt_bert.config.hidden_size
        elif self.txt_agg in ['vlad', 'mxp', 'mnp', 'lstm']:
            # Need to get text embeddings
            if self.txt_inp == 'bertfrz':
                ckpt = 'data/word_embeddings/bert/ckpt_from_huggingface.pth'
                self.word_embeddings = TxtEmbeddings(ckpt=ckpt, freeze=True)
            elif self.txt_inp == 'bertftn':
                ckpt = 'data/word_embeddings/bert/ckpt_from_huggingface.pth'
                self.word_embeddings = TxtEmbeddings(ckpt=ckpt)
            elif self.txt_inp == 'bertscr':
                vocab_size = 28996
                emb_dim = 768
                self.word_embeddings = TxtEmbeddings(vocab_size, emb_dim)
            else:
                self.word_embeddings = tokenizer.we_model
            emb_dim = self.word_embeddings.text_dim

            if self.txt_agg == 'vlad':
                self.text_pooling = NetVLAD(
                    feature_size=emb_dim,
                    cluster_size=28,
                )
                text_dim = self.text_pooling.out_dim
            elif self.txt_agg == 'mxp':
                text_dim = emb_dim
            elif self.txt_agg == 'lstm':
                input_dim = self.word_embeddings.text_dim
                hidden_dim = 512
                layer_dim = 1
                output_dim = hidden_dim
                self.text_pooling = LSTMModel(input_dim, hidden_dim, layer_dim,
                                              output_dim)
                text_dim = output_dim

        self.text_GU = nn.ModuleDict()
        for mod in self.modalities:
            if self.txt_pro == 'gbn':
                self.text_GU[mod] = GatedEmbeddingUnit(
                    text_dim,
                    same_dim,
                    use_bn=True,
                    normalize=self.normalize_experts)
            elif self.txt_pro == 'gem':
                self.text_GU[mod] = GatedEmbeddingUnit(
                    text_dim,
                    same_dim,
                    use_bn=False,
                    normalize=self.normalize_experts)
            elif self.txt_pro == 'lin':
                self.text_GU[mod] = ReduceDim(text_dim, same_dim)

        # Weightening of each modality similarity
        if self.txt_wgh == 'emb':
            self.moe_fc_txt = nn.ModuleDict()
            dout_prob = txt_bert_params['hidden_dropout_prob']
            self.moe_txt_dropout = nn.Dropout(dout_prob)
            for mod in self.modalities:
                self.moe_fc_txt[mod] = nn.Linear(text_dim, 1)
        if self.vid_wgh == 'emb':
            self.moe_fc_vid = nn.ModuleDict()
            dout_prob = vid_bert_params['hidden_dropout_prob']
            self.moe_vid_dropout = nn.Dropout(dout_prob)
            for mod in self.modalities:
                self.moe_fc_vid[mod] = nn.Linear(self.same_dim, 1)

        self.debug_dataloader = False
        if self.debug_dataloader:
            self.tokenizer = tokenizer
Пример #4
0
    def __init__(self, modalities, txt_agg, txt_inp, txt_bert_params,
                 vid_bert_params):
        """modalities: all modalities used in video
           txt_agg: txt aggression method for bert: bert+(state: ftn: finetune, frz: freeze), for other: text pooling
           txt_inp: the way to construct the embeddings from word, position and token_type embeddings"""
        super().__init__()

        self.modalities = modalities
        self.txt_agg = txt_agg
        self.txt_inp = txt_inp
        self.vid_bert_params = vid_bert_params

        if self.txt_agg[:4] in ['bert']:
            z = re.match(r'bert([a-z]{3})(\d*)(\D*)', txt_agg)
            assert z
            state = z.groups()[0]
            freeze_until = z.groups()[1]

            # Post aggregation: Use [CLS] token ("cls") or aggregate all tokens
            # (mxp, mnp)
            if z.groups()[2] and z.groups()[2] != 'cls':
                self.post_agg = z.groups()[2]
            else:
                self.post_agg = 'cls'

            if state in ['ftn', 'frz']:
                # State is finetune or frozen, we use a pretrained bert model
                txt_bert_config = 'bert-base-cased'

                # Overwrite config
                if txt_bert_params is None:
                    dout_prob = vid_bert_params['hidden_dropout_prob']
                    txt_bert_params = {
                        'hidden_dropout_prob': dout_prob,
                        'attention_probs_dropout_prob': dout_prob,
                    }

                self.txt_bert = TxtBertModel.from_pretrained(
                    txt_bert_config,
                    cache_dir=
                    '/youtu_pedestrian_detection/wenzhewang/mmt_data/cache_dir',
                    **txt_bert_params)

                if state == 'frz':
                    if freeze_until:
                        # Freeze only certain layers
                        freeze_until = int(freeze_until)
                        logger.debug(
                            'Freezing text bert until layer %d excluded',
                            freeze_until)
                        # Freeze net until given layer
                        for name, param in self.txt_bert.named_parameters():
                            module = name.split('.')[0]
                            if name.split('.')[2].isdigit():
                                layer_nb = int(name.split('.')[2])
                            else:
                                continue
                            if module == 'encoder' and layer_nb in range(
                                    freeze_until):
                                param.requires_grad = False
                                logger.debug(name)
                    else:
                        # Freeze the whole model
                        for name, param in self.txt_bert.named_parameters():
                            module = name.split('.')[0]
                            if module == 'encoder':
                                param.requires_grad = False
                else:
                    assert not freeze_until

            if self.txt_inp == 'bertfrz':
                # Freeze model
                for param in self.txt_bert.embeddings.parameters():
                    param.requires_grad = False
            elif self.txt_inp not in ['bertftn']:
                logger.error('Wrong parameter for the text encoder')
            self.text_dim = self.txt_bert.config.hidden_size

        elif self.txt_agg in ['vlad', 'mxp', 'mnp', 'lstm']:
            # Need to get text embeddings
            if self.txt_inp == 'bertfrz':
                ckpt = '/youtu_pedestrian_detection/wenzhewang/mmt_data/word_embeddings/bert/ckpt_from_huggingface.pth'
                self.word_embeddings = TxtEmbeddings(ckpt=ckpt, freeze=True)
            elif self.txt_inp == 'bertftn':
                ckpt = '/youtu_pedestrian_detection/wenzhewang/mmt_data/word_embeddings/bert/ckpt_from_huggingface.pth'
                self.word_embeddings = TxtEmbeddings(ckpt=ckpt)
            elif self.txt_inp == 'bertscr':
                vocab_size = 28996
                emb_dim = 768
                self.word_embeddings = TxtEmbeddings(
                    vocab_size, emb_dim)  # return nn.Embedding
            else:
                self.word_embeddings = tokenizer.we_model  #word2vec
            emb_dim = self.word_embeddings.text_dim

            if self.txt_agg == 'vlad':
                self.text_pooling = NetVLAD(
                    feature_size=emb_dim,
                    cluster_size=28,
                )
                self.text_dim = self.text_pooling.out_dim
            elif self.txt_agg == 'mxp':
                self.text_dim = emb_dim
            elif self.txt_agg == 'lstm':
                input_dim = self.word_embeddings.text_dim
                hidden_dim = 512
                layer_dim = 1
                output_dim = hidden_dim
                self.text_pooling = LSTMModel(input_dim, hidden_dim, layer_dim,
                                              output_dim)
                self.text_dim = output_dim
Пример #5
0
    def __init__(self, backbone, expert_dims, use_ce, mimic_ce_dims, concat_experts, concat_mix_experts,
                 vlad_clusters, attr_fusion_name, attr_vocab_size, same_dim=512):
        super().__init__()

        modalities = list(expert_dims.keys())
        self.expert_dims = expert_dims
        self.modalities = modalities
        self.use_ce = use_ce
        self.mimic_ce_dims = mimic_ce_dims
        self.concat_experts = concat_experts
        self.concat_mix_experts = concat_mix_experts
        self.attr_fusion_name = attr_fusion_name
        self.backbone_name = backbone

        in_dims = [expert_dims[mod][0] for mod in modalities]
        agg_dims = [expert_dims[mod][1] for mod in modalities]
        use_bns = [True for modality in self.modalities]

        if self.use_ce or self.mimic_ce_dims:
            dim_reducers = [ReduceDim(in_dim, same_dim) for in_dim in in_dims]
            self.video_dim_reduce = nn.ModuleList(dim_reducers)

        if self.use_ce:
            self.g_reason_1 = nn.Linear(same_dim * 2, same_dim)
            self.g_reason_2 = nn.Linear(same_dim, same_dim)

            self.f_reason_1 = nn.Linear(same_dim, same_dim)
            self.f_reason_2 = nn.Linear(same_dim, same_dim)

            gated_vid_embds = [GatedEmbeddingUnitReasoning(same_dim) for _ in in_dims]

        elif self.mimic_ce_dims:  # ablation study
            gated_vid_embds = [MimicCEGatedEmbeddingUnit(same_dim, same_dim, use_bn=True) for _ in modalities]

        elif self.concat_mix_experts:  # ablation study
            in_dim, out_dim = sum(in_dims), sum(agg_dims)
            gated_vid_embds = [GatedEmbeddingUnit(in_dim, out_dim, use_bn=True)]

        elif self.concat_experts:  # ablation study
            gated_vid_embds = []

        else:
            gated_vid_embds = [GatedEmbeddingUnit(in_dim, dim, use_bn) for
                               in_dim, dim, use_bn in zip(in_dims, agg_dims, use_bns)]

        self.video_GU = nn.ModuleList(gated_vid_embds)

        if backbone == 'resnet':
            resnet = models.resnet152(pretrained=True)
            modules = list(resnet.children())[:-2]
            self.backbone = nn.Sequential(*modules)
        elif backbone == 'densenet':
            densenet = models.densenet169(pretrained=True)
            modules = list(densenet.children())[:-1]
            self.backbone = nn.Sequential(*modules)
        elif backbone in ['inceptionresnetv2', 'pnasnet5large', 'nasnetalarge', 'senet154', 'polynet']:
            self.backbone = pretrainedmodels.__dict__[backbone](num_classes=1000, pretrained='imagenet')
        else:
            raise ValueError
        self.dropout = nn.Dropout(p=0.2)
        self.avg_pooling = nn.AdaptiveAvgPool2d((1, 1))

        # self.video_multi_encoding = VideoMultilevelEncoding(in_dim=in_dims[-1], out_dim=in_dims[-1])

        if 'keypoint' in self.expert_dims.keys():
            self.effnet = EffNet()
            self.keypoint_pooling = NetVLAD(
                feature_size=512,
                cluster_size=vlad_clusters['keypoint'],
            )

        if 'attr0' in self.expert_dims.keys():
            self.attr_embed = nn.Embedding(attr_vocab_size, 300, padding_idx=0)
            attr_pooling_list = [NetVLAD(feature_size=300, cluster_size=vlad_clusters['attr']) for _ in range(6)]
            self.attr_pooling = nn.ModuleList(attr_pooling_list)

            if attr_fusion_name == 'attrmlb':
                self.attr_fusion = AttrMLB()
            else:
                self.attr_fusion = TIRG(attr_fusion_name, embed_dim=same_dim)
Пример #6
0
    def __init__(self, text_dim, use_ce, use_film, vlad_clusters, composition, target_comp, fusion, attr_fusion,
                 expert_dims, same_dim, text_feat, norm_scale=5.0, vocab_size=None, we_parameter=None,
                 attr_vocab_size=None, mimic_ce_dims=False, concat_experts=False, concat_mix_experts=False,
                 backbone='resnet'):
        super().__init__()

        self.composition_type = composition
        self.fusion = fusion
        self.expert_dims = expert_dims
        self.text_feat = text_feat
        self.vocab_size = vocab_size

        if text_feat == 'learnable':
            self.embed = nn.Embedding(vocab_size, text_dim, padding_idx=0)
            if we_parameter is not None:
                self.embed.weight.data.copy_(torch.from_numpy(we_parameter))

            text_pooling_list = [
                [mode, TextMultilevelEncoding(word_dim=text_dim, hidden_dim=text_dim)] for mode in ('src', 'trg')]
            self.text_pooling = nn.ModuleDict(text_pooling_list)

            # self.src_text_pooling = TextMultilevelEncoding(
            #     word_dim=text_dim,
            #     hidden_dim=text_dim
            # )
            # self.trg_text_pooling = TextMultilevelEncoding(
            #     word_dim=text_dim,
            #     hidden_dim=text_dim
            # )
            encoder_text_dim = text_dim + vocab_size + 1536
        elif text_feat == 'w2v':
            text_pooling_list = [
                [mode, NetVLAD(feature_size=text_dim, cluster_size=vlad_clusters["text"])] for mode in ('src', 'trg')]
            self.text_pooling = nn.ModuleDict(text_pooling_list)

            # self.src_text_pooling = NetVLAD(
            #     feature_size=text_dim,
            #     cluster_size=vlad_clusters["text"],
            # )
            # self.trg_text_pooling = NetVLAD(
            #     feature_size=text_dim,
            #     cluster_size=vlad_clusters["text"],
            # )
            # encoder_text_dim = self.src_text_pooling.out_dim
            encoder_text_dim = self.text_pooling['src'].out_dim
        else:
            raise ValueError

        text_encoder_list = [
            [mode,
             TextCEModule(expert_dims=expert_dims,
                          text_dim=encoder_text_dim,
                          concat_experts=concat_experts,
                          concat_mix_experts=concat_mix_experts,
                          same_dim=same_dim)] for mode in ('src', 'trg')]
        # text_encoder_list = [
        #     ['src',
        #      TextCEModule(expert_dims=expert_dims,
        #                   text_dim=encoder_text_dim,
        #                   concat_experts=concat_experts,
        #                   concat_mix_experts=concat_mix_experts,
        #                   same_dim=same_dim)],
        #     ['trg',
        #      TextCEModule(expert_dims=expert_dims,
        #                   text_dim=300,
        #                   concat_experts=concat_experts,
        #                   concat_mix_experts=concat_mix_experts,
        #                   same_dim=same_dim)]]

        self.text_encoder = nn.ModuleDict(text_encoder_list)

        # self.src_text_encoder = TextCEModule(
        #     expert_dims=expert_dims,
        #     text_dim=encoder_text_dim,
        #     concat_experts=concat_experts,
        #     concat_mix_experts=concat_mix_experts,
        #     same_dim=same_dim
        # )
        # self.trg_text_encoder = TextCEModule(
        #     expert_dims=expert_dims,
        #     text_dim=encoder_text_dim,
        #     concat_experts=concat_experts,
        #     concat_mix_experts=concat_mix_experts,
        #     same_dim=same_dim
        # )

        self.image_encoder = VideoCEModule(
            backbone=backbone,
            use_ce=use_ce,
            concat_experts=concat_experts,
            concat_mix_experts=concat_mix_experts,
            expert_dims=expert_dims,
            mimic_ce_dims=mimic_ce_dims,
            same_dim=same_dim,
            vlad_clusters=vlad_clusters,
            attr_vocab_size=attr_vocab_size,
            attr_fusion_name=attr_fusion
        )

        if self.composition_type == 'multi':
            composition_list = [TIRG(fusion, embed_dim=same_dim) for _ in self.expert_dims]
            self.composition_layer = nn.ModuleList(composition_list)
        else:
            self.composition_layer = TIRG(fusion, embed_dim=same_dim)

        self.normalization_layer = NormalizationLayer(normalize_scale=norm_scale, learn_scale=True)
        self.trg_normalization_layer = NormalizationLayer(normalize_scale=norm_scale, learn_scale=True)

        if target_comp == 'cnn':
            target_comp_list = [CNNAttention() for _ in self.expert_dims]
            self.target_composition = nn.ModuleList(target_comp_list)
        elif target_comp == 'film':
            target_comp_list = [DualTIRG2Film() for _ in self.expert_dims]
            self.target_composition = nn.ModuleList(target_comp_list)
        elif target_comp == 'tirg':
            target_comp_list = [TIRG(fusion, embed_dim=same_dim) for _ in self.expert_dims]
            self.target_composition = nn.ModuleList(target_comp_list)
        elif target_comp == 'ba':
            target_comp_list = [BilinearAttention(dim=same_dim) for _ in self.expert_dims]
            self.target_composition = nn.ModuleList(target_comp_list)
        elif target_comp == 'cbpa':
            target_comp_list = [CBPAttention(dim=same_dim) for _ in self.expert_dims]
            self.target_composition = nn.ModuleList(target_comp_list)
        else:
            raise ValueError