Esempio n. 1
0
    def forward(self, support, query, sup_len, que_len):

        if self.DataParallel:
            support = support.squeeze(0)
            sup_len = sup_len[0]

        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(support, query)

        support = support.view(n * k, sup_seq_len)

        support, query = self._embed(support, sup_len), \
                         self._embed(query, que_len)
        # 计算类的原型向量
        # shape: [n, k, d]
        support = support.view(n, k, -1)
        d = support.size(2)

        # coupling shape: [n, d]
        coupling = t.zeros_like(support).sum(dim=2)
        proto = None
        # 使用动态路由来计算原型向量
        for i in range(self.Iters):
            coupling, proto = dynamicRouting(self.Transformer,
                                             support, coupling,
                                             k)

        support = repeatProtoToCompShape(proto, qk, n)
        query = repeatQueryToCompShape(query, qk, n)

        return self.NTN(support, query, n).view(qk, n)
Esempio n. 2
0
    def _feature_forward(self,
                         support_features,
                         query_features,
                         support_labels,
                         query_labels,
                         feature_name='none') -> dict:
        assert support_features is not None, f"[MLossProtoNet] {feature_name} is None, " \
                                             f"which is not allowed in multi-loss fusion models"

        k, n, qk = self.TaskParams.k, self.TaskParams.n, self.TaskParams.qk
        dim = support_features.size(1)

        # 取类均值作为prototype
        original_protos = support_features.view(n, k, dim).mean(dim=1)

        # 整型成为可比较的形状: [qk, n, dim]
        protos = repeatProtoToCompShape(original_protos, qk, n)
        rep_query = repeatQueryToCompShape(query_features, qk, n)

        similarity = protoDisAdapter(protos,
                                     rep_query,
                                     qk,
                                     n,
                                     dim,
                                     dis_type=self.DistType,
                                     temperature=self.DistTemp)

        logits = F.log_softmax(similarity, dim=1)
        return {
            "logits": logits,
            "loss": self.LossFunc(logits, query_labels),
            "predicts": None
        }
Esempio n. 3
0
    def forward(
            self,  # forward接受所有可能用到的参数
            support_seqs,
            support_imgs,
            support_lens,
            support_labels,
            query_seqs,
            query_imgs,
            query_lens,
            query_labels,
            epoch=None,
            return_embeddings=False):

        embedded_support_seqs, embedded_query_seqs, \
        embedded_support_imgs, embedded_query_imgs = self.embed(support_seqs, query_seqs,
                                                                support_lens, query_lens,
                                                                support_imgs, query_imgs)

        # support seqs/imgs shape: [n*k, dim]
        # query seqs/imgs shape: [qk, dim]

        k, n, qk = self.TaskParams.k, self.TaskParams.n, self.TaskParams.qk

        # 直接使用seq和img的raw output进行fuse
        support_fused_features = self._fuse(embedded_support_seqs,
                                            embedded_support_imgs,
                                            fuse_dim=1)
        query_fused_features = self._fuse(embedded_query_seqs,
                                          embedded_query_imgs,
                                          fuse_dim=1)
        dim = support_fused_features.size(1)

        # 原型向量
        # shape: [n, dim]
        original_protos = support_fused_features.view(n, k, dim).mean(dim=1)

        # 整型成为可比较的形状: [qk, n, dim]
        protos = repeatProtoToCompShape(original_protos, qk, n)
        rep_query = repeatQueryToCompShape(query_fused_features, qk, n)

        similarity = protoDisAdapter(protos,
                                     rep_query,
                                     qk,
                                     n,
                                     dim,
                                     dis_type=self.DistType,
                                     temperature=self.DistTemp)

        if return_embeddings:
            return support_seqs, query_seqs.view(
                qk, -1), original_protos, F.log_softmax(similarity, dim=1)

        logits = F.log_softmax(similarity, dim=1)
        return {
            "logits": logits,
            "loss": self.LossFunc(logits, query_labels),
            "predicts": None
        }
