Beispiel #1
0
def test_Omniglot():
    import cv2

    root = './data/'
    dataset = Omniglot(root=root, download=True)
    #dataset = Omniglot(root=root,background=False,download=True)
    print(len(dataset))
    idx = 0

    idx_alphabet = 0
    idx_character = 0
    idx_sample = 0

    while True:

        sample = dataset[idx]
        image_path = dataset.sampleSample4Character4Alphabet(
            idx_alphabet, idx_character, idx_sample)

        img = np.array(sample[0])
        cv2.imshow('test', img)

        key = cv2.waitKey(1) & 0xFF
        if key == ord('q'):
            break
        elif key == ord('n'):
            idx += 1
            #print(sample[1])
            print(image_path)
            idx_sample += 1
            print("next image")
Beispiel #2
0
def test_OmniglotBatchedSeq():
    import cv2

    root = './data/'
    h = 240
    w = 240
    dataset = Omniglot(root=root, h=h, w=w)
    #dataset = Omniglot(root=root,background=False,download=True)
    print(len(dataset))

    batch_size = 2
    max_nbr_char = 5
    idx_alphabet = list()
    for i in range(batch_size):
        idx_alphabet.append(i)

    idx_sample = 0

    keys = ('alphabet', 'character', 'sample', 'target', 'nbrCharacter')
    seqs = list()
    for s in range(batch_size):
        seq, nbrCharacter4Task, nbrSample4Task = dataset.generateIterFewShotInputSequence(
            alphabet_idx=idx_alphabet[s], max_nbr_char=max_nbr_char)
        dseqs = {
            'alphabet': [],
            'character': [],
            'sample': [],
            'target': [],
            'nbrCharacter': []
        }
        for el in seq:
            for k in el.keys():
                dseqs[k].append(el[k])
        seqs.append(dseqs)
    seqs = [{k: [seqs[b][k] for b in range(batch_size)]} for k in keys]
    dseqs = dict()
    for d in seqs:
        dseqs.update(d)
    # key x task_idx x seq_idx
    for k in dseqs.keys():
        dseqs[k] = zip(*dseqs[k])
    print(seqs)
    raise
    batch = dataset.getBatchedSample(seqs)
Beispiel #3
0
def test_OmniglotClass():
    import cv2

    root = './data/'
    dataset = Omniglot(root=root, download=True)
    #dataset = Omniglot(root=root,background=False,download=True)
    print(len(dataset))
    idx = 0

    idx_alphabet = 0
    idx_character = 0
    idx_sample = 0

    sample = dataset.getSample(idx_alphabet, idx_character, idx_sample)
    image_path = dataset.sampleSample4Character4Alphabet(
        idx_alphabet, idx_character, idx_sample)
    changed = False

    while True:

        #sample = dataset[idx]

        img = np.array(sample[0])
        cv2.imshow('test', img)

        key = cv2.waitKey(1) & 0xFF
        if key == ord('q'):
            break
        elif key == ord('n'):
            idx += 1
            print(image_path)
            idx_sample += 1
            changed = True
        elif key == ord('a'):
            idx += 1
            print(image_path)
            idx_alphabet += 1
            idx_character = 0
            idx_sample = 0
            changed = True
        elif key == ord('c'):
            idx += 1
            print(image_path)
            idx_character += 1
            idx_sample = 0
            changed = True

        if changed:
            changed = False
            sample = dataset.getSample(idx_alphabet, idx_character, idx_sample)
            image_path = dataset.sampleSample4Character4Alphabet(
                idx_alphabet, idx_character, idx_sample)
