Пример #1
0
def test_batches_context_window():
    context_window = 2
    ctx_lr = context_window - 1
    ctx_left = ctx_lr // 2
    ctx_right = ctx_lr - ctx_left

    dataset = DummyDataset(input_dim=2,
                           output_dim=3,
                           num_seqs=1,
                           seq_len=11,
                           context_window=context_window)
    dataset.init_seq_order(1)
    dataset.chunk_size = 5
    dataset.chunk_step = 5
    batch_gen = dataset.generate_batches(recurrent_net=True,
                                         max_seqs=1,
                                         batch_size=20)
    all_batches = []  # type: list[Batch]
    while batch_gen.has_more():
        batch, = batch_gen.peek_next_n(1)
        assert_is_instance(batch, Batch)
        print("batch:", batch)
        print("batch seqs:", batch.seqs)
        all_batches.append(batch)
        batch_gen.advance(1)

    # Each batch will have 1 batch-slice (max_seqs) and up to 10 frames (chunk_size).
    # For each seq, we get 3 chunks (chunk_step 5 for 11 frames).
    # Thus, 3 batches.
    assert_equal(len(all_batches), 3)
    b0, b1, b2 = all_batches
    assert isinstance(b0, Batch)
    assert isinstance(b1, Batch)
    assert isinstance(b2, Batch)

    assert_equal(b0.start_seq, 0)
    assert_equal(b0.end_seq, 1)  # exclusive
    assert_equal(len(b0.seqs), 1)  # 1 BatchSeqCopyPart
    assert_equal(b0.seqs[0].seq_idx, 0)
    assert_equal(b0.seqs[0].seq_start_frame["classes"], 0)
    assert_equal(b0.seqs[0].seq_end_frame["classes"], 5)
    assert_equal(b0.seqs[0].frame_length["classes"], 5)
    assert_equal(b0.seqs[0].seq_start_frame["data"], 0 - ctx_left)
    assert_equal(b0.seqs[0].seq_end_frame["data"], 5 + ctx_right)
    assert_equal(b0.seqs[0].frame_length["data"], 5 + ctx_lr)
    assert_equal(b0.seqs[0].batch_slice, 0)
    assert_equal(b0.seqs[0].batch_frame_offset, 0)

    assert_equal(b1.start_seq, 0)
    assert_equal(b1.end_seq, 1)  # exclusive
    assert_equal(len(b1.seqs), 1)  # 1 BatchSeqCopyPart
    assert_equal(b1.seqs[0].seq_idx, 0)
    assert_equal(b1.seqs[0].seq_start_frame["classes"], 5)
    assert_equal(b1.seqs[0].seq_end_frame["classes"], 10)
    assert_equal(b1.seqs[0].frame_length["classes"], 5)
    assert_equal(b1.seqs[0].seq_start_frame["data"], 5 - ctx_left)
    assert_equal(b1.seqs[0].seq_end_frame["data"], 10 + ctx_right)
    assert_equal(b1.seqs[0].frame_length["data"], 5 + ctx_lr)
    assert_equal(b1.seqs[0].batch_slice, 0)
    assert_equal(b1.seqs[0].batch_frame_offset, 0)

    assert_equal(b2.start_seq, 0)
    assert_equal(b2.end_seq, 1)  # exclusive
    assert_equal(len(b2.seqs), 1)  # 1 BatchSeqCopyPart
    assert_equal(b2.seqs[0].seq_idx, 0)
    assert_equal(b2.seqs[0].seq_start_frame["classes"], 10)
    assert_equal(b2.seqs[0].seq_end_frame["classes"], 11)
    assert_equal(b2.seqs[0].frame_length["classes"], 1)
    assert_equal(b2.seqs[0].seq_start_frame["data"], 10 - ctx_left)
    assert_equal(b2.seqs[0].seq_end_frame["data"], 11 + ctx_right)
    assert_equal(b2.seqs[0].frame_length["data"], 1 + ctx_lr)
    assert_equal(b2.seqs[0].batch_slice, 0)
    assert_equal(b2.seqs[0].batch_frame_offset, 0)
