コード例 #1
0
ファイル: macros.py プロジェクト: ucb-sejits/ctree
def clSetKernelArg(kernel, arg_index, arg_size, arg_value):
    if isinstance(kernel, str): kernel = SymbolRef(kernel)
    if isinstance(arg_index, int): arg_index = Constant(arg_index)
    if isinstance(arg_size, int): arg_size = Constant(arg_size)
    if isinstance(arg_value, str): arg_value = Ref(SymbolRef(arg_value))
    return FunctionCall(SymbolRef("clSetKernelArg"),
        [kernel, arg_index, arg_size, arg_value])
コード例 #2
0
ファイル: test_ArrayDefs.py プロジェクト: ucb-sejits/ctree
 def test_simple_array_def(self):
     self._check_code(
         ArrayDef(
             SymbolRef('hi', ct.c_int()),
             Constant(2),
             Array(body=[Constant(0), Constant(1)]),
         ), "int hi[2] = {0, 1}")
コード例 #3
0
ファイル: test_templates.py プロジェクト: ucb-sejits/ctree
 def test_child_nested_list(self):
     d = {'stmts': [[Constant(1)], Constant(2)]}
     t = """$stmts"""
     tree = StringTemplate(t, d)
     self._check(tree, """\
     1;
     2;""")
コード例 #4
0
ファイル: test_ArrayDefs.py プロジェクト: ucb-sejits/ctree
 def test_complex(self):
     node = ArrayDef(
         SymbolRef('myArray', ct.c_int()), Constant(2),
         Array(body=[
             Add(SymbolRef('b'), SymbolRef('c')),
             Mul(Sub(Constant(99), SymbolRef('d')), Constant(200))
         ]))
     self._check_code(node, "int myArray[2] = {b + c, (99 - d) * 200}")
コード例 #5
0
ファイル: macros.py プロジェクト: ucb-sejits/ctree
def clEnqueueReadBuffer(queue, buf, blocking, offset, cb, ptr, num_events=0, evt_list_ptr=None, evt=None):
    if     isinstance(buf, str):              buf = SymbolRef(buf)
    if     isinstance(blocking, bool):        blocking = Constant(int(blocking))
    if     isinstance(ptr, str):              ptr = SymbolRef(ptr)
    if not isinstance(offset, ast.AST):       offset = Constant(offset)
    if not isinstance(cb, ast.AST):           cb = Constant(cb)
    if not isinstance(num_events, ast.AST):   num_events = Constant(num_events)
    if not isinstance(evt_list_ptr, ast.AST): event_list_ptr = NULL()
    if not isinstance(evt, ast.AST):          evt = NULL()
    return FunctionCall(SymbolRef('clEnqueueReadBuffer'), [
        queue, buf, blocking, offset, cb, ptr, num_events, event_list_ptr, evt])
コード例 #6
0
 def test_indent_0(self):
     d = {'cond': Constant(1)}
     t = """\
     while($cond)
         printf("hello");
     """
     tree = While(Constant(0), [StringTemplate(t, d)])
     self._check(tree, """\
     while (0) {
         while(1)
             printf("hello");
     }""")
コード例 #7
0
ファイル: macros.py プロジェクト: ucb-sejits/ctree
def clEnqueueCopyBuffer(queue, src_buf, dst_buf, src_offset=0, dst_offset=0, cb=0):
    if isinstance(src_buf, str):      src_buf = SymbolRef(src_buf)
    if isinstance(dst_buf, str):      dst_buf = SymbolRef(dst_buf)
    if isinstance(src_offset, int):   src_offset = Constant(src_offset)
    if isinstance(dst_offset, int):   dst_offset = Constant(dst_offset)
    if isinstance(cb, int):           cb = Constant(cb)

    num_events = Constant(0)
    event_list_ptr = NULL()
    evt = NULL()

    return FunctionCall(SymbolRef('clEnqueueCopyBuffer'), [
        queue, src_buf, dst_buf, src_offset, dst_offset, cb, num_events, event_list_ptr, evt])
コード例 #8
0
 def test_child_single_list(self):
     d = {
         'stmts': [Constant(1)]
     }
     t = """$stmts"""
     tree = StringTemplate(t, d)
     self._check(tree, "1")