Beispiel #4
0
    def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz):
        """
        Different from mnistNShot, the
        :param root:
        :param batchsz: task num
        :param n_way:
        :param k_shot:
        :param k_qry:
        :param imgsz:
        """

        self.resize = imgsz
        if not os.path.isfile(os.path.join(root, 'omniglot.npy')):
            # if root/data.npy does not exist, just download it
            self.x = Omniglot(root, download=True,
                              transform=transforms.Compose([lambda x: Image.open(x).convert('L'),
                                                            lambda x: x.resize((imgsz, imgsz)),
                                                            lambda x: np.reshape(x, (imgsz, imgsz, 1)),
                                                            lambda x: np.transpose(x, [2, 0, 1]),
                                                            lambda x: x/255.])
                              )

            temp = dict()  # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
            for (img, label) in self.x:
                if label in temp.keys():
                    temp[label].append(img)
                else:
                    temp[label] = [img]

            self.x = []
            for label, imgs in temp.items():  # labels info deserted , each label contains 20imgs
                self.x.append(np.array(imgs))

            # as different class may have different number of imgs
            self.x = np.array(self.x).astype(np.float)  # [[20 imgs],..., 1623 classes in total]
            # each character contains 20 imgs
            print('data shape:', self.x.shape)  # [1623, 20, 84, 84, 1]
            temp = []  # Free memory
            # save all dataset into npy file.
            np.save(os.path.join(root, 'omniglot.npy'), self.x)
            print('write into omniglot.npy.')
        else:
            # if data.npy exists, just load it.
            self.x = np.load(os.path.join(root, 'omniglot.npy'))
            print('load from omniglot.npy.')

        # [1623, 20, 84, 84, 1]
        # TODO: can not shuffle here, we must keep training and test set distinct!
        self.x_train, self.x_test = self.x[:1200], self.x[1200:]

        # self.normalization()

        self.batchsz = batchsz
        self.n_cls = self.x.shape[0]  # 1623
        self.n_way = n_way  # n way
        self.k_shot = k_shot  # k shot
        self.k_query = k_query  # k query
        assert (k_shot + k_query) <=20

        # save pointer of current read batch in total cache
        self.indexes = {"train": 0, "test": 0}
        self.datasets = {"train": self.x_train, "test": self.x_test}  # original data cached
        print("DB: train", self.x_train.shape, "test", self.x_test.shape)

        self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]),  # current epoch data cached
                               "test": self.load_data_cache(self.datasets["test"])}
