示例#1
0
文件: imagenet.py 项目: zialiu/BF3S
    def __init__(self, split="train", size256=False, transform=None):

        dataset_name = "ImageNet256" if size256 else "ImageNet"
        assert (split
                in ("train", "val")) or (split.find("train_subset") != -1)
        self.split = split
        self.name = f"{dataset_name}_Split_" + self.split

        data_dir = _IMAGENET256_DATASET_DIR if size256 else _IMAGENET_DATASET_DIR
        print(f"==> Loading {dataset_name} dataset - split {self.split}")
        print(f"==> {dataset_name} directory: {data_dir}")

        self.transform = transform
        print(f"==> transform: {self.transform}")
        train_dir = os.path.join(data_dir, "train")
        val_dir = os.path.join(data_dir, "val")
        split_dir = train_dir if (self.split.find("train") != -1) else val_dir
        self.data = datasets.ImageFolder(split_dir, self.transform)
        self.labels = [item[1] for item in self.data.imgs]

        if self.split.find("train_subset") != -1:
            subsetK = int(self.split[len("train_subset"):])
            assert subsetK > 0
            self.split = "train"

            label2ind = utils.build_label_index(self.data.targets)
            all_indices = []
            for label, img_indices in label2ind.items():
                assert len(img_indices) >= subsetK
                all_indices += img_indices[:subsetK]

            self.data.imgs = [self.data.imgs[idx] for idx in all_indices]
            self.data.samples = [self.data.samples[idx] for idx in all_indices]
            self.data.targets = [self.data.targets[idx] for idx in all_indices]
            self.labels = [self.labels[idx] for idx in all_indices]
示例#2
0
文件: imagenet.py 项目: zialiu/BF3S
    def __init__(self, data_dir, image_split="train", phase="train"):
        # data_dir: path to the directory with the saved ImageNet features.
        # image_split: the image split of the ImageNet that will be loaded.
        # phase: whether the dataset will be used for training, validating, or
        # testing the few-shot model model.
        assert image_split in ("train", "val")
        assert phase in ("train", "val", "test")

        self.phase = phase
        self.image_split = image_split
        self.name = (f"ImageNetLowShotFeatures_ImageSplit_{self.image_split}"
                     f"_Phase_{self.phase}")

        dataset_file = os.path.join(data_dir,
                                    "ImageNet_" + self.image_split + ".h5")
        self.data_file = h5py.File(dataset_file, "r")
        self.count = self.data_file["count"][0]
        self.features = self.data_file["all_features"][...]
        self.labels = self.data_file["all_labels"][:self.count].tolist()

        # ***********************************************************************
        data_tmp = datasets.ImageFolder(
            os.path.join(_IMAGENET_DATASET_DIR, "train"), None)
        (
            base_classes,
            base_classes_val,
            base_classes_test,
            novel_classes_val,
            novel_classes_test,
        ) = load_ImageNet_fewshot_split(data_tmp.classes)
        # ***********************************************************************

        self.label2ind = utils.build_label_index(self.labels)
        self.labelIds = sorted(self.label2ind.keys())
        self.num_cats = len(self.labelIds)
        assert self.num_cats == 1000

        self.labelIds_base = base_classes
        self.num_cats_base = len(self.labelIds_base)

        if self.phase == "val" or self.phase == "test":
            self.labelIds_novel = (novel_classes_val if (self.phase == "val")
                                   else novel_classes_test)
            self.num_cats_novel = len(self.labelIds_novel)

            intersection = set(self.labelIds_base) & set(self.labelIds_novel)
            assert len(intersection) == 0
            self.base_classes_eval_split = (base_classes_val if
                                            (self.phase
                                             == "val") else base_classes_test)
