예제 #1
0
def test_balance_class_weight():

    create_dataset()
    sr = SBG(h5_file,
             key_file,
             batch_size=5,
             class_weight='unbalanced',
             shuffle_seqs=False)
    u2c0 = sr.u2c

    sr = SBG(h5_file,
             key_file,
             batch_size=5,
             class_weight='balanced',
             shuffle_seqs=False,
             max_class_imbalance=1)

    class_ids = [0] * 4 + [1] * 4 + [2] * 4 + [3] * 4
    class_ids = [str(i) for i in class_ids]
    key = [u2c0.key[0]] * 4 + list(u2c0.key[1:3]) * 2 + list(
        u2c0.key[3:6]) + [u2c0.key[3]] + list(u2c0.key[6:])
    print(key)
    print(class_ids)
    print(sr.u2c.key)
    print(sr.u2c.info)
    u2c = Utt2Info.create(key, class_ids)
    assert u2c == sr.u2c
예제 #2
0
def test_read_full_seq():

    u2c = create_dataset()
    sr = SBG(h5_file,
             key_file,
             shuffle_seqs=False,
             gen_method='full_seqs',
             batch_size=5)

    x_e = []
    for epoch in xrange(2):
        x0 = x_e
        key_e = []
        c_e = []
        x_e = []
        sw_e = []
        for i in xrange(sr.steps_per_epoch):
            key_i, x_i, sw_i, y_i = sr.read()
            assert len(x_i) == 5
            key_e += key_i
            c_e += [str(i) for i in np.argmax(y_i, axis=-1)]
            x_e.append(x_i)
            sw_e.append(sw_i)
        x_e = np.vstack(tuple(x_e))
        sw_e = np.vstack(tuple(sw_e))
        sl_e = np.sum(sw_e, axis=-1).astype(int)

        if epoch > 0:
            assert_allclose(x0, x_e)
        assert_allclose(seq_lengths, sl_e)
        u2c_e = Utt2Info.create(key_e, c_e)
        assert u2c == u2c_e
예제 #3
0
def test_reset():

    create_dataset()
    sr = SBG(h5_file,
             key_file,
             batch_size=5,
             gen_method='sequential',
             shuffle_seqs=False,
             min_seq_length=5,
             max_seq_length=17,
             seq_overlap=1)

    u2c = sr.init_u2c
    seq_lengths = sr.seq_lengths
    num_subseqs = sr._init_num_subseqs

    sr.shuffle_seqs = True
    sr.reset()

    assert u2c == sr.init_u2c
    assert_allclose(seq_lengths, sr._init_seq_lengths)
    assert_allclose(num_subseqs, sr._init_num_subseqs)

    idx1 = np.argsort(u2c.key)
    idx2 = np.argsort(sr.u2c.key)

    u2c1 = u2c.filter_index(idx1)
    u2c2 = sr.u2c.filter_index(idx2)

    assert u2c1 == u2c2
    assert_allclose(seq_lengths[idx1], sr.seq_lengths[idx2])
    assert_allclose(num_subseqs[idx1], sr.num_subseqs[idx2])
    assert np.all(sr.cur_subseq == 0)
    assert np.all(sr.cur_frame == 0)
예제 #4
0
def test_compute_iters_auto():

    create_dataset()
    sr = SBG(h5_file,
             key_file,
             batch_size=5,
             gen_method='random',
             shuffle_seqs=False)
    assert sr.iters_per_epoch == 1

    sr = SBG(h5_file,
             key_file,
             batch_size=5,
             gen_method='random',
             shuffle_seqs=False,
             max_seq_length=min_seq_length)
    assert sr.iters_per_epoch == 2
예제 #5
0
def test_seq_lengths():

    create_dataset()
    sr = SBG(h5_file, key_file, shuffle_seqs=False)

    assert np.all(sr.seq_lengths == seq_lengths)
    assert sr.total_length == np.sum(seq_lengths)
    assert sr.min_seq_length == min_seq_length
    assert sr.max_seq_length == max_seq_length
예제 #6
0
def test_class_info():

    create_dataset()
    sr = SBG(h5_file, key_file, batch_size=5, shuffle_seqs=False)
    assert sr.num_classes == 4

    print(sr.u2c.key)
    print(sr.u2c.info)
    class_ids = [0, 1, 1, 2, 2, 2, 3, 3, 3, 3]
    key2class = {p: k for p, k in zip(sr.u2c.key, class_ids)}
    assert sr.key2class == key2class
