def test_batches_context_window(): context_window = 2 ctx_lr = context_window - 1 ctx_left = ctx_lr // 2 ctx_right = ctx_lr - ctx_left dataset = DummyDataset(input_dim=2, output_dim=3, num_seqs=1, seq_len=11, context_window=context_window) dataset.init_seq_order(1) dataset.chunk_size = 5 dataset.chunk_step = 5 batch_gen = dataset.generate_batches(recurrent_net=True, max_seqs=1, batch_size=20) all_batches = [] # type: list[Batch] while batch_gen.has_more(): batch, = batch_gen.peek_next_n(1) assert_is_instance(batch, Batch) print("batch:", batch) print("batch seqs:", batch.seqs) all_batches.append(batch) batch_gen.advance(1) # Each batch will have 1 batch-slice (max_seqs) and up to 10 frames (chunk_size). # For each seq, we get 3 chunks (chunk_step 5 for 11 frames). # Thus, 3 batches. assert_equal(len(all_batches), 3) b0, b1, b2 = all_batches assert isinstance(b0, Batch) assert isinstance(b1, Batch) assert isinstance(b2, Batch) assert_equal(b0.start_seq, 0) assert_equal(b0.end_seq, 1) # exclusive assert_equal(len(b0.seqs), 1) # 1 BatchSeqCopyPart assert_equal(b0.seqs[0].seq_idx, 0) assert_equal(b0.seqs[0].seq_start_frame["classes"], 0) assert_equal(b0.seqs[0].seq_end_frame["classes"], 5) assert_equal(b0.seqs[0].frame_length["classes"], 5) assert_equal(b0.seqs[0].seq_start_frame["data"], 0 - ctx_left) assert_equal(b0.seqs[0].seq_end_frame["data"], 5 + ctx_right) assert_equal(b0.seqs[0].frame_length["data"], 5 + ctx_lr) assert_equal(b0.seqs[0].batch_slice, 0) assert_equal(b0.seqs[0].batch_frame_offset, 0) assert_equal(b1.start_seq, 0) assert_equal(b1.end_seq, 1) # exclusive assert_equal(len(b1.seqs), 1) # 1 BatchSeqCopyPart assert_equal(b1.seqs[0].seq_idx, 0) assert_equal(b1.seqs[0].seq_start_frame["classes"], 5) assert_equal(b1.seqs[0].seq_end_frame["classes"], 10) assert_equal(b1.seqs[0].frame_length["classes"], 5) assert_equal(b1.seqs[0].seq_start_frame["data"], 5 - ctx_left) assert_equal(b1.seqs[0].seq_end_frame["data"], 10 + ctx_right) assert_equal(b1.seqs[0].frame_length["data"], 5 + ctx_lr) assert_equal(b1.seqs[0].batch_slice, 0) assert_equal(b1.seqs[0].batch_frame_offset, 0) assert_equal(b2.start_seq, 0) assert_equal(b2.end_seq, 1) # exclusive assert_equal(len(b2.seqs), 1) # 1 BatchSeqCopyPart assert_equal(b2.seqs[0].seq_idx, 0) assert_equal(b2.seqs[0].seq_start_frame["classes"], 10) assert_equal(b2.seqs[0].seq_end_frame["classes"], 11) assert_equal(b2.seqs[0].frame_length["classes"], 1) assert_equal(b2.seqs[0].seq_start_frame["data"], 10 - ctx_left) assert_equal(b2.seqs[0].seq_end_frame["data"], 11 + ctx_right) assert_equal(b2.seqs[0].frame_length["data"], 1 + ctx_lr) assert_equal(b2.seqs[0].batch_slice, 0) assert_equal(b2.seqs[0].batch_frame_offset, 0)
def test_batches_recurrent_1(): dataset = DummyDataset(input_dim=2, output_dim=3, num_seqs=2, seq_len=11) dataset.init_seq_order(1) dataset.chunk_size = 10 dataset.chunk_step = 5 batch_gen = dataset.generate_batches(recurrent_net=True, max_seqs=1, batch_size=20) all_batches = [] " :type: list[Batch] " while batch_gen.has_more(): batch, = batch_gen.peek_next_n(1) assert_is_instance(batch, Batch) print("batch:", batch) print("batch seqs:", batch.seqs) all_batches.append(batch) batch_gen.advance(1) # Each batch will have 1 batch-slice (max_seqs) and up to 10 frames (chunk_size). # For each seq, we get 3 chunks (chunk_step 5 for 11 frames). # Thus, 6 batches. assert_equal(len(all_batches), 6) assert_equal(all_batches[0].start_seq, 0) assert_equal(all_batches[0].end_seq, 1) # exclusive assert_equal(len(all_batches[0].seqs), 1) # 1 BatchSeqCopyPart assert_equal(all_batches[0].seqs[0].seq_idx, 0) assert_equal(all_batches[0].seqs[0].seq_start_frame, 0) assert_equal(all_batches[0].seqs[0].seq_end_frame, 10) assert_equal(all_batches[0].seqs[0].frame_length, 10) assert_equal(all_batches[0].seqs[0].batch_slice, 0) assert_equal(all_batches[0].seqs[0].batch_frame_offset, 0) assert_equal(all_batches[1].start_seq, 0) assert_equal(all_batches[1].end_seq, 1) # exclusive assert_equal(len(all_batches[1].seqs), 1) # 1 BatchSeqCopyPart assert_equal(all_batches[1].seqs[0].seq_idx, 0) assert_equal(all_batches[1].seqs[0].seq_start_frame, 5) assert_equal(all_batches[1].seqs[0].seq_end_frame, 11) assert_equal(all_batches[1].seqs[0].frame_length, 6) assert_equal(all_batches[1].seqs[0].batch_slice, 0) assert_equal(all_batches[1].seqs[0].batch_frame_offset, 0) assert_equal(all_batches[2].start_seq, 0) assert_equal(all_batches[2].end_seq, 1) # exclusive assert_equal(len(all_batches[2].seqs), 1) # 1 BatchSeqCopyPart assert_equal(all_batches[2].seqs[0].seq_idx, 0) assert_equal(all_batches[2].seqs[0].seq_start_frame, 10) assert_equal(all_batches[2].seqs[0].seq_end_frame, 11) assert_equal(all_batches[2].seqs[0].frame_length, 1) assert_equal(all_batches[2].seqs[0].batch_slice, 0) assert_equal(all_batches[2].seqs[0].batch_frame_offset, 0) assert_equal(all_batches[3].start_seq, 1) assert_equal(all_batches[3].end_seq, 2) # exclusive assert_equal(len(all_batches[3].seqs), 1) # 1 BatchSeqCopyPart assert_equal(all_batches[3].seqs[0].seq_idx, 1) assert_equal(all_batches[3].seqs[0].seq_start_frame, 0) assert_equal(all_batches[3].seqs[0].seq_end_frame, 10) assert_equal(all_batches[3].seqs[0].frame_length, 10) assert_equal(all_batches[3].seqs[0].batch_slice, 0) assert_equal(all_batches[3].seqs[0].batch_frame_offset, 0)