コード例 #9
0
 def test_array_ref(self):
     tree = MultiNode([
         SymbolRef("foo", ctypes.POINTER(ctypes.c_double)()),
         Assign(SymbolRef("____temp__x"), ArrayRef(SymbolRef("foo"), Constant(0)))
     ])
     DeclarationFiller().visit(tree)
     self._check_code(tree, "\ndouble* foo;\n"
                            "double ____temp__x = foo[0];\n")
コード例 #10
0
ファイル: macros.py プロジェクト: ucb-sejits/ctree
def clEnqueueNDRangeKernel(queue, kernel, work_dim=1, work_offset=0, global_size=0, local_size=0):
    assert isinstance(queue, SymbolRef)
    assert isinstance(kernel, SymbolRef)
    global_size_sym = SymbolRef('global_size', c_size_t())
    local_size_sym = SymbolRef('local_size', c_size_t())
    call = FunctionCall(SymbolRef("clEnqueueNDRangeKernel"), [
        queue, kernel,
        work_dim, work_offset,
        Ref(global_size_sym.copy()), Ref(local_size_sym.copy()),
        0, NULL(), NULL()
    ])

    return Block([
        Assign(global_size_sym, Constant(global_size)),
        Assign(local_size_sym, Constant(local_size)),
        call
    ])
コード例 #11
0
    def visit_For(self, node):
        """restricted, for now, to range as iterator with long-type args"""
        if isinstance(node, ast.For) and \
           isinstance(node.iter, ast.Call) and \
           isinstance(node.iter.func, ast.Name) and \
           node.iter.func.id == 'range':
            Range = node.iter
            nArgs = len(Range.args)
            if nArgs == 1:
                stop = self.visit(Range.args[0])
                start, step = Constant(0), Constant(1)
            elif nArgs == 2:
                start, stop = map(self.visit, Range.args)
                step = Constant(1)
            elif nArgs == 3:
                start, stop, step = map(self.visit, Range.args)
            else:
                raise Exception("Cannot convert a for...range with %d args." % nArgs)

            # TODO allow any expressions castable to Long type
            assert isinstance(stop.get_type(), c_long), "Can only convert range's with stop values of Long type."
            assert isinstance(start.get_type(), c_long), "Can only convert range's with start values of Long type."
            assert isinstance(step.get_type(), c_long), "Can only convert range's with step values of Long type."

            target = SymbolRef(node.target.id, c_long())
            for_loop = For(
                Assign(target, start),
                Lt(target.copy(), stop),
                AddAssign(target.copy(), step),
                [self.visit(stmt) for stmt in node.body],
            )
            return for_loop
        node.body = list(map(self.visit, node.body))
        return node
コード例 #12
0
 def test_indent_1(self):
     d = {'cond': Constant(1)}
     t = """\
     while($cond)
         printf("hello");
     """
     tree = StringTemplate(t, d)
     self._check(tree, """\
     while(1)
         printf("hello");""")
コード例 #13
0
 def visit_Num(self, node):
     return Constant(node.n)
コード例 #14
0
 def test_float_00(self):
     assert str(Constant(0)) == "0"
コード例 #15
0
 def test_char_02(self):
     assert str(Constant("!")) == "'!'"
コード例 #16
0
 def test_char_01(self):
     assert str(Constant("A")) == "'A'"
コード例 #17
0
 def test_char_00(self):
     assert str(Constant("a")) == "'a'"
コード例 #18
0
 def test_simple_template_one(self):
     tree = StringTemplate("return $one", {'one': Constant(1)})
     self._check(tree, "return 1")
コード例 #19
0
 def test_float_02(self):
     assert str(Constant(1.2)) == "1.2"
コード例 #20
0
 def test_no_args(self):
     node = CppDefine("test_macro", [], Constant(39))
     self._check(node, "#define test_macro() (39)")
コード例 #21
0
def get_local_size(id):
    return FunctionCall(SymbolRef('get_local_size'), [Constant(id)])
コード例 #22
0
 def test_dotgen(self):
     tree = StringTemplate("return $one $two", {
         'one': Constant(1),
         'two': Constant(2),
     })
     dot = tree.to_dot()
コード例 #23
0
 def test_simple_template_two(self):
     tree = StringTemplate("return $one $two", {
         'one': Constant(1),
         'two': Constant(2),
     })
     self._check(tree, "return 1 2")
コード例 #24
0
 def visit_Name(self, node):
     if node.id in self.constants_dict:
         return Constant(self.constants_dict[node.id])
     if node.id in self.names_dict:
         return SymbolRef(self.names_dict[node.id])
     return SymbolRef(node.id)
