예제 #1
0
    def Indexes_of_inliers(self, Keypoints, Descriptors, buffersize):
        res = faiss.StandardGpuResources()
        nlist = 100
        quantizer = faiss.IndexFlatL2(256)
        index = faiss.IndexIVFFlat(quantizer, 256, nlist)

        gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index)

        gpu_index_flat.train(
            clustering.preprocess_features(Descriptors[:buffersize]))
        gpu_index_flat.add(
            clustering.preprocess_features(Descriptors[:buffersize]))

        #we process the descriptors in batches of 10000 vectors
        rg = np.linspace(0,
                         len(Descriptors),
                         math.ceil(len(Descriptors) / 10000) + 1,
                         dtype=int)
        keypoints_outlier_score = np.zeros(len(Keypoints))
        for i in range(len(rg) - 1):
            descr = clustering.preprocess_features(Descriptors[rg[i]:rg[i +
                                                                        1], :])
            distance_to_closest_points, _ = gpu_index_flat.search(descr, 100)
            outlierscore = np.median(distance_to_closest_points, axis=1)
            keypoints_outlier_score[rg[i]:rg[i + 1]] = outlierscore

        inliers = keypoints_outlier_score.copy()
        inliers = np.sort(inliers)

        threshold = inliers[int(
            (1 - self.remove_superpoint_outliers_percentage) *
            (len(inliers) - 1))]
        inliers = keypoints_outlier_score < threshold
        return inliers
    def GetThresholdsPerCluster(self, Descriptors):

        rg = np.linspace(0,
                         len(Descriptors),
                         math.ceil(len(Descriptors) / 10000) + 1,
                         dtype=int)
        distance_to_centroid_per_cluster = list(
            [[] for i in range(self.number_of_clusters)])

        for i in range(len(rg) - 1):
            descriptors = clustering.preprocess_features(
                Descriptors[rg[i]:rg[i + 1], :][rg[i]:rg[i + 1]])
            distancesFromCenter, clustering_assingments = self.KmeansClustering.index.search(
                descriptors, 1)
            for point in range(len(clustering_assingments)):
                distance_to_centroid_per_cluster[int(
                    clustering_assingments[point])].append(
                        distancesFromCenter[point][0])

        thresholds = np.zeros(self.number_of_clusters)

        for i in range(self.number_of_clusters):
            if (len(distance_to_centroid_per_cluster[i]) == 0):
                thresholds[i] = 0
            else:
                thresholds[i] = np.average(
                    np.array(distance_to_centroid_per_cluster[i])) + np.std(
                        distance_to_centroid_per_cluster[i])

        return thresholds
예제 #3
0
    def cluster(self,
                args,
                features,
                dataloader,
                num_imgs,
                model,
                proc_feat=False,
                verbose=False):
        """Performs k-means clustering.
        Args:
            x_data (np.array N * dim): data to cluster
    """

        # already vectorised
        # need to use pca_mat here unlike in clustering, because inference data
        # != training data for the clusterer
        if proc_feat:
            features, pca_mat = preprocess_features(features)
            #features, eigvals, eigvecs = preprocess_features_pytorch(features)
        else:
            pca_mat = None
            #eigvals, eigvecs = None, None

        # cluster the features and perform inference on spatially uncollapsed
        # dataset

        pseudolabelled_x, loss, centroids = run_kmeans(args, features, self.k,
                                                       dataloader, num_imgs,
                                                       model, pca_mat, verbose)

        # no need to store masks, reloaded in dataloader later
        self.centroids = centroids
        self.pseudolabelled_x = pseudolabelled_x

        return loss
