コード例 #1
0
def create_dataset():

    file_path = [str(k) for k in xrange(num_seqs)]
    key=[]
    i = 0
    j = 0
    while i < num_seqs:
        key_i = (j+1)*str(j)
        i += (i+1)
        j += 1
        key += key_i
    key = key[:num_seqs]

    u2c = Utt2Info.create(file_path, key)

    if os.path.exists(h5_file):
        return u2c
    
    u2c.save(key_file, sep=' ')

    h = H5DataWriter(h5_file)
    rng = np.random.RandomState(seed=0)

    for i in xrange(num_seqs):
        x_i = rng.randn(seq_lengths[i], dim)
        h.write(file_path[i], x_i)
    
    return u2c
コード例 #2
0
def test_read_full_seq():

    u2c = create_dataset()
    sr = SBG(h5_file,
             key_file,
             shuffle_seqs=False,
             gen_method='full_seqs',
             batch_size=5)

    x_e = []
    for epoch in xrange(2):
        x0 = x_e
        key_e = []
        c_e = []
        x_e = []
        sw_e = []
        for i in xrange(sr.steps_per_epoch):
            key_i, x_i, sw_i, y_i = sr.read()
            assert len(x_i) == 5
            key_e += key_i
            c_e += [str(i) for i in np.argmax(y_i, axis=-1)]
            x_e.append(x_i)
            sw_e.append(sw_i)
        x_e = np.vstack(tuple(x_e))
        sw_e = np.vstack(tuple(sw_e))
        sl_e = np.sum(sw_e, axis=-1).astype(int)

        if epoch > 0:
            assert_allclose(x0, x_e)
        assert_allclose(seq_lengths, sl_e)
        u2c_e = Utt2Info.create(key_e, c_e)
        assert u2c == u2c_e
コード例 #3
0
def test_balance_class_weight():

    create_dataset()
    sr = SBG(h5_file,
             key_file,
             batch_size=5,
             class_weight='unbalanced',
             shuffle_seqs=False)
    u2c0 = sr.u2c

    sr = SBG(h5_file,
             key_file,
             batch_size=5,
             class_weight='balanced',
             shuffle_seqs=False,
             max_class_imbalance=1)

    class_ids = [0] * 4 + [1] * 4 + [2] * 4 + [3] * 4
    class_ids = [str(i) for i in class_ids]
    key = [u2c0.key[0]] * 4 + list(u2c0.key[1:3]) * 2 + list(
        u2c0.key[3:6]) + [u2c0.key[3]] + list(u2c0.key[6:])
    print(key)
    print(class_ids)
    print(sr.u2c.key)
    print(sr.u2c.info)
    u2c = Utt2Info.create(key, class_ids)
    assert u2c == sr.u2c
コード例 #4
0
def test_read_sequential_unbalanced():

    u2c = create_dataset()
    u2c_list = [u2c] * int(np.min(seq_lengths) / 10)
    for i in xrange(1, num_seqs):
        u2c_list.append(u2c.filter_index(np.arange(i, num_seqs)))
    u2c = Utt2Info.merge(u2c_list)
    sr = SBG(h5_file,
             key_file,
             shuffle_seqs=False,
             reset_rng=True,
             min_seq_length=5,
             max_seq_length=17,
             seq_overlap=1,
             gen_method='sequential',
             seq_weight='unbalanced',
             batch_size=5)

    x_e = []
    for epoch in xrange(2):
        x0 = x_e
        key_e = []
        c_e = []
        x_e = []
        sw_e = []
        for i in xrange(sr.steps_per_epoch):
            key_i, x_i, sw_i, y_i = sr.read()
            assert len(x_i) == 5
            key_e += key_i
            c_e += [str(i) for i in np.argmax(y_i, axis=-1)]
            x_e.append(x_i)
            sw_e.append(sw_i)
        x_e = np.vstack(tuple(x_e))
        sw_e = np.vstack(tuple(sw_e))
        sl_e = np.sum(sw_e, axis=-1).astype(int)

        if epoch > 0:
            assert_allclose(x0, x_e)
        assert np.all(np.logical_and(sl_e >= 5, sl_e <= 17))
        print(u2c.key)
        print(u2c.info)
        print(np.array(key_e))
        print(np.array(c_e))

        u2c_e = Utt2Info.create(key_e, c_e)
        assert u2c == u2c_e
コード例 #5
0
def test_read_random():

    u2c = create_dataset()
    u2c = Utt2Info.merge([u2c] * 2)
    sr = SBG(h5_file,
             key_file,
             shuffle_seqs=False,
             reset_rng=True,
             iters_per_epoch=2,
             min_seq_length=10,
             max_seq_length=20,
             gen_method='random',
             batch_size=5)

    x_e = []
    for epoch in xrange(2):
        x0 = x_e
        key_e = []
        c_e = []
        x_e = []
        sw_e = []
        for i in xrange(sr.steps_per_epoch):
            key_i, x_i, sw_i, y_i = sr.read()
            assert len(x_i) == 5
            key_e += key_i
            c_e += [str(i) for i in np.argmax(y_i, axis=-1)]
            x_e.append(x_i)
            sw_e.append(sw_i)
        x_e = np.vstack(tuple(x_e))
        sw_e = np.vstack(tuple(sw_e))
        sl_e = np.sum(sw_e, axis=-1).astype(int)

        if epoch > 0:
            assert_allclose(x0, x_e)
        assert np.all(np.logical_and(sl_e >= 10, sl_e <= 20))
        u2c_e = Utt2Info.create(key_e, c_e)
        assert u2c == u2c_e