def test_ast_build_expr(): pa = isl.pw_aff("[n] -> { [n + 1] }") build = isl.ast_build.from_context(pa.domain()) op = build.expr_from(pa) assert (type(op) == isl.ast_expr_op_add) assert (op.n_arg() == 2)
def test_ast_build_expr(): pa = isl.pw_aff("[n] -> { [n + 1] }") build = isl.ast_build.from_context(pa.domain()) op = build.expr_from(pa) assert(type(op) == isl.ast_expr_op_add) assert(op.n_arg() == 2)
def test_return_string(): context = isl.set("[n] -> { : }") build = isl.ast_build.from_context(context) pw_aff = isl.pw_aff("[n] -> { [n] }") set = isl.set("[n] -> { : n >= 0 }") expr = build.expr_from(pw_aff) expected_string = "n" assert (expected_string == expr.to_C_str()) expr = build.expr_from(set) expected_string = "n >= 0" assert (expected_string == expr.to_C_str())
def test_return_string(): context = isl.set("[n] -> { : }") build = isl.ast_build.from_context(context) pw_aff = isl.pw_aff("[n] -> { [n] }") set = isl.set("[n] -> { : n >= 0 }") expr = build.expr_from(pw_aff) expected_string = "n" assert(expected_string == expr.to_C_str()) expr = build.expr_from(set) expected_string = "n >= 0" assert(expected_string == expr.to_C_str())
def cuda_tile(tree, tile_size, permutation=None): assert tree.parallel_tilable() box_size, lowers, strides = tree.outermost_band_box() n = len(box_size) tile_size = tile_size[:n] real_tile_size = [tile_size[i] * strides[i] for i in range(n)] filled_box_size = [ -(-box_size[i] // (real_tile_size[i])) * real_tile_size[i] for i in range(n) ] fake_args = ['i%d' % i for i in range(n)] thread_fake_constraints = [ f'({i} mod {stride}) = (({lower}) mod {stride})' f' and 0 <= {i} - ({lower}) < {size}' for i, lower, stride, size in zip(fake_args, lowers, strides, filled_box_size) ] thread_fake_named_tuple = f'_thread[{", ".join(fake_args)}]' thread_fake_statement = isl.union_set( f'{{ {thread_fake_named_tuple} : {" and ".join(thread_fake_constraints)} }}' ).coalesce() block_fake_constraints = [ f'({i} mod {stride}) = (({lower}) mod {stride})' f' and 0 <= {i} - ({lower}) < {size}' f' and ({i} mod {rt_size}) = (({lower}) mod {rt_size})' for i, lower, stride, size, rt_size in zip( fake_args, lowers, strides, filled_box_size, real_tile_size) ] block_fake_named_tuple = f'_block[{", ".join(fake_args)}]' block_fake_statement = isl.union_set( f'{{ {block_fake_named_tuple} : {" and ".join(block_fake_constraints)} }}' ).coalesce() old_domain = tree.domain() tree.add_to_domain(thread_fake_statement) tree.add_to_domain(block_fake_statement) band = tree.outermost_band() for i in range(n): s = band.schedule.at(i).union_add( isl.pw_aff( f'{{ {thread_fake_named_tuple} -> [({fake_args[i]})] }}')) band.schedule = band.schedule.set_at(i, s.coalesce()) s = band.schedule.at(i).union_add( isl.pw_aff( f'{{ {block_fake_named_tuple} -> [({fake_args[i]})] }}')) band.schedule = band.schedule.set_at(i, s.coalesce()) fake_branch = SequenceNode() fake_branch.add_child(FilterNode(filter='{%s}' % thread_fake_named_tuple)) fake_branch.add_child(FilterNode(filter='{%s}' % block_fake_named_tuple)) kernel_branch = FilterNode(filter=old_domain) if band.child: kernel_branch.child = band.child fake_branch.add_child(kernel_branch) band.child = fake_branch if permutation is not None: band.permute(*permutation) band.tile(*real_tile_size) band.insert_before(MarkNode('bind=blockIdx')) child = band.child child.insert_before(MarkNode('bind=threadIdx')) kernel = child.child kernel.insert_before(MarkNode('clear(bind)'))