예제 #4
0
    def Indexes_of_BackgroundPoints(self, Keypoints, Descriptors,
                                    keypoint_indexes):
        backgroundpoitnsIndex = Keypoints[:, 2] == -1
        insideboxPoitnsIndex = Keypoints[:, 2] == 1

        backgroundDescriptors = clustering.preprocess_features(
            Descriptors[:500000][[backgroundpoitnsIndex[:500000]]])

        insideboxDescriptors = clustering.preprocess_features(
            Descriptors[:500000][[insideboxPoitnsIndex[:500000]]])

        number_of_insideClusters = 100
        number_of_outsideClusters = 250
        backgroundclustering = clustering.Kmeans(number_of_outsideClusters,
                                                 centroids=None)
        insideboxclustering = clustering.Kmeans(number_of_insideClusters,
                                                centroids=None)

        backgroundclustering.cluster(backgroundDescriptors, verbose=False)
        insideboxclustering.cluster(insideboxDescriptors, verbose=False)

        foregroundpointindex = np.zeros(len(Keypoints)) == -1
        for imagename in keypoint_indexes:
            start, end = keypoint_indexes[imagename]
            keypoints = Keypoints[start:end, :]
            descriptors = Descriptors[start:end, :]

            distanceinside, Iinside = insideboxclustering.index.search(
                clustering.preprocess_features(descriptors), 1)
            distanceoutside, Ioutside = backgroundclustering.index.search(
                clustering.preprocess_features(descriptors), 1)

            points_to_keep = (distanceinside < distanceoutside).reshape(-1)
            points_to_keep = np.logical_and(points_to_keep, keypoints[:,
                                                                      2] == 1)
            foregroundpointindex[start:end] = points_to_keep

        return foregroundpointindex
    def cluster(self, data1, data2, verbose=False):
        """Performs k-means clustering.
            Args:
                x_data (np.array N * dim): data to cluster
        """
        end = time.time()

        # PCA-reducing, whitening and L2-normalization
        if data1.shape[1] >= 256:
            xb = clustering.preprocess_features(data1)
            yb = clustering.preprocess_features(data2)
        else:
            xb = data1.astype('float32')
            row_sums = np.linalg.norm(xb, axis=1)
            xb = xb / row_sums[:, np.newaxis]

            yb = data2.astype('float32')
            row_sums = np.linalg.norm(yb, axis=1)
            yb = yb / row_sums[:, np.newaxis]
        # cluster the data
        I, J, loss = new_run_kmeans(xb, yb, self.k, verbose)

        self.feature = xb
        self.I = I
        self.J = J

        print(len(I), len(J))

        self.images_lists = [[] for i in range(self.k)]
        for i in range(len(data1)):
            self.images_lists[I[i]].append(i)

        self.spe_number = self.cnt_strong_frame(self.J)

        if verbose:
            print('k-means time: {0:.0f} s'.format(time.time() - end))

        return loss
