예제 #1
0
def test_SplitRandom_ratios():
    train, val, test = range(1000) >> SplitRandom(ratio=(0.6, 0.3, 0.1))
    assert len(train) == 600
    assert len(val) == 300
    assert len(test) == 100

    with pytest.raises(ValueError) as ex:
        range(1000) >> SplitRandom(ratio=(0.6, 0.7))
    assert str(ex.value).startswith('Ratios must sum up to one')
예제 #2
0
def test_SplitRandom_constraint():
    same_letter = lambda t: t[0]
    data = zip('aabbccddee', range(10))
    train, val = data >> SplitRandom(
        rand=rnd.Random(0), ratio=0.6, constraint=same_letter) >> Map(sorted)
    assert train == [('a', 0), ('a', 1), ('b', 2), ('b', 3), ('c', 4),
                     ('c', 5)]
    assert val == [('d', 6), ('d', 7), ('e', 8), ('e', 9)]
예제 #3
0
def test_SplitRandom_constraint():
    same_letter = lambda t: t[0]
    data = zip('aabbccddee', range(10))
    train, val = data >> SplitRandom(rand=StableRandom(0), ratio=0.6,
                                     constraint=same_letter) >> Collect()
    print(train)
    print(val)
    assert train == [('a', 1), ('a', 0), ('d', 7), ('b', 2), ('d', 6), ('b', 3)]
    assert val == [('c', 5), ('e', 8), ('e', 9), ('c', 4)]
예제 #4
0
def test_SplitRandom_constraint():
    same_letter = lambda t: t[0]
    data = zip('aabbccddee', range(10))
    train, val = data >> SplitRandom(
        rand=None, ratio=0.6, constraint=same_letter) >> Collect()
    train.sort()
    val.sort()
    assert train == [('a', 0), ('a', 1), ('b', 2), ('b', 3), ('d', 6),
                     ('d', 7)]
    assert val == [('c', 4), ('c', 5), ('e', 8), ('e', 9)]
예제 #5
0
def train():
    from keras.metrics import categorical_accuracy

    rerange = TransformImage(0).by('rerange', 0, 255, 0, 1, 'float32')
    build_batch = (BuildBatch(BATCH_SIZE).by(0, 'image', 'float32').by(
        1, 'one_hot', 'uint8', NUM_CLASSES))
    p = 0.1
    augment = (AugmentImage(0).by('identical', 1.0).by(
        'elastic', p, [5, 5], [100, 100],
        [0, 100]).by('brightness', p,
                     [0.7, 1.3]).by('color', p, [0.7, 1.3]).by(
                         'shear', p,
                         [0, 0.1]).by('fliplr', p).by('rotate', p, [-10, 10]))
    plot_eval = PlotLines((0, 1), layout=(2, 1))

    print('creating network...')
    network = create_network()

    print('loading data...')
    train_samples, test_samples = load_samples()
    train_samples, val_samples = train_samples >> SplitRandom(0.8)

    print('training...', len(train_samples), len(val_samples))
    for epoch in range(NUM_EPOCHS):
        print('EPOCH:', epoch)

        t_loss, t_acc = (train_samples >> PrintProgress(train_samples) >>
                         Pick(PICK) >> augment >> rerange >> Shuffle(100) >>
                         build_batch >> network.train() >> Unzip())
        t_loss, t_acc = t_loss >> Mean(), t_acc >> Mean()
        print("train loss : {:.6f}".format(t_loss))
        print("train acc  : {:.1f}".format(100 * t_acc))

        v_loss, v_acc = (val_samples >> rerange >> build_batch >>
                         network.validate() >> Unzip())
        v_loss, v_acc = v_acc >> Mean(), v_acc >> Mean()
        print('val loss   : {:.6f}'.format(v_loss))
        print('val acc    : {:.1f}'.format(100 * v_acc))

        network.save_best(v_acc, isloss=False)
        plot_eval((t_acc, v_acc))

    print('testing...', len(test_samples))
    e_acc = (test_samples >> rerange >> build_batch >> network.evaluate(
        [categorical_accuracy]))
    print('test acc   : {:.1f}'.format(100 * e_acc))
예제 #6
0
def test_SplitRandom_seed():
    split1 = range(10) >> SplitRandom(rand=StableRandom(0))
    split2 = range(10) >> SplitRandom(rand=StableRandom(0))
    split3 = range(10) >> SplitRandom(rand=StableRandom(1))
    assert split1 == split2
    assert split1 != split3
예제 #7
0
def test_SplitRandom_stable_default():
    split1 = range(10) >> SplitRandom()
    split2 = range(10) >> SplitRandom()
    assert split1 == split2
예제 #8
0
def test_SplitRandom_split():
    train, val = range(1000) >> SplitRandom(ratio=0.7)
    assert len(train) == 700
    assert len(val) == 300
    assert not set(train).intersection(val)
예제 #9
0
def load_samples() -> (Samples, Samples, Samples):
    read = ReadLabelDirs(CFG.datadir, '*.json')
    label2int = MapCol(1, int)
    return read >> label2int >> SplitRandom(ratio=CFG.ratios)