Esempio n. 1
0
    def forward(self, support, query, sup_len, que_len, support_labels):

        if self.DataParallel:
            support = support.squeeze(0)
            sup_len = sup_len[0]
            support_labels = support_labels[0]

        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(
            support, query)

        # 提取了任务结构后,将所有样本展平为一个批次
        support = support.view(n * k, sup_seq_len)

        s_predict = self.Learner(support, sup_len)
        loss = self.LossFn(s_predict, support_labels)
        self.Learner.zero_grad()  # 先清空基学习器梯度
        grads = t.autograd.grad(loss,
                                self.Learner.parameters(),
                                create_graph=True)
        adapted_state_dict = self.Learner.clone_state_dict()

        # 计算适应后的参数
        for (key, val), grad in zip(self.Learner.named_parameters(), grads):
            # 利用已有参数和每个参数对应的alpha调整系数来计算适应后的参数
            adapted_state_dict[key] = val - self.alpha(key) * grad

        # 利用适应后的参数来生成测试集结果
        return self.Learner(query, que_len,
                            params=adapted_state_dict).contiguous()
Esempio n. 2
0
    def forward(self, support, query, sup_len, que_len):

        if self.DataParallel:
            support = support.squeeze(0)
            sup_len = sup_len[0]

        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(support, query)

        support = support.view(n * k, sup_seq_len)

        support, query = self._embed(support, sup_len), \
                         self._embed(query, que_len)
        # 计算类的原型向量
        # shape: [n, k, d]
        support = support.view(n, k, -1)
        d = support.size(2)

        # coupling shape: [n, d]
        coupling = t.zeros_like(support).sum(dim=2)
        proto = None
        # 使用动态路由来计算原型向量
        for i in range(self.Iters):
            coupling, proto = dynamicRouting(self.Transformer,
                                             support, coupling,
                                             k)

        support = repeatProtoToCompShape(proto, qk, n)
        query = repeatQueryToCompShape(query, qk, n)

        return self.NTN(support, query, n).view(qk, n)
Esempio n. 3
0
    def forward(self, support, query, sup_len, que_len, metric='euc'):

        if self.DataParallel:
            support = support.squeeze(0)
            sup_len = sup_len[0]

        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(
            support, query)

        # 提取了任务结构后,将所有样本展平为一个批次
        support = support.view(n * k, sup_seq_len)

        support = self._embed(support, sup_len)
        query = self._embed(query, que_len)

        assert support.size(1)==query.size(1), '支持集维度 %d 和查询集维度 %d 必须相同!'%\
                                               (support.size(1),query.size(1))

        dim = support.size(1)

        # 整型成为可比较的形状: [qk, n, dim]
        support = support.repeat((qk, 1, 1)).view(qk, n * k, -1)
        query = query.repeat(n * k, 1,
                             1).transpose(0,
                                          1).contiguous().view(qk, n * k, -1)

        # directly compare with support samples, instead of prototypes
        # shape: [qk, n*k, dim]->[qk, n, k, dim] -> [qk, n]
        similarity = ((support - query)**2).neg().view(qk, n, k, -1).sum(
            (-1, -2))

        return F.log_softmax(similarity, dim=1)