예제 #6
0
def get_means_and_variances(dc, features, args):
    m = []
    cv = []
    v = []
    for i in range(args.nmb_cluster):
        feats, _ = clustering.preprocess_features(features[dc.images_lists[i]], mat=dc.mat)
        mm = feats.mean(0)
        xx = feats - mm

        cov = (xx.transpose() @ xx) / len(dc.images_lists[i])
        m.append(mm )
        cv.append(cov)

        v.append((((xx)**2).sum(-1) ** 0.5).mean())

    return m, cv, v
    def Update_pseudoLabels(self, dataloader, oldkeypoints=None):

        LogText(f"Clustering stage for iteration {self.iterations}",
                self.experiment_name, self.log_path)
        self.model.eval()

        imagesize = 256
        heatmapsize = 64
        numberoffeatures = 256
        buffersize = 500000
        # allocation of 2 buffers for temporal storing of keypoints and descriptors.
        Keypoint_buffer = torch.zeros(buffersize, 3)
        Descriptor__buffer = torch.zeros(buffersize, numberoffeatures)

        # arrays on which we save buffer content periodically. Corresponding files are temporal and
        # will be deleted after the completion of the process
        CreateFileArray(
            str(
                GetCheckPointsPath(self.experiment_name, self.log_path) /
                'keypoints'), 3)
        CreateFileArray(
            str(
                GetCheckPointsPath(self.experiment_name, self.log_path) /
                'descriptors'), numberoffeatures)

        # intermediate variables
        first_index = 0
        last_index = 0
        buffer_first_index = 0
        buffer_last_index = 0
        keypoint_indexes = {}

        pointsperimage = 0
        LogText(f"Inference of keypoints and descriptors begins",
                self.experiment_name, self.log_path)
        for i_batch, sample in enumerate(dataloader):
            input = Cuda(sample['image'])
            names = sample['filename']

            with torch.no_grad():
                output = self.model.forward(input)
            outputHeatmap = output[0]
            descriptors_volume = output[1]

            batch_keypoints = GetBatchMultipleHeatmap(
                outputHeatmap, self.confidence_thres_FAN)

            for i in range(input.size(0)):

                indexes = batch_keypoints[:, 0] == i
                sample_keypoints = batch_keypoints[indexes, 1:][:, :3]

                pointsperimage += len(sample_keypoints)
                if (oldkeypoints is not None):
                    if (names[i] in oldkeypoints):
                        keypoints_previous_round = Cuda(
                            torch.from_numpy(
                                oldkeypoints[names[i]].copy())).float()
                        sample_keypoints = MergePoints(
                            sample_keypoints, keypoints_previous_round)

                descriptors = GetDescriptors(descriptors_volume[i],
                                             sample_keypoints[:, :2],
                                             heatmapsize, heatmapsize)

                numofpoints = sample_keypoints.shape[0]
                last_index += numofpoints
                buffer_last_index += numofpoints

                Keypoint_buffer[buffer_first_index:buffer_last_index, :
                                2] = sample_keypoints.cpu()[:, :2]
                Descriptor__buffer[
                    buffer_first_index:buffer_last_index, :] = descriptors

                keypoint_indexes[names[i]] = [first_index, last_index]
                first_index += numofpoints
                buffer_first_index += numofpoints

            # periodically we store the buffer in file
            if buffer_last_index > int(buffersize * 0.8):
                AppendFileArray(
                    np.array(Keypoint_buffer[:buffer_last_index]),
                    str(
                        GetCheckPointsPath(self.experiment_name, self.log_path)
                        / 'keypoints'))
                AppendFileArray(
                    np.array(Descriptor__buffer[:buffer_last_index]),
                    str(
                        GetCheckPointsPath(self.experiment_name, self.log_path)
                        / 'descriptors'))

                Keypoint_buffer = torch.zeros(buffersize, 3)
                Descriptor__buffer = torch.zeros(buffersize, numberoffeatures)
                buffer_first_index = 0
                buffer_last_index = 0

        # store any keypoints left on the buffers
        AppendFileArray(
            np.array(Keypoint_buffer[:buffer_last_index]),
            str(
                GetCheckPointsPath(self.experiment_name, self.log_path) /
                'keypoints'))
        AppendFileArray(
            np.array(Descriptor__buffer[:buffer_last_index]),
            str(
                GetCheckPointsPath(self.experiment_name, self.log_path) /
                'descriptors'))

        # load handlers to the Keypoints and Descriptor files
        Descriptors, fileHandler1 = OpenreadFileArray(
            str(
                GetCheckPointsPath(self.experiment_name, self.log_path) /
                'descriptors'))
        Keypoints, fileHandler2 = OpenreadFileArray(
            str(
                GetCheckPointsPath(self.experiment_name, self.log_path) /
                'keypoints'))
        Keypoints = Keypoints[:, :]
        LogText(
            f"Keypoints Detected per image Only detector {pointsperimage / len(keypoint_indexes)}",
            self.experiment_name, self.log_path)
        LogText(f"Inference of keypoints and descriptors completed",
                self.experiment_name, self.log_path)
        LogText(
            f"Keypoints Detected per image {len(Keypoints)/len(keypoint_indexes)}",
            self.experiment_name, self.log_path)

        # we use a subset of all the descriptors for clustering based on the recomendation of the Faiss repository
        numberOfPointsForClustering = 500000

        descriptors = clustering.preprocess_features(
            Descriptors[:numberOfPointsForClustering])
        _, self.centroid = self.KmeansClustering.cluster(descriptors,
                                                         verbose=False)

        self.KmeansClustering.clus.nredo = 1

        thresholds = self.GetThresholdsPerCluster(Descriptors)

        Image_Keypoints = {}

        averagepointsperimage = 0

        for image in keypoint_indexes:
            start, end = keypoint_indexes[image]
            keypoints = Keypoints[start:end, :]

            image_descriptors = clustering.preprocess_features(
                Descriptors[start:end])

            # calculate distance of each keypoints to each centroid
            distanceMatrix, clustering_assignments = self.KmeansClustering.index.search(
                image_descriptors, self.number_of_clusters)

            distanceMatrix = np.take_along_axis(
                distanceMatrix, np.argsort(clustering_assignments), axis=-1)

            # assign keypoints to centroids using the Hungarian algorithm. This ensures that each
            # image has at most one instance of each cluster
            keypointIndex, clusterAssignment = linear_sum_assignment(
                distanceMatrix)

            tempKeypoints = np.zeros((len(keypointIndex), 3))
            tempKeypoints = keypoints[keypointIndex]

            clusterAssignmentDistance = distanceMatrix[keypointIndex,
                                                       clusterAssignment]

            clusterstokeep = np.zeros(len(clusterAssignmentDistance))
            clusterstokeep = clusterstokeep == 1

            # keep only points that lie in their below a cluster specific theshold
            clusterstokeep[clusterAssignmentDistance <
                           thresholds[clusterAssignment]] = True

            tempKeypoints[:, 2] = clusterAssignment

            Image_Keypoints[image] = tempKeypoints[clusterstokeep]

            averagepointsperimage += sum(clusterstokeep)

        #initialise centroids for next clustering round
        self.KmeansClustering = clustering.Kmeans(self.number_of_clusters,
                                                  self.centroid)
        LogText(
            f"Keypoints Detected per image {averagepointsperimage/len(Image_Keypoints)}",
            self.experiment_name, self.log_path)

        self.save_keypoints(Image_Keypoints,
                            f'UpdatedKeypoints{self.iterations}.pickle')
        ClosereadFileArray(
            fileHandler1,
            str(
                GetCheckPointsPath(self.experiment_name, self.log_path) /
                'keypoints'))
        ClosereadFileArray(
            fileHandler2,
            str(
                GetCheckPointsPath(self.experiment_name, self.log_path) /
                'descriptors'))
        LogText(f"Clustering stage completed", self.experiment_name,
                self.log_path)
        return Image_Keypoints