示例#3
0
文件: imagenet.py 项目: zialiu/BF3S
    def __init__(self,
                 phase="train",
                 split="train",
                 do_not_use_random_transf=False):

        assert phase in ("train", "test", "val")
        assert split in ("train", "val")

        use_aug = (phase == "train") and (do_not_use_random_transf == False)

        super().__init__(split=split,
                         use_geometric_aug=use_aug,
                         use_color_aug=use_aug)

        self.phase = phase
        self.split = split
        self.name = "ImageNetLowShot_Phase_" + phase + "_Split_" + split
        print(f"==> Loading ImageNet few-shot benchmark - phase {phase}")

        # ***********************************************************************
        (
            base_classes,
            _,
            _,
            novel_classes_val,
            novel_classes_test,
        ) = load_ImageNet_fewshot_split(self.data.classes)
        # ***********************************************************************

        self.label2ind = utils.build_label_index(self.labels)
        self.labelIds = sorted(self.label2ind.keys())
        self.num_cats = len(self.labelIds)
        assert self.num_cats == 1000

        self.labelIds_base = base_classes
        self.num_cats_base = len(self.labelIds_base)
        if self.phase == "val" or self.phase == "test":
            self.labelIds_novel = (novel_classes_val if (self.phase == "val")
                                   else novel_classes_test)
            self.num_cats_novel = len(self.labelIds_novel)

            intersection = set(self.labelIds_base) & set(self.labelIds_novel)
            assert len(intersection) == 0
示例#4
0
    def __init__(self, phase="train", do_not_use_random_transf=False):
        assert phase in ("train", "val", "test")
        self.phase = phase
        self.name = "CIFAR100FewShot_" + phase

        normalize = transforms.Normalize(mean=_CIFAR_MEAN_PIXEL,
                                         std=_CIFAR_STD_PIXEL)

        if (self.phase == "test"
                or self.phase == "val") or (do_not_use_random_transf == True):
            self.transform = transforms.Compose(
                [lambda x: np.asarray(x),
                 transforms.ToTensor(), normalize])
        else:
            self.transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                lambda x: np.asarray(x),
                transforms.ToTensor(),
                normalize,
            ])

        cifar100_metadata_path = os.path.join(_CIFAR_DATASET_DIR,
                                              "cifar-100-python", "meta")
        all_category_names = pickle.load(open(cifar100_metadata_path,
                                              "rb"))["fine_label_names"]

        def read_categories(filename):
            with open(filename) as f:
                categories = f.readlines()
            categories = [x.strip() for x in categories]
            return categories

        def get_label_ids(category_names):
            label_ids = [
                all_category_names.index(cname) for cname in category_names
            ]
            return label_ids

        train_category_names = read_categories(
            os.path.join(_CIFAR_CATEGORY_SPLITS_DIR, "train.txt"))
        val_category_names = read_categories(
            os.path.join(_CIFAR_CATEGORY_SPLITS_DIR, "val.txt"))
        test_category_names = read_categories(
            os.path.join(_CIFAR_CATEGORY_SPLITS_DIR, "test.txt"))

        train_category_ids = get_label_ids(train_category_names)
        val_category_ids = get_label_ids(val_category_names)
        test_category_ids = get_label_ids(test_category_names)

        print(f"Loading CIFAR-100 FewShot dataset - phase {phase}")

        if self.phase == "train":
            self.data_train = datasets.__dict__["CIFAR100"](
                _CIFAR_DATASET_DIR,
                train=True,
                download=True,
                transform=self.transform)
            self.labels = self.data_train.targets
            self.images = self.data_train.data

            self.label2ind = utils.build_label_index(self.labels)
            self.labelIds = sorted(self.label2ind.keys())
            self.num_cats = len(self.labelIds)
            self.labelIds_base = train_category_ids
            self.num_cats_base = len(self.labelIds_base)

        elif self.phase == "val" or self.phase == "test":
            self.data_train = datasets.__dict__["CIFAR100"](
                _CIFAR_DATASET_DIR,
                train=True,
                download=True,
                transform=self.transform)
            labels_train = self.data_train.targets
            images_train = self.data_train.data
            label2ind_train = utils.build_label_index(labels_train)
            self.labelIds_novel = (val_category_ids if (self.phase == "val")
                                   else test_category_ids)

            labels_novel = []
            images_novel = []
            for label_id in self.labelIds_novel:
                indices = label2ind_train[label_id]
                images_novel.append(images_train[indices])
                labels_novel += [labels_train[index] for index in indices]
            images_novel = np.concatenate(images_novel, axis=0)
            assert images_novel.shape[0] == len(labels_novel)

            self.data_test = datasets.__dict__["CIFAR100"](
                _CIFAR_DATASET_DIR,
                train=False,
                download=True,
                transform=self.transform)
            labels_test = self.data_test.targets
            images_test = self.data_test.data
            label2ind_test = utils.build_label_index(labels_test)
            self.labelIds_base = train_category_ids

            labels_base = []
            images_base = []
            for label_id in self.labelIds_base:
                indices = label2ind_test[label_id]
                images_base.append(images_test[indices])
                labels_base += [labels_test[index] for index in indices]
            images_base = np.concatenate(images_base, axis=0)
            assert images_base.shape[0] == len(labels_base)

            self.images = np.concatenate([images_base, images_novel], axis=0)
            self.labels = labels_base + labels_novel
            assert self.images.shape[0] == len(self.labels)

            self.num_cats_base = len(self.labelIds_base)
            self.num_cats_novel = len(self.labelIds_novel)
            intersection = set(self.labelIds_base) & set(self.labelIds_novel)
            assert len(intersection) == 0

            self.label2ind_base = utils.build_label_index(labels_base)
            assert len(self.label2ind_base) == self.num_cats_base

            self.label2ind_novel = utils.build_label_index(labels_novel)
            assert len(self.label2ind_novel) == self.num_cats_novel

            self.label2ind = utils.build_label_index(self.labels)
            assert len(
                self.label2ind) == self.num_cats_novel + self.num_cats_base
            self.labelIds = sorted(self.label2ind.keys())
            self.num_cats = len(self.labelIds)
        else:
            raise ValueError(f"Not valid phase {self.phase}")