Beispiel #5
0
class OmniglotNShot:

    def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz):
        """
        Different from mnistNShot, the
        :param root:
        :param batchsz: task num
        :param n_way:
        :param k_shot:
        :param k_qry:
        :param imgsz:
        """

        self.resize = imgsz
        if not os.path.isfile(os.path.join(root, 'omniglot.npy')):
            # if root/data.npy does not exist, just download it
            self.x = Omniglot(root, download=True,
                              transform=transforms.Compose([lambda x: Image.open(x).convert('L'),
                                                            lambda x: x.resize((imgsz, imgsz)),
                                                            lambda x: np.reshape(x, (imgsz, imgsz, 1)),
                                                            lambda x: np.transpose(x, [2, 0, 1]),
                                                            lambda x: x/255.])
                              )

            temp = dict()  # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
            for (img, label) in self.x:
                if label in temp.keys():
                    temp[label].append(img)
                else:
                    temp[label] = [img]

            self.x = []
            for label, imgs in temp.items():  # labels info deserted , each label contains 20imgs
                self.x.append(np.array(imgs))

            # as different class may have different number of imgs
            self.x = np.array(self.x).astype(np.float)  # [[20 imgs],..., 1623 classes in total]
            # each character contains 20 imgs
            print('data shape:', self.x.shape)  # [1623, 20, 84, 84, 1]
            temp = []  # Free memory
            # save all dataset into npy file.
            np.save(os.path.join(root, 'omniglot.npy'), self.x)
            print('write into omniglot.npy.')
        else:
            # if data.npy exists, just load it.
            self.x = np.load(os.path.join(root, 'omniglot.npy'))
            print('load from omniglot.npy.')

        # [1623, 20, 84, 84, 1]
        # TODO: can not shuffle here, we must keep training and test set distinct!
        self.x_train, self.x_test = self.x[:1200], self.x[1200:]

        # self.normalization()

        self.batchsz = batchsz
        self.n_cls = self.x.shape[0]  # 1623
        self.n_way = n_way  # n way
        self.k_shot = k_shot  # k shot
        self.k_query = k_query  # k query
        assert (k_shot + k_query) <=20

        # save pointer of current read batch in total cache
        self.indexes = {"train": 0, "test": 0}
        self.datasets = {"train": self.x_train, "test": self.x_test}  # original data cached
        print("DB: train", self.x_train.shape, "test", self.x_test.shape)

        self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]),  # current epoch data cached
                               "test": self.load_data_cache(self.datasets["test"])}

    def normalization(self):
        """
        Normalizes our data, to have a mean of 0 and sdt of 1
        """
        self.mean = np.mean(self.x_train)
        self.std = np.std(self.x_train)
        self.max = np.max(self.x_train)
        self.min = np.min(self.x_train)
        # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
        self.x_train = (self.x_train - self.mean) / self.std
        self.x_test = (self.x_test - self.mean) / self.std

        self.mean = np.mean(self.x_train)
        self.std = np.std(self.x_train)
        self.max = np.max(self.x_train)
        self.min = np.min(self.x_train)

    # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)

    def load_data_cache(self, data_pack):
        """
        Collects several batches data for N-shot learning
        :param data_pack: [cls_num, 20, 84, 84, 1]
        :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks
        """
        #  take 5 way 1 shot as example: 5 * 1
        setsz = self.k_shot * self.n_way
        querysz = self.k_query * self.n_way
        data_cache = []

        # print('preload next 50 caches of batchsz of batch.')
        for sample in range(10):  # num of episodes

            x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
            for i in range(self.batchsz):  # one batch means one set

                x_spt, y_spt, x_qry, y_qry = [], [], [], []
                selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False)

                for j, cur_class in enumerate(selected_cls):

                    selected_img = np.random.choice(20, self.k_shot + self.k_query, False)

                    # meta-training and meta-test
                    x_spt.append(data_pack[cur_class][selected_img[:self.k_shot]])
                    x_qry.append(data_pack[cur_class][selected_img[self.k_shot:]])
                    y_spt.append([j for _ in range(self.k_shot)])
                    y_qry.append([j for _ in range(self.k_query)])

                # shuffle inside a batch
                perm = np.random.permutation(self.n_way * self.k_shot)
                x_spt = np.array(x_spt).reshape(self.n_way * self.k_shot, 1, self.resize, self.resize)[perm]
                y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm]
                perm = np.random.permutation(self.n_way * self.k_query)
                x_qry = np.array(x_qry).reshape(self.n_way * self.k_query, 1, self.resize, self.resize)[perm]
                y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]

                # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84]
                x_spts.append(x_spt)
                y_spts.append(y_spt)
                x_qrys.append(x_qry)
                y_qrys.append(y_qry)


            # [b, setsz, 1, 84, 84]
            x_spts = np.array(x_spts).astype(np.float32).reshape(self.batchsz, setsz, 1, self.resize, self.resize)
            y_spts = np.array(y_spts).astype(np.int).reshape(self.batchsz, setsz)
            # [b, qrysz, 1, 84, 84]
            x_qrys = np.array(x_qrys).astype(np.float32).reshape(self.batchsz, querysz, 1, self.resize, self.resize)
            y_qrys = np.array(y_qrys).astype(np.int).reshape(self.batchsz, querysz)

            data_cache.append([x_spts, y_spts, x_qrys, y_qrys])

        return data_cache

    def next(self, mode='train'):
        """
        Gets next batch from the dataset with name.
        :param mode: The name of the splitting (one of "train", "val", "test")
        :return:
        """
        # update cache if indexes is larger cached num
        if self.indexes[mode] >= len(self.datasets_cache[mode]):
            self.indexes[mode] = 0
            self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode])

        next_batch = self.datasets_cache[mode][self.indexes[mode]]
        self.indexes[mode] += 1

        return next_batch