Esempio n. 4
0
    def forward(
            self,  # forward接受所有可能用到的参数
            support_seqs,
            support_imgs,
            support_lens,
            support_labels,
            query_seqs,
            query_imgs,
            query_lens,
            query_labels,
            epoch=None,
            metric='euc',
            return_embeddings=False):

        embedded_support_seqs, embedded_query_seqs, \
        embedded_support_imgs, embedded_query_imgs = self.embed(support_seqs, query_seqs,
                                                                support_lens, query_lens,
                                                                support_imgs, query_imgs)

        # support_fused_features seqs/imgs shape: [n, k, dim]
        # query seqs/imgs shape: [qk, dim]

        k, n, qk = self.TaskParams.k, self.TaskParams.n, self.TaskParams.qk

        # 直接使用seq和img的raw output进行fuse
        support_fused_features = self._fuse(embedded_support_seqs,
                                            embedded_support_imgs,
                                            fuse_dim=1)
        query_fused_features = self._fuse(embedded_query_seqs,
                                          embedded_query_imgs,
                                          fuse_dim=1)
        dim = support_fused_features.size(1)

        # 原型向量
        # shape: [n, dim]
        support_fused_features = support_fused_features.view(n, k, dim)
        support_fused_features = self.Induction(
            support_fused_features.unsqueeze(1)).squeeze()

        # 整型成为可比较的形状: [qk, n, dim]
        support_fused_features = repeatProtoToCompShape(
            support_fused_features, qk, n)
        query_fused_features = repeatQueryToCompShape(query_fused_features, qk,
                                                      n)

        similarity = protoDisAdapter(support_fused_features,
                                     query_fused_features,
                                     qk,
                                     n,
                                     dim,
                                     dis_type='cos')
        logits = torch.log_softmax(similarity, dim=1)

        return {
            'logits': logits,
            'loss': self.LossFunc(logits, query_labels),
            'predict': None
        }
Esempio n. 5
0
    def forward(self, support, query, sup_len, que_len, metric='euc'):
        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(
            support, query)

        # forehead forward to obtain task prototype
        f_support = support.view(n * k, sup_seq_len)
        f_support = self.Embedding(f_support)
        f_support = self.EmbedNorm(f_support)
        f_support = self.Encoder(f_support, sup_len)
        f_support = self.CNN(f_support, sup_len)

        f_support = f_support.view(n, k, -1)

        task_proto = f_support.mean((0, 1))

        # 提取了任务结构后,将所有样本展平为一个批次
        support = support.view(n * k, sup_seq_len)

        # shape: [batch, seq, dim]
        support = self.Embedding(support)
        query = self.Embedding(query)

        support = self.EmbedNorm(support)
        query = self.EmbedNorm(query)

        # shape: [batch, dim]
        support = self.Encoder(support, sup_len)
        query = self.Encoder(query, que_len)

        # task-conditioning affine
        support = self.TEN(support, task_proto)
        query = self.TEN(query, task_proto)

        support = self.CNN(support, sup_len)
        query = self.CNN(query, que_len)

        assert support.size(1)==query.size(1), '支持集维度 %d 和查询集维度 %d 必须相同!'%\
                                               (support.size(1),query.size(1))
        dim = support.size(1)

        # 原型向量
        # shape: [n, dim]
        support = support.view(n, k, dim).mean(dim=1)

        # 整型成为可比较的形状: [qk, n, dim]
        support = repeatProtoToCompShape(support, qk, n)
        query = repeatQueryToCompShape(query, qk, n)

        similarity = protoDisAdapter(support,
                                     query,
                                     qk,
                                     n,
                                     dim,
                                     dis_type='euc')

        # return t.softmax(similarity, dim=1)
        return F.log_softmax(similarity, dim=1)