示例#5
0
    def __init__(
        self,
        transform_test,
        transform_train,
        phase="train",
        load_single_file_split=False,
        file_split=None,
        do_not_use_random_transf=False,
    ):

        data_dir = _MINIIMAGENET_DATASET_DIR
        print(f"==> Download MiniImageNet dataset at {data_dir}")
        file_train_categories_train_phase = os.path.join(
            data_dir, "miniImageNet_category_split_train_phase_train.pickle"
        )
        file_train_categories_val_phase = os.path.join(
            data_dir, "miniImageNet_category_split_train_phase_val.pickle"
        )
        file_train_categories_test_phase = os.path.join(
            data_dir, "miniImageNet_category_split_train_phase_test.pickle"
        )
        file_val_categories_val_phase = os.path.join(
            data_dir, "miniImageNet_category_split_val.pickle"
        )
        file_test_categories_test_phase = os.path.join(
            data_dir, "miniImageNet_category_split_test.pickle"
        )

        self.phase = phase
        if load_single_file_split:
            assert file_split in (
                "category_split_train_phase_train",
                "category_split_train_phase_val",
                "category_split_train_phase_test",
                "category_split_val",
                "category_split_test",
            )
            self.name = "MiniImageNet_" + file_split

            print(f"==> Loading mini ImageNet dataset - phase {file_split}")

            file_to_load = os.path.join(data_dir, f"miniImageNet_{file_split}.pickle")

            data = utils.load_pickle_data(file_to_load)
            self.data = data["data"]
            self.labels = data["labels"]
            self.label2ind = utils.build_label_index(self.labels)
            self.labelIds = sorted(self.label2ind.keys())
            self.num_cats = len(self.labelIds)
        else:
            assert phase in ("train", "val", "test", "trainval") or "train_subset" in phase
            self.name = "MiniImageNet_" + phase

            print(f"Loading mini ImageNet dataset - phase {phase}")
            if self.phase == "train":
                # Loads the training classes (and their data) as base classes
                data_train = utils.load_pickle_data(file_train_categories_train_phase)
                self.data = data_train["data"]
                self.labels = data_train["labels"]

                self.label2ind = utils.build_label_index(self.labels)
                self.labelIds = sorted(self.label2ind.keys())
                self.num_cats = len(self.labelIds)
                self.labelIds_base = self.labelIds
                self.num_cats_base = len(self.labelIds_base)

            elif self.phase == "trainval":
                # Loads the training + validation classes (and their data) as
                # base classes
                data_train = utils.load_pickle_data(file_train_categories_train_phase)
                data_val = utils.load_pickle_data(file_val_categories_val_phase)
                self.data = np.concatenate([data_train["data"], data_val["data"]], axis=0)
                self.labels = data_train["labels"] + data_val["labels"]

                self.label2ind = utils.build_label_index(self.labels)
                self.labelIds = sorted(self.label2ind.keys())
                self.num_cats = len(self.labelIds)
                self.labelIds_base = self.labelIds
                self.num_cats_base = len(self.labelIds_base)

            elif self.phase.find("train_subset") != -1:
                subsetK = int(self.phase[len("train_subset") :])
                assert subsetK > 0
                # Loads the training classes as base classes. For each class it
                # loads only the `subsetK` first images.

                data_train = utils.load_pickle_data(file_train_categories_train_phase)
                label2ind = utils.build_label_index(data_train["labels"])

                all_indices = []
                for label, img_indices in label2ind.items():
                    assert len(img_indices) >= subsetK
                    all_indices += img_indices[:subsetK]

                labels_semi = [data_train["labels"][idx] for idx in all_indices]
                data_semi = data_train["data"][all_indices]

                self.data = data_semi
                self.labels = labels_semi

                self.label2ind = utils.build_label_index(self.labels)
                self.labelIds = sorted(self.label2ind.keys())
                self.num_cats = len(self.labelIds)
                self.labelIds_base = self.labelIds
                self.num_cats_base = len(self.labelIds_base)

                self.phase = "train"

            elif self.phase == "val" or self.phase == "test":
                # Uses the validation / test classes (and their data) as novel
                # as novel class data and the vaditation / test image split of
                # the training classes for the base classes.

                if self.phase == "test":
                    # load data that will be used for evaluating the recognition
                    # accuracy of the base classes.
                    data_base = utils.load_pickle_data(file_train_categories_test_phase)
                    # load data that will be use for evaluating the few-shot
                    # recogniton accuracy on the novel classes.
                    data_novel = utils.load_pickle_data(file_test_categories_test_phase)
                else:  # phase=='val'
                    # load data that will be used for evaluating the recognition
                    # accuracy of the base classes.
                    data_base = utils.load_pickle_data(file_train_categories_val_phase)
                    # load data that will be use for evaluating the few-shot
                    # recogniton accuracy on the novel classes.
                    data_novel = utils.load_pickle_data(file_val_categories_val_phase)

                self.data = np.concatenate([data_base["data"], data_novel["data"]], axis=0)
                self.labels = data_base["labels"] + data_novel["labels"]

                self.label2ind = utils.build_label_index(self.labels)
                self.labelIds = sorted(self.label2ind.keys())
                self.num_cats = len(self.labelIds)

                self.labelIds_base = utils.build_label_index(data_base["labels"]).keys()
                self.labelIds_novel = utils.build_label_index(data_novel["labels"]).keys()
                self.num_cats_base = len(self.labelIds_base)
                self.num_cats_novel = len(self.labelIds_novel)
                intersection = set(self.labelIds_base) & set(self.labelIds_novel)
                assert len(intersection) == 0
            else:
                raise ValueError(f"Not valid phase {self.phase}")

        self.transform_test = transform_test
        self.transform_train = transform_train
        if (self.phase == "test" or self.phase == "val") or (do_not_use_random_transf == True):
            self.transform = self.transform_test
        else:
            self.transform = self.transform_train