Beispiel #6
0
def get_data(dataset, dset_dir, image_size, model, phase, unsupervised, spc,
             batch_size, workers):
    dataset = dataset.lower()
    if dataset == 'omniglot_original':
        dset_dir = os.path.join(dset_dir, 'omniglot')
        transformations = get_transform(dataset, image_size, model)
        if phase == 'train':
            train_data = Omniglot(
                root=dset_dir,
                phase='background',
                unsupervised=unsupervised,
                spc=spc,
                pre_transform=transformations['pre_transform'],
                transform=transformations['transform'],
                post_transform=transformations['post_transform'])
            data_loader = DataLoader(train_data,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=int(workers))
        else:
            test_data = Omniglot(
                root=dset_dir,
                phase='evaluation',
                spc=20,
                pre_transform=transformations['test_transform'])
            data_loader = DataLoader(test_data,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     num_workers=int(workers))

    elif dataset == 'omniglot_aug':
        dset_dir = os.path.join(dset_dir, 'omniglot')
        transformations = get_transform(dataset, image_size, model)
        if phase == 'train':
            train_data = Omniglot(
                root=dset_dir,
                phase=phase,
                unsupervised=unsupervised,
                spc=spc,
                pre_transform=transformations['pre_transform'],
                transform=transformations['transform'],
                post_transform=transformations['post_transform'])
            data_loader = DataLoader(train_data,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=int(workers))
        else:
            test_data = Omniglot(
                root=dset_dir,
                phase=phase,
                spc=20,
                pre_transform=transformations['test_transform'])
            data_loader = DataLoader(test_data,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     num_workers=int(workers))

    elif dataset == 'miniimagenet':
        dset_dir = os.path.join(dset_dir, 'miniImagenet')
        transformations = get_transform(dataset, image_size, model)
        if phase == 'train':
            train_data = MiniImageNet(
                root=dset_dir,
                phase=phase,
                unsupervised=unsupervised,
                spc=spc,
                pre_transform=transformations['pre_transform'],
                transform=transformations['transform'],
                post_transform=transformations['post_transform'])
            data_loader = DataLoader(train_data,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=int(workers))
        else:
            test_data = MiniImageNet(
                root=dset_dir,
                phase=phase,
                spc=600,
                pre_transform=transformations['test_transform'])
            data_loader = DataLoader(test_data,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     num_workers=int(workers))

    else:
        raise NotImplementedError

    return data_loader
Beispiel #7
0
def test_OmniglotTask():
    import cv2

    root = './data/'
    dataset = Omniglot(root=root, download=True)
    #dataset = Omniglot(root=root,background=False,download=True)
    print(len(dataset))

    idx_alphabet = 0
    idx_sample = 0

    #task, nbrCharacter4Task, nbrSample4Task = dataset.generateFewShotLearningTask( alphabet_idx=idx_alphabet)
    task, nbrCharacter4Task, nbrSample4Task = dataset.generateIterFewShotLearningTask(
        alphabet_idx=idx_alphabet)
    sample = dataset.getSample(task[idx_sample]['alphabet'],
                               task[idx_sample]['character'],
                               task[idx_sample]['sample'])
    image_path = dataset.sampleSample4Character4Alphabet(
        task[idx_sample]['alphabet'], task[idx_sample]['character'],
        task[idx_sample]['sample'])
    changed = False
    taskChanged = False
    idx = 0
    while True:

        #sample = dataset[idx]

        img = np.array(sample[0])
        cv2.imshow('test', img)

        key = cv2.waitKey(1) & 0xFF
        if key == ord('q'):
            break
        elif key == ord('n'):
            idx += 1
            idx_sample = (idx_sample + 1) % len(task)
            changed = True
        elif key == ord('a'):
            idx_alphabet += 1
            idx_character = 0
            idx_sample = 0
            taskChanged = True

        if changed:
            changed = False
            sample = dataset.getSample(task[idx_sample]['alphabet'],
                                       task[idx_sample]['character'],
                                       task[idx_sample]['sample'])
            image_path = dataset.sampleSample4Character4Alphabet(
                task[idx_sample]['alphabet'], task[idx_sample]['character'],
                task[idx_sample]['sample'])
            print(image_path, '/{}'.format(nbrCharacter4Task))
            print('Sample {} / {}'.format(idx, nbrSample4Task))

        if taskChanged:
            taskChanged = False
            idx = 0
            task, nbrCharacter4Task, nbrSample4Task = dataset.generateFewShotLearningTask(
                alphabet_idx=idx_alphabet)
            sample = dataset.getSample(task[idx_sample]['alphabet'],
                                       task[idx_sample]['character'],
                                       task[idx_sample]['sample'])
            image_path = dataset.sampleSample4Character4Alphabet(
                task[idx_sample]['alphabet'], task[idx_sample]['character'],
                task[idx_sample]['sample'])
            print(image_path, '/{}'.format(nbrCharacter4Task))
