def test_reverse_compute_inline_fail_multi_producer(): sch = tir.Schedule(elementwise_multi_producer_consumer, debug_mask="all") block_d = sch.get_block("D") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_d)
def test_tir_schedule_error_none(): sch = tir.Schedule(matmul, debug_mode=True, error_render_level="none") with pytest.raises(tir.ScheduleError) as excinfo: sch.get_block("wrong_name") (msg, ) = excinfo.value.args assert "(not rendered)" in msg
def test_reduction_rfactor_loop_multiple_children(): s = tir.Schedule(matmul_loop_multiple_children, debug_mask="all") k, _, _ = s.get_loops(s.get_block("C")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0)
def test_bind_cross_thread_reduction(): s = tir.Schedule(rowsum, debug_mask="all") _, k = s.get_loops(s.get_block("B")) s.bind(k, "threadIdx.x") tvm.ir.assert_structural_equal(s.mod["main"], rowsum_cross_thread_reduction) verify_trace_roundtrip(s, mod=rowsum)
def test_tir_schedule_error_detail(): sch = tir.Schedule(matmul, debug_mask="all", error_render_level="detail") with pytest.raises(tir.ScheduleError) as excinfo: sch.get_block("wrong_name") (msg, ) = excinfo.value.args assert "Cannot find a block with the name: wrong_name" in msg
def test_vectorize_predicate(): s = tir.Schedule(element_wise_split_predicate, debug_mask="all") i, _, _ = s.get_loops(s.get_block("B")) s.vectorize(i) tvm.ir.assert_structural_equal(s.mod["main"], element_wise_split_predicate_vectorized) verify_trace_roundtrip(s, mod=element_wise_split_predicate)
def test_unroll(): s = tir.Schedule(rowsum, debug_mask="all") i, _ = s.get_loops(s.get_block("B")) s.unroll(i) tvm.ir.assert_structural_equal(s.mod["main"], rowsum_unrolled) verify_trace_roundtrip(s, mod=rowsum)
def test_compute_inline_with_opaque_access(): """Test not rewrite opaque reads/writes after irrelavant compute inline""" sch = tir.Schedule(access_opaque_ptr_then_elemwise, debug_mask="all") BB = sch.get_block("BB") sch.compute_inline(BB) tvm.ir.assert_structural_equal(access_opaque_ptr_then_elemwise_inline, sch.mod["main"])
def _create_schedule(mod: IRModule, sch_fn: Callable[[Schedule], None]) -> Schedule: sch = tir.Schedule(mod=mod, debug_mask="all") sch_fn(sch) return sch
def test_buffer_matched(): sch = tir.Schedule(buffer_matched, debug_mask="all") block_b = sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b)
def test_compute_inline_multi_loads(): sch = tir.Schedule(elementwise_multi_loads, debug_mask="all") block_b = sch.get_block("B") sch.compute_inline(block_b) tvm.ir.assert_structural_equal(elementwise_multi_loads_inlined, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise_multi_loads)
def test_opaque_access_store(): sch = tir.Schedule(opaque_access_store, debug_mask="all") block_b = sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.compute_inline(block_b)
def test_reverse_compute_fail_multi_reverse_loads(): sch = tir.Schedule(elementwise_multi_loads, debug_mask="all") block_c = sch.get_block("C") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_c)
def test_reverse_compute_inline_fail_multi_reader(): sch = tir.Schedule(fail_multi_reader_writer, debug_mask="all") block_c = sch.get_block("C") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_c)
def test_parallel_not_compact_data_flow(): s = tir.Schedule(rowsum_not_compact_data_flow, debug_mask="all") i, _ = s.get_loops(s.get_block("B")) with pytest.raises(tvm.tir.ScheduleError): s.parallel(i)
def test_block_inside_init(): s = tir.Schedule(block_inside_init, debug_mask="all") (i,) = s.get_loops(s.get_block("outer")) s.bind(i, "threadIdx.x") tvm.ir.assert_structural_equal(s.mod["main"], thread_bound_block_inside_init) verify_trace_roundtrip(s, mod=block_inside_init)
def test_vectorize(): s = tir.Schedule(element_wise_compute_at_split, debug_mask="all") _, _, j1i = s.get_loops(s.get_block("C")) s.vectorize(j1i) tvm.ir.assert_structural_equal(s.mod["main"], element_wise_compute_at_split_vectorized) verify_trace_roundtrip(s, mod=element_wise_compute_at_split)
def test_vectorize_after_decompose(): s = tir.Schedule(decomposed_gemm, debug_mask="all") jj = s.get_loops(s.get_block("C"))[-1] s.vectorize(jj) tvm.ir.assert_structural_equal(s.mod["main"], decomposed_gemm_after_vectorize) verify_trace_roundtrip(s, mod=decomposed_gemm)
def test_vectorize_opaque_block(): s = tir.Schedule(opaque_block, debug_mask="all") (i,) = s.get_loops(s.get_block("opaque")) with pytest.raises(tvm.tir.ScheduleError): s.vectorize(i)
def test_storage_align_invalid_buffer_index(): func = element_wise s = tir.Schedule(func, debug_mask='all') B = s.get_block("B") with pytest.raises(tir.ScheduleError): s.storage_align(B, 2, axis=0, factor=128, offset=127)
def test_bind1(): s = tir.Schedule(element_wise, debug_mask="all") i, _ = s.get_loops(s.get_block("B")) s.bind(i, "threadIdx.x") tvm.ir.assert_structural_equal(s.mod["main"], element_wise_i_bound) verify_trace_roundtrip(s, mod=element_wise)
def test_storage_align_invalid_annotation(): func = element_wise_invalid_annotation s = tir.Schedule(func, debug_mask='all') B = s.get_block("B") with pytest.raises(tir.ScheduleError): s.storage_align(B, 0, axis=2, factor=128, offset=127)
def test_bind_not_cross_thread_reduction(): s = tir.Schedule(rowsum, debug_mask="all") _, k = s.get_loops(s.get_block("B")) with pytest.raises(tvm.tir.ScheduleError): s.bind(k, "blockIdx.x")
def test_parallel(): s = tir.Schedule(element_wise, debug_mask="all") i, _ = s.get_loops(s.get_block("B")) s.parallel(i) tvm.ir.assert_structural_equal(s.mod["main"], element_wise_parallelized) verify_trace_roundtrip(s, mod=element_wise)
def test_tir_schedule_error_fast(): sch = tir.Schedule(matmul, debug_mode=True, error_render_level="fast") with pytest.raises(tir.ScheduleError) as excinfo: sch.get_block("wrong_name") (msg, ) = excinfo.value.args assert "Cannot find a block with the specified name" in msg
def test_parallel_reduction_block_iter(): s = tir.Schedule(matmul, debug_mask="all") _, _, k = s.get_loops(s.get_block("C")) with pytest.raises(tvm.tir.ScheduleError): s.parallel(k)
def test_postproc_verify_gpu_3(): mod = Conv2dCuda3 ctx = _create_context(mod, target=_target()) sch = tir.Schedule(mod, debug_mask="all") assert not ctx.postprocs[0].apply(sch)
def test_parallel_not_quasi_affine(): s = tir.Schedule(rowsum_not_quasi_affine, debug_mask="all") i, _ = s.get_loops(s.get_block("B")) with pytest.raises(tvm.tir.ScheduleError): s.parallel(i)
def test_reduction_rfactor_not_stage_pipeline(): s = tir.Schedule(matmul_not_stage_pipeline, debug_mask="all") _, _, k = s.get_loops(s.get_block("C")) with pytest.raises(tvm.tir.ScheduleError): s.rfactor(k, 0)
def test_reverse_compute_inline_fail_as_dce(): sch = tir.Schedule(elementwise_standalone, debug_mask="all") block_b = sch.get_block("B") with pytest.raises(tvm.tir.ScheduleError): sch.reverse_compute_inline(block_b)