Esempio n. 6
0
    def forward(
            self,  # forward接受所有可能用到的参数
            support_seqs,
            support_imgs,
            support_lens,
            support_labels,
            query_seqs,
            query_imgs,
            query_lens,
            query_labels,
            epoch=None,
            metric='euc',
            return_embeddings=False):

        embedded_support_seqs, embedded_query_seqs, \
        embedded_support_imgs, embedded_query_imgs = self.embed(support_seqs, query_seqs,
                                                                support_lens, query_lens,
                                                                support_imgs, query_imgs)

        # support_fused_features seqs/imgs shape: [n, k, dim]
        # query seqs/imgs shape: [qk, dim]

        k, n, qk = self.TaskParams.k, self.TaskParams.n, self.TaskParams.qk

        # 直接使用seq和img的raw output进行fuse
        support_fused_features = self._fuse(embedded_support_seqs,
                                            embedded_support_imgs,
                                            fuse_dim=1)
        query_fused_features = self._fuse(embedded_query_seqs,
                                          embedded_query_imgs,
                                          fuse_dim=1)
        dim = support_fused_features.size(1)

        support_fused_features = support_fused_features.view(n, k, dim)

        # coupling shape: [n, d]
        coupling = t.zeros_like(support_fused_features).sum(dim=2)
        proto = None
        # 使用动态路由来计算原型向量
        for i in range(self.DynamicRoutingIter):
            coupling, proto = dynamicRouting(self.Transformer,
                                             support_fused_features, coupling,
                                             k)

        support_fused_features = repeatProtoToCompShape(proto, qk, n)
        query_fused_features = repeatQueryToCompShape(query_fused_features, qk,
                                                      n)

        logits = self.NTN(support_fused_features,
                          query_fused_features).view(qk, n)
        return {
            "logits": logits,
            "loss": self.LossFunc(logits, query_labels),
            "predict": None
        }
Esempio n. 7
0
    def forward(self,
                support,
                query,
                sup_len,
                que_len,
                metric='euc',
                return_embeddings=False,
                **kwargs):

        if self.DataParallel:
            support = support.squeeze(0)
            sup_len = sup_len[0]

        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(
            support, query)

        # 提取了任务结构后,将所有样本展平为一个批次
        support = support.view(n * k, sup_seq_len)

        support, query = self._embed(support, sup_len), \
                         self._embed(query, que_len)


        assert support.size(1)==query.size(1), '支持集维度 %d 和查询集维度 %d 必须相同!'%\
                                               (support.size(1),query.size(1))

        dim = support.size(1)

        # 原型向量
        # shape: [n, dim]
        orig_protos = support.view(n, k, dim).mean(dim=1)

        # 整型成为可比较的形状: [qk, n, dim]
        protos = repeatProtoToCompShape(orig_protos, qk, n)
        rep_query = repeatQueryToCompShape(query, qk, n)

        similarity = protoDisAdapter(protos,
                                     rep_query,
                                     qk,
                                     n,
                                     dim,
                                     dis_type='euc',
                                     temperature=self.DistTemp)

        # return t.softmax(similarity, dim=1)
        if return_embeddings:
            return support, query.view(qk, -1), orig_protos, F.log_softmax(
                similarity, dim=1)
        return F.log_softmax(similarity, dim=1)
Esempio n. 8
0
    def forward(self, support, query, sup_len, que_len, metric='euc'):
        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(
            support, query)

        # 提取了任务结构后,将所有样本展平为一个批次
        support = support.view(n * k, sup_seq_len)

        # shape: [batch, seq, dim]
        support = self.Embedding(support)
        query = self.Embedding(query)

        support = self.EmbedNorm(support)
        query = self.EmbedNorm(query)

        # shape: [batch, dim]
        support = self.Encoder(support, sup_len)
        query = self.Encoder(query, que_len)

        support = self.Decoder(support, sup_len)
        query = self.Decoder(query, que_len)

        assert support.size(1)==query.size(1), '支持集维度 %d 和查询集维度 %d 必须相同!'%\
                                               (support.size(1),query.size(1))

        dim = support.size(1)

        # support set2set
        support_weight = t.softmax(self.SetFunc(support.view(1, n * k,
                                                             dim)).view(n, k),
                                   dim=1)
        support_weight = support_weight.unsqueeze(-1).repeat(1, 1, dim)

        # shape: [n, k, dim] -> [n, dim]
        support = support.view(n, k, dim)
        support = (support * support_weight).sum(dim=1)

        support = repeatProtoToCompShape(support, qk, n)
        query = repeatQueryToCompShape(query, qk, n)

        similarity = protoDisAdapter(support,
                                     query,
                                     qk,
                                     n,
                                     dim,
                                     dis_type='euc',
                                     temperature=self.DisTempr)

        return F.log_softmax(similarity, dim=1)