Beispiel #8
0
def test_OmniglotSeq():
    import cv2

    root = './data/'
    h = 240
    w = 240
    dataset = Omniglot(root=root, h=h, w=w)
    #dataset = Omniglot(root=root,background=False,download=True)
    print(len(dataset))

    idx_alphabet = 0
    idx_sample = 0

    seq, nbrCharacter4Task, nbrSample4Task = dataset.generateIterFewShotInputSequence(
        alphabet_idx=idx_alphabet)
    sample = dataset.getSample(seq[idx_sample]['alphabet'],
                               seq[idx_sample]['character'],
                               seq[idx_sample]['sample'])
    image_path = dataset.sampleSample4Character4Alphabet(
        seq[idx_sample]['alphabet'], seq[idx_sample]['character'],
        seq[idx_sample]['sample'])
    changed = False
    seqChanged = False
    idx = 0
    while True:

        #sample = dataset[idx]

        img = (sample['image'].numpy() * 255).transpose((1, 2, 0))
        cv2.imshow('test', img)

        key = cv2.waitKey(1) & 0xFF
        if key == ord('q'):
            break
        elif key == ord('n'):
            idx += 1
            idx_sample = (idx_sample + 1) % len(seq)
            changed = True
        elif key == ord('a'):
            idx_alphabet += 1
            idx_character = 0
            idx_sample = 0
            seqChanged = True

        if changed:
            changed = False
            sample = dataset.getSample(seq[idx_sample]['alphabet'],
                                       seq[idx_sample]['character'],
                                       seq[idx_sample]['sample'])
            image_path = dataset.sampleSample4Character4Alphabet(
                seq[idx_sample]['alphabet'], seq[idx_sample]['character'],
                seq[idx_sample]['sample'])
            print(image_path, '/{}'.format(nbrCharacter4Task))
            print('Sample {} / {}'.format(idx, nbrSample4Task))
            print('Target :{}'.format(seq[idx_sample]['target']))

        if seqChanged:
            seqChanged = False
            idx = 0
            seq, nbrCharacter4Task, nbrSample4Task = dataset.generateFewShotLearningTask(
                alphabet_idx=idx_alphabet)
            sample = dataset.getSample(seq[idx_sample]['alphabet'],
                                       seq[idx_sample]['character'],
                                       seq[idx_sample]['sample'])
            image_path = dataset.sampleSample4Character4Alphabet(
                seq[idx_sample]['alphabet'], seq[idx_sample]['character'],
                seq[idx_sample]['sample'])
            print(image_path, '/{}'.format(nbrCharacter4Task))
    def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz):
        """

		:param dataroot:
		:param batch_size:
		:param n_way:
		:param k_shot:
		"""

        self.resize = imgsz
        if not os.path.isfile(os.path.join(root, 'omni.npy')):
            # if root/data.npy does not exist, just download it
            self.x = Omniglot(
                root,
                download=True,
                transform=transforms.Compose([
                    lambda x: Image.open(x).convert('L'),
                    transforms.Resize(self.resize),
                    lambda x: np.reshape(x, (self.resize, self.resize, 1))
                ]))

            temp = dict(
            )  # {label:img1, img2..., 20 imgs in total, 1623 label}
            for (img, label) in self.x:
                if label in temp:
                    temp[label].append(img)
                else:
                    temp[label] = [img]

            self.x = []
            for label, imgs in temp.items(
            ):  # labels info deserted , each label contains 20imgs
                self.x.append(np.array(imgs))

            # as different class may have different number of imgs
            self.x = np.array(self.x)  # [[20 imgs],..., 1623 classes in total]
            # each character contains 20 imgs
            print('dataset shape:', self.x.shape)  # [1623, 20, 84, 84, 1]
            temp = []  # Free memory
            # save all dataset into npy file.
            np.save(os.path.join(root, 'omni.npy'), self.x)
        else:
            # if data.npy exists, just load it.
            self.x = np.load(os.path.join(root, 'omni.npy'))

        self.x = self.x / 255
        # self.x: [1623, shuffled, 20 imgs, 84, 84, 1]
        np.random.shuffle(self.x)  # shuffle on the first dim = 1623 cls

        self.x_train, self.x_test = self.x[:1200], self.x[1200:]
        self.normalization()

        self.batchsz = batchsz
        self.n_cls = self.x.shape[0]  # 1623
        self.n_way = n_way  # n way
        self.k_shot = k_shot  # k shot
        self.k_query = k_query  # k query

        # save pointer of current read batch in total cache
        self.indexes = {"train": 0, "test": 0}
        self.datasets = {
            "train": self.x_train,
            "test": self.x_test
        }  # original data cached
        print("train_shape", self.x_train.shape, "test_shape",
              self.x_test.shape)

        self.datasets_cache = {
            "train": self.load_data_cache(
                self.datasets["train"]),  # current epoch data cached
            "test": self.load_data_cache(self.datasets["test"])
        }
