def test_OmniglotBatchedSeq(): import cv2 root = './data/' h = 240 w = 240 dataset = Omniglot(root=root, h=h, w=w) #dataset = Omniglot(root=root,background=False,download=True) print(len(dataset)) batch_size = 2 max_nbr_char = 5 idx_alphabet = list() for i in range(batch_size): idx_alphabet.append(i) idx_sample = 0 keys = ('alphabet', 'character', 'sample', 'target', 'nbrCharacter') seqs = list() for s in range(batch_size): seq, nbrCharacter4Task, nbrSample4Task = dataset.generateIterFewShotInputSequence( alphabet_idx=idx_alphabet[s], max_nbr_char=max_nbr_char) dseqs = { 'alphabet': [], 'character': [], 'sample': [], 'target': [], 'nbrCharacter': [] } for el in seq: for k in el.keys(): dseqs[k].append(el[k]) seqs.append(dseqs) seqs = [{k: [seqs[b][k] for b in range(batch_size)]} for k in keys] dseqs = dict() for d in seqs: dseqs.update(d) # key x task_idx x seq_idx for k in dseqs.keys(): dseqs[k] = zip(*dseqs[k]) print(seqs) raise batch = dataset.getBatchedSample(seqs)
def test_OmniglotSeq(): import cv2 root = './data/' h = 240 w = 240 dataset = Omniglot(root=root, h=h, w=w) #dataset = Omniglot(root=root,background=False,download=True) print(len(dataset)) idx_alphabet = 0 idx_sample = 0 seq, nbrCharacter4Task, nbrSample4Task = dataset.generateIterFewShotInputSequence( alphabet_idx=idx_alphabet) sample = dataset.getSample(seq[idx_sample]['alphabet'], seq[idx_sample]['character'], seq[idx_sample]['sample']) image_path = dataset.sampleSample4Character4Alphabet( seq[idx_sample]['alphabet'], seq[idx_sample]['character'], seq[idx_sample]['sample']) changed = False seqChanged = False idx = 0 while True: #sample = dataset[idx] img = (sample['image'].numpy() * 255).transpose((1, 2, 0)) cv2.imshow('test', img) key = cv2.waitKey(1) & 0xFF if key == ord('q'): break elif key == ord('n'): idx += 1 idx_sample = (idx_sample + 1) % len(seq) changed = True elif key == ord('a'): idx_alphabet += 1 idx_character = 0 idx_sample = 0 seqChanged = True if changed: changed = False sample = dataset.getSample(seq[idx_sample]['alphabet'], seq[idx_sample]['character'], seq[idx_sample]['sample']) image_path = dataset.sampleSample4Character4Alphabet( seq[idx_sample]['alphabet'], seq[idx_sample]['character'], seq[idx_sample]['sample']) print(image_path, '/{}'.format(nbrCharacter4Task)) print('Sample {} / {}'.format(idx, nbrSample4Task)) print('Target :{}'.format(seq[idx_sample]['target'])) if seqChanged: seqChanged = False idx = 0 seq, nbrCharacter4Task, nbrSample4Task = dataset.generateFewShotLearningTask( alphabet_idx=idx_alphabet) sample = dataset.getSample(seq[idx_sample]['alphabet'], seq[idx_sample]['character'], seq[idx_sample]['sample']) image_path = dataset.sampleSample4Character4Alphabet( seq[idx_sample]['alphabet'], seq[idx_sample]['character'], seq[idx_sample]['sample']) print(image_path, '/{}'.format(nbrCharacter4Task))