def test_Omniglot(): import cv2 root = './data/' dataset = Omniglot(root=root, download=True) #dataset = Omniglot(root=root,background=False,download=True) print(len(dataset)) idx = 0 idx_alphabet = 0 idx_character = 0 idx_sample = 0 while True: sample = dataset[idx] image_path = dataset.sampleSample4Character4Alphabet( idx_alphabet, idx_character, idx_sample) img = np.array(sample[0]) cv2.imshow('test', img) key = cv2.waitKey(1) & 0xFF if key == ord('q'): break elif key == ord('n'): idx += 1 #print(sample[1]) print(image_path) idx_sample += 1 print("next image")
def test_OmniglotBatchedSeq(): import cv2 root = './data/' h = 240 w = 240 dataset = Omniglot(root=root, h=h, w=w) #dataset = Omniglot(root=root,background=False,download=True) print(len(dataset)) batch_size = 2 max_nbr_char = 5 idx_alphabet = list() for i in range(batch_size): idx_alphabet.append(i) idx_sample = 0 keys = ('alphabet', 'character', 'sample', 'target', 'nbrCharacter') seqs = list() for s in range(batch_size): seq, nbrCharacter4Task, nbrSample4Task = dataset.generateIterFewShotInputSequence( alphabet_idx=idx_alphabet[s], max_nbr_char=max_nbr_char) dseqs = { 'alphabet': [], 'character': [], 'sample': [], 'target': [], 'nbrCharacter': [] } for el in seq: for k in el.keys(): dseqs[k].append(el[k]) seqs.append(dseqs) seqs = [{k: [seqs[b][k] for b in range(batch_size)]} for k in keys] dseqs = dict() for d in seqs: dseqs.update(d) # key x task_idx x seq_idx for k in dseqs.keys(): dseqs[k] = zip(*dseqs[k]) print(seqs) raise batch = dataset.getBatchedSample(seqs)
def test_OmniglotClass(): import cv2 root = './data/' dataset = Omniglot(root=root, download=True) #dataset = Omniglot(root=root,background=False,download=True) print(len(dataset)) idx = 0 idx_alphabet = 0 idx_character = 0 idx_sample = 0 sample = dataset.getSample(idx_alphabet, idx_character, idx_sample) image_path = dataset.sampleSample4Character4Alphabet( idx_alphabet, idx_character, idx_sample) changed = False while True: #sample = dataset[idx] img = np.array(sample[0]) cv2.imshow('test', img) key = cv2.waitKey(1) & 0xFF if key == ord('q'): break elif key == ord('n'): idx += 1 print(image_path) idx_sample += 1 changed = True elif key == ord('a'): idx += 1 print(image_path) idx_alphabet += 1 idx_character = 0 idx_sample = 0 changed = True elif key == ord('c'): idx += 1 print(image_path) idx_character += 1 idx_sample = 0 changed = True if changed: changed = False sample = dataset.getSample(idx_alphabet, idx_character, idx_sample) image_path = dataset.sampleSample4Character4Alphabet( idx_alphabet, idx_character, idx_sample)
def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz): """ Different from mnistNShot, the :param root: :param batchsz: task num :param n_way: :param k_shot: :param k_qry: :param imgsz: """ self.resize = imgsz if not os.path.isfile(os.path.join(root, 'omniglot.npy')): # if root/data.npy does not exist, just download it self.x = Omniglot(root, download=True, transform=transforms.Compose([lambda x: Image.open(x).convert('L'), lambda x: x.resize((imgsz, imgsz)), lambda x: np.reshape(x, (imgsz, imgsz, 1)), lambda x: np.transpose(x, [2, 0, 1]), lambda x: x/255.]) ) temp = dict() # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label} for (img, label) in self.x: if label in temp.keys(): temp[label].append(img) else: temp[label] = [img] self.x = [] for label, imgs in temp.items(): # labels info deserted , each label contains 20imgs self.x.append(np.array(imgs)) # as different class may have different number of imgs self.x = np.array(self.x).astype(np.float) # [[20 imgs],..., 1623 classes in total] # each character contains 20 imgs print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1] temp = [] # Free memory # save all dataset into npy file. np.save(os.path.join(root, 'omniglot.npy'), self.x) print('write into omniglot.npy.') else: # if data.npy exists, just load it. self.x = np.load(os.path.join(root, 'omniglot.npy')) print('load from omniglot.npy.') # [1623, 20, 84, 84, 1] # TODO: can not shuffle here, we must keep training and test set distinct! self.x_train, self.x_test = self.x[:1200], self.x[1200:] # self.normalization() self.batchsz = batchsz self.n_cls = self.x.shape[0] # 1623 self.n_way = n_way # n way self.k_shot = k_shot # k shot self.k_query = k_query # k query assert (k_shot + k_query) <=20 # save pointer of current read batch in total cache self.indexes = {"train": 0, "test": 0} self.datasets = {"train": self.x_train, "test": self.x_test} # original data cached print("DB: train", self.x_train.shape, "test", self.x_test.shape) self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]), # current epoch data cached "test": self.load_data_cache(self.datasets["test"])}
class OmniglotNShot: def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz): """ Different from mnistNShot, the :param root: :param batchsz: task num :param n_way: :param k_shot: :param k_qry: :param imgsz: """ self.resize = imgsz if not os.path.isfile(os.path.join(root, 'omniglot.npy')): # if root/data.npy does not exist, just download it self.x = Omniglot(root, download=True, transform=transforms.Compose([lambda x: Image.open(x).convert('L'), lambda x: x.resize((imgsz, imgsz)), lambda x: np.reshape(x, (imgsz, imgsz, 1)), lambda x: np.transpose(x, [2, 0, 1]), lambda x: x/255.]) ) temp = dict() # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label} for (img, label) in self.x: if label in temp.keys(): temp[label].append(img) else: temp[label] = [img] self.x = [] for label, imgs in temp.items(): # labels info deserted , each label contains 20imgs self.x.append(np.array(imgs)) # as different class may have different number of imgs self.x = np.array(self.x).astype(np.float) # [[20 imgs],..., 1623 classes in total] # each character contains 20 imgs print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1] temp = [] # Free memory # save all dataset into npy file. np.save(os.path.join(root, 'omniglot.npy'), self.x) print('write into omniglot.npy.') else: # if data.npy exists, just load it. self.x = np.load(os.path.join(root, 'omniglot.npy')) print('load from omniglot.npy.') # [1623, 20, 84, 84, 1] # TODO: can not shuffle here, we must keep training and test set distinct! self.x_train, self.x_test = self.x[:1200], self.x[1200:] # self.normalization() self.batchsz = batchsz self.n_cls = self.x.shape[0] # 1623 self.n_way = n_way # n way self.k_shot = k_shot # k shot self.k_query = k_query # k query assert (k_shot + k_query) <=20 # save pointer of current read batch in total cache self.indexes = {"train": 0, "test": 0} self.datasets = {"train": self.x_train, "test": self.x_test} # original data cached print("DB: train", self.x_train.shape, "test", self.x_test.shape) self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]), # current epoch data cached "test": self.load_data_cache(self.datasets["test"])} def normalization(self): """ Normalizes our data, to have a mean of 0 and sdt of 1 """ self.mean = np.mean(self.x_train) self.std = np.std(self.x_train) self.max = np.max(self.x_train) self.min = np.min(self.x_train) # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) self.x_train = (self.x_train - self.mean) / self.std self.x_test = (self.x_test - self.mean) / self.std self.mean = np.mean(self.x_train) self.std = np.std(self.x_train) self.max = np.max(self.x_train) self.min = np.min(self.x_train) # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) def load_data_cache(self, data_pack): """ Collects several batches data for N-shot learning :param data_pack: [cls_num, 20, 84, 84, 1] :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks """ # take 5 way 1 shot as example: 5 * 1 setsz = self.k_shot * self.n_way querysz = self.k_query * self.n_way data_cache = [] # print('preload next 50 caches of batchsz of batch.') for sample in range(10): # num of episodes x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] for i in range(self.batchsz): # one batch means one set x_spt, y_spt, x_qry, y_qry = [], [], [], [] selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False) for j, cur_class in enumerate(selected_cls): selected_img = np.random.choice(20, self.k_shot + self.k_query, False) # meta-training and meta-test x_spt.append(data_pack[cur_class][selected_img[:self.k_shot]]) x_qry.append(data_pack[cur_class][selected_img[self.k_shot:]]) y_spt.append([j for _ in range(self.k_shot)]) y_qry.append([j for _ in range(self.k_query)]) # shuffle inside a batch perm = np.random.permutation(self.n_way * self.k_shot) x_spt = np.array(x_spt).reshape(self.n_way * self.k_shot, 1, self.resize, self.resize)[perm] y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm] perm = np.random.permutation(self.n_way * self.k_query) x_qry = np.array(x_qry).reshape(self.n_way * self.k_query, 1, self.resize, self.resize)[perm] y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm] # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84] x_spts.append(x_spt) y_spts.append(y_spt) x_qrys.append(x_qry) y_qrys.append(y_qry) # [b, setsz, 1, 84, 84] x_spts = np.array(x_spts).astype(np.float32).reshape(self.batchsz, setsz, 1, self.resize, self.resize) y_spts = np.array(y_spts).astype(np.int).reshape(self.batchsz, setsz) # [b, qrysz, 1, 84, 84] x_qrys = np.array(x_qrys).astype(np.float32).reshape(self.batchsz, querysz, 1, self.resize, self.resize) y_qrys = np.array(y_qrys).astype(np.int).reshape(self.batchsz, querysz) data_cache.append([x_spts, y_spts, x_qrys, y_qrys]) return data_cache def next(self, mode='train'): """ Gets next batch from the dataset with name. :param mode: The name of the splitting (one of "train", "val", "test") :return: """ # update cache if indexes is larger cached num if self.indexes[mode] >= len(self.datasets_cache[mode]): self.indexes[mode] = 0 self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode]) next_batch = self.datasets_cache[mode][self.indexes[mode]] self.indexes[mode] += 1 return next_batch
def get_data(dataset, dset_dir, image_size, model, phase, unsupervised, spc, batch_size, workers): dataset = dataset.lower() if dataset == 'omniglot_original': dset_dir = os.path.join(dset_dir, 'omniglot') transformations = get_transform(dataset, image_size, model) if phase == 'train': train_data = Omniglot( root=dset_dir, phase='background', unsupervised=unsupervised, spc=spc, pre_transform=transformations['pre_transform'], transform=transformations['transform'], post_transform=transformations['post_transform']) data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=int(workers)) else: test_data = Omniglot( root=dset_dir, phase='evaluation', spc=20, pre_transform=transformations['test_transform']) data_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=int(workers)) elif dataset == 'omniglot_aug': dset_dir = os.path.join(dset_dir, 'omniglot') transformations = get_transform(dataset, image_size, model) if phase == 'train': train_data = Omniglot( root=dset_dir, phase=phase, unsupervised=unsupervised, spc=spc, pre_transform=transformations['pre_transform'], transform=transformations['transform'], post_transform=transformations['post_transform']) data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=int(workers)) else: test_data = Omniglot( root=dset_dir, phase=phase, spc=20, pre_transform=transformations['test_transform']) data_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=int(workers)) elif dataset == 'miniimagenet': dset_dir = os.path.join(dset_dir, 'miniImagenet') transformations = get_transform(dataset, image_size, model) if phase == 'train': train_data = MiniImageNet( root=dset_dir, phase=phase, unsupervised=unsupervised, spc=spc, pre_transform=transformations['pre_transform'], transform=transformations['transform'], post_transform=transformations['post_transform']) data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=int(workers)) else: test_data = MiniImageNet( root=dset_dir, phase=phase, spc=600, pre_transform=transformations['test_transform']) data_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=int(workers)) else: raise NotImplementedError return data_loader
def test_OmniglotTask(): import cv2 root = './data/' dataset = Omniglot(root=root, download=True) #dataset = Omniglot(root=root,background=False,download=True) print(len(dataset)) idx_alphabet = 0 idx_sample = 0 #task, nbrCharacter4Task, nbrSample4Task = dataset.generateFewShotLearningTask( alphabet_idx=idx_alphabet) task, nbrCharacter4Task, nbrSample4Task = dataset.generateIterFewShotLearningTask( alphabet_idx=idx_alphabet) sample = dataset.getSample(task[idx_sample]['alphabet'], task[idx_sample]['character'], task[idx_sample]['sample']) image_path = dataset.sampleSample4Character4Alphabet( task[idx_sample]['alphabet'], task[idx_sample]['character'], task[idx_sample]['sample']) changed = False taskChanged = False idx = 0 while True: #sample = dataset[idx] img = np.array(sample[0]) cv2.imshow('test', img) key = cv2.waitKey(1) & 0xFF if key == ord('q'): break elif key == ord('n'): idx += 1 idx_sample = (idx_sample + 1) % len(task) changed = True elif key == ord('a'): idx_alphabet += 1 idx_character = 0 idx_sample = 0 taskChanged = True if changed: changed = False sample = dataset.getSample(task[idx_sample]['alphabet'], task[idx_sample]['character'], task[idx_sample]['sample']) image_path = dataset.sampleSample4Character4Alphabet( task[idx_sample]['alphabet'], task[idx_sample]['character'], task[idx_sample]['sample']) print(image_path, '/{}'.format(nbrCharacter4Task)) print('Sample {} / {}'.format(idx, nbrSample4Task)) if taskChanged: taskChanged = False idx = 0 task, nbrCharacter4Task, nbrSample4Task = dataset.generateFewShotLearningTask( alphabet_idx=idx_alphabet) sample = dataset.getSample(task[idx_sample]['alphabet'], task[idx_sample]['character'], task[idx_sample]['sample']) image_path = dataset.sampleSample4Character4Alphabet( task[idx_sample]['alphabet'], task[idx_sample]['character'], task[idx_sample]['sample']) print(image_path, '/{}'.format(nbrCharacter4Task))
def test_OmniglotSeq(): import cv2 root = './data/' h = 240 w = 240 dataset = Omniglot(root=root, h=h, w=w) #dataset = Omniglot(root=root,background=False,download=True) print(len(dataset)) idx_alphabet = 0 idx_sample = 0 seq, nbrCharacter4Task, nbrSample4Task = dataset.generateIterFewShotInputSequence( alphabet_idx=idx_alphabet) sample = dataset.getSample(seq[idx_sample]['alphabet'], seq[idx_sample]['character'], seq[idx_sample]['sample']) image_path = dataset.sampleSample4Character4Alphabet( seq[idx_sample]['alphabet'], seq[idx_sample]['character'], seq[idx_sample]['sample']) changed = False seqChanged = False idx = 0 while True: #sample = dataset[idx] img = (sample['image'].numpy() * 255).transpose((1, 2, 0)) cv2.imshow('test', img) key = cv2.waitKey(1) & 0xFF if key == ord('q'): break elif key == ord('n'): idx += 1 idx_sample = (idx_sample + 1) % len(seq) changed = True elif key == ord('a'): idx_alphabet += 1 idx_character = 0 idx_sample = 0 seqChanged = True if changed: changed = False sample = dataset.getSample(seq[idx_sample]['alphabet'], seq[idx_sample]['character'], seq[idx_sample]['sample']) image_path = dataset.sampleSample4Character4Alphabet( seq[idx_sample]['alphabet'], seq[idx_sample]['character'], seq[idx_sample]['sample']) print(image_path, '/{}'.format(nbrCharacter4Task)) print('Sample {} / {}'.format(idx, nbrSample4Task)) print('Target :{}'.format(seq[idx_sample]['target'])) if seqChanged: seqChanged = False idx = 0 seq, nbrCharacter4Task, nbrSample4Task = dataset.generateFewShotLearningTask( alphabet_idx=idx_alphabet) sample = dataset.getSample(seq[idx_sample]['alphabet'], seq[idx_sample]['character'], seq[idx_sample]['sample']) image_path = dataset.sampleSample4Character4Alphabet( seq[idx_sample]['alphabet'], seq[idx_sample]['character'], seq[idx_sample]['sample']) print(image_path, '/{}'.format(nbrCharacter4Task))
def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz): """ :param dataroot: :param batch_size: :param n_way: :param k_shot: """ self.resize = imgsz if not os.path.isfile(os.path.join(root, 'omni.npy')): # if root/data.npy does not exist, just download it self.x = Omniglot( root, download=True, transform=transforms.Compose([ lambda x: Image.open(x).convert('L'), transforms.Resize(self.resize), lambda x: np.reshape(x, (self.resize, self.resize, 1)) ])) temp = dict( ) # {label:img1, img2..., 20 imgs in total, 1623 label} for (img, label) in self.x: if label in temp: temp[label].append(img) else: temp[label] = [img] self.x = [] for label, imgs in temp.items( ): # labels info deserted , each label contains 20imgs self.x.append(np.array(imgs)) # as different class may have different number of imgs self.x = np.array(self.x) # [[20 imgs],..., 1623 classes in total] # each character contains 20 imgs print('dataset shape:', self.x.shape) # [1623, 20, 84, 84, 1] temp = [] # Free memory # save all dataset into npy file. np.save(os.path.join(root, 'omni.npy'), self.x) else: # if data.npy exists, just load it. self.x = np.load(os.path.join(root, 'omni.npy')) self.x = self.x / 255 # self.x: [1623, shuffled, 20 imgs, 84, 84, 1] np.random.shuffle(self.x) # shuffle on the first dim = 1623 cls self.x_train, self.x_test = self.x[:1200], self.x[1200:] self.normalization() self.batchsz = batchsz self.n_cls = self.x.shape[0] # 1623 self.n_way = n_way # n way self.k_shot = k_shot # k shot self.k_query = k_query # k query # save pointer of current read batch in total cache self.indexes = {"train": 0, "test": 0} self.datasets = { "train": self.x_train, "test": self.x_test } # original data cached print("train_shape", self.x_train.shape, "test_shape", self.x_test.shape) self.datasets_cache = { "train": self.load_data_cache( self.datasets["train"]), # current epoch data cached "test": self.load_data_cache(self.datasets["test"]) }
class OmniglotNShot(): def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz): """ :param dataroot: :param batch_size: :param n_way: :param k_shot: """ self.resize = imgsz if not os.path.isfile(os.path.join(root, 'omni.npy')): # if root/data.npy does not exist, just download it self.x = Omniglot( root, download=True, transform=transforms.Compose([ lambda x: Image.open(x).convert('L'), transforms.Resize(self.resize), lambda x: np.reshape(x, (self.resize, self.resize, 1)) ])) temp = dict( ) # {label:img1, img2..., 20 imgs in total, 1623 label} for (img, label) in self.x: if label in temp: temp[label].append(img) else: temp[label] = [img] self.x = [] for label, imgs in temp.items( ): # labels info deserted , each label contains 20imgs self.x.append(np.array(imgs)) # as different class may have different number of imgs self.x = np.array(self.x) # [[20 imgs],..., 1623 classes in total] # each character contains 20 imgs print('dataset shape:', self.x.shape) # [1623, 20, 84, 84, 1] temp = [] # Free memory # save all dataset into npy file. np.save(os.path.join(root, 'omni.npy'), self.x) else: # if data.npy exists, just load it. self.x = np.load(os.path.join(root, 'omni.npy')) self.x = self.x / 255 # self.x: [1623, shuffled, 20 imgs, 84, 84, 1] np.random.shuffle(self.x) # shuffle on the first dim = 1623 cls self.x_train, self.x_test = self.x[:1200], self.x[1200:] self.normalization() self.batchsz = batchsz self.n_cls = self.x.shape[0] # 1623 self.n_way = n_way # n way self.k_shot = k_shot # k shot self.k_query = k_query # k query # save pointer of current read batch in total cache self.indexes = {"train": 0, "test": 0} self.datasets = { "train": self.x_train, "test": self.x_test } # original data cached print("train_shape", self.x_train.shape, "test_shape", self.x_test.shape) self.datasets_cache = { "train": self.load_data_cache( self.datasets["train"]), # current epoch data cached "test": self.load_data_cache(self.datasets["test"]) } def normalization(self): """ Normalizes our data, to have a mean of 0 and sdt of 1 """ self.mean = np.mean(self.x_train) self.std = np.std(self.x_train) self.max = np.max(self.x_train) self.min = np.min(self.x_train) print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) self.x_train = (self.x_train - self.mean) / self.std self.x_test = (self.x_test - self.mean) / self.std self.mean = np.mean(self.x_train) self.std = np.std(self.x_train) self.max = np.max(self.x_train) self.min = np.min(self.x_train) print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) def load_data_cache(self, data_pack): """ Collects several batches data for N-shot learning :param data_pack: [cls_num, 20, 84, 84, 1] :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks """ # take 5 way 1 shot as example: 5 * 1 setsz = self.k_shot * self.n_way querysz = self.k_query * self.n_way data_cache = [] # print('preload next 50 caches of batchsz of batch.') for sample in range(50): # num of episodes # (batch, setsz, imgs) support_x = np.zeros( (self.batchsz, setsz, self.resize, self.resize, 1)) # (batch, setsz) support_y = np.zeros((self.batchsz, setsz), dtype=np.int) # (batch, querysz, imgs) query_x = np.zeros( (self.batchsz, querysz, self.resize, self.resize, 1)) # (batch, querysz) query_y = np.zeros((self.batchsz, querysz), dtype=np.int) for i in range(self.batchsz): # one batch means one set shuffle_idx = np.arange(self.n_way) # [0,1,2,3,4] np.random.shuffle(shuffle_idx) # [2,4,1,0,3] shuffle_idx_test = np.arange(self.n_way) # [0,1,2,3,4] np.random.shuffle(shuffle_idx_test) # [2,0,1,4,3] selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False) for j, cur_class in enumerate( selected_cls): # for each selected cls # Count number of times this class is inside the meta-test # [img1, img2 ,,, = k_shot + k_query ] selected_imgs = np.random.choice( data_pack.shape[1], self.k_shot + self.k_query, False) # meta-training, select the first k_shot imgs for each class as support imgs for offset, img in enumerate(selected_imgs[:self.k_shot]): # i: batch idx # cur_class: cls in n_way support_x[i, shuffle_idx[j] * self.k_shot + offset, ...] = data_pack[cur_class][img] support_y[i, shuffle_idx[j] * self.k_shot + offset] = j # relative indexing # meta-test, treat following k_query imgs as query imgs for offset, img in enumerate(selected_imgs[self.k_shot:]): query_x[i, shuffle_idx_test[j] * self.k_query + offset, ...] = data_pack[cur_class][img] query_y[i, shuffle_idx_test[j] * self.k_query + offset] = j # relative indexing data_cache.append([support_x, support_y, query_x, query_y]) return data_cache def __get_batch(self, mode): """ Gets next batch from the dataset with name. :param dataset_name: The name of the dataset (one of "train", "val", "test") :return: """ # update cache if indexes is larger cached num if self.indexes[mode] >= len(self.datasets_cache[mode]): self.indexes[mode] = 0 self.datasets_cache[mode] = self.load_data_cache( self.datasets[mode]) next_batch = self.datasets_cache[mode][self.indexes[mode]] self.indexes[mode] += 1 return next_batch def get_batch(self, mode): """ Get next batch :return: Next batch """ x_support_set, y_support_set, x_target, y_target = self.__get_batch( mode) k = int(np.random.uniform(low=0, high=4)) # 0 - 3 # Iterate over the sequence. Extract batches. for i in np.arange(x_support_set.shape[0]): # batchsz, setsz, c, h, w x_support_set[i, :, :, :, :] = self.__rotate_batch( x_support_set[i, :, :, :, :], k) # Rotate all the batch of the target images for i in np.arange(x_target.shape[0]): x_target[i, :, :, :, :] = self.__rotate_batch( x_target[i, :, :, :, :], k) return x_support_set, y_support_set, x_target, y_target def __rotate_batch(self, batch_images, k): """ Rotates a whole image batch :param batch_images: A batch of images :param k: integer degree of rotation counter-clockwise :return: The rotated batch of images """ batch_size = len(batch_images) for i in np.arange(batch_size): batch_images[i] = np.rot90(batch_images[i], k) return batch_images
new_var = torch.mean(torch.stack([s[k] for s in new_states]), dim=0) tmp[k] = (1 - meta_step_size) * old_var + meta_step_size * new_var.clone() self.load_state_dict(deepcopy(tmp)) del self._old_state # -- # Run args = parse_args() set_seeds(args.seed) # -- # IO dataset = Omniglot(root='./data/omniglot') train_classes, test_classes = train_test_split(dataset._classes, train_size=args.num_train_classes) train_taskset = OmniglotTaskWrapper(dataset, classes=train_classes, rotation=False) test_taskset = OmniglotTaskWrapper(dataset, classes=test_classes, rotation=False) cuda = torch.device('cuda') model = OmniglotModel(num_classes=args.num_classes).to(cuda) model.init_optimizer( opt=torch.optim.Adam, params=model.parameters(), lr=args.lr, betas=[0.0, 0.999], )
def __init__(self, args): self.task_num = args.task_num self.class_num = args.class_num self.train_sample_size_per_class = args.train_sample_size_per_class self.test_sample_size_per_class = args.test_sample_size_per_class self.sample_size_per_class = self.train_sample_size_per_class + self.test_sample_size_per_class if args.data_source == 'sinusoid': self.generate = self.generate_sinusoid_batch self.amp_range = args.amp_range self.phase_range = args.phase_range self.input_range = args.input_range self.dim_input = 1 self.dim_output = 1 elif args.data_source == 'omniglot': assert self.sample_size_per_class <= 20 self.generate = self.load_omniglot_batch self.dim_output = self.class_num self.data_filename = 'omniglot_' + str( args.img_size[0]) + 'x' + str(args.img_size[0]) + '.npy' #load processed data or download and process the original data if not os.path.isfile( os.path.join(args.data_folder, self.data_filename)): # if root/data.npy does not exist, just download it omniglot = Omniglot(args, download=True) temp = defaultdict( list ) # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label} for (img, label) in omniglot: temp[label].append(img) self.omniglot_data = np.array( list(temp.values()), dtype=np.float) # [[20 imgs],..., 1623 classes in total] del temp # Free memory # save all dataset into npy file. np.save(os.path.join(args.data_folder, self.data_filename), self.omniglot_data) print('\nWrite data into ' + self.data_filename) else: # if data.npy exists, just load it. self.omniglot_data = np.load( os.path.join(args.data_folder, self.data_filename)) print('\nLoad data from ' + self.data_filename) self.datasets = { 'train': self.omniglot_data[:1200], 'test': self.omniglot_data[1200:] } print('Training data size: {}, test data size: {}'.format( self.datasets['train'].shape, self.datasets['test'].shape)) # save pointer of current read batch in total cache self.indexes = {'train': 0, 'test': 0} self.datasets_cache = { 'train': self.preload_omniglot_data_cache( self.datasets['train']), # current epoch data cached 'test': self.preload_omniglot_data_cache(self.datasets['test']) } elif args.data_source == 'miniimagenet': self.generate = self.load_miniimagenet_batch self.datasets = { "train": MiniImagenet(args, mode='train', total_task_num=10000), "test": MiniImagenet(args, mode='test', total_task_num=100) } else: raise NotImplementedError