예제 #8
0
    def CreateInitialPseudoGroundtruth(self, dataloader):

        LogText(f"Extraction of initial Superpoint pseudo groundtruth",
                self.experiment_name, self.log_path)

        imagesize = 256
        heatmapsize = 64
        numberoffeatures = 256
        buffersize = 500000

        #allocation of 2 buffers for temporal storing of keypoints and descriptors.
        Keypoint_buffer = torch.zeros(buffersize, 3)
        Descriptor__buffer = torch.zeros(buffersize, numberoffeatures)

        #arrays on which we save buffer content periodically. Corresponding files are temporal and
        #will be deleted after the completion of the process
        CreateFileArray(
            str(
                GetCheckPointsPath(self.experiment_name, self.log_path) /
                'keypoints'), 3)
        CreateFileArray(
            str(
                GetCheckPointsPath(self.experiment_name, self.log_path) /
                'descriptors'), numberoffeatures)

        #intermediate variables
        first_index = 0
        last_index = 0
        buffer_first_index = 0
        buffer_last_index = 0
        keypoint_indexes = {}

        LogText(f"Inference of Keypoints begins", self.experiment_name,
                self.log_path)
        for i_batch, sample in enumerate(dataloader):
            input = Cuda(sample['image_gray'])
            names = sample['filename']
            bsize = input.size(0)

            if (self.UseScales):
                input = input.view(-1, 1, input.shape[2], input.shape[3])

            with torch.no_grad():
                detectorOutput, descriptorOutput = self.GetSuperpointOutput(
                    input)

            if (self.UseScales):
                detectorOutput = detectorOutput.view(bsize, -1,
                                                     detectorOutput.shape[2],
                                                     detectorOutput.shape[3])
                input = input.view(bsize, -1, input.shape[2], input.shape[3])
                descriptorOutput = descriptorOutput.view(
                    bsize, -1, descriptorOutput.size(1),
                    descriptorOutput.size(2), descriptorOutput.size(3))[:, 0]
            for i in range(0, bsize):

                keypoints = self.GetPoints(detectorOutput[i].unsqueeze(0),
                                           self.confidence_thres_superpoint,
                                           self.nms_thres_superpoint)

                if (self.RemoveBackgroundClusters):
                    bounding_box = sample['bounding_box'][i]
                    pointsinbox = torch.ones(len(keypoints))
                    pointsinbox[(keypoints[:, 0] < int(bounding_box[0]))] = -1
                    pointsinbox[(keypoints[:, 1] < int(bounding_box[1]))] = -1
                    pointsinbox[(keypoints[:, 0] > int(bounding_box[2]))] = -1
                    pointsinbox[(keypoints[:, 1] > int(bounding_box[3]))] = -1

                elif (self.use_box):
                    bounding_box = sample['bounding_box'][i]
                    pointsinbox = torch.ones(len(keypoints))
                    pointsinbox[(keypoints[:, 0] < int(bounding_box[0]))] = -1
                    pointsinbox[(keypoints[:, 1] < int(bounding_box[1]))] = -1
                    pointsinbox[(keypoints[:, 0] > int(bounding_box[2]))] = -1
                    pointsinbox[(keypoints[:, 1] > int(bounding_box[3]))] = -1
                    keypoints = keypoints[pointsinbox == 1]

                descriptors = GetDescriptors(descriptorOutput[i], keypoints,
                                             input.shape[3], input.shape[2])

                #scale image keypoints to FAN resolution
                keypoints = dataloader.dataset.keypointsToFANResolution(
                    dataloader.dataset, names[i], keypoints)

                keypoints = ((heatmapsize / imagesize) * keypoints).round()

                last_index += len(keypoints)
                buffer_last_index += len(keypoints)

                Keypoint_buffer[
                    buffer_first_index:buffer_last_index, :2] = keypoints
                Descriptor__buffer[
                    buffer_first_index:buffer_last_index] = descriptors

                if (self.RemoveBackgroundClusters):
                    Keypoint_buffer[buffer_first_index:buffer_last_index,
                                    2] = pointsinbox

                keypoint_indexes[names[i]] = [first_index, last_index]
                first_index += len(keypoints)
                buffer_first_index += len(keypoints)

            #periodically we store the buffer in file
            if buffer_last_index > int(buffersize * 0.8):
                AppendFileArray(
                    np.array(Keypoint_buffer[:buffer_last_index]),
                    str(
                        GetCheckPointsPath(self.experiment_name, self.log_path)
                        / 'keypoints'))
                AppendFileArray(
                    np.array(Descriptor__buffer[:buffer_last_index]),
                    str(
                        GetCheckPointsPath(self.experiment_name, self.log_path)
                        / 'descriptors'))

                Keypoint_buffer = torch.zeros(buffersize, 3)
                Descriptor__buffer = torch.zeros(buffersize, numberoffeatures)
                buffer_first_index = 0
                buffer_last_index = 0

        LogText(f"Inference of Keypoints completed", self.experiment_name,
                self.log_path)
        #store any keypoints left on the buffers
        AppendFileArray(
            np.array(Keypoint_buffer[:buffer_last_index]),
            str(
                GetCheckPointsPath(self.experiment_name, self.log_path) /
                'keypoints'))
        AppendFileArray(
            np.array(Descriptor__buffer[:buffer_last_index]),
            str(
                GetCheckPointsPath(self.experiment_name, self.log_path) /
                'descriptors'))

        #load handlers to the Keypoints and Descriptor files
        Descriptors, fileHandler1 = OpenreadFileArray(
            str(
                GetCheckPointsPath(self.experiment_name, self.log_path) /
                'descriptors'))
        Keypoints, fileHandler2 = OpenreadFileArray(
            str(
                GetCheckPointsPath(self.experiment_name, self.log_path) /
                'keypoints'))
        Keypoints = Keypoints[:, :]
        LogText(
            f"Keypoints Detected per image {len(Keypoints)/len(keypoint_indexes)}",
            self.experiment_name, self.log_path)

        #perform outlier detection
        inliersindexes = np.ones(len(Keypoints)) == 1
        if (self.remove_superpoint_outliers_percentage > 0):
            inliersindexes = self.Indexes_of_inliers(Keypoints, Descriptors,
                                                     buffersize)

        #extend outliers with background points for constant background datasets
        if (self.RemoveBackgroundClusters):
            foregroundpointindex = self.Indexes_of_BackgroundPoints(
                Keypoints, Descriptors, keypoint_indexes)
            inliersindexes = np.logical_and(inliersindexes,
                                            foregroundpointindex)

        LogText(
            f"Keypoints Detected per image(filtering) {sum(inliersindexes) / len(keypoint_indexes)}",
            self.experiment_name, self.log_path)
        #we use a subset of all the descriptors for clustering based on the recomendation of the Faiss repository
        numberOfPointsForClustering = 500000

        LogText(f"Clustering of keypoints", self.experiment_name,
                self.log_path)
        #clustering of superpoint features
        KmeansClustering = clustering.Kmeans(self.number_of_clusters,
                                             centroids=None)
        descriptors = clustering.preprocess_features(
            Descriptors[:numberOfPointsForClustering][
                inliersindexes[:numberOfPointsForClustering]])
        KmeansClustering.cluster(descriptors, verbose=False)

        thresholds = self.GetThresholdsPerCluster(inliersindexes, Descriptors,
                                                  KmeansClustering)

        Image_Keypoints = {}
        averagepointsperimage = 0
        for image in keypoint_indexes:
            start, end = keypoint_indexes[image]
            inliersinimage = inliersindexes[start:end]
            keypoints = Keypoints[start:end, :]

            inliersinimage[np.sum(keypoints[:, :2] < 0, 1) > 0] = False
            inliersinimage[np.sum(keypoints[:, :2] > 64, 1) > 0] = False

            keypoints = keypoints[inliersinimage]

            image_descriptors = clustering.preprocess_features(
                Descriptors[start:end])
            image_descriptors = image_descriptors[inliersinimage]

            #calculate distance of each keypoints to each centroid
            distanceMatrix, clustering_assignments = KmeansClustering.index.search(
                image_descriptors, self.number_of_clusters)

            distanceMatrix = np.take_along_axis(
                distanceMatrix, np.argsort(clustering_assignments), axis=-1)

            #assign keypoints to centroids using the Hungarian algorithm. This ensures that each
            #image has at most one instance of each cluster
            keypointIndex, clusterAssignment = linear_sum_assignment(
                distanceMatrix)

            tempKeypoints = keypoints[keypointIndex]

            clusterAssignmentDistance = distanceMatrix[keypointIndex,
                                                       clusterAssignment]

            clusterstokeep = np.zeros(len(clusterAssignmentDistance))
            clusterstokeep = clusterstokeep == 1

            # keep only points that lie in their below a cluster specific theshold
            clusterstokeep[clusterAssignmentDistance <
                           thresholds[clusterAssignment]] = True

            tempKeypoints[:, 2] = clusterAssignment

            Image_Keypoints[image] = tempKeypoints[clusterstokeep]
            averagepointsperimage += sum(clusterstokeep)

        LogText(
            f"Keypoints Detected per image(clusteringAssignment) {averagepointsperimage / len(Image_Keypoints)}",
            self.experiment_name, self.log_path)
        ClosereadFileArray(
            fileHandler1,
            str(
                GetCheckPointsPath(self.experiment_name, self.log_path) /
                'keypoints'))
        ClosereadFileArray(
            fileHandler2,
            str(
                GetCheckPointsPath(self.experiment_name, self.log_path) /
                'descriptors'))
        self.save_keypoints(Image_Keypoints, "SuperPointKeypoints.pickle")
        LogText(f"Extraction of Initial pseudoGroundtruth completed",
                self.experiment_name, self.log_path)
        return Image_Keypoints