Esempio n. 4
0
    def forward(self, support, query, sup_len, que_len, metric='euc'):
        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(
            support, query)

        # forehead forward to obtain task prototype
        f_support = support.view(n * k, sup_seq_len)
        f_support = self.Embedding(f_support)
        f_support = self.EmbedNorm(f_support)
        f_support = self.Encoder(f_support, sup_len)
        f_support = self.CNN(f_support, sup_len)

        f_support = f_support.view(n, k, -1)

        task_proto = f_support.mean((0, 1))

        # 提取了任务结构后,将所有样本展平为一个批次
        support = support.view(n * k, sup_seq_len)

        # shape: [batch, seq, dim]
        support = self.Embedding(support)
        query = self.Embedding(query)

        support = self.EmbedNorm(support)
        query = self.EmbedNorm(query)

        # shape: [batch, dim]
        support = self.Encoder(support, sup_len)
        query = self.Encoder(query, que_len)

        # task-conditioning affine
        support = self.TEN(support, task_proto)
        query = self.TEN(query, task_proto)

        support = self.CNN(support, sup_len)
        query = self.CNN(query, que_len)

        assert support.size(1)==query.size(1), '支持集维度 %d 和查询集维度 %d 必须相同!'%\
                                               (support.size(1),query.size(1))
        dim = support.size(1)

        # 原型向量
        # shape: [n, dim]
        support = support.view(n, k, dim).mean(dim=1)

        # 整型成为可比较的形状: [qk, n, dim]
        support = repeatProtoToCompShape(support, qk, n)
        query = repeatQueryToCompShape(query, qk, n)

        similarity = protoDisAdapter(support,
                                     query,
                                     qk,
                                     n,
                                     dim,
                                     dis_type='euc')

        # return t.softmax(similarity, dim=1)
        return F.log_softmax(similarity, dim=1)
Esempio n. 5
0
    def forward(self,
                support,
                query,
                sup_len,
                que_len,
                support_labels,
                adapt_iter=1):
        method = self.Method

        if self.DataParallel:
            support = support.squeeze(0)
            sup_len = sup_len[0]
            support_labels = support_labels[0]

        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(
            support, query)

        # 提取了任务结构后,将所有样本展平为一个批次
        support = support.view(n * k, sup_seq_len)

        # ---------------------------------------------------------------
        # fix the bug, which: reset the 'adapted_par' in every adapt iteration
        # ---------------------------------------------------------------
        adapted_state_dict = self.Learner.clone_state_dict()
        for n, p in adapted_state_dict.items():
            p.requires_grad_(True)

        for a_i in range(self.AdaptIter):
            # ---------------------------------------------------------------
            # fix the bug, which: use original parameters instead of the adapted
            # ones in every adapt iteration
            # ---------------------------------------------------------------
            adapted_pars = collectParamsFromStateDict(adapted_state_dict)

            s_predict = self.Learner(support,
                                     sup_len,
                                     params=adapted_state_dict)
            loss = self.LossFn(s_predict, support_labels)
            grads = t.autograd.grad(loss, adapted_pars, create_graph=True)

            # 计算适应后的参数
            for (key, val), grad in zip(adapted_state_dict.items(), grads):
                adapted_state_dict[key] = val - self.MetaLr * grad

        if method == 'fomaml':
            adapted_par = collectParamsFromStateDict(adapted_state_dict)

            # 对于一阶MAML,需要查询集loss和适应后参数来计算梯度
            return self.Learner(query, que_len,
                                params=adapted_state_dict), adapted_par

        # 对于vanilia MAML,只需要query loss
        else:
            return self.Learner(query, que_len, params=adapted_state_dict)
Esempio n. 6
0
    def forward(self,
                support,
                query,
                sup_len,
                que_len,
                metric='euc',
                return_embeddings=False,
                **kwargs):

        if self.DataParallel:
            support = support.squeeze(0)
            sup_len = sup_len[0]

        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(
            support, query)

        # 提取了任务结构后,将所有样本展平为一个批次
        support = support.view(n * k, sup_seq_len)

        support, query = self._embed(support, sup_len), \
                         self._embed(query, que_len)


        assert support.size(1)==query.size(1), '支持集维度 %d 和查询集维度 %d 必须相同!'%\
                                               (support.size(1),query.size(1))

        dim = support.size(1)

        # 原型向量
        # shape: [n, dim]
        orig_protos = support.view(n, k, dim).mean(dim=1)

        # 整型成为可比较的形状: [qk, n, dim]
        protos = repeatProtoToCompShape(orig_protos, qk, n)
        rep_query = repeatQueryToCompShape(query, qk, n)

        similarity = protoDisAdapter(protos,
                                     rep_query,
                                     qk,
                                     n,
                                     dim,
                                     dis_type='euc',
                                     temperature=self.DistTemp)

        # return t.softmax(similarity, dim=1)
        if return_embeddings:
            return support, query.view(qk, -1), orig_protos, F.log_softmax(
                similarity, dim=1)
        return F.log_softmax(similarity, dim=1)
