Ejemplo n.º 1
0
 def simple_array_test(self):
     py_ast = ast.parse("x = 0.5 + b * c\nreturn x")
     dfdag = nes.ast_to_dfdag(py_ast, {'b':(2,'nodes','modes'),'c':(2,'nodes','modes')})
     ct_builder = nes.dfdag_to_ctree(dfdag)
     ret_sym = ct_builder.return_symbol
     ret_sym.type = np.ctypeslib.ndpointer(dtype=np.float64)()
     c_ast = CFile(body=[
             CppInclude("stdlib.h"),
             FunctionDecl(
             None, "test_fun", 
             params = [
                 SymbolRef("b", np.ctypeslib.ndpointer(dtype=np.float64)()),
                 SymbolRef("c", np.ctypeslib.ndpointer(dtype=np.float64)()),
                 SymbolRef("nodes", ctypes.c_int()),
                 SymbolRef("modes", ctypes.c_int()),
                 ret_sym
                 ],
             defn = ct_builder.c_ast
             )
         ])
     mod = JitModule()
     submod = CFile("test_fun", [c_ast], path=CONFIG.get('jit','COMPILE_PATH'))._compile(c_ast.codegen())
     mod._link_in(submod)
     entry = c_ast.find(FunctionDecl, name="test_fun")             
     c_test_fun = mod.get_callable(entry.name, entry.get_type())     
     nodes = 19
     modes = 5
     a = np.random.rand(2, nodes, modes)
     b = np.random.rand(2, nodes, modes)
     
     res = np.zeros((2,nodes,modes))
     c_test_fun(a,b,nodes,modes,res)
     self.assertTrue(np.allclose( res, 0.5+a*b)) 
Ejemplo n.º 2
0
    def transform(self, py_ast, program_cfg):
        arg_cfg, tune_cfg = program_cfg
        tree = PyBasicConversions().visit(py_ast)
        param_dict = {}
        tree.body[0].params.append(C.SymbolRef("retval", arg_cfg[0]()))
        # Annotate arguments
        for param, type in zip(tree.body[0].params, arg_cfg):
            param.type = type()
            param_dict[param.name] = type._dtype_

        length = np.prod(arg_cfg[0]._shape_)
        transformer = MapTransformer("i", param_dict, "retval")
        body = list(map(transformer.visit, tree.body[0].defn))

        tree.body[0].defn = [C.For(
                C.Assign(C.SymbolRef("i", ct.c_int()), C.Constant(0)),
                C.Lt(C.SymbolRef("i"), C.Constant(length)),
                C.PostInc(C.SymbolRef("i")),
                body=body,
                pragma="ivdep"
            )]

        tree = DeclarationFiller().visit(tree)
        defns = []
        tree = HwachaVectorize(param_dict, defns).visit(tree)
        file_body = [
            StringTemplate("#include <stdlib.h>"),
            StringTemplate("#include <stdint.h>"),
            StringTemplate("#include <assert.h>"),
            StringTemplate("extern \"C\" void __hwacha_body(void);"),
        ]
        file_body.extend(defns)
        file_body.append(tree)
        return [CFile("generated", file_body)]
Ejemplo n.º 3
0
 def transform(self, tree, program_cfg):
     arg_cfg, tune_cfg = program_cfg
     conv_param = self.conv_param
     kernel_size = conv_param.kernel_size
     pad = conv_param.pad
     stride = conv_param.stride
     group = conv_param.group
     out_num, out_c, out_h, out_w = arg_cfg['out']._shape_
     in_ptr_num, in_ptr_c, in_ptr_h, in_ptr_w = arg_cfg['in_ptr']._shape_
     weights_g, weights_c, weights_h, weights_w = arg_cfg['weights']._shape_
     return [
         CFile('conv', [
             FileTemplate(
                 os.path.dirname(os.path.realpath(__file__)) +
                 '/conv_test.tmpl.c', {
                     'kernel_size': Constant(kernel_size),
                     'pad': Constant(pad),
                     'stride': Constant(stride),
                     'group': Constant(group),
                     'out_num': Constant(out_num),
                     'out_c': Constant(out_c),
                     'out_h': Constant(out_h),
                     'out_w': Constant(out_w),
                     'in_num': Constant(in_ptr_num),
                     'in_c': Constant(in_ptr_c),
                     'in_h': Constant(in_ptr_h),
                     'in_w': Constant(in_ptr_w),
                     'weight_g': Constant(weights_g),
                     'weight_c': Constant(weights_c),
                     'weight_h': Constant(weights_h),
                     'weight_w': Constant(weights_w),
                     'bias_term': Constant(1 if conv_param.bias_term else 0)
                 })
         ],
               config_target='omp')
     ]
Ejemplo n.º 4
0
 def visit_Module(self, node):
     body = [self.visit(s) for s in node.body]
     return Project([CFile("module", body)])