예제 #1
0
    def all_data(self, partition_proportions=None, seed=None):
        if not self._loaded_images:
            self.load_all_images()
            while not self.check_loaded_images(600):
                import time

                time.sleep(5)
        data, targets = [], []
        for k, c in enumerate(sorted(self._loaded_images)):
            data += list(self._loaded_images[c].values())
            targets += [k] * 600
        if self.info["one_hot_enc"]:
            targets = dl.to_one_hot_enc(targets,
                                        dimension=len(self._loaded_images))
        _dts = [
            dl.Dataset(data=np.stack(data),
                       target=np.array(targets),
                       name="MiniImagenet_full")
        ]
        if seed:
            np.random.seed(seed)
        if partition_proportions:
            _dts = redivide_data(_dts,
                                 partition_proportions=partition_proportions,
                                 shuffle=True)
        return dl.Datasets.from_list(_dts)
예제 #2
0
    def generate_datasets(self,
                          rand=None,
                          num_classes=None,
                          num_examples=None):
        rand = dl.get_rand_state(rand)

        if not num_examples:
            num_examples = self.kwargs["num_examples"]
        if not num_classes:
            num_classes = self.kwargs["num_classes"]

        clss = self._loaded_images if self._loaded_images else self.info[
            "classes"]

        random_classes = rand.choice(list(clss.keys()),
                                     size=(num_classes, ),
                                     replace=False)
        rand_class_dict = {rnd: k for k, rnd in enumerate(random_classes)}

        _dts = []
        for ns in as_tuple_or_list(num_examples):
            classes = balanced_choice_wr(random_classes, ns, rand)

            all_images = {cls: list(clss[cls]) for cls in classes}
            data, targets, sample_info = [], [], []
            for c in classes:
                rand.shuffle(all_images[c])
                img_name = all_images[c][0]
                all_images[c].remove(img_name)
                sample_info.append({"name": img_name, "label": c})
                data.append(clss[c][img_name])
                targets.append(rand_class_dict[c])

            if self.info["one_hot_enc"]:
                targets = dl.to_one_hot_enc(targets, dimension=num_classes)

            _dts.append(
                dl.Dataset(
                    data=np.array(np.stack(data)),
                    target=targets,
                    sample_info=sample_info,
                    info={"all_classes": random_classes},
                ))
        return dl.Datasets.from_list(_dts)