Exemple #1
0
def test_OmniglotBatchedSeq():
    import cv2

    root = './data/'
    h = 240
    w = 240
    dataset = Omniglot(root=root, h=h, w=w)
    #dataset = Omniglot(root=root,background=False,download=True)
    print(len(dataset))

    batch_size = 2
    max_nbr_char = 5
    idx_alphabet = list()
    for i in range(batch_size):
        idx_alphabet.append(i)

    idx_sample = 0

    keys = ('alphabet', 'character', 'sample', 'target', 'nbrCharacter')
    seqs = list()
    for s in range(batch_size):
        seq, nbrCharacter4Task, nbrSample4Task = dataset.generateIterFewShotInputSequence(
            alphabet_idx=idx_alphabet[s], max_nbr_char=max_nbr_char)
        dseqs = {
            'alphabet': [],
            'character': [],
            'sample': [],
            'target': [],
            'nbrCharacter': []
        }
        for el in seq:
            for k in el.keys():
                dseqs[k].append(el[k])
        seqs.append(dseqs)
    seqs = [{k: [seqs[b][k] for b in range(batch_size)]} for k in keys]
    dseqs = dict()
    for d in seqs:
        dseqs.update(d)
    # key x task_idx x seq_idx
    for k in dseqs.keys():
        dseqs[k] = zip(*dseqs[k])
    print(seqs)
    raise
    batch = dataset.getBatchedSample(seqs)
Exemple #2
0
def test_OmniglotSeq():
    import cv2

    root = './data/'
    h = 240
    w = 240
    dataset = Omniglot(root=root, h=h, w=w)
    #dataset = Omniglot(root=root,background=False,download=True)
    print(len(dataset))

    idx_alphabet = 0
    idx_sample = 0

    seq, nbrCharacter4Task, nbrSample4Task = dataset.generateIterFewShotInputSequence(
        alphabet_idx=idx_alphabet)
    sample = dataset.getSample(seq[idx_sample]['alphabet'],
                               seq[idx_sample]['character'],
                               seq[idx_sample]['sample'])
    image_path = dataset.sampleSample4Character4Alphabet(
        seq[idx_sample]['alphabet'], seq[idx_sample]['character'],
        seq[idx_sample]['sample'])
    changed = False
    seqChanged = False
    idx = 0
    while True:

        #sample = dataset[idx]

        img = (sample['image'].numpy() * 255).transpose((1, 2, 0))
        cv2.imshow('test', img)

        key = cv2.waitKey(1) & 0xFF
        if key == ord('q'):
            break
        elif key == ord('n'):
            idx += 1
            idx_sample = (idx_sample + 1) % len(seq)
            changed = True
        elif key == ord('a'):
            idx_alphabet += 1
            idx_character = 0
            idx_sample = 0
            seqChanged = True

        if changed:
            changed = False
            sample = dataset.getSample(seq[idx_sample]['alphabet'],
                                       seq[idx_sample]['character'],
                                       seq[idx_sample]['sample'])
            image_path = dataset.sampleSample4Character4Alphabet(
                seq[idx_sample]['alphabet'], seq[idx_sample]['character'],
                seq[idx_sample]['sample'])
            print(image_path, '/{}'.format(nbrCharacter4Task))
            print('Sample {} / {}'.format(idx, nbrSample4Task))
            print('Target :{}'.format(seq[idx_sample]['target']))

        if seqChanged:
            seqChanged = False
            idx = 0
            seq, nbrCharacter4Task, nbrSample4Task = dataset.generateFewShotLearningTask(
                alphabet_idx=idx_alphabet)
            sample = dataset.getSample(seq[idx_sample]['alphabet'],
                                       seq[idx_sample]['character'],
                                       seq[idx_sample]['sample'])
            image_path = dataset.sampleSample4Character4Alphabet(
                seq[idx_sample]['alphabet'], seq[idx_sample]['character'],
                seq[idx_sample]['sample'])
            print(image_path, '/{}'.format(nbrCharacter4Task))