예제 #7
0
def test_prune_min_length():

    create_dataset()
    sr = SBG(h5_file,
             key_file,
             batch_size=5,
             shuffle_seqs=False,
             prune_min_length=min_seq_length + 5)
    assert sr.num_seqs == num_seqs - 1
    assert np.all(sr.seq_lengths == seq_lengths[1:])
    assert sr.total_length == np.sum(seq_lengths[1:])
    assert sr.min_seq_length == np.min(seq_lengths[1:])
    assert sr.max_seq_length == max_seq_length
예제 #8
0
def test_prepare_sequential_subseqs():

    create_dataset()
    sr = SBG(h5_file,
             key_file,
             batch_size=5,
             gen_method='sequential',
             shuffle_seqs=False,
             min_seq_length=5,
             max_seq_length=17,
             seq_overlap=1)
    print(sr._init_num_subseqs)
    assert_allclose(sr._init_num_subseqs, seq_lengths / 10)
예제 #9
0
def test_read_sequential_unbalanced():

    u2c = create_dataset()
    u2c_list = [u2c] * int(np.min(seq_lengths) / 10)
    for i in xrange(1, num_seqs):
        u2c_list.append(u2c.filter_index(np.arange(i, num_seqs)))
    u2c = Utt2Info.merge(u2c_list)
    sr = SBG(h5_file,
             key_file,
             shuffle_seqs=False,
             reset_rng=True,
             min_seq_length=5,
             max_seq_length=17,
             seq_overlap=1,
             gen_method='sequential',
             seq_weight='unbalanced',
             batch_size=5)

    x_e = []
    for epoch in xrange(2):
        x0 = x_e
        key_e = []
        c_e = []
        x_e = []
        sw_e = []
        for i in xrange(sr.steps_per_epoch):
            key_i, x_i, sw_i, y_i = sr.read()
            assert len(x_i) == 5
            key_e += key_i
            c_e += [str(i) for i in np.argmax(y_i, axis=-1)]
            x_e.append(x_i)
            sw_e.append(sw_i)
        x_e = np.vstack(tuple(x_e))
        sw_e = np.vstack(tuple(sw_e))
        sl_e = np.sum(sw_e, axis=-1).astype(int)

        if epoch > 0:
            assert_allclose(x0, x_e)
        assert np.all(np.logical_and(sl_e >= 5, sl_e <= 17))
        print(u2c.key)
        print(u2c.info)
        print(np.array(key_e))
        print(np.array(c_e))

        u2c_e = Utt2Info.create(key_e, c_e)
        assert u2c == u2c_e
예제 #10
0
def test_read_random():

    u2c = create_dataset()
    u2c = Utt2Info.merge([u2c] * 2)
    sr = SBG(h5_file,
             key_file,
             shuffle_seqs=False,
             reset_rng=True,
             iters_per_epoch=2,
             min_seq_length=10,
             max_seq_length=20,
             gen_method='random',
             batch_size=5)

    x_e = []
    for epoch in xrange(2):
        x0 = x_e
        key_e = []
        c_e = []
        x_e = []
        sw_e = []
        for i in xrange(sr.steps_per_epoch):
            key_i, x_i, sw_i, y_i = sr.read()
            assert len(x_i) == 5
            key_e += key_i
            c_e += [str(i) for i in np.argmax(y_i, axis=-1)]
            x_e.append(x_i)
            sw_e.append(sw_i)
        x_e = np.vstack(tuple(x_e))
        sw_e = np.vstack(tuple(sw_e))
        sl_e = np.sum(sw_e, axis=-1).astype(int)

        if epoch > 0:
            assert_allclose(x0, x_e)
        assert np.all(np.logical_and(sl_e >= 10, sl_e <= 20))
        u2c_e = Utt2Info.create(key_e, c_e)
        assert u2c == u2c_e
예제 #11
0
def test_num_total_subseqs():

    create_dataset()
    sr = SBG(h5_file, key_file, gen_method='full_seqs', batch_size=5)
    sr.num_total_subseqs == num_seqs
예제 #12
0
def test_num_seqs():

    create_dataset()
    sr = SBG(h5_file, key_file)
    assert sr.num_seqs == num_seqs