class OmniglotNShot():
    def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz):
        """

		:param dataroot:
		:param batch_size:
		:param n_way:
		:param k_shot:
		"""

        self.resize = imgsz
        if not os.path.isfile(os.path.join(root, 'omni.npy')):
            # if root/data.npy does not exist, just download it
            self.x = Omniglot(
                root,
                download=True,
                transform=transforms.Compose([
                    lambda x: Image.open(x).convert('L'),
                    transforms.Resize(self.resize),
                    lambda x: np.reshape(x, (self.resize, self.resize, 1))
                ]))

            temp = dict(
            )  # {label:img1, img2..., 20 imgs in total, 1623 label}
            for (img, label) in self.x:
                if label in temp:
                    temp[label].append(img)
                else:
                    temp[label] = [img]

            self.x = []
            for label, imgs in temp.items(
            ):  # labels info deserted , each label contains 20imgs
                self.x.append(np.array(imgs))

            # as different class may have different number of imgs
            self.x = np.array(self.x)  # [[20 imgs],..., 1623 classes in total]
            # each character contains 20 imgs
            print('dataset shape:', self.x.shape)  # [1623, 20, 84, 84, 1]
            temp = []  # Free memory
            # save all dataset into npy file.
            np.save(os.path.join(root, 'omni.npy'), self.x)
        else:
            # if data.npy exists, just load it.
            self.x = np.load(os.path.join(root, 'omni.npy'))

        self.x = self.x / 255
        # self.x: [1623, shuffled, 20 imgs, 84, 84, 1]
        np.random.shuffle(self.x)  # shuffle on the first dim = 1623 cls

        self.x_train, self.x_test = self.x[:1200], self.x[1200:]
        self.normalization()

        self.batchsz = batchsz
        self.n_cls = self.x.shape[0]  # 1623
        self.n_way = n_way  # n way
        self.k_shot = k_shot  # k shot
        self.k_query = k_query  # k query

        # save pointer of current read batch in total cache
        self.indexes = {"train": 0, "test": 0}
        self.datasets = {
            "train": self.x_train,
            "test": self.x_test
        }  # original data cached
        print("train_shape", self.x_train.shape, "test_shape",
              self.x_test.shape)

        self.datasets_cache = {
            "train": self.load_data_cache(
                self.datasets["train"]),  # current epoch data cached
            "test": self.load_data_cache(self.datasets["test"])
        }

    def normalization(self):
        """
		Normalizes our data, to have a mean of 0 and sdt of 1
		"""
        self.mean = np.mean(self.x_train)
        self.std = np.std(self.x_train)
        self.max = np.max(self.x_train)
        self.min = np.min(self.x_train)
        print("before norm:", "mean", self.mean, "max", self.max, "min",
              self.min, "std", self.std)
        self.x_train = (self.x_train - self.mean) / self.std
        self.x_test = (self.x_test - self.mean) / self.std

        self.mean = np.mean(self.x_train)
        self.std = np.std(self.x_train)
        self.max = np.max(self.x_train)
        self.min = np.min(self.x_train)
        print("after norm:", "mean", self.mean, "max", self.max, "min",
              self.min, "std", self.std)

    def load_data_cache(self, data_pack):
        """
		Collects several batches data for N-shot learning
		:param data_pack: [cls_num, 20, 84, 84, 1]
		:return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks
		"""
        #  take 5 way 1 shot as example: 5 * 1
        setsz = self.k_shot * self.n_way
        querysz = self.k_query * self.n_way
        data_cache = []

        # print('preload next 50 caches of batchsz of batch.')
        for sample in range(50):  # num of episodes
            # (batch, setsz, imgs)
            support_x = np.zeros(
                (self.batchsz, setsz, self.resize, self.resize, 1))
            # (batch, setsz)
            support_y = np.zeros((self.batchsz, setsz), dtype=np.int)
            # (batch, querysz, imgs)
            query_x = np.zeros(
                (self.batchsz, querysz, self.resize, self.resize, 1))
            # (batch, querysz)
            query_y = np.zeros((self.batchsz, querysz), dtype=np.int)

            for i in range(self.batchsz):  # one batch means one set
                shuffle_idx = np.arange(self.n_way)  # [0,1,2,3,4]
                np.random.shuffle(shuffle_idx)  # [2,4,1,0,3]
                shuffle_idx_test = np.arange(self.n_way)  # [0,1,2,3,4]
                np.random.shuffle(shuffle_idx_test)  # [2,0,1,4,3]
                selected_cls = np.random.choice(data_pack.shape[0], self.n_way,
                                                False)

                for j, cur_class in enumerate(
                        selected_cls):  # for each selected cls
                    # Count number of times this class is inside the meta-test
                    # [img1, img2 ,,,  = k_shot + k_query ]
                    selected_imgs = np.random.choice(
                        data_pack.shape[1], self.k_shot + self.k_query, False)

                    # meta-training, select the first k_shot imgs for each class as support imgs
                    for offset, img in enumerate(selected_imgs[:self.k_shot]):
                        # i: batch idx
                        # cur_class: cls in n_way
                        support_x[i, shuffle_idx[j] * self.k_shot + offset,
                                  ...] = data_pack[cur_class][img]
                        support_y[i, shuffle_idx[j] * self.k_shot +
                                  offset] = j  # relative indexing

                    # meta-test, treat following k_query imgs as query imgs
                    for offset, img in enumerate(selected_imgs[self.k_shot:]):
                        query_x[i, shuffle_idx_test[j] * self.k_query + offset,
                                ...] = data_pack[cur_class][img]
                        query_y[i, shuffle_idx_test[j] * self.k_query +
                                offset] = j  # relative indexing

            data_cache.append([support_x, support_y, query_x, query_y])
        return data_cache

    def __get_batch(self, mode):
        """
		Gets next batch from the dataset with name.
		:param dataset_name: The name of the dataset (one of "train", "val", "test")
		:return:
		"""
        # update cache if indexes is larger cached num
        if self.indexes[mode] >= len(self.datasets_cache[mode]):
            self.indexes[mode] = 0
            self.datasets_cache[mode] = self.load_data_cache(
                self.datasets[mode])

        next_batch = self.datasets_cache[mode][self.indexes[mode]]
        self.indexes[mode] += 1

        return next_batch

    def get_batch(self, mode):
        """
		Get next batch
		:return: Next batch
		"""
        x_support_set, y_support_set, x_target, y_target = self.__get_batch(
            mode)

        k = int(np.random.uniform(low=0, high=4))  # 0 - 3
        # Iterate over the sequence. Extract batches.

        for i in np.arange(x_support_set.shape[0]):
            # batchsz, setsz, c, h, w
            x_support_set[i, :, :, :, :] = self.__rotate_batch(
                x_support_set[i, :, :, :, :], k)

        # Rotate all the batch of the target images
        for i in np.arange(x_target.shape[0]):
            x_target[i, :, :, :, :] = self.__rotate_batch(
                x_target[i, :, :, :, :], k)

        return x_support_set, y_support_set, x_target, y_target

    def __rotate_batch(self, batch_images, k):
        """
		Rotates a whole image batch
		:param batch_images: A batch of images
		:param k: integer degree of rotation counter-clockwise
		:return: The rotated batch of images
		"""
        batch_size = len(batch_images)
        for i in np.arange(batch_size):
            batch_images[i] = np.rot90(batch_images[i], k)
        return batch_images