Esempio n. 7
0
    def forward(self, support, query, sup_len, que_len, metric='euc'):
        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(
            support, query)

        # 提取了任务结构后,将所有样本展平为一个批次
        support = support.view(n * k, sup_seq_len)

        # shape: [batch, seq, dim]
        support = self.Embedding(support)
        query = self.Embedding(query)

        support = self.EmbedNorm(support)
        query = self.EmbedNorm(query)

        # shape: [batch, dim]
        support = self.Encoder(support, sup_len)
        query = self.Encoder(query, que_len)

        support = self.Decoder(support, sup_len)
        query = self.Decoder(query, que_len)

        assert support.size(1)==query.size(1), '支持集维度 %d 和查询集维度 %d 必须相同!'%\
                                               (support.size(1),query.size(1))

        dim = support.size(1)

        # support set2set
        support_weight = t.softmax(self.SetFunc(support.view(1, n * k,
                                                             dim)).view(n, k),
                                   dim=1)
        support_weight = support_weight.unsqueeze(-1).repeat(1, 1, dim)

        # shape: [n, k, dim] -> [n, dim]
        support = support.view(n, k, dim)
        support = (support * support_weight).sum(dim=1)

        support = repeatProtoToCompShape(support, qk, n)
        query = repeatQueryToCompShape(query, qk, n)

        similarity = protoDisAdapter(support,
                                     query,
                                     qk,
                                     n,
                                     dim,
                                     dis_type='euc',
                                     temperature=self.DisTempr)

        return F.log_softmax(similarity, dim=1)
Esempio n. 8
0
    def forward(self, support, query, sup_len, que_len, metric='euc'):

        if self.DataParallel:
            support = support.squeeze(0)
            sup_len = sup_len[0]

        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(support, query)

        # 提取了任务结构后,将所有样本展平为一个批次
        support = support.view(n*k, sup_seq_len)
        support = self._embed(support, sup_len)
        query = self._embed(query, que_len)

        # ------------------------------------------------------
        # shape: [batch, seq, dim]
        # support = self.EmbedDrop(self.Embedding(support))
        # query = self.EmbedDrop(self.Embedding(query))
        #
        # # support = self.EmbedDrop(self.EmbedNorm(support))
        # # query = self.EmbedDrop(self.EmbedNorm(query))
        #
        # # shape: [batch, dim]
        # support = self.Encoder(support, sup_len)
        # query = self.Encoder(query, que_len)
        #
        # support = self.Decoder(support, sup_len)
        # query = self.Decoder(query, que_len)
        # ------------------------------------------------------

        assert support.size(1)==query.size(1), '支持集维度 %d 和查询集维度 %d 必须相同!'%\
                                               (support.size(1),query.size(1))

        dim = support.size(1)

        # 原型向量
        # shape: [n, dim]
        support = support.view(n, k, dim)
        support = self.Induction(support.unsqueeze(1)).squeeze()

        # 整型成为可比较的形状: [qk, n, dim]
        support = repeatProtoToCompShape(support, qk, n)
        query = repeatQueryToCompShape(query, qk, n)

        similarity = protoDisAdapter(support, query, qk, n, dim, dis_type='cos')

        # return t.softmax(similarity, dim=1)
        return F.log_softmax(similarity, dim=1)