예제 #9
0
def main():
    global args
    args = parser.parse_args()

    # create repo
    repo = os.path.join(args.exp, 'conv' + str(args.conv))
    if not os.path.isdir(repo):
        os.makedirs(repo)

    # build model
    model = load_model(args.model)
    model.cuda()
    for params in model.parameters():
        params.requires_grad = False
    model.eval()

    #load data
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # tra = [transforms.Resize(256),
    #        transforms.CenterCrop(224),
    #        transforms.ToTensor(),
    #        normalize]
    if args.dataset == 'miniimagenet':
        tra = [transforms.CenterCrop(64), transforms.ToTensor(), normalize]
    elif args.dataset == 'celeba':
        # tra = [transforms.Resize(64, interpolation=1),  # 1 = LANCZOS
        #        transforms.ToTensor(),
        #        normalize]
        tra = [transforms.CenterCrop(64), transforms.ToTensor(), normalize]
    for split in ['train', 'val', 'test']:
        # dataset
        dataset = datasets.ImageFolder(os.path.join(args.data, split),
                                       transform=transforms.Compose(tra))
        # ipdb.set_trace()
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=256,
                                                 num_workers=args.workers)

        # remove head
        model.top_layer = None
        model.classifier = nn.Sequential(
            *list(model.classifier.children())[:-1])

        # compute features
        features, labels = compute_features(dataloader, model, len(dataset))
        if not args.raw:
            features = preprocess_features(features, pca=256)
        # invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.],
        #                                                     std=[1 / 0.229, 1 / 0.224, 1 / 0.225]),
        #                                transforms.Normalize(mean=[-0.485, -0.456, -0.406],
        #                                                     std=[1., 1., 1.]),
        #                                ])
        origs = dataset.imgs
        images = np.stack(
            [np.array(Image.open(filename)) for filename, label in origs],
            axis=0)
        if args.dataset == 'celeba':
            labels = np.array([
                int(filename[filename.rfind('/') + 1:filename.find('.jpg')])
                for filename, _ in origs
            ])

        if args.raw:
            np.savez(os.path.join(
                args.exp, '%s_%d_%s_raw.npz' % (args.dataset, 256, split)),
                     X=images,
                     Y=labels,
                     Z=features)
        else:
            np.savez(os.path.join(args.exp,
                                  '%s_%d_%s.npz' % (args.dataset, 256, split)),
                     X=images,
                     Y=labels,
                     Z=features)