Beispiel #11
0
            new_var = torch.mean(torch.stack([s[k] for s in new_states]), dim=0)
            tmp[k] = (1 - meta_step_size) * old_var + meta_step_size * new_var.clone()
        
        self.load_state_dict(deepcopy(tmp))
        del self._old_state

# --
# Run

args = parse_args()
set_seeds(args.seed)

# --
# IO

dataset = Omniglot(root='./data/omniglot')

train_classes, test_classes = train_test_split(dataset._classes, train_size=args.num_train_classes)

train_taskset = OmniglotTaskWrapper(dataset, classes=train_classes, rotation=False)
test_taskset  = OmniglotTaskWrapper(dataset, classes=test_classes, rotation=False)

cuda = torch.device('cuda')
model = OmniglotModel(num_classes=args.num_classes).to(cuda)

model.init_optimizer(
    opt=torch.optim.Adam,
    params=model.parameters(),
    lr=args.lr,
    betas=[0.0, 0.999],
)
Beispiel #12
0
 def __init__(self, args):
     self.task_num = args.task_num
     self.class_num = args.class_num
     self.train_sample_size_per_class = args.train_sample_size_per_class
     self.test_sample_size_per_class = args.test_sample_size_per_class
     self.sample_size_per_class = self.train_sample_size_per_class + self.test_sample_size_per_class
     if args.data_source == 'sinusoid':
         self.generate = self.generate_sinusoid_batch
         self.amp_range = args.amp_range
         self.phase_range = args.phase_range
         self.input_range = args.input_range
         self.dim_input = 1
         self.dim_output = 1
     elif args.data_source == 'omniglot':
         assert self.sample_size_per_class <= 20
         self.generate = self.load_omniglot_batch
         self.dim_output = self.class_num
         self.data_filename = 'omniglot_' + str(
             args.img_size[0]) + 'x' + str(args.img_size[0]) + '.npy'
         #load processed data or download and process the original data
         if not os.path.isfile(
                 os.path.join(args.data_folder, self.data_filename)):
             # if root/data.npy does not exist, just download it
             omniglot = Omniglot(args, download=True)
             temp = defaultdict(
                 list
             )  # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
             for (img, label) in omniglot:
                 temp[label].append(img)
             self.omniglot_data = np.array(
                 list(temp.values()),
                 dtype=np.float)  # [[20 imgs],..., 1623 classes in total]
             del temp  # Free memory
             # save all dataset into npy file.
             np.save(os.path.join(args.data_folder, self.data_filename),
                     self.omniglot_data)
             print('\nWrite data into ' + self.data_filename)
         else:
             # if data.npy exists, just load it.
             self.omniglot_data = np.load(
                 os.path.join(args.data_folder, self.data_filename))
             print('\nLoad data from ' + self.data_filename)
         self.datasets = {
             'train': self.omniglot_data[:1200],
             'test': self.omniglot_data[1200:]
         }
         print('Training data size: {}, test data size: {}'.format(
             self.datasets['train'].shape, self.datasets['test'].shape))
         # save pointer of current read batch in total cache
         self.indexes = {'train': 0, 'test': 0}
         self.datasets_cache = {
             'train': self.preload_omniglot_data_cache(
                 self.datasets['train']),  # current epoch data cached
             'test': self.preload_omniglot_data_cache(self.datasets['test'])
         }
     elif args.data_source == 'miniimagenet':
         self.generate = self.load_miniimagenet_batch
         self.datasets = {
             "train": MiniImagenet(args, mode='train',
                                   total_task_num=10000),
             "test": MiniImagenet(args, mode='test', total_task_num=100)
         }
     else:
         raise NotImplementedError