Esempio n. 9
0
    def forward(self, support, query, sup_len, que_len, support_labels):

        if self.DataParallel:
            support = support.squeeze(0)
            sup_len = sup_len[0]
            support_labels = support_labels[0]

        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(
            support, query)

        # 提取了任务结构后,将所有样本展平为一个批次
        support = support.view(n * k, sup_seq_len)

        # ---------------------------------------------------------------
        # fix the bug, which: reset the 'adapted_par' in every adapt iteration
        # ---------------------------------------------------------------
        adapted_state_dict = self.Learner.clone_state_dict()
        for n, p in adapted_state_dict.items():
            p.requires_grad_(True)

        for a_i in range(self.AdaptIter):
            # ---------------------------------------------------------------
            # fix the bug, which: use original parameters instead of the adapted
            # ones in every adapt iteration
            # ---------------------------------------------------------------
            adapted_pars = collectParamsFromStateDict(
                adapted_state_dict
            )  #self.Learner.adapt_parameters(with_named=False)

            s_predict = self.Learner(support,
                                     sup_len,
                                     params=adapted_state_dict)
            loss = self.LossFn(s_predict, support_labels)
            grads = t.autograd.grad(loss, adapted_pars, create_graph=True)

            # 计算适应后的参数
            for (key, val), grad in zip(adapted_state_dict.items(), grads):
                adapted_lr = self.PreLayerLr[a_i].expand_as(grad)
                adapted_state_dict[key] = val - adapted_lr * grad

        # 利用适应后的参数来生成测试集结果
        return self.Learner(query, que_len, params=adapted_state_dict)
Esempio n. 10
0
    def forward(self, support, query, sup_len=None, que_len=None):
        # input shape:
        # sup=[n, k, sup_seq_len, height, width]
        # que=[qk, que_seq_len, height, width]
        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(
            support, query, unsqueezed=False, is_matrix=True)
        height, width = query.size(2), query.size(3)

        support = support.view(n * k, sup_seq_len, height, width)

        # output shape: [batch, seq_len, feature]
        support = self.Embedding(support)
        query = self.Embedding(query)

        support = self.LstmEncoder(support)
        query = self.LstmEncoder(query)

        # TODO:直接展开序列作为特征
        support = support.view(n, k, -1).mean(dim=1)
        query = query.view(qk, -1)

        assert support.size(1)==query.size(1), \
            '支持集和查询集的嵌入后特征维度不相同!'

        dim = support.size(1)

        # 整型成为可比较的形状: [qk, n, dim]
        support = repeatProtoToCompShape(support, qk, n)
        query = repeatQueryToCompShape(query, qk, n)

        similarity = protoDisAdapter(support,
                                     query,
                                     qk,
                                     n,
                                     dim,
                                     dis_type='cos')

        # return t.softmax(similarity, dim=1)
        return F.log_softmax(similarity, dim=1)