コード例 #25
0
 def test_int_01(self):
     assert str(Constant(1)) == "1"
コード例 #26
0
 def test_int_02(self):
     assert str(Constant(12)) == "12"
コード例 #27
0
def get_global_id(id):
    return FunctionCall(SymbolRef('get_global_id'), [Constant(id)])
コード例 #28
0
    def visit_For(self, node):
        """restricted, for now, to range as iterator with long-type args"""
        if isinstance(node, ast.For) and \
           isinstance(node.iter, ast.Call) and \
           isinstance(node.iter.func, ast.Name) and \
           node.iter.func.id in ('range', 'xrange'):
            Range = node.iter
            nArgs = len(Range.args)
            if nArgs == 1:
                stop = self.visit(Range.args[0])
                start, step = Constant(0), Constant(1)
            elif nArgs == 2:
                start, stop = map(self.visit, Range.args)
                step = Constant(1)
            elif nArgs == 3:
                start, stop, step = map(self.visit, Range.args)
            else:
                raise Exception("Cannot convert a for...range with %d args." %
                                nArgs)

            #  check no-op conditions.
            if all(isinstance(item, Constant) for item in (start, stop, step)):
                if step.value == 0:
                    raise ValueError("range() step argument must not be zero")
                elif start.value == stop.value or \
                        (start.value < stop.value and step.value < 0) or \
                        (start.value > stop.value and step.value > 0):
                    return None

            if not all(
                    isinstance(item, CtreeNode)
                    for item in (start, stop, step)):
                node.body = list(map(self.visit, node.body))
                return node

            # TODO allow any expressions castable to Long type
            target_types = [c_long]
            for el in (stop, start, step):
                #  typed item to try and guess type off of. Imperfect right now.
                if hasattr(el, 'get_type'):
                    # TODO take the proper class instead of the last; if start,
                    # end are doubles, but step is long, target is double
                    t = el.get_type()
                    assert any(
                        isinstance(t, klass)
                        for klass in [c_byte, c_int, c_long, c_short]
                    ), "Can only convert ranges with integer/long \
                         start/stop/step values"

                    target_types.append(type(t))
            target_type = get_common_ctype(target_types)()

            target = SymbolRef(node.target.id, target_type)
            op = Lt
            if hasattr(start, 'value') and hasattr(stop, 'value') and \
                    start.value > stop.value:
                op = Gt
            for_loop = For(
                Assign(target, start),
                op(target.copy(), stop),
                AddAssign(target.copy(), step),
                [self.visit(stmt) for stmt in node.body],
            )
            return for_loop
        node.body = list(map(self.visit, node.body))
        return node
コード例 #29
0
def get_num_groups(id):
    return FunctionCall(SymbolRef('get_num_groups'), [Constant(id)])
コード例 #30
0
 def transform(self, tree, program_cfg):
     arg_cfg, tune_cfg = program_cfg
     conv_param = self.conv_param
     kernel_size = conv_param.kernel_size
     pad = conv_param.pad
     stride = conv_param.stride
     group = conv_param.group
     out_num, out_c, out_h, out_w = arg_cfg['out']._shape_
     in_ptr_num, in_ptr_c, in_ptr_h, in_ptr_w = arg_cfg['in_ptr']._shape_
     weights_g, weights_c, weights_h, weights_w = arg_cfg['weights']._shape_
     return [
         CFile('conv', [
             FileTemplate(
                 os.path.dirname(os.path.realpath(__file__)) +
                 '/conv_test.tmpl.c', {
                     'kernel_size': Constant(kernel_size),
                     'pad': Constant(pad),
                     'stride': Constant(stride),
                     'group': Constant(group),
                     'out_num': Constant(out_num),
                     'out_c': Constant(out_c),
                     'out_h': Constant(out_h),
                     'out_w': Constant(out_w),
                     'in_num': Constant(in_ptr_num),
                     'in_c': Constant(in_ptr_c),
                     'in_h': Constant(in_ptr_h),
                     'in_w': Constant(in_ptr_w),
                     'weight_g': Constant(weights_g),
                     'weight_c': Constant(weights_c),
                     'weight_h': Constant(weights_h),
                     'weight_w': Constant(weights_w),
                     'bias_term': Constant(1 if conv_param.bias_term else 0)
                 })
         ],
               config_target='omp')
     ]