Пример #2
0
def test_batches_recurrent_1():
    dataset = DummyDataset(input_dim=2, output_dim=3, num_seqs=2, seq_len=11)
    dataset.init_seq_order(1)
    dataset.chunk_size = 10
    dataset.chunk_step = 5
    batch_gen = dataset.generate_batches(recurrent_net=True,
                                         max_seqs=1,
                                         batch_size=20)
    all_batches = []
    " :type: list[Batch] "
    while batch_gen.has_more():
        batch, = batch_gen.peek_next_n(1)
        assert_is_instance(batch, Batch)
        print("batch:", batch)
        print("batch seqs:", batch.seqs)
        all_batches.append(batch)
        batch_gen.advance(1)

    # Each batch will have 1 batch-slice (max_seqs) and up to 10 frames (chunk_size).
    # For each seq, we get 3 chunks (chunk_step 5 for 11 frames).
    # Thus, 6 batches.
    assert_equal(len(all_batches), 6)

    assert_equal(all_batches[0].start_seq, 0)
    assert_equal(all_batches[0].end_seq, 1)  # exclusive
    assert_equal(len(all_batches[0].seqs), 1)  # 1 BatchSeqCopyPart
    assert_equal(all_batches[0].seqs[0].seq_idx, 0)
    assert_equal(all_batches[0].seqs[0].seq_start_frame, 0)
    assert_equal(all_batches[0].seqs[0].seq_end_frame, 10)
    assert_equal(all_batches[0].seqs[0].frame_length, 10)
    assert_equal(all_batches[0].seqs[0].batch_slice, 0)
    assert_equal(all_batches[0].seqs[0].batch_frame_offset, 0)

    assert_equal(all_batches[1].start_seq, 0)
    assert_equal(all_batches[1].end_seq, 1)  # exclusive
    assert_equal(len(all_batches[1].seqs), 1)  # 1 BatchSeqCopyPart
    assert_equal(all_batches[1].seqs[0].seq_idx, 0)
    assert_equal(all_batches[1].seqs[0].seq_start_frame, 5)
    assert_equal(all_batches[1].seqs[0].seq_end_frame, 11)
    assert_equal(all_batches[1].seqs[0].frame_length, 6)
    assert_equal(all_batches[1].seqs[0].batch_slice, 0)
    assert_equal(all_batches[1].seqs[0].batch_frame_offset, 0)

    assert_equal(all_batches[2].start_seq, 0)
    assert_equal(all_batches[2].end_seq, 1)  # exclusive
    assert_equal(len(all_batches[2].seqs), 1)  # 1 BatchSeqCopyPart
    assert_equal(all_batches[2].seqs[0].seq_idx, 0)
    assert_equal(all_batches[2].seqs[0].seq_start_frame, 10)
    assert_equal(all_batches[2].seqs[0].seq_end_frame, 11)
    assert_equal(all_batches[2].seqs[0].frame_length, 1)
    assert_equal(all_batches[2].seqs[0].batch_slice, 0)
    assert_equal(all_batches[2].seqs[0].batch_frame_offset, 0)

    assert_equal(all_batches[3].start_seq, 1)
    assert_equal(all_batches[3].end_seq, 2)  # exclusive
    assert_equal(len(all_batches[3].seqs), 1)  # 1 BatchSeqCopyPart
    assert_equal(all_batches[3].seqs[0].seq_idx, 1)
    assert_equal(all_batches[3].seqs[0].seq_start_frame, 0)
    assert_equal(all_batches[3].seqs[0].seq_end_frame, 10)
    assert_equal(all_batches[3].seqs[0].frame_length, 10)
    assert_equal(all_batches[3].seqs[0].batch_slice, 0)
    assert_equal(all_batches[3].seqs[0].batch_frame_offset, 0)