Esempio n. 11
0
    def forward(self,
                support,
                query,
                sup_len,
                que_len,
                support_labels,
                query_labels,
                if_cache_data=False):

        # 由于数据并行关系,为了保证支持集的完整性,
        # 将support的batch维度置为1
        if self.DataParallel:
            support = support.squeeze(0)
            sup_len = sup_len[0]
            support_labels = support_labels[0]

        # support shape: [n, k, seq]
        # query shape: [qk, seq]
        # print(f'sup={support.size()},que={query.size()},sup_len={sup_len.size()},que_len={que_len.size()},'
        #       f'sup_lab={support_labels.size()},que_lab={query_labels.size()}')

        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(
            support, query)

        nClusters = n  # 初始类簇的数量等于类数量
        # batch = self._process_batch(sample, super_classes=super_classes)

        nInitialClusters = nClusters

        # run data through network
        support = self._embed(support.view(n * k, sup_seq_len),
                              sup_len)  # 数据嵌入
        query = self._embed(query, que_len)

        # 此处设定batch=1
        # support_labels = t.arange(0, n)[:, None].repeat(1, k).view(1, -1)
        support_labels = support_labels.unsqueeze(0)
        query_labels = query_labels.unsqueeze(0)
        support = support.view(n * k, -1).unsqueeze(0)
        query = query.view(qk, -1).unsqueeze(0)

        # create probabilities for points
        # _, idx = np.unique(batch.y_train.squeeze().data.cpu().numpy(), return_inverse=True)
        prob_support = one_hot(support_labels,
                               nClusters).cuda()  # 将属于类簇的概率初始化为标签的one-hot

        # make initial radii for labeled clusters
        bsize = support.size()[0]
        radii = t.ones(bsize, nClusters).cuda(
        )  # * t.exp(self.Sigma)  # 初始半径由log_sigma_l初始化(该参数可学习)

        if self.Sigma is not None:
            radii *= t.exp(self.Sigma)

        cluster_labels = t.arange(0, nClusters).cuda().long()

        # compute initial prototypes from labeled examples
        # 由于初始时,共有类别个类簇,而且类簇的分配系数是one-hot,因此初始类簇就是类中心
        # shape: [batch, cluster, dim]
        protos = self._compute_protos(support, prob_support)

        # estimate lamda
        # lamda = self.estimate_lambda(protos.data, False)

        # loop for a given number of clustering steps
        for ii in range(self.NumClusterSteps):
            # protos = protos.data
            # iterate over labeled examples to reassign first
            for i, ex in enumerate(support[0]):
                # 找到样本label对应的cluster的index
                idxs = t.nonzero(
                    support_labels[0, i] == cluster_labels)[0]  # TODO: 取0?

                #****************************************************************************
                # 计算与标签对应的类簇的距离(由于其他不对应的类簇的距离都是正无穷,求min时直接可忽略)
                # distances = self._compute_distances(protos[:, idxs, :], ex.data)
                # if t.min(distances) > lamda:
                #****************************************************************************

                distances = self._compute_distances(protos, ex)
                # 如果发现离自己最近的cluster不是自己的类的cluster,就直接增加一个cluster
                if not t.any(t.min(distances, dim=1).indices == idxs).item():

                    nClusters, protos, radii = self._add_cluster(
                        nClusters,
                        protos,
                        radii,
                        cluster_type='labeled',
                        ex=ex.data)
                    cluster_labels = t.cat(
                        [cluster_labels, support_labels[0, [i]].data],
                        dim=0)  # 将样本标签设定为类簇标签

            # perform partial reassignment based on newly created labeled clusters
            if nClusters > nInitialClusters:
                support_targets = support_labels.data[
                    0, :,
                    None] == cluster_labels  # 找到每个样本实际对应的类簇(每一行是每个样本对应的类簇bool)
                prob_support = assign_cluster_radii_limited(
                    protos, support, radii, support_targets)  # 样本属于每个类簇的概率

            nTrainClusters = nClusters
            protos = protos.cuda()
            protos = self._compute_protos(support, prob_support)
            protos, radii, cluster_labels = self.delete_empty_clusters(
                protos, prob_support, radii, cluster_labels)

        # 计算query的类簇logits
        logits = compute_logits_radii(protos,
                                      query,
                                      radii,
                                      use_sigma=self.Sigma
                                      is not None).squeeze()

        # convert class targets into indicators for supports in each class
        labels = query_labels  # batch.y_test.data
        labels[labels >= nInitialClusters] = -1

        support_targets = labels[0, :,
                                 None] == cluster_labels  # 寻找查询集样本的标签对应的类簇
        loss = self.loss(
            logits, support_targets, cluster_labels
        )  # support_targets: 查询样本标签对应的类簇指示; suppott_labels: 类簇的标签

        # map support predictions back into classes to check accuracy
        _, support_preds = t.max(logits.data, dim=1)
        y_pred = cluster_labels[support_preds]

        if if_cache_data:
            self.Clusters = protos
            self.ClusterLabels = cluster_labels
            return support, query, (
                y_pred == query_labels).sum().item() / query.size(1)  #y_pred#

        return y_pred, loss
