Esempio n. 1
0
def test_ast_build_expr():
    pa = isl.pw_aff("[n] -> { [n + 1] }")
    build = isl.ast_build.from_context(pa.domain())

    op = build.expr_from(pa)
    assert (type(op) == isl.ast_expr_op_add)
    assert (op.n_arg() == 2)
Esempio n. 2
0
def test_ast_build_expr():
	pa = isl.pw_aff("[n] -> { [n + 1] }")
	build = isl.ast_build.from_context(pa.domain())

	op = build.expr_from(pa)
	assert(type(op) == isl.ast_expr_op_add)
	assert(op.n_arg() == 2)
Esempio n. 3
0
def test_return_string():
    context = isl.set("[n] -> { : }")
    build = isl.ast_build.from_context(context)
    pw_aff = isl.pw_aff("[n] -> { [n] }")
    set = isl.set("[n] -> { : n >= 0 }")

    expr = build.expr_from(pw_aff)
    expected_string = "n"
    assert (expected_string == expr.to_C_str())

    expr = build.expr_from(set)
    expected_string = "n >= 0"
    assert (expected_string == expr.to_C_str())
Esempio n. 4
0
def test_return_string():
	context = isl.set("[n] -> { : }")
	build = isl.ast_build.from_context(context)
	pw_aff = isl.pw_aff("[n] -> { [n] }")
	set = isl.set("[n] -> { : n >= 0 }")

	expr = build.expr_from(pw_aff)
	expected_string = "n"
	assert(expected_string == expr.to_C_str())

	expr = build.expr_from(set)
	expected_string = "n >= 0"
	assert(expected_string == expr.to_C_str())
Esempio n. 5
0
def cuda_tile(tree, tile_size, permutation=None):
    assert tree.parallel_tilable()
    box_size, lowers, strides = tree.outermost_band_box()
    n = len(box_size)
    tile_size = tile_size[:n]

    real_tile_size = [tile_size[i] * strides[i] for i in range(n)]
    filled_box_size = [
        -(-box_size[i] // (real_tile_size[i])) * real_tile_size[i]
        for i in range(n)
    ]

    fake_args = ['i%d' % i for i in range(n)]

    thread_fake_constraints = [
        f'({i} mod {stride}) = (({lower}) mod {stride})'
        f' and 0 <= {i} - ({lower}) < {size}' for i, lower, stride, size in
        zip(fake_args, lowers, strides, filled_box_size)
    ]
    thread_fake_named_tuple = f'_thread[{", ".join(fake_args)}]'
    thread_fake_statement = isl.union_set(
        f'{{ {thread_fake_named_tuple} : {" and ".join(thread_fake_constraints)} }}'
    ).coalesce()

    block_fake_constraints = [
        f'({i} mod {stride}) = (({lower}) mod {stride})'
        f' and 0 <= {i} - ({lower}) < {size}'
        f' and ({i} mod {rt_size}) = (({lower}) mod {rt_size})'
        for i, lower, stride, size, rt_size in zip(
            fake_args, lowers, strides, filled_box_size, real_tile_size)
    ]
    block_fake_named_tuple = f'_block[{", ".join(fake_args)}]'
    block_fake_statement = isl.union_set(
        f'{{ {block_fake_named_tuple} : {" and ".join(block_fake_constraints)} }}'
    ).coalesce()

    old_domain = tree.domain()
    tree.add_to_domain(thread_fake_statement)
    tree.add_to_domain(block_fake_statement)

    band = tree.outermost_band()

    for i in range(n):
        s = band.schedule.at(i).union_add(
            isl.pw_aff(
                f'{{ {thread_fake_named_tuple} -> [({fake_args[i]})] }}'))
        band.schedule = band.schedule.set_at(i, s.coalesce())
        s = band.schedule.at(i).union_add(
            isl.pw_aff(
                f'{{ {block_fake_named_tuple} -> [({fake_args[i]})] }}'))
        band.schedule = band.schedule.set_at(i, s.coalesce())

    fake_branch = SequenceNode()
    fake_branch.add_child(FilterNode(filter='{%s}' % thread_fake_named_tuple))
    fake_branch.add_child(FilterNode(filter='{%s}' % block_fake_named_tuple))

    kernel_branch = FilterNode(filter=old_domain)
    if band.child:
        kernel_branch.child = band.child
    fake_branch.add_child(kernel_branch)

    band.child = fake_branch

    if permutation is not None:
        band.permute(*permutation)

    band.tile(*real_tile_size)
    band.insert_before(MarkNode('bind=blockIdx'))
    child = band.child
    child.insert_before(MarkNode('bind=threadIdx'))
    kernel = child.child
    kernel.insert_before(MarkNode('clear(bind)'))