Esempio n. 9
0
    def forward(self, support, query, sup_len, que_len, metric='euc'):

        if self.DataParallel:
            support = support.squeeze(0)
            sup_len = sup_len[0]

        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(support, query)

        # 提取了任务结构后,将所有样本展平为一个批次
        support = support.view(n*k, sup_seq_len)
        support = self._embed(support, sup_len)
        query = self._embed(query, que_len)

        # ------------------------------------------------------
        # shape: [batch, seq, dim]
        # support = self.EmbedDrop(self.Embedding(support))
        # query = self.EmbedDrop(self.Embedding(query))
        #
        # # support = self.EmbedDrop(self.EmbedNorm(support))
        # # query = self.EmbedDrop(self.EmbedNorm(query))
        #
        # # shape: [batch, dim]
        # support = self.Encoder(support, sup_len)
        # query = self.Encoder(query, que_len)
        #
        # support = self.Decoder(support, sup_len)
        # query = self.Decoder(query, que_len)
        # ------------------------------------------------------

        assert support.size(1)==query.size(1), '支持集维度 %d 和查询集维度 %d 必须相同!'%\
                                               (support.size(1),query.size(1))

        dim = support.size(1)

        # 原型向量
        # shape: [n, dim]
        support = support.view(n, k, dim)
        support = self.Induction(support.unsqueeze(1)).squeeze()

        # 整型成为可比较的形状: [qk, n, dim]
        support = repeatProtoToCompShape(support, qk, n)
        query = repeatQueryToCompShape(query, qk, n)

        similarity = protoDisAdapter(support, query, qk, n, dim, dis_type='cos')

        # return t.softmax(similarity, dim=1)
        return F.log_softmax(similarity, dim=1)
Esempio n. 10
0
    def forward(self, support, query, sup_len=None, que_len=None):
        # input shape:
        # sup=[n, k, sup_seq_len, height, width]
        # que=[qk, que_seq_len, height, width]
        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(
            support, query, unsqueezed=False, is_matrix=True)
        height, width = query.size(2), query.size(3)

        support = support.view(n * k, sup_seq_len, height, width)

        # output shape: [batch, seq_len, feature]
        support = self.Embedding(support)
        query = self.Embedding(query)

        support = self.LstmEncoder(support)
        query = self.LstmEncoder(query)

        # TODO:直接展开序列作为特征
        support = support.view(n, k, -1).mean(dim=1)
        query = query.view(qk, -1)

        assert support.size(1)==query.size(1), \
            '支持集和查询集的嵌入后特征维度不相同!'

        dim = support.size(1)

        # 整型成为可比较的形状: [qk, n, dim]
        support = repeatProtoToCompShape(support, qk, n)
        query = repeatQueryToCompShape(query, qk, n)

        similarity = protoDisAdapter(support,
                                     query,
                                     qk,
                                     n,
                                     dim,
                                     dis_type='cos')

        # return t.softmax(similarity, dim=1)
        return F.log_softmax(similarity, dim=1)