Esempio n. 12
0
    def forward(self,
                support,
                query,
                sup_len,
                que_len,
                metric='euc',
                return_unadapted=False):

        if self.DataParallel:
            support = support.squeeze(0)
            sup_len = sup_len[0]

        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(
            support, query)

        qk_per_class = qk // n

        # 提取了任务结构后,将所有样本展平为一个批次
        support = support.view(n * k, sup_seq_len)

        # ------------------------------------------------------
        support, query = self._embed(support, sup_len), \
                         self._embed(query, que_len)
        # ------------------------------------------------------

        # support = self.Encoder(support, sup_len)
        # query = self.Encoder(query, que_len)

        assert support.size(1)==query.size(1), '支持集维度 %d 和查询集维度 %d 必须相同!'%\
                                               (support.size(1),query.size(1))

        dim = support.size(1)

        # contrastive-loss for regulization during training
        if self.training and self.ContraFac is not None:
            # union shape: [n, qk+k, dim]
            # here suppose query set is constructed in group by class
            union = t.cat(
                (support.view(n, k, dim), query.view(n, qk_per_class, dim)),
                dim=1)  # TODO: make it capable to process in batch

            adapted_union = self.SetFunc(union)

            # post-avg in default
            adapted_proto = adapted_union.mean(dim=1)

            # union shape: [(qk+k)*n, dim]
            adapted_union = adapted_union.view((qk_per_class + k) * n, dim)

            # let the whole dataset execute classification task based on the adapted prototypes
            adapted_proto = repeatProtoToCompShape(adapted_proto,
                                                   (qk_per_class + k) * n, n)
            adapted_union = repeatQueryToCompShape(adapted_union,
                                                   (qk_per_class + k) * n, n)

            adapted_sim = protoDisAdapter(adapted_proto,
                                          adapted_union,
                                          (qk_per_class + k) * n,
                                          n,
                                          dim,
                                          dis_type='euc')

            # here, the target label set has labels for both support set and query set,
            # where labels permute in order and cluster (every 'qk_per_class+k')
            adapted_res = F.log_softmax(adapted_sim, dim=1)

        if return_unadapted:
            unada_support = support.view(n, k, -1).mean(1)
            unada_support = repeatProtoToCompShape(unada_support, qk, n)

        ################################################################
        if self.Avg == 'post':

            # support set2set
            support = self.SetFunc(support.view(1, n * k, dim))

            # shape: [n, dim]
            support = support.view(n, k, dim).mean(dim=1)

        elif self.Avg == 'pre':

            # shape: [n, dim]
            support = support.view(n, k, dim).mean(dim=1)
            # support set2set
            support = self.SetFunc(support.unsqueeze(0))
        ################################################################

        # shape: [n, dim] -> [1, n, dim]
        # pre-avg in default, treat prototypes as sequence
        # support = support.view(n, k, dim).mean(dim=1).unsqueeze(0)
        # # support set2set
        # support = self.SetFunc(support)

        support = repeatProtoToCompShape(support, qk, n)
        query = repeatQueryToCompShape(query, qk, n)

        similarity = protoDisAdapter(support,
                                     query,
                                     qk,
                                     n,
                                     dim,
                                     dis_type='euc',
                                     temperature=self.DisTempr)

        if self.training and self.ContraFac is not None:
            return F.log_softmax(similarity, dim=1), adapted_res

        else:
            if return_unadapted:
                unada_sim = protoDisAdapter(unada_support,
                                            query,
                                            qk,
                                            n,
                                            dim,
                                            dis_type='euc',
                                            temperature=self.DisTempr)
                return F.log_softmax(similarity,
                                     dim=1), F.log_softmax(unada_sim, dim=1)

            else:
                return F.log_softmax(similarity, dim=1)
