def test_read_random_samples(): sr = SequenceReader(h5_file, key_file, batch_size=5, max_seq_length=20, min_seq_length=20, seq_split_mode='random_samples', seq_split_overlap=5) #read epoch 1 x1 = [] for i in xrange(sr.num_batches): x1_i = sr.read()[0] assert (len(x1_i) == 5) x1 += x1_i #read epoch 2 x2 = [] for i in xrange(sr.num_batches): x2_i = sr.read()[0] assert (len(x2_i) == 5) x2 += x2_i assert (len(x1) == int(sr.num_total_subseqs / sr.batch_size) * sr.batch_size) assert (len(x1) == len(x2)) for i in xrange(len(x1)): assert (x1[i].shape[0] == sr.max_batch_seq_length) assert (np.any(x1[i] != x2[i]))
def test_read_sequential(): sr = SequenceReader(h5_file, key_file, shuffle_seqs=False, batch_size=5, max_seq_length=20, seq_split_mode='sequential', seq_split_overlap=5) #read epoch 1 x1 = [] for i in xrange(sr.num_batches): x1_i = sr.read()[0] assert (len(x1_i) == 5) x1 += x1_i #read epoch 2 x2 = [] for i in xrange(sr.num_batches): x2_i = sr.read()[0] assert (len(x2_i) == 5) x2 += x2_i assert (len(x1) == sr.num_total_subseqs) assert (len(x1) == len(x2)) for i in xrange(len(x1)): assert (x1[i].shape[0] <= sr.max_batch_seq_length) assert (np.all(x1[i] == x2[i]))
def test_read_random_slice_1seq(): sr = SequenceReader(h5_file, key_file, shuffle_seqs=False, batch_size=5, max_seq_length=20, min_seq_length=20, seq_split_mode='random_slice_1seq') print(sr.num_batches) #read epoch 1 x1 = [] for i in xrange(sr.num_batches): x1_i = sr.read()[0] assert (len(x1_i) == 5) x1 += x1_i #read epoch 2 x2 = [] for i in xrange(sr.num_batches): x2_i = sr.read()[0] assert (len(x2_i) == 5) x2 += x2_i assert (int(len(x1) / 5) == sr.num_batches) assert (len(x1) == sr.num_seqs) assert (len(x1) == sr.num_total_subseqs) assert (len(x1) == len(x2)) for i in xrange(len(x1)): assert (x1[i].shape[0] == sr.max_batch_seq_length) assert (x2[i].shape[0] == sr.max_batch_seq_length) assert (np.all(x1[i] != x2[i]))
def test_read_full_seq(): create_dataset() sr = SequenceReader(h5_file, key_file, shuffle_seqs=False, batch_size=5) seq_length = sr.seq_length #read epoch 1 x1 = [] for i in xrange(sr.num_batches): x1_i = sr.read()[0] assert (len(x1_i) == 5) x1 += x1_i #read epoch 2 x2 = [] for i in xrange(sr.num_batches): x2_i = sr.read()[0] assert (len(x2_i) == 5) x2 += x2_i assert (len(x1) == len(x2)) for i in xrange(len(x1)): assert (x1[i].shape[0] == seq_length[i]) assert (np.all(x1[i] == x2[i]))