Esempio n. 11
0
    def forward(self,
                support,
                query,
                sup_len,
                que_len,
                metric='euc',
                return_unadapted=False):

        if self.DataParallel:
            support = support.squeeze(0)
            sup_len = sup_len[0]

        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(
            support, query)

        qk_per_class = qk // n

        # 提取了任务结构后,将所有样本展平为一个批次
        support = support.view(n * k, sup_seq_len)

        # ------------------------------------------------------
        support, query = self._embed(support, sup_len), \
                         self._embed(query, que_len)
        # ------------------------------------------------------

        # support = self.Encoder(support, sup_len)
        # query = self.Encoder(query, que_len)

        assert support.size(1)==query.size(1), '支持集维度 %d 和查询集维度 %d 必须相同!'%\
                                               (support.size(1),query.size(1))

        dim = support.size(1)

        # contrastive-loss for regulization during training
        if self.training and self.ContraFac is not None:
            # union shape: [n, qk+k, dim]
            # here suppose query set is constructed in group by class
            union = t.cat(
                (support.view(n, k, dim), query.view(n, qk_per_class, dim)),
                dim=1)  # TODO: make it capable to process in batch

            adapted_union = self.SetFunc(union)

            # post-avg in default
            adapted_proto = adapted_union.mean(dim=1)

            # union shape: [(qk+k)*n, dim]
            adapted_union = adapted_union.view((qk_per_class + k) * n, dim)

            # let the whole dataset execute classification task based on the adapted prototypes
            adapted_proto = repeatProtoToCompShape(adapted_proto,
                                                   (qk_per_class + k) * n, n)
            adapted_union = repeatQueryToCompShape(adapted_union,
                                                   (qk_per_class + k) * n, n)

            adapted_sim = protoDisAdapter(adapted_proto,
                                          adapted_union,
                                          (qk_per_class + k) * n,
                                          n,
                                          dim,
                                          dis_type='euc')

            # here, the target label set has labels for both support set and query set,
            # where labels permute in order and cluster (every 'qk_per_class+k')
            adapted_res = F.log_softmax(adapted_sim, dim=1)

        if return_unadapted:
            unada_support = support.view(n, k, -1).mean(1)
            unada_support = repeatProtoToCompShape(unada_support, qk, n)

        ################################################################
        if self.Avg == 'post':

            # support set2set
            support = self.SetFunc(support.view(1, n * k, dim))

            # shape: [n, dim]
            support = support.view(n, k, dim).mean(dim=1)

        elif self.Avg == 'pre':

            # shape: [n, dim]
            support = support.view(n, k, dim).mean(dim=1)
            # support set2set
            support = self.SetFunc(support.unsqueeze(0))
        ################################################################

        # shape: [n, dim] -> [1, n, dim]
        # pre-avg in default, treat prototypes as sequence
        # support = support.view(n, k, dim).mean(dim=1).unsqueeze(0)
        # # support set2set
        # support = self.SetFunc(support)

        support = repeatProtoToCompShape(support, qk, n)
        query = repeatQueryToCompShape(query, qk, n)

        similarity = protoDisAdapter(support,
                                     query,
                                     qk,
                                     n,
                                     dim,
                                     dis_type='euc',
                                     temperature=self.DisTempr)

        if self.training and self.ContraFac is not None:
            return F.log_softmax(similarity, dim=1), adapted_res

        else:
            if return_unadapted:
                unada_sim = protoDisAdapter(unada_support,
                                            query,
                                            qk,
                                            n,
                                            dim,
                                            dis_type='euc',
                                            temperature=self.DisTempr)
                return F.log_softmax(similarity,
                                     dim=1), F.log_softmax(unada_sim, dim=1)

            else:
                return F.log_softmax(similarity, dim=1)