Esempio n. 13
0
    def forward(self, support, query, sup_len, que_len):

        if self.DataParallel:
            support = support.squeeze(0)
            sup_len = sup_len[0]

        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(
            support, query)

        support = support.view(n * k, sup_seq_len)

        # ------------------------------------------------------
        # shape: [batch, seq, dim]
        support, query = self._embed(support, sup_len), \
                         self._embed(query, que_len)
        # ------------------------------------------------------

        assert support.size(1)==query.size(1), '支持集维度 %d 和查询集维度 %d 必须相同!'%\
                                               (support.size(1),query.size(1))

        dim = support.size(1)

        # 将嵌入的支持集展为合适形状
        # support shape: [n,k,d]->[n,k,d]
        support = support.view(n, k, dim)
        # query shape: [qk, d]
        query = query.view(qk, -1)

        # 将支持集嵌入视为一个单通道矩阵输入到特征注意力模块中获得特征注意力
        # 并重复qk次让基于支持集的特征注意力对于qk个样本相同
        # 输入: [n,k,d]->[n,1,k,d]
        # 输出: [n,1,1,d]->[n,d]->[qk,n,d]
        feature_attentions = self.FeatureAttention(
            support.unsqueeze(dim=1)).squeeze().repeat(qk, 1, 1)

        # 将支持集重复qk次,将查询集重复n*k次以获得qk*n*k长度的样本
        # 便于在获得样例注意力时,对不同的查询集有不同的样例注意力
        # 将qk,n与k均压缩到一个维度上以便输入到线性层中
        # query_expand shape:[qk,d]->[n*k,qk,d]->[qk,n,k,d]
        # support_expand shape: [n,k,d]->[qk,n,k,d]
        support_expand = support.repeat((qk, 1, 1, 1)).view(qk * n * k, -1)
        query_expand = query.repeat(
            (n * k, 1, 1)).transpose(0, 1).contiguous().view(qk * n * k, -1)

        # 利用样例注意力注意力对齐支持集样本
        # shape: [qk,n,k,d]
        support = self.InstanceAttention(support_expand, query_expand, k, qk,
                                         n)

        # 生成对于每一个qk都不同的类原型向量
        # 注意力对齐以后,将同一类内部的加权的向量相加以后
        # proto shape: [qk,n,k,d]->[qk,n,d]
        support = support.sum(dim=2).squeeze()
        # support = support.mean(dim=1).repeat((qk,1,1)).view(qk,n,-1)

        # query shape: [qk,d]->[qk,n,d]
        query = query.unsqueeze(dim=1).repeat(1, n, 1)

        # dis_attented shape: [qk*n,n,d]->[qk*n,n,d]->[qk*n,n]
        # dis_attented = (((support-query)**2)).sum(dim=2).neg()
        dis_attented = (((support - query)**2) *
                        feature_attentions).sum(dim=2).neg()

        return t.log_softmax(dis_attented, dim=1)
