Exemplo n.º 1
0
    def solve(self):
        stop = False
        if self.resume:
            self.iters += 1
            self.loop += 1

        while True:
            # updating the target label hypothesis through clustering
            target_hypt = {}
            filtered_classes = []
            with torch.no_grad():
                # self.update_ss_alignment_loss_weight()
                print('Clustering based on %s...' % self.source_name)
                # 1-3 生成目标域的伪标签
                self.update_labels()
                self.clustered_target_samples = self.clustering.samples
                target_centers = self.clustering.centers
                center_change = self.clustering.center_change
                path2label = self.clustering.path2label

                # updating the history
                self.register_history('target_centers', target_centers,
                                      self.opt.CLUSTERING.HISTORY_LEN)
                self.register_history('ts_center_dist', center_change,
                                      self.opt.CLUSTERING.HISTORY_LEN)
                self.register_history('target_labels', path2label,
                                      self.opt.CLUSTERING.HISTORY_LEN)

                if self.clustered_target_samples is not None and \
                        self.clustered_target_samples['gt'] is not None:
                    preds = to_onehot(self.clustered_target_samples['label'],
                                      self.opt.DATASET.NUM_CLASSES)
                    gts = self.clustered_target_samples['gt']
                    # 模型评估,mean_acc, accuracy
                    res = self.model_eval(preds, gts)
                    print('Clustering %s: %.4f' % (self.opt.EVAL_METRIC, res))

                # check if meet the stop condition
                stop = self.complete_training()
                if stop:
                    break

                # 4.过滤掉模糊的样本和类别,filtering the clustering results
                target_hypt, filtered_classes = self.filtering()

                # update dataloaders
                self.construct_categorical_dataloader(target_hypt,
                                                      filtered_classes)
                # update train data setting
                self.compute_iters_per_loop(filtered_classes)

            # 5.k步更新网络参数,k-step update of network parameters through forward-backward process
            self.update_network(filtered_classes)
            self.loop += 1

        print('Training Done!')
Exemplo n.º 2
0
    def feature_clustering(self, feature_extractor, loader):
        centers = None
        self.stop = False

        # self.collect_samples(net, loader)
        self.collect_samples(feature_extractor, loader)
        feature = self.samples['feature']

        refs = to_cuda(torch.LongTensor(range(self.num_classes)).unsqueeze(1))
        num_samples = feature.size(0)
        # 啥意思??
        num_split = ceil(1.0 * num_samples / self.max_len)

        while True:
            self.clustering_stop(centers)
            if centers is not None:
                self.centers = centers
            if self.stop:
                break

            centers = 0
            count = 0

            start = 0
            for N in range(num_split):
                cur_len = min(self.max_len, num_samples - start)
                cur_feature = feature.narrow(0, start, cur_len)
                dist2center, labels = self.assign_labels(cur_feature)
                labels_onehot = to_onehot(labels, self.num_classes)
                count += torch.sum(labels_onehot, dim=0)
                labels = labels.unsqueeze(0)
                mask = (labels == refs).unsqueeze(2).type(
                    torch.cuda.FloatTensor)
                reshaped_feature = cur_feature.unsqueeze(0)
                # update centers
                centers += torch.sum(reshaped_feature * mask, dim=1)
                start += cur_len

            mask = (count.unsqueeze(1) > 0).type(torch.cuda.FloatTensor)
            centers = mask * centers + (1 - mask) * self.init_centers

        dist2center, labels = [], []
        start = 0
        count = 0
        for N in range(num_split):
            cur_len = min(self.max_len, num_samples - start)
            cur_feature = feature.narrow(0, start, cur_len)
            cur_dist2center, cur_labels = self.assign_labels(cur_feature)

            labels_onehot = to_onehot(cur_labels, self.num_classes)
            count += torch.sum(labels_onehot, dim=0)

            dist2center += [cur_dist2center]
            labels += [cur_labels]
            start += cur_len

        self.samples['label'] = torch.cat(labels, dim=0)
        self.samples['dist2center'] = torch.cat(dist2center, dim=0)

        cluster2label = self.align_centers()
        # reorder the centers
        self.centers = self.centers[cluster2label, :]
        # re-label the data according to the index
        num_samples = len(self.samples['feature'])
        for k in range(num_samples):
            self.samples['label'][k] = cluster2label[self.samples['label']
                                                     [k]].item()

        self.center_change = torch.mean(self.Dist.get_dist(self.centers, \
                                                           self.init_centers))

        for i in range(num_samples):
            self.path2label[self.samples['data']
                            [i]] = self.samples['label'][i].item()

        del self.samples['feature']