Esempio n. 12
0
    def forward(
            self,  # forward接受所有可能用到的参数
            support_seqs,
            support_imgs,
            support_lens,
            support_labels,
            query_seqs,
            query_imgs,
            query_lens,
            query_labels,
            epoch=None,
            metric='euc',
            return_embeddings=False):

        embedded_support_seqs, embedded_query_seqs, \
        embedded_support_imgs, embedded_query_imgs = self.embed(support_seqs, query_seqs,
                                                                support_lens, query_lens,
                                                                support_imgs, query_imgs)

        # support_fused_features seqs/imgs shape: [n, k, dim]
        # query seqs/imgs shape: [qk, dim]

        k, n, qk = self.TaskParams.k, self.TaskParams.n, self.TaskParams.qk
        qk_per_class = qk // n

        # 直接使用seq和img的raw output进行fuse
        support_fused_features = self._fuse(embedded_support_seqs,
                                            embedded_support_imgs,
                                            fuse_dim=1)
        query_fused_features = self._fuse(embedded_query_seqs,
                                          embedded_query_imgs,
                                          fuse_dim=1)
        dim = support_fused_features.size(1)

        # contrastive-loss for regulization during training
        if self.training and self.ContraFac is not None:
            # union shape: [n, qk+k, dim]
            # here suppose query set is constructed in group by class
            union = t.cat((support_fused_features.view(
                n, k, dim), query_fused_features.view(n, qk_per_class, dim)),
                          dim=1)  # TODO: make it capable to process in batch

            adapted_union = self.SetFunc(union)

            # post-avg in default
            adapted_proto = adapted_union.mean(dim=1)

            # union shape: [(qk+k)*n, dim]
            adapted_union = adapted_union.view((qk_per_class + k) * n, dim)

            # let the whole dataset execute classification task based on the adapted prototypes
            adapted_proto = repeatProtoToCompShape(adapted_proto,
                                                   (qk_per_class + k) * n, n)
            adapted_union = repeatQueryToCompShape(adapted_union,
                                                   (qk_per_class + k) * n, n)

            adapted_sim = protoDisAdapter(adapted_proto,
                                          adapted_union,
                                          (qk_per_class + k) * n,
                                          n,
                                          dim,
                                          dis_type='euc')

            # here, the target label set has labels for both support set and query set,
            # where labels permute in order and cluster (every 'qk_per_class+k')
            adapted_logits = F.log_softmax(adapted_sim, dim=1)

        # if return_unadapted:
        #     unada_support = support_fused_features.view(n,k,-1).mean(1)
        #     unada_support = repeatProtoToCompShape(unada_support,
        #                                            qk, n)

        ################################################################
        if self.Avg == 'post':

            # support set2set
            support_fused_features = self.SetFunc(
                support_fused_features.view(1, n * k, dim))

            # shape: [n, dim]
            support_fused_features = support_fused_features.view(
                n, k, dim).mean(dim=1)

        elif self.Avg == 'pre':

            # shape: [n, dim]
            support_fused_features = support_fused_features.view(
                n, k, dim).mean(dim=1)
            # support set2set
            support_fused_features = self.SetFunc(
                support_fused_features.unsqueeze(0))
        ################################################################

        # shape: [n, dim] -> [1, n, dim]
        # pre-avg in default, treat prototypes as sequence
        # support = support.view(n, k, dim).mean(dim=1).unsqueeze(0)
        # # support set2set
        # support = self.SetFunc(support)

        support_fused_features = repeatProtoToCompShape(
            support_fused_features, qk, n)
        query_fused_features = repeatQueryToCompShape(query_fused_features, qk,
                                                      n)

        similarity = protoDisAdapter(support_fused_features,
                                     query_fused_features,
                                     qk,
                                     n,
                                     dim,
                                     dis_type='euc',
                                     temperature=self.DisTemp)

        logits = F.log_softmax(similarity, dim=1)
        loss = self.LossFunc(logits, query_labels)

        # 在原损失基础上添加一个对比损失值帮助训练
        if self.training and self.ContraFac is not None:
            # 此处假设没有shuffle,标签直接从0排列到n
            adapted_labels = t.arange(0, n, dtype=t.long).cuda()
            adapted_labels = adapted_labels.unsqueeze(1).expand(
                (n, (qk_per_class + k))).flatten()
            contrastive_loss = self.LossFunc(adapted_logits, adapted_labels)
            loss += self.ContraFac * contrastive_loss

        return {'logits': logits, 'loss': loss, 'predict': None}