Esempio n. 14
0
    def forward(self,
                support,
                query,
                sup_len,
                que_len,
                support_labels,
                query_labels,
                if_cache_data=False):

        if self.DataParallel:
            support = support.squeeze(0)
            sup_len = sup_len[0]
            support_labels = support_labels[0]

        n, k, qk, sup_seq_len, que_seq_len = extractTaskStructFromInput(
            support, query)

        nClusters = n  # 初始类簇的数量等于类数量
        # batch = self._process_batch(sample, super_classes=super_classes)

        nInitialClusters = nClusters

        # run data through network
        support = self._embed(support.view(n * k, sup_seq_len),
                              sup_len)  # 数据嵌入
        query = self._embed(query, que_len)

        # 此处设定batch=1
        # support_labels = t.arange(0, n)[:, None].repeat(1, k).view(1, -1)
        support_labels = support_labels.unsqueeze(0)
        query_labels = query_labels.unsqueeze(0)
        support = support.view(n * k, -1).unsqueeze(0)
        query = query.view(qk, -1).unsqueeze(0)

        # create probabilities for points
        # _, idx = np.unique(batch.y_train.squeeze().data.cpu().numpy(), return_inverse=True)
        prob_support = one_hot(support_labels,
                               nClusters).cuda()  # 将属于类簇的概率初始化为标签的one-hot

        # make initial radii for labeled clusters
        bsize = support.size()[0]
        radii = t.ones(bsize, nClusters).cuda() * t.exp(
            self.Sigma)  # 初始半径由log_sigma_l初始化(该参数可学习)

        cluster_labels = t.arange(0, nClusters).cuda().long()

        # compute initial prototypes from labeled examples
        # 由于初始时,共有类别个类簇,而且类簇的分配系数是one-hot,因此初始类簇就是类中心
        # shape: [batch, cluster, dim]
        protos = self._compute_protos(support, prob_support)

        # estimate lamda
        lamda = self.estimate_lambda(protos.data, False)

        # loop for a given number of clustering steps
        for ii in range(self.NumClusterSteps):
            # protos = protos.data
            # iterate over labeled examples to reassign first
            for i, ex in enumerate(support[0]):
                # 找到样本label对应的cluster的index
                idxs = t.nonzero(support_labels[0,
                                                i] == cluster_labels).squeeze(
                                                    0)  # TODO: 取0?
                # 计算与标签对应的类簇的距离(由于其他不对应的类簇的距离都是正无穷,求min时直接可忽略)
                distances = self._compute_distances(protos[:, idxs, :],
                                                    ex.data)
                # print(distances.tolist(), lamda.item())
                if t.min(distances) > lamda:
                    nClusters, protos, radii = self._add_cluster(
                        nClusters,
                        protos,
                        radii,
                        cluster_type='labeled',
                        ex=ex.data)
                    cluster_labels = t.cat(
                        [cluster_labels, support_labels[0, [i]].data],
                        dim=0)  # 将样本标签设定为类簇标签

            # perform partial reassignment based on newly created labeled clusters
            if nClusters > nInitialClusters:
                support_targets = support_labels.data[
                    0, :,
                    None] == cluster_labels  # 找到每个样本实际对应的类簇(每一行是每个样本对应的类簇bool)
                prob_support = assign_cluster_radii_limited(
                    protos, support, radii, support_targets)  # 样本属于每个类簇的概率

            nTrainClusters = nClusters

            # # iterate over unlabeled examples
            # if batch.x_unlabel is not None:
            #     h_unlabel = self._run_forward(batch.x_unlabel)
            #     h_all = t.cat([h_train, h_unlabel], dim=1)
            #     unlabeled_flag = t.LongTensor([-1]).cuda()
            #
            #     for i, ex in enumerate(h_unlabel[0]):
            #         distances = self._compute_distances(protos, ex.data)
            #         if t.min(distances) > lamda:
            #             nClusters, protos, radii = self._add_cluster(nClusters, protos, radii,
            #                                                                cluster_type='unlabeled', ex=ex.data)
            #             cluster_labels = t.cat([cluster_labels, unlabeled_flag], dim=0)
            #
            #     # add new, unlabeled clusters to the total probability
            #     if nClusters > nTrainClusters:
            #         unlabeled_clusters = t.zeros(prob_support.size(0), prob_support.size(1), nClusters - nTrainClusters)
            #         prob_support = t.cat([prob_support, Variable(unlabeled_clusters).cuda()], dim=2)
            #
            #     prob_unlabel = assign_cluster_radii(Variable(protos).cuda(), h_unlabel, radii)
            #     prob_unlabel_nograd = Variable(prob_unlabel.data, requires_grad=False).cuda()
            #     prob_all = t.cat([Variable(prob_support.data, requires_grad=False), prob_unlabel_nograd], dim=1)
            #
            #     protos = self._compute_protos(h_all, prob_all)
            #     protos, radii, cluster_labels = self.delete_empty_clusters(protos, prob_all, radii, cluster_labels)
            # else:
            protos = protos.cuda()
            protos = self._compute_protos(support, prob_support)
            protos, radii, cluster_labels = self.delete_empty_clusters(
                protos, prob_support, radii, cluster_labels)

        # 计算query的类簇logits
        logits = compute_logits_radii(protos, query, radii).squeeze()

        # convert class targets into indicators for supports in each class
        labels = query_labels  #batch.y_test.data
        labels[labels >= nInitialClusters] = -1

        support_targets = labels[0, :,
                                 None] == cluster_labels  # 寻找查询集样本的标签对应的类簇
        loss = self.loss(
            logits, support_targets, cluster_labels
        )  # support_targets: 查询样本标签对应的类簇指示; suppott_labels: 类簇的标签

        # map support predictions back into classes to check accuracy
        _, support_preds = t.max(logits.data, dim=1)
        y_pred = cluster_labels[support_preds]

        if if_cache_data:
            self.Clusters = protos
            self.ClusterLabels = cluster_labels
            return support, query, (
                y_pred == query_labels).sum().item() / query.size(1)

        return y_pred, loss