def create_dataset(): file_path = [str(k) for k in xrange(num_seqs)] key=[] i = 0 j = 0 while i < num_seqs: key_i = (j+1)*str(j) i += (i+1) j += 1 key += key_i key = key[:num_seqs] u2c = Utt2Info.create(file_path, key) if os.path.exists(h5_file): return u2c u2c.save(key_file, sep=' ') h = H5DataWriter(h5_file) rng = np.random.RandomState(seed=0) for i in xrange(num_seqs): x_i = rng.randn(seq_lengths[i], dim) h.write(file_path[i], x_i) return u2c
def test_read_full_seq(): u2c = create_dataset() sr = SBG(h5_file, key_file, shuffle_seqs=False, gen_method='full_seqs', batch_size=5) x_e = [] for epoch in xrange(2): x0 = x_e key_e = [] c_e = [] x_e = [] sw_e = [] for i in xrange(sr.steps_per_epoch): key_i, x_i, sw_i, y_i = sr.read() assert len(x_i) == 5 key_e += key_i c_e += [str(i) for i in np.argmax(y_i, axis=-1)] x_e.append(x_i) sw_e.append(sw_i) x_e = np.vstack(tuple(x_e)) sw_e = np.vstack(tuple(sw_e)) sl_e = np.sum(sw_e, axis=-1).astype(int) if epoch > 0: assert_allclose(x0, x_e) assert_allclose(seq_lengths, sl_e) u2c_e = Utt2Info.create(key_e, c_e) assert u2c == u2c_e
def test_balance_class_weight(): create_dataset() sr = SBG(h5_file, key_file, batch_size=5, class_weight='unbalanced', shuffle_seqs=False) u2c0 = sr.u2c sr = SBG(h5_file, key_file, batch_size=5, class_weight='balanced', shuffle_seqs=False, max_class_imbalance=1) class_ids = [0] * 4 + [1] * 4 + [2] * 4 + [3] * 4 class_ids = [str(i) for i in class_ids] key = [u2c0.key[0]] * 4 + list(u2c0.key[1:3]) * 2 + list( u2c0.key[3:6]) + [u2c0.key[3]] + list(u2c0.key[6:]) print(key) print(class_ids) print(sr.u2c.key) print(sr.u2c.info) u2c = Utt2Info.create(key, class_ids) assert u2c == sr.u2c
def test_read_sequential_unbalanced(): u2c = create_dataset() u2c_list = [u2c] * int(np.min(seq_lengths) / 10) for i in xrange(1, num_seqs): u2c_list.append(u2c.filter_index(np.arange(i, num_seqs))) u2c = Utt2Info.merge(u2c_list) sr = SBG(h5_file, key_file, shuffle_seqs=False, reset_rng=True, min_seq_length=5, max_seq_length=17, seq_overlap=1, gen_method='sequential', seq_weight='unbalanced', batch_size=5) x_e = [] for epoch in xrange(2): x0 = x_e key_e = [] c_e = [] x_e = [] sw_e = [] for i in xrange(sr.steps_per_epoch): key_i, x_i, sw_i, y_i = sr.read() assert len(x_i) == 5 key_e += key_i c_e += [str(i) for i in np.argmax(y_i, axis=-1)] x_e.append(x_i) sw_e.append(sw_i) x_e = np.vstack(tuple(x_e)) sw_e = np.vstack(tuple(sw_e)) sl_e = np.sum(sw_e, axis=-1).astype(int) if epoch > 0: assert_allclose(x0, x_e) assert np.all(np.logical_and(sl_e >= 5, sl_e <= 17)) print(u2c.key) print(u2c.info) print(np.array(key_e)) print(np.array(c_e)) u2c_e = Utt2Info.create(key_e, c_e) assert u2c == u2c_e
def test_read_random(): u2c = create_dataset() u2c = Utt2Info.merge([u2c] * 2) sr = SBG(h5_file, key_file, shuffle_seqs=False, reset_rng=True, iters_per_epoch=2, min_seq_length=10, max_seq_length=20, gen_method='random', batch_size=5) x_e = [] for epoch in xrange(2): x0 = x_e key_e = [] c_e = [] x_e = [] sw_e = [] for i in xrange(sr.steps_per_epoch): key_i, x_i, sw_i, y_i = sr.read() assert len(x_i) == 5 key_e += key_i c_e += [str(i) for i in np.argmax(y_i, axis=-1)] x_e.append(x_i) sw_e.append(sw_i) x_e = np.vstack(tuple(x_e)) sw_e = np.vstack(tuple(sw_e)) sl_e = np.sum(sw_e, axis=-1).astype(int) if epoch > 0: assert_allclose(x0, x_e) assert np.all(np.logical_and(sl_e >= 10, sl_e <= 20)) u2c_e = Utt2Info.create(key_e, c_e) assert u2c == u2c_e