def visit_Return(self, node): """ A return statement ought to be checked for safety, then translated into a memcpy to output memory. """ stmt = self.visit(node.value) if isinstance(stmt, OpenCall): output = stmt.get_output_params() self.retModel = ReturnModel(output['type'], output['shape']) pre = stmt.write_to(self.retModel) pre.contents.append(cpp_ast.ReturnStatement("")) return pre elif isinstance(stmt, cpp_ast.CNumber): self.retModel = ReturnModel(type(stmt.num), [1]) return cpp_ast.UnbracedBlock([ cpp_ast.Statement('*_blb_result = %s;' % stmt.num), cpp_ast.ReturnStatement("") ]) else: try: source = self.get_or_create_model(stmt) self.retModel = ReturnModel(source.scalar_t, source.dimension()) return cpp_ast.UnbracedBlock([ cpp_ast.FunctionCall('memcpy', [ '_blb_result', source.ref_name(), source.dimension() * source.scalar_t.csize() ]), cpp_ast.ReturnStatement("") ]) except ValueError, TypeError: raise ValueError("Invalid return object: %s" % str(stmt))
def gen_array_unpack(self): ret = [ cpp_ast.Assign( cpp_ast.Pointer(cpp_ast.Value("npy_double", "_my_" + x)), cpp_ast.TypeCast( cpp_ast.Pointer(cpp_ast.Value("npy_double", "")), cpp_ast.FunctionCall(cpp_ast.CName("PyArray_DATA"), params=[cpp_ast.CName(x)]))) for x in self.argdict.keys() ] return ret
def handle_initialize(args, converter, kargs): if len(args) < 2: raise TypeError("initialize() requires 2 arguments, %d given." % len(args)) target = converter.get_or_create_model(args[0]) assert type( args[1] ) == cpp_ast.CNumber, "initialize(): Invalid initial value: %s" % str( args[1]) REQUESTED_FUNCS.append(('_blb_vecinit', (target.scalar_t.ctype(), ))) return cpp_ast.FunctionCall( '_blb_vecinit', [target.ref_name(), args[1].num, target.dimension()])
def handle_copy(args, converter, kargs): if len(args) < 2: raise TypeError("copy() requires two arguments, %d given" % len(args)) target = converter.get_or_create_model(converter.visit(args[0])) source = converter.get_or_create_model(converter.visit(args[1])) assert target.scalar_t == source.scalar_t, "copy(): Type mismatch between %s and %s" % ( target.ref_name(), source.ref_name()) assert target.dimension() == source.dimentsion( ), "copy(): Length mismatch: %s and %s " % (target.ref_name(), source.ref_name()) return cpp_ast.FunctionCall('memcpy', [ target.ref_name(), source.ref_name(), target.dimension() * target.scalar_t.csize() ])
def visit_MathFunction(self, node): return cpp_ast.FunctionCall(cpp_ast.CName(node.name), params=map(self.visit, node.args))