def convert_argument(i, arg, annotations, *, type_=None, is_ret=False): name = arg.displayname if not is_ret else RET_ARGUMENT_NAME if not name: name = "__arg{}".format(i) annotations["depends_on"].discard(name) apply_rules(arg, annotations, name=name) with location(f"argument {term.yellow(name)}", convert_location(arg.location)): if not is_ret: expressions = list( arg.find_descendants(lambda c: c.kind.is_expression())) parse_assert( len(expressions) <= 1, "There must only be one expression child in argument declarations." ) value = expressions[0].source if expressions else None else: value = None type_ = type_ or arg.type return Argument( name, convert_type(type_, name, annotations, set()), value=value, location=convert_location(arg.location), **annotations.direct(argument_annotations).flatten(), )
def return_command_implementation(f: Function): with location(f"at {term.yellow(str(f.name))}", f.location): # pthread_mutex_lock(&nw_handler_lock); # took_lock = 1; generate_requires(not f.return_value.type.buffer or f.return_value.type.lifetime != Expr("AVA_CALL"), "Returned buffers must have a lifetime other than `call' (i.e., must be annotated with `ava_lifetime_static', `ava_lifetime_coupled', or `ava_lifetime_manual').") return f""" case {f.ret_id_spelling}: {{\ {timing_code_guest("before_unmarshal", str(f.name), f.generate_timing_code)} ava_is_in = 0; ava_is_out = 1; struct {f.ret_spelling}* __ret = (struct {f.ret_spelling}*)__cmd; assert(__ret->base.api_id == {f.api.number_spelling}); assert(__ret->base.command_size == sizeof(struct {f.ret_spelling}) && "Command size does not match ID. (Can be caused by incorrectly computed buffer sizes, especially using `strlen(s)` instead of `strlen(s)+1`)"); struct {f.call_record_spelling}* __local = (struct {f.call_record_spelling}*)ava_remove_call(&__ava_endpoint, __ret->__call_id); {{ {unpack_struct("__local", f.arguments, "->")} \ {unpack_struct("__local", f.logue_declarations, "->")} \ {unpack_struct("__ret", [f.return_value], "->", convert=get_buffer_expr) if not f.return_value.type.is_void else ""} \ {lines(copy_result_for_argument(a, "__local", "__ret") for a in f.arguments if a.type.contains_buffer)} {copy_result_for_argument(f.return_value, "__local", "__ret") if not f.return_value.type.is_void else ""}\ {lines(f.epilogue)} {lines(deallocate_managed_for_argument(a, "__local") for a in f.arguments)} }} {timing_code_guest("after_unmarshal", str(f.name), f.generate_timing_code)} __local->__call_complete = 1; if(__local->__handler_deallocate) {{ free(__local); }} break; }} """.strip()
def unsupported_function_implementation(f: Function) -> str: """ Generate a stub function which simply fails with an "Unsupported" message. :param f: The unsupported function. :return: A C function definition. """ with location(f"at {term.yellow(str(f.name))}", f.location): return f"""
def record_argument_metadata(arg: Argument): def convert_result_value(values, cast_type: Type, type_: Type, depth, original_type=None, **other) -> str: if isinstance(type_, ConditionalType): return Expr(type_.predicate).if_then_else( convert_result_value( values, type_.then_type, type_.then_type, depth, original_type=type_.original_type, **other ), convert_result_value( values, type_.else_type, type_.else_type, depth, original_type=type_.original_type, **other ), ) if type_.is_void: return """abort_with_reason("Reached code to handle void value.");""" (param_value,) = values buffer_pred = Expr(type_.transfer).equals("NW_BUFFER") & Expr(param_value).not_equals("NULL") def simple_buffer_case(): return "" def buffer_case(): if not hasattr(type_, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" tmp_name = f"__tmp_{arg.name}_{depth}" inner_values = (tmp_name,) loop = for_all_elements(inner_values, cast_type, type_, depth=depth, original_type=original_type, **other) if loop: return buffer_pred.if_then_else( f""" {type_.nonconst.attach_to(tmp_name)}; {tmp_name} = {param_value}; {loop} """ ) return "" def default_case(): return (Expr(type_.transfer).equals("NW_HANDLE")).if_then_else( Expr(not type_.deallocates).if_then_else( assign_record_replay_functions(param_value, type_).then(record_call_metadata(param_value, type_)), expunge_calls(param_value), ) ) if type_.fields: return for_all_elements(values, cast_type, type_, depth=depth, original_type=original_type, **other) return type_.is_simple_buffer().if_then_else( simple_buffer_case, Expr(type_.transfer).equals("NW_BUFFER").if_then_else(buffer_case, default_case) ) with location(f"at {term.yellow(str(arg.name))}", arg.location): conv = convert_result_value( (f"{arg.name}",), arg.type, arg.type, depth=0, name=arg.name, kernel=convert_result_value, self_index=0 ) return conv
def function_implementation(f: Function) -> Union[str, Expr]: """ Generate a stub function which sends the appropriate CALL command over the channel. :param f: The function to generate a stub for. :return: A C function definition (as a string or Expr) """ with location(f"at {term.yellow(str(f.name))}", f.location): if f.return_value.type.buffer: forge_success = f"#error Async returned buffers are not implemented." elif f.return_value.type.is_void: forge_success = "return;" elif f.return_value.type.success is not None: forge_success = f"return {f.return_value.type.success};" else: forge_success = """abort_with_reason("Cannot forge success without a success value for the type.");""" if f.return_value.type.is_void: return_statement = f""" free(__call_record); return; """.strip() else: return_statement = f""" {f.return_value.declaration}; {f.return_value.name} = __call_record->{f.return_value.name}; free(__call_record); return {f.return_value.name}; """.strip() is_async = ~Expr(f.synchrony).equals("NW_SYNC") alloc_list = AllocList(f) send_code = f""" command_channel_send_command(__chan, (struct command_base*)__cmd); """.strip() if f.api.send_code: import_code = f.api.send_code.encode("ascii", "ignore").decode("unicode_escape")[1:-1] ldict = locals() exec(import_code, globals(), ldict) send_code = ldict["send_code"] return_code = is_async.if_then_else( forge_success, f""" shadow_thread_handle_command_until(nw_shadow_thread_pool, __call_record->__call_complete); {return_statement} """.strip(), ) return f"""
def function_call_struct(f: Function, errors: List[Any]): with capture_errors(): with location(f"at {term.yellow(str(f.name))}", f.location, report_continue=errors): arg_suffix = "\n" return f""" struct {f.call_spelling} {{ struct command_base base; intptr_t __call_id; {"".join(argument(a) + arg_suffix for a in f.arguments).strip()} }}; """ # noinspection PyUnreachableCode return f'#error "{captured_errors()}" '
def function_ret_struct(f: Function, errors: List[Any]): with capture_errors(): with location(f"at {term.yellow(str(f.name))}", f.location, report_continue=errors): arg_suffix = "\n" return f""" struct {f.ret_spelling} {{ struct command_base base; intptr_t __call_id; {"".join(argument(a) + arg_suffix for a in f.arguments if a.type.contains_buffer and a.output).strip()}\ {argument(f.return_value) if not f.return_value.type.is_void else ""} }}; """ # noinspection PyUnreachableCode return f'#error "{captured_errors()}" '
def function_call_record_struct(f: Function, errors: List[Any]): with capture_errors(): with location(f"at {term.yellow(str(f.name))}", f.location, report_continue=errors): arg_suffix = "\n" return f""" struct {f.call_record_spelling} {{ {"".join(argument(a) + arg_suffix for a in f.arguments).strip()}\ {argument(f.return_value) if not f.return_value.type.is_void else ""}\ {"".join(argument(a) + arg_suffix for a in f.logue_declarations).strip()}\ char __handler_deallocate; volatile char __call_complete; }}; """ # noinspection PyUnreachableCode return f'#error "{captured_errors()}" '
def replay_command_implementation(f: Function): with location(f"at {term.yellow(str(f.name))}", f.location): alloc_list = AllocList(f) return f""" case {f.call_id_spelling}: {{\ {alloc_list.alloc} ava_is_in = 1; ava_is_out = 0; __cmd = __call_cmd; struct {f.call_spelling}* __call = (struct {f.call_spelling}*)__call_cmd; assert(__call->base.api_id == {f.api.number_spelling}); assert(__call->base.command_size == sizeof(struct {f.call_spelling}) && "Command size does not match ID. (Can be caused by incorrectly computed buffer sizes, expecially using `strlen(s)` instead of `strlen(s)+1`)"); /* Unpack and translate arguments */ {lines(convert_input_for_argument(a, "__call") for a in f.arguments)} /* Perform Call */ {call_function_wrapper(f)} ava_is_in = 0; ava_is_out = 1; __cmd = __ret_cmd; struct {f.ret_spelling}* __ret = (struct {f.ret_spelling}*)__ret_cmd; assert(__ret->base.api_id == {f.api.number_spelling}); assert(__ret->base.command_size == sizeof(struct {f.ret_spelling}) && "Command size does not match ID. (Can be caused by incorrectly computed buffer sizes, expecially using `strlen(s)` instead of `strlen(s)+1`)"); assert(__ret->base.command_id == {f.ret_id_spelling}); assert(__ret->__call_id == __call->__call_id); /* Assign original handle IDs */ {assign_original_handle_for_argument(f.return_value, "__ret") if not f.return_value.type.is_void else ""} {lines(assign_original_handle_for_argument(a, "__ret") for a in f.arguments if a.type.contains_buffer)} #ifdef AVA_RECORD_REPLAY {log_call_declaration} {log_ret_declaration} {lines( record_argument_metadata(a, src="__ret" if a.type.contains_buffer and a.output else "__call") for a in f.arguments)} {record_argument_metadata(f.return_value, "__ret") if not f.return_value.type.is_void else ""} {record_call_metadata("NULL", None) if f.object_record else ""} #endif {alloc_list.dealloc} break; }} """.strip()
def function_implementation(f: Function, enabled_opts: List[str] = None ) -> Union[str, Expr]: """ Generate a stub function which sends the appropriate CALL command over the channel. :param f: The function to generate a stub for. :return: A C function definition (as a string or Expr) """ with location(f"at {term.yellow(str(f.name))}", f.location): if f.return_value.type.buffer: forge_success = "#error Async returned buffers are not implemented." elif f.return_value.type.is_void: forge_success = "return;" elif f.return_value.type.success is not None: forge_success = f"return {f.return_value.type.success};" else: forge_success = """abort_with_reason("Cannot forge success without a success value for the type.");""" if f.return_value.type.is_void: return_statement = """ free(__call_record); return; """.strip() else: return_statement = f""" {f.return_value.declaration}; {f.return_value.name} = __call_record->{f.return_value.name}; free(__call_record); return {f.return_value.name}; """.strip() is_async = ~Expr(f.synchrony).equals("NW_SYNC") alloc_list = AllocList(f) if enabled_opts: # Enable batching optimization: the APIs are batched into a `__do_batch_emit` call. if "batching" in enabled_opts: send_code = f""" batch_insert_command(nw_global_cmd_batch, (struct command_base*)__cmd, __chan, {int(is_async.is_true())}); """.strip() if f.name == "__do_batch_emit": send_code = """ command_channel_send_command(__chan, (struct command_base*)__cmd); """.strip() else: send_code = """ command_channel_send_command(__chan, (struct command_base*)__cmd); """.strip() return_code = is_async.if_then_else( forge_success, f""" shadow_thread_handle_command_until( common_context->nw_shadow_thread_pool, __call_record->__call_complete); {return_statement} """.strip(), ) return f"""
def deallocate_managed_for_argument(arg: Argument, src): def convert_result_value(values, type: Type, depth, original_type=None, **other): if isinstance(type, ConditionalType): return Expr(type.predicate).if_then_else( convert_result_value(values, type.then_type, depth, original_type=type.original_type, **other), convert_result_value(values, type.else_type, depth, original_type=type.original_type, **other)) local_value, = values buffer_pred = (Expr(type.transfer).equals("NW_BUFFER") & f"{local_value} != NULL") dealloc_shadows = Expr(type.deallocates).if_then_else( f"ava_shadow_buffer_free_coupled(&__ava_endpoint, (void *){local_value});" ) def simple_buffer_case(): return "" def buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" return buffer_pred.if_then_else( for_all_elements(values, type, depth=depth, original_type=original_type, **other)) def default_case(): dealloc_code = Expr(type.deallocates).if_then_else( Expr(type.transfer).equals("NW_HANDLE").if_then_else(f""" ava_coupled_free(&__ava_endpoint, {local_value}); """.strip())) return dealloc_code if type.fields: return for_all_elements(values, type, depth=depth, original_type=original_type, **other) return type.is_simple_buffer(allow_handle=False).if_then_else( simple_buffer_case, Expr(type.transfer).equals("NW_BUFFER").if_then_else( buffer_case, (Expr(type.transfer).one_of({ "NW_OPAQUE", "NW_HANDLE" })).if_then_else(default_case))).then(dealloc_shadows).scope() with location(f"at {term.yellow(str(arg.name))}", arg.location): conv = convert_result_value( (f"""{src + "->" if src else ""}{arg.name}""", ), arg.type, depth=0, name=arg.name, kernel=convert_result_value, self_index=0) return comment_block(f"Dealloc: {arg}", conv)
def convert_input_for_argument(arg: Argument, src): """ Generate code to extract the value for arg from the call structure in src. The value of arg is left in a variable named arg.name. The value is fully converted to local values. This code used in the command receiver to implement a CALL command. :param arg: The argument to extract. :param src: The CALL command structure. :return: A series of C statements to perform the extraction. """ alloc_list = AllocList(arg.function) def convert_input_value(values, type: Type, depth, original_type=None, **other): local_value, param_value = values preassignment = f"{local_value} = {get_transfer_buffer_expr(param_value, type)};" if isinstance(type, ConditionalType): return Expr(preassignment).then( Expr(type.predicate).if_then_else( convert_input_value(values, type.then_type, depth, original_type=type.original_type, **other), convert_input_value(values, type.else_type, depth, original_type=type.original_type, **other))) if type.is_void: return """abort_with_reason("Reached code to handle void value.");""" def maybe_alloc_local_temporary_buffer(): # TODO: Deduplicate with allocate_tmp_buffer allocator = type.buffer_allocator deallocator = type.buffer_deallocator return Expr(param_value).not_equals("NULL").if_then_else(f"""{{ const size_t __size = {compute_buffer_size(type, original_type)}; {local_value} = ({type.nonconst.spelling}){allocator}({size_to_bytes("__size", type)}); {alloc_list.insert(local_value, deallocator)} }}""") src_name = f"__src_{arg.name}_{depth}" def get_buffer_code(): return f""" {type.nonconst.attach_to(src_name)}; {src_name} = {local_value}; {get_buffer(local_value, param_value, type, original_type=original_type, not_null=True)} {(type.lifetime.equals("AVA_CALL") & (~type.is_simple_buffer() | type.buffer_allocator.not_equals("malloc"))).if_then_else( maybe_alloc_local_temporary_buffer)} """ def simple_buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" copy_code = ( Expr(arg.input) & Expr(local_value).not_equals(src_name) ).if_then_else( f"""memcpy({local_value}, {src_name}, {size_to_bytes("__buffer_size", type)});""" ) return ((type.lifetime.not_equals("AVA_CALL") | arg.input) & Expr(param_value).not_equals("NULL")).if_then_else( f""" {get_buffer_code()} {copy_code} """.strip(), (Expr(arg.input) | type.transfer.equals("NW_ZEROCOPY_BUFFER") ).if_then_else(preassignment, maybe_alloc_local_temporary_buffer)) def buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" if not arg.input: return simple_buffer_case() inner_values = (local_value, src_name) core = for_all_elements(inner_values, type, depth=depth, precomputed_size="__buffer_size", original_type=original_type, **other) return ((type.lifetime.not_equals("AVA_CALL") | arg.input) & Expr(param_value).not_equals("NULL")).if_then_else( f""" {get_buffer_code()} {core} """.strip(), maybe_alloc_local_temporary_buffer) def default_case(): def deref_code(handlepool_function: str) -> callable: return lambda: (Expr(type.transfer).one_of({ "NW_CALLBACK", "NW_CALLBACK_REGISTRATION" })).if_then_else( f"{local_value} = ({param_value} == NULL) ? NULL : {type.callback_stub_function};", (Expr(type.transfer).equals("NW_HANDLE")).if_then_else( f"{local_value} = ({type.nonconst.spelling}){handlepool_function}(handle_pool, (void*){param_value});", Expr(not type.is_void).if_then_else( f"{local_value} = {param_value};", """abort_with_reason("Reached code to handle void value.");""" ))) return Expr(type.deallocates).if_then_else( deref_code("nw_handle_pool_deref_and_remove"), deref_code("nw_handle_pool_deref")) if type.fields: return for_all_elements(values, type, depth=depth, **other) rest = type.is_simple_buffer().if_then_else( simple_buffer_case, Expr(type.transfer).equals("NW_BUFFER").if_then_else( buffer_case, default_case)) if rest: return Expr(preassignment).then(rest).scope() else: return "" with location(f"at {term.yellow(str(arg.name))}", arg.location): conv = convert_input_value((arg.name, f"{src}->{arg.param_spelling}"), arg.type, depth=0, name=arg.name, kernel=convert_input_value, original_type=arg.type, self_index=0) return comment_block( f"Input: {arg}", f"""\ {arg.type.nonconst.attach_to(arg.name)}; \ {conv} """)
def convert_function(cursor, supported=True): with location(f"at {term.yellow(cursor.displayname)}", convert_location(cursor.location), report_continue=errors): # TODO: Capture tokens here and then search them while processing arguments to find commented argument # names. body = None for c in cursor.get_children(): if c.kind == CursorKind.COMPOUND_STMT: body = c break prologue = [] epilogue = [] declarations = [] implicit_arguments = [] annotations = annotation_set() annotations.update(extract_attr_annotations(cursor)) if body: annotations.update(extract_annotations(body)) output_list = prologue for c in body.get_children(): c_annotations = extract_annotations(c) c_attr_annotations = extract_attr_annotations(c) if "implicit_argument" in c_attr_annotations: # FIXME: The [0] should be replaced with code to select the actual correct var decl implicit_arguments.append(c.children[0]) continue if len(c_annotations) and list( c_annotations.keys()) != ["depends_on"]: continue found_variables = False if c.kind.is_declaration: for cc in c.find_descendants( lambda cc: cc.kind == CursorKind.VAR_DECL): if not cc.displayname.startswith( NIGHTWATCH_PREFIX ) and cc.displayname != "ret": parse_expects( len(cc.children) == 0, "Declarations in prologue and epilogue code may not be initialized. " "(This is currently not checked fully.)", ) declarations.append( convert_argument(-2, cc, annotation_set())) found_variables = True if list( c.find_descendants( lambda cc: cc.displayname == "ava_execute")): parse_requires( c.kind != CursorKind.DECL_STMT or c.children[0].displayname == "ret", "The result of ava_execute() must be named 'ret'.", ) output_list = epilogue elif not found_variables: src = c.source output_list.append(src + ("" if src.endswith(";") else ";")) apply_rules(cursor, annotations, name=cursor.mangled_name) args = [] for i, arg in enumerate( list(cursor.get_arguments()) + implicit_arguments): args.append( convert_argument(i, arg, annotations.subelement(arg.displayname))) resources = {} for annotation_name, annotation_value in annotations.direct( ).flatten().items(): if annotation_name.startswith(consumes_amount_prefix): resource = strip_prefix(consumes_amount_prefix, annotation_name) resources[resource] = annotation_value return_value = convert_argument( -1, cursor, annotations.subelement("return_value"), is_ret=True, type_=cursor.result_type) if "unsupported" in annotations: supported = not bool(annotations["unsupported"]) disable_native = False if "disable_native" in annotations: disable_native = bool(annotations["disable_native"]) return Function( cursor.mangled_name, return_value, args, location=convert_location(cursor.location), logue_declarations=declarations, prologue=prologue, epilogue=epilogue, consumes_resources=resources, supported=supported, disable_native=disable_native, type=convert_type(cursor.type, cursor.mangled_name, annotation_set(), set()), **annotations.direct(function_annotations).flatten(), )
def convert_type(tpe, name, annotations, containing_types): parse_requires( tpe.get_canonical().spelling not in containing_types or "void" in tpe.get_canonical().spelling, "Recursive types don't work.", ) original_containing_types = containing_types containing_types = copy.copy(original_containing_types) containing_types.add(tpe.get_canonical().spelling) parse_assert(tpe.spelling, "Element requires valid and complete type.") apply_rules(tpe, annotations, name=name) with location(f"in type {term.yellow(tpe.spelling)}"): allocates_resources, deallocates_resources = {}, {} for annotation_name, annotation_value in annotations.direct( ).flatten().items(): if annotation_name.startswith(allocates_amount_prefix): resource = strip_prefix(allocates_amount_prefix, annotation_name) allocates_resources[resource] = annotation_value elif annotation_name.startswith(deallocates_amount_prefix): resource = strip_prefix(deallocates_amount_prefix, annotation_name) deallocates_resources[resource] = annotation_value parse_expects( allocates_resources.keys().isdisjoint( deallocates_resources.keys()), "The same argument is allocating and deallocating the same resource.", ) our_annotations = annotations.direct(type_annotations).flatten() our_annotations.update(allocates_resources=allocates_resources, deallocates_resources=deallocates_resources) if annotations["type_cast"]: new_type = annotations["type_cast"] # annotations = copy.copy(annotations) annotations.pop("type_cast") if isinstance(new_type, Conditional): ret = ConditionalType( new_type.predicate, convert_type(new_type.then_branch or tpe, name, annotations, containing_types), convert_type(new_type.else_branch or tpe, name, annotations, containing_types), convert_type(tpe, name, annotations, containing_types), ) return ret parse_assert(new_type is not None, "ava_type_cast must provide a new type") # Attach the original type and then perform conversion using the new type. our_annotations["original_type"] = convert_type( tpe, name, annotation_set(), original_containing_types) tpe = new_type if tpe.is_function_pointer(): pointee = tpe.get_pointee() if pointee.kind == TypeKind.FUNCTIONNOPROTO: args = [] else: args = [ convert_type(t, "", annotation_set(), containing_types) for t in pointee.argument_types() ] return FunctionPointer( tpe.spelling, Type(f"*{name}", **our_annotations), return_type=convert_type(pointee.get_result(), "ret", annotation_set(), containing_types), argument_types=args, **our_annotations, ) if tpe.kind in (TypeKind.FUNCTIONPROTO, TypeKind.FUNCTIONNOPROTO): if tpe.kind == TypeKind.FUNCTIONNOPROTO: args = [] else: args = [ convert_type(t, "", annotation_set(), containing_types) for t in tpe.argument_types() ] return FunctionPointer( tpe.spelling, Type(tpe.spelling, **our_annotations), return_type=convert_type(tpe.get_result(), "ret", annotation_set(), containing_types), argument_types=args, **our_annotations, ) if tpe.is_static_array(): pointee = tpe.get_pointee() pointee_annotations = annotations.subelement("element") pointee_name = f"{name}[{buffer_index_spelling}]" our_annotations["buffer"] = Expr(tpe.get_array_size()) return StaticArray( tpe.spelling, pointee=convert_type(pointee, pointee_name, pointee_annotations, containing_types), **our_annotations, ) if tpe.is_pointer(): pointee = tpe.get_pointee() pointee_annotations = annotations.subelement("element") pointee_name = f"{name}[{buffer_index_spelling}]" if tpe.kind in (TypeKind.VARIABLEARRAY, TypeKind.INCOMPLETEARRAY): sp: str = tpe.spelling sp = sp.replace("[]", "*") return Type( sp, pointee=convert_type(tpe.element_type, pointee_name, pointee_annotations, containing_types), **our_annotations, ) return Type( tpe.spelling, pointee=convert_type(pointee, pointee_name, pointee_annotations, containing_types), **our_annotations, ) if tpe.get_canonical().kind == TypeKind.RECORD: def expand_field(f: Cursor, prefix): f_tpe = f.type decl = f_tpe.get_declaration() if decl.is_anonymous(): if decl.kind == CursorKind.UNION_DECL: # FIXME: This assumes the first field is as large or larger than any other field. first_field = sorted( f_tpe.get_fields(), key=lambda f: f.type.get_size())[0] return expand_field( first_field, f"{prefix}.{first_field.displayname}") parse_requires( False, "The only supported anonymous member type is unions." ) return [( f.displayname, convert_type( f.type, f"{prefix}.{f.displayname}", annotations.subelement(Field(f.displayname)), containing_types, ), )] field_types = dict( ff for field in tpe.get_canonical().get_fields() for ff in expand_field(field, name)) return Type(tpe.spelling, fields=field_types, **our_annotations) return Type(tpe.spelling, **our_annotations)
def for_all_elements( values: tuple, cast_type: Type, type: Type, *, depth: int, kernel, name: str, self_index: int, precomputed_size=None, original_type=None, **extra, ): """ kernel(values, cast_type, type, **other) """ size = f"__{name}_size_{depth}" index = f"__{name}_index_{depth}" with location(f"in type {term.yellow(type.spelling)}"): if hasattr(type, "pointee") and type.pointee: loop = "" size_expr = Expr(precomputed_size or compute_buffer_size(type, original_type)) eval_size = f"const size_t {size} = {size_expr};" inner_values = tuple(f"__{name}_{_letters[i]}_{depth}" for i in range(len(values))) type_pointee = _char_type_like( type.pointee) if type.pointee.is_void else type.pointee nested = kernel( tuple("*" + v for v in inner_values), type_pointee.nonconst, type_pointee, depth=depth + 1, name=name, kernel=kernel, self_index=self_index, **extra, ) if nested: set_inner_values = lines(f""" {type_pointee.nonconst.attach_to(iv, additional_inner_type_elements="*")}; {iv} = {type_pointee.nonconst.cast_type(type.ascribe_type(v), "*")} + {index}; """ for v, iv in zip(values, inner_values)) if size_expr.is_constant(1): loop = f""" const size_t {index} = 0; const size_t ava_index = 0; {set_inner_values} {nested} """ else: loop = f""" for(size_t {index} = 0; {index} < {size}; {index}++) {{ const size_t ava_index = {index}; {set_inner_values} {nested} }} """.strip() if nested: return eval_size + loop else: return "" elif type.fields: prefix = f""" {type.nonconst.attach_to("ava_self", additional_inner_type_elements="*")}; ava_self = {type.nonconst.cast_type(type.ascribe_type(f"&{values[self_index]}", "*"), "*")}; """ field_infos = [] for field_name, field in type.fields.items(): inner_values = tuple( f"__{name}_{_letters[i]}_{depth}_{field_name}" for i in range(len(values))) nested = kernel( tuple("*" + v for v in inner_values), field.nonconst, field, depth=depth + 1, name=name, kernel=kernel, self_index=self_index, **extra, ) inner_code = "" if str(nested).strip(): set_inner_values = lines(f""" {field.nonconst.attach_to(iv, additional_inner_type_elements="*")}; {iv} = {field.nonconst.cast_type(field.ascribe_type(f"&({v}).{field_name}", "*"), "*")}; """.strip() for v, iv in zip(values, inner_values)) inner_code = set_inner_values + str(nested) field_infos.append((field_name, nested, inner_code)) code = "\n".join(inner_code for _, _, inner_code in _sort_fields(field_infos)) if code.strip(): return "{" + prefix + code + "}" else: return "" else: raise ValueError("Type must be a buffer of some kind.")
def assign_original_handle_for_argument(arg: Argument, original: str): def convert_result_value(values, cast_type: Type, type: Type, depth, original_type=None, **other) -> str: if isinstance(type, ConditionalType): return Expr(type.predicate).if_then_else( convert_result_value(values, type.then_type.nonconst, type.then_type, depth, original_type=type.original_type, **other), convert_result_value(values, type.then_type.nonconst, type.else_type, depth, original_type=type.original_type, **other), ) if type.is_void: return """abort_with_reason("Reached code to handle void value.");""" original_value, local_value = values buffer_pred = (Expr(type.transfer).equals("NW_BUFFER") & Expr(local_value).not_equals("NULL") & (Expr(type.buffer) > 0)) def simple_buffer_case(): return "" def buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" tmp_name = f"__tmp_{arg.name}_{depth}" inner_values = (tmp_name, local_value) return buffer_pred.if_then_else(f""" {type.nonconst.attach_to(tmp_name)}; {get_buffer(tmp_name, cast_type, original_value, type, original_type=original_type)} {for_all_elements(inner_values, cast_type, type, depth=depth, original_type=original_type, **other)} """) def default_case(): return (Expr(type.transfer).equals("NW_HANDLE").if_then_else( (~Expr(type.deallocates)).if_then_else( f"nw_handle_pool_assign_handle(handle_pool, (void*){original_value}, (void*){local_value});" ), ((Expr(arg.ret) | Expr(type.transfer).equals("NW_OPAQUE")) & Expr(not isinstance(type, FunctionPointer)) ).if_then_else(f"assert({original_value} == {local_value});"), )) if type.fields: return for_all_elements(values, cast_type, type, depth=depth, original_type=original_type, **other) return (type.is_simple_buffer().if_then_else( simple_buffer_case, Expr(type.transfer).equals("NW_BUFFER").if_then_else( buffer_case, default_case)).scope()) with location(f"at {term.yellow(str(arg.name))}", arg.location): conv = convert_result_value( (f"{original}->{arg.param_spelling}", f"{arg.name}"), arg.type.nonconst, arg.type, depth=0, name=arg.name, kernel=convert_result_value, self_index=1, ) return (Expr(arg.output) | arg.ret).if_then_else( comment_block(f"Assign or check: {arg}", conv))
def convert_result_for_argument(arg: Argument, dest) -> ExprOrStr: """ Take the value of arg in the local scope and write it into dest. :param arg: The argument to place in the output. :param dest: A RET command struct pointer. :return: A series of C statements. """ alloc_list = AllocList(arg.function) def convert_result_value(values, type: Type, depth, original_type=None, **other) -> str: if isinstance(type, ConditionalType): return Expr(type.predicate).if_then_else( convert_result_value(values, type.then_type, depth, original_type=type.original_type, **other), convert_result_value(values, type.else_type, depth, original_type=type.original_type, **other)) if type.is_void: return """abort_with_reason("Reached code to handle void value.");""" param_value, local_value = values def attach_data(data): return attach_buffer(param_value, local_value, data, type, arg.output, cmd=dest, original_type=original_type, expect_reply=False) def simple_buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" return (Expr(local_value).not_equals("NULL") & (Expr(type.buffer) > 0)).if_then_else( attach_data(local_value), f"{param_value} = NULL;") def buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" if not arg.output: return simple_buffer_case() tmp_name = f"__tmp_{arg.name}_{depth}" size_name = f"__size_{arg.name}_{depth}" inner_values = (tmp_name, local_value) return Expr(local_value).not_equals("NULL").if_then_else( f"""{{ {allocate_tmp_buffer(tmp_name, size_name, type, alloc_list=alloc_list, original_type=original_type)} {for_all_elements(inner_values, type, precomputed_size=size_name, depth=depth, original_type=original_type, **other)} {attach_data(tmp_name)} }}""", f"{param_value} = NULL;") def default_case(): handlepool_function = "nw_handle_pool_lookup_or_insert" return Expr(type.transfer).equals("NW_HANDLE").if_then_else( Expr(type.deallocates).if_then_else( f"{param_value} = NULL;", f"{param_value} = ({type.nonconst.spelling}){handlepool_function}(handle_pool, (void*){local_value});" ), Expr(not type.is_void).if_then_else( f"{param_value} = {local_value};")) if type.fields: return for_all_elements(values, type, depth=depth, original_type=original_type, **other) return type.is_simple_buffer().if_then_else( simple_buffer_case, Expr(type.transfer).equals("NW_BUFFER").if_then_else( buffer_case, default_case)).scope() with location(f"at {term.yellow(str(arg.name))}", arg.location): conv = convert_result_value( (f"{dest}->{arg.param_spelling}", f"{arg.name}"), arg.type, depth=0, name=arg.name, kernel=convert_result_value, self_index=1) return (Expr(arg.output) | arg.ret).if_then_else( comment_block(f"Output: {arg}", conv))
def parse(filename: str, include_path: List[str], definitions: List[Any], extra_args: List[Any]) -> API: index = Index.create(True) includes = [s for p in include_path for s in ["-I", p]] definitions = [s for d in definitions for s in ["-D", d]] cplusplus = filename.endswith(("cpp", "C", "cc")) nightwatch_parser_c_header_fullname = str(resource_directory / nightwatch_parser_c_header) llvm_args = (includes + extra_args + clang_flags + definitions + [ "-include", nightwatch_parser_c_header_fullname, f"-D__AVA_PREFIX={NIGHTWATCH_PREFIX}", "-x", "c++" if cplusplus else "c", filename, ]) unit = index.parse( None, args=llvm_args, options=TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD) errors = [] severity_table = { Diagnostic.Ignored: (None, parse_expects), Diagnostic.Note: (info, parse_expects), Diagnostic.Warning: (warning, parse_expects), Diagnostic.Error: (error, parse_requires), Diagnostic.Fatal: (error, parse_requires), } for d in unit.diagnostics: if (d.spelling == "incomplete definition of type 'struct __ava_unknown'" or d.spelling.startswith("incompatible pointer") and d.spelling.endswith( "with an expression of type 'struct __ava_unknown *'")): continue with location("Clang Parser", report_continue=errors): kind, func = severity_table[d.severity] func(not kind, d.spelling, loc=convert_location(d.location), kind=kind) primary_include_files: Dict[str, File] = {} primary_include_extents = [] utility_extents = [] replacement_extents = [] type_extents = [] global_config = {} functions: Dict[str, Function] = {} include_functions = {} replaced_functions = {} metadata_type = None rules = [] default_rules = [] final_rules = [] def apply_rules(c, annotations, *, name=None): if name: annotations["name"] = name def do(rules): for rule in rules: rule.apply(c, annotations) do(rules) if not annotations or (len(annotations) == 1 and "name" in annotations): do(default_rules) do(final_rules) if name: del annotations["name"] # pylint: disable=too-many-return-statements def convert_type(tpe, name, annotations, containing_types): parse_requires( tpe.get_canonical().spelling not in containing_types or "void" in tpe.get_canonical().spelling, "Recursive types don't work.", ) original_containing_types = containing_types containing_types = copy.copy(original_containing_types) containing_types.add(tpe.get_canonical().spelling) parse_assert(tpe.spelling, "Element requires valid and complete type.") apply_rules(tpe, annotations, name=name) with location(f"in type {term.yellow(tpe.spelling)}"): allocates_resources, deallocates_resources = {}, {} for annotation_name, annotation_value in annotations.direct( ).flatten().items(): if annotation_name.startswith(allocates_amount_prefix): resource = strip_prefix(allocates_amount_prefix, annotation_name) allocates_resources[resource] = annotation_value elif annotation_name.startswith(deallocates_amount_prefix): resource = strip_prefix(deallocates_amount_prefix, annotation_name) deallocates_resources[resource] = annotation_value parse_expects( allocates_resources.keys().isdisjoint( deallocates_resources.keys()), "The same argument is allocating and deallocating the same resource.", ) our_annotations = annotations.direct(type_annotations).flatten() our_annotations.update(allocates_resources=allocates_resources, deallocates_resources=deallocates_resources) if annotations["type_cast"]: new_type = annotations["type_cast"] # annotations = copy.copy(annotations) annotations.pop("type_cast") if isinstance(new_type, Conditional): ret = ConditionalType( new_type.predicate, convert_type(new_type.then_branch or tpe, name, annotations, containing_types), convert_type(new_type.else_branch or tpe, name, annotations, containing_types), convert_type(tpe, name, annotations, containing_types), ) return ret parse_assert(new_type is not None, "ava_type_cast must provide a new type") # Attach the original type and then perform conversion using the new type. our_annotations["original_type"] = convert_type( tpe, name, annotation_set(), original_containing_types) tpe = new_type if tpe.is_function_pointer(): pointee = tpe.get_pointee() if pointee.kind == TypeKind.FUNCTIONNOPROTO: args = [] else: args = [ convert_type(t, "", annotation_set(), containing_types) for t in pointee.argument_types() ] return FunctionPointer( tpe.spelling, Type(f"*{name}", **our_annotations), return_type=convert_type(pointee.get_result(), "ret", annotation_set(), containing_types), argument_types=args, **our_annotations, ) if tpe.kind in (TypeKind.FUNCTIONPROTO, TypeKind.FUNCTIONNOPROTO): if tpe.kind == TypeKind.FUNCTIONNOPROTO: args = [] else: args = [ convert_type(t, "", annotation_set(), containing_types) for t in tpe.argument_types() ] return FunctionPointer( tpe.spelling, Type(tpe.spelling, **our_annotations), return_type=convert_type(tpe.get_result(), "ret", annotation_set(), containing_types), argument_types=args, **our_annotations, ) if tpe.is_static_array(): pointee = tpe.get_pointee() pointee_annotations = annotations.subelement("element") pointee_name = f"{name}[{buffer_index_spelling}]" our_annotations["buffer"] = Expr(tpe.get_array_size()) return StaticArray( tpe.spelling, pointee=convert_type(pointee, pointee_name, pointee_annotations, containing_types), **our_annotations, ) if tpe.is_pointer(): pointee = tpe.get_pointee() pointee_annotations = annotations.subelement("element") pointee_name = f"{name}[{buffer_index_spelling}]" if tpe.kind in (TypeKind.VARIABLEARRAY, TypeKind.INCOMPLETEARRAY): sp: str = tpe.spelling sp = sp.replace("[]", "*") return Type( sp, pointee=convert_type(tpe.element_type, pointee_name, pointee_annotations, containing_types), **our_annotations, ) return Type( tpe.spelling, pointee=convert_type(pointee, pointee_name, pointee_annotations, containing_types), **our_annotations, ) if tpe.get_canonical().kind == TypeKind.RECORD: def expand_field(f: Cursor, prefix): f_tpe = f.type decl = f_tpe.get_declaration() if decl.is_anonymous(): if decl.kind == CursorKind.UNION_DECL: # FIXME: This assumes the first field is as large or larger than any other field. first_field = sorted( f_tpe.get_fields(), key=lambda f: f.type.get_size())[0] return expand_field( first_field, f"{prefix}.{first_field.displayname}") parse_requires( False, "The only supported anonymous member type is unions." ) return [( f.displayname, convert_type( f.type, f"{prefix}.{f.displayname}", annotations.subelement(Field(f.displayname)), containing_types, ), )] field_types = dict( ff for field in tpe.get_canonical().get_fields() for ff in expand_field(field, name)) return Type(tpe.spelling, fields=field_types, **our_annotations) return Type(tpe.spelling, **our_annotations) def convert_argument(i, arg, annotations, *, type_=None, is_ret=False): name = arg.displayname if not is_ret else RET_ARGUMENT_NAME if not name: name = "__arg{}".format(i) annotations["depends_on"].discard(name) apply_rules(arg, annotations, name=name) with location(f"argument {term.yellow(name)}", convert_location(arg.location)): if not is_ret: expressions = list( arg.find_descendants(lambda c: c.kind.is_expression())) parse_assert( len(expressions) <= 1, "There must only be one expression child in argument declarations." ) value = expressions[0].source if expressions else None else: value = None type_ = type_ or arg.type return Argument( name, convert_type(type_, name, annotations, set()), value=value, location=convert_location(arg.location), **annotations.direct(argument_annotations).flatten(), ) def convert_function(cursor, supported=True): with location(f"at {term.yellow(cursor.displayname)}", convert_location(cursor.location), report_continue=errors): # TODO: Capture tokens here and then search them while processing arguments to find commented argument # names. body = None for c in cursor.get_children(): if c.kind == CursorKind.COMPOUND_STMT: body = c break prologue = [] epilogue = [] declarations = [] implicit_arguments = [] annotations = annotation_set() annotations.update(extract_attr_annotations(cursor)) if body: annotations.update(extract_annotations(body)) output_list = prologue for c in body.get_children(): c_annotations = extract_annotations(c) c_attr_annotations = extract_attr_annotations(c) if "implicit_argument" in c_attr_annotations: # FIXME: The [0] should be replaced with code to select the actual correct var decl implicit_arguments.append(c.children[0]) continue if len(c_annotations) and list( c_annotations.keys()) != ["depends_on"]: continue found_variables = False if c.kind.is_declaration: for cc in c.find_descendants( lambda cc: cc.kind == CursorKind.VAR_DECL): if not cc.displayname.startswith( NIGHTWATCH_PREFIX ) and cc.displayname != "ret": parse_expects( len(cc.children) == 0, "Declarations in prologue and epilogue code may not be initialized. " "(This is currently not checked fully.)", ) declarations.append( convert_argument(-2, cc, annotation_set())) found_variables = True if list( c.find_descendants( lambda cc: cc.displayname == "ava_execute")): parse_requires( c.kind != CursorKind.DECL_STMT or c.children[0].displayname == "ret", "The result of ava_execute() must be named 'ret'.", ) output_list = epilogue elif not found_variables: src = c.source output_list.append(src + ("" if src.endswith(";") else ";")) apply_rules(cursor, annotations, name=cursor.mangled_name) args = [] for i, arg in enumerate( list(cursor.get_arguments()) + implicit_arguments): args.append( convert_argument(i, arg, annotations.subelement(arg.displayname))) resources = {} for annotation_name, annotation_value in annotations.direct( ).flatten().items(): if annotation_name.startswith(consumes_amount_prefix): resource = strip_prefix(consumes_amount_prefix, annotation_name) resources[resource] = annotation_value return_value = convert_argument( -1, cursor, annotations.subelement("return_value"), is_ret=True, type_=cursor.result_type) if "unsupported" in annotations: supported = not bool(annotations["unsupported"]) disable_native = False if "disable_native" in annotations: disable_native = bool(annotations["disable_native"]) return Function( cursor.mangled_name, return_value, args, location=convert_location(cursor.location), logue_declarations=declarations, prologue=prologue, epilogue=epilogue, consumes_resources=resources, supported=supported, disable_native=disable_native, type=convert_type(cursor.type, cursor.mangled_name, annotation_set(), set()), **annotations.direct(function_annotations).flatten(), ) utility_mode = False utility_mode_start = None replacement_mode = False replacement_mode_start = None def convert_decl(c: Cursor): nonlocal utility_mode, utility_mode_start, replacement_mode, replacement_mode_start, metadata_type assert not (replacement_mode and utility_mode) if c.kind in ignored_cursor_kinds: return normal_mode = not replacement_mode and not utility_mode # not (c.kind == CursorKind.VAR_DECL and c.displayname.startswith( # NIGHTWATCH_PREFIX)) and (utility_mode or replacement_mode): included_extent = True if (normal_mode and c.kind == CursorKind.FUNCTION_DECL and c.location.file.name == filename and c.spelling == "ava_metadata"): metadata_type = convert_type(c.result_type.get_pointee(), "ava_metadata", annotation_set(), set()) elif (normal_mode and c.kind == CursorKind.FUNCTION_DECL and c.displayname.startswith(NIGHTWATCH_PREFIX + "category_")): name = strip_unique_suffix( strip_prefix(NIGHTWATCH_PREFIX + "category_", c.displayname)) annotations = extract_annotations(c) attr_annotations = extract_attr_annotations(c) rule_list = default_rules if "default" in attr_annotations else rules annotations.pop("default", None) if name == "type": rule_list.append( Types(c.result_type.get_pointee(), annotations)) elif name == "functions": rule_list.append(Functions(annotations)) elif name == "pointer_types": rule_list.append(PointerTypes(annotations)) elif name == "const_pointer_types": rule_list.append(ConstPointerTypes(annotations)) elif name == "nonconst_pointer_types": rule_list.append(NonconstPointerTypes(annotations)) elif name == "non_transferable_types": rule_list.append(NonTransferableTypes(annotations)) elif normal_mode and c.kind == CursorKind.VAR_DECL and c.storage_class == StorageClass.STATIC: # This is a utility function for the API forwarding code. parse_expects( c.linkage == LinkageKind.INTERNAL, f"at {term.yellow(c.displayname)}", "API utility functions should be static (or similar) since they are included in header files.", loc=convert_location(c.location), ) utility_extents.append((c.extent.start.line, c.extent.end.line)) elif c.kind == CursorKind.VAR_DECL and c.displayname.startswith( NIGHTWATCH_PREFIX): name = strip_unique_suffix(strip_nw(c.displayname)) if name == "begin_utility": parse_requires( not utility_mode, "ava_begin_utility can only be used outside utility mode to enter that mode." ) utility_mode = True utility_mode_start = c.extent.start.line elif name == "end_utility": parse_requires( utility_mode, "ava_end_utility can only be used inside utility mode to exit that mode." ) utility_mode = False parse_assert(utility_mode_start is not None, "Should be unreachable.") utility_extents.append((utility_mode_start, c.extent.end.line)) elif name == "begin_replacement": parse_requires( not replacement_mode, "ava_begin_replacement can only be used outside replacement mode to enter that mode.", ) replacement_mode = True replacement_mode_start = c.extent.start.line elif name == "end_replacement": parse_requires( replacement_mode, "ava_end_replacement can only be used inside replacement mode to exit that mode." ) replacement_mode = False parse_assert(replacement_mode_start is not None, "Should be unreachable.") replacement_extents.append( (replacement_mode_start, c.extent.end.line)) else: global_config[name] = get_string_literal(c) elif (normal_mode and c.kind == CursorKind.VAR_DECL and c.type.spelling.endswith("_resource") and c.type.spelling.startswith("ava_")): # TODO: Use the resource declarations to check resource usage. pass elif c.kind == CursorKind.FUNCTION_DECL and c.location.file.name == filename: if normal_mode and c.is_definition( ) and c.storage_class == StorageClass.STATIC: # This is a utility function for the API forwarding code. parse_expects( c.linkage == LinkageKind.INTERNAL, f"at {term.yellow(c.displayname)}", "API utility functions should be static (or similar) since they are included in header files.", loc=convert_location(c.location), ) utility_extents.append( (c.extent.start.line, c.extent.end.line)) elif normal_mode: # This is an API function. f = convert_function(c) if f: functions[c.mangled_name] = f elif replacement_mode: # Remove the function from the list because it is replaced replaced_functions[c.mangled_name] = c elif (normal_mode and c.kind == CursorKind.FUNCTION_DECL and c.location.file.name in [f.name for f in primary_include_files.values()]): included_extent = False f = convert_function(c, supported=False) if f: include_functions[c.mangled_name] = f elif (normal_mode and c.kind == CursorKind.INCLUSION_DIRECTIVE and not c.displayname.endswith(nightwatch_parser_c_header) and c.location.file.name == filename): try: primary_include_files[c.displayname] = c.get_included_file() except AssertionError as e: parse_assert(not e, str(e), loc=convert_location(c.location)) # elif normal_mode and c.kind == CursorKind.INCLUSION_DIRECTIVE and c.tokens[-1].spelling == '"' \ # and not c.displayname.endswith(nightwatch_parser_c_header): # parse_assert(False, "Including AvA specifications in other specifications is not yet supported.") elif (normal_mode and c.kind in (CursorKind.MACRO_DEFINITION, CursorKind.STRUCT_DECL, CursorKind.TYPEDEF_DECL) and c.location.file and c.location.file.name == filename): # This is a utility macro for the API forwarding code. type_extents.append((c.extent.start.line, c.extent.end.line)) elif ( # pylint: disable=too-many-boolean-expressions (normal_mode or replacement_mode) and c.kind in (CursorKind.UNEXPOSED_DECL, ) and len(c.tokens) and c.tokens[0].spelling == "extern" and c.location.file in primary_include_files.values()): for cc in c.get_children(): convert_decl(cc) return # Skip the extents processing below elif normal_mode: # Default case for normal mode. is_semicolon = len(c.tokens) == 1 and c.tokens[0].spelling == ";" if c.location.file and not is_semicolon: parse_expects( c.location.file.name != filename, f"Ignoring unsupported: {c.kind} {c.spelling}", loc=convert_location(c.location), ) # if len(c.tokens) >= 1 and c.tokens[0].spelling == "extern" and c.kind == CursorKind.UNEXPOSED_DECL: # print(c.kind, c.tokens[0].spelling) else: # Default case for non-normal modes return # Skip the extents processing below if c.location.file in primary_include_files.values(): primary_include_extents.append( (c.location.file, c.extent.start.line, c.extent.end.line, included_extent)) for c in unit.cursor.get_children(): convert_decl(c) parse_expects(primary_include_files, "Expected at least one API include file.") extra_functions = {} if errors: raise MultipleError(*errors) for name, function in functions.items(): if name in include_functions: del include_functions[name] elif not function.callback_decl: extra_functions[name] = function for name, cursor in replaced_functions.items(): if name in include_functions: del include_functions[name] else: parse_requires( name not in functions, "Replacing forwarded functions is not allowed.", loc=convert_location(cursor.location), ) if extra_functions: function_str = ", ".join(str(f.name) for f in extra_functions.values()) parse_expects( False, f""" Functions appear in {filename}, but are not in {", ".join(primary_include_files.keys())}: {function_str}""".strip(), loc=Location(filename, None, None, None), ) # We use binary mode because clang operates in bytes not characters. # TODO: If the source files have "\r\n" and clang uses text mode then this will cause incorrect removals. # TODO: There could be functions in the header which are not processed with the current configuration. That will # mess things up. c_types_header_code = bytearray() for name, file in primary_include_files.items(): with open(file.name, "rb") as fi: # content = fi.read() # primary_include_extents.sort(key=lambda r: r(0).start.offset) def find_modes(i): modes = set() for in_name, start, end, mode in primary_include_extents: # pylint: disable=cell-var-from-loop if in_name == file and start <= i <= end: modes.add(mode) return modes error_reported = False i = None for i, line in enumerate(fi): modes = find_modes(i + 1) # print(i, modes, line) keep_line = True in modes or not modes error_line = keep_line and False in modes parse_expects( not error_line or error_reported, "Line both needed and excluded. Incorrect types header may be generated.", loc=Location(file.name, i, None, None), ) error_reported = error_reported or error_line if keep_line: c_types_header_code.extend(line) else: c_types_header_code.extend(b"/* NWR: " + line.replace( b"/*", b"@*").replace(b"*/", b"*@").rstrip() + b" */\n") def load_extents(extents): with open(filename, "rb") as fi: def utility_line(i): for start, end in extents: if start <= i <= end: return True return False c_code = bytearray() last_emitted_line = None for i, line in enumerate(fi): if utility_line(i + 1): if last_emitted_line != i - 1: c_code.extend("""#line {} "{}"\n""".format( i + 1, filename).encode("utf-8")) c_code.extend(line) last_emitted_line = i return bytes(c_code).decode("utf-8") c_utility_code = load_extents(utility_extents) c_replacement_code = load_extents(replacement_extents) c_type_code = load_extents(type_extents) return API( functions=list(functions.values()) + list(include_functions.values()), includes=list(primary_include_files.keys()), c_types_header_code=bytes(c_types_header_code).decode("utf-8"), c_type_code=c_type_code, c_utility_code=c_utility_code, c_replacement_code=c_replacement_code, metadata_type=metadata_type, missing_functions=list(include_functions.values()), cplusplus=cplusplus, **global_config, )
def call_command_implementation(f: Function): with location(f"at {term.yellow(str(f.name))}", f.location): alloc_list = AllocList(f) is_async = ~Expr(f.synchrony).equals("NW_SYNC") reply_code = f""" command_channel_send_command(__chan, (struct command_base*)__ret); """.strip() if (f.api.reply_code): import_code = f.api.reply_code.encode( 'ascii', 'ignore').decode('unicode_escape')[1:-1] ldict = locals() exec(import_code, globals(), ldict) reply_code = ldict['reply_code'] worker_argument_process_code = "" if (f.api.worker_argument_process_code): import_code = f.api.worker_argument_process_code.encode( 'ascii', 'ignore').decode('unicode_escape')[1:-1] ldict = locals() exec(import_code, globals(), ldict) worker_argument_process_code = ldict[ 'worker_argument_process_code'] return f""" case {f.call_id_spelling}: {{\ {timing_code_worker("before_unmarshal", str(f.name), f.generate_timing_code)} ava_is_in = 1; ava_is_out = 0; {alloc_list.alloc} struct {f.call_spelling}* __call = (struct {f.call_spelling}*)__cmd; assert(__call->base.api_id == {f.api.number_spelling}); assert(__call->base.command_size == sizeof(struct {f.call_spelling}) && "Command size does not match ID. (Can be caused by incorrectly computed buffer sizes, expecially using `strlen(s)` instead of `strlen(s)+1`)"); /* Unpack and translate arguments */ {lines(convert_input_for_argument(a, "__call") for a in f.arguments)} {timing_code_worker("after_unmarshal", str(f.name), f.generate_timing_code)} /* Perform Call */ {worker_argument_process_code} {call_function_wrapper(f)} {timing_code_worker("after_execution", str(f.name), f.generate_timing_code)} ava_is_in = 0; ava_is_out = 1; {compute_total_size(f.arguments + [f.return_value], lambda a: a.output)} struct {f.ret_spelling}* __ret = (struct {f.ret_spelling}*)command_channel_new_command( __chan, sizeof(struct {f.ret_spelling}), __total_buffer_size); __ret->base.api_id = {f.api.number_spelling}; __ret->base.command_id = {f.ret_id_spelling}; __ret->base.thread_id = __call->base.original_thread_id; __ret->__call_id = __call->__call_id; {convert_result_for_argument(f.return_value, "__ret") if not f.return_value.type.is_void else ""} {lines(convert_result_for_argument(a, "__ret") for a in f.arguments if a.type.contains_buffer)} #ifdef AVA_RECORD_REPLAY {log_call_declaration} {log_ret_declaration} {lines( record_argument_metadata(a, src="__ret" if a.type.contains_buffer else "__call") for a in f.arguments)} {record_argument_metadata(f.return_value, "__ret") if not f.return_value.type.is_void else ""} {record_call_metadata("NULL", None) if f.object_record else ""} #endif {timing_code_worker("after_marshal", str(f.name), f.generate_timing_code)} /* Send reply message */ {reply_code} {alloc_list.dealloc} {lines(deallocate_managed_for_argument(a, "") for a in f.arguments)} break; }} """.strip()
def copy_result_for_argument(arg: Argument, dest, src) -> ExprOrStr: """ Copy arg from the src struct into dest struct. :param arg: The argument to copy. :param dest: The destination call record struct for the call. :param src: The source RET struct from the call. :return: A series C statements. """ reported_missing_lifetime = False def convert_result_value(values, type: Type, depth, original_type=None, **other): if isinstance(type, ConditionalType): return Expr(type.predicate).if_then_else( convert_result_value(values, type.then_type, depth, original_type=type.original_type, **other), convert_result_value(values, type.else_type, depth, original_type=type.original_type, **other)) param_value, local_value = values src_name = f"__src_{arg.name}_{depth}" def get_buffer_code(): nonlocal reported_missing_lifetime if not reported_missing_lifetime and \ ((arg.ret or arg.output and depth > 0) and type.buffer) and \ type.lifetime == "AVA_CALL": reported_missing_lifetime = True generate_expects( False, "Returned buffers with call lifetime are almost always incorrect. (You may want to set a lifetime.)") return Expr(f""" {DECLARE_BUFFER_SIZE_EXPR} {type.attach_to(src_name)}; {src_name} = {get_transfer_buffer_expr(local_value, type, not_null=True)}; """).then( Expr(type.lifetime).not_equals("AVA_CALL").if_then_else( f"""{get_buffer(param_value, local_value, type, original_type=original_type, not_null=True, declare_buffer_size=False)}""", f"""__buffer_size = {compute_buffer_size(type, original_type)};""" ).then( Expr(arg.output).if_then_else(f"AVA_DEBUG_ASSERT({param_value} != NULL);") )) def simple_buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" copy_code = Expr(arg.output).if_then_else( f"""memcpy({param_value}, {src_name}, {size_to_bytes("__buffer_size", type)});""") if copy_code: return Expr(local_value).not_equals("NULL").if_then_else( f""" {get_buffer_code()} {copy_code} """.strip() ) else: return "" def buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" if not arg.output: return simple_buffer_case() inner_values = (param_value, src_name) loop = for_all_elements(inner_values, type, depth=depth, precomputed_size="__buffer_size", original_type=original_type, **other) if loop: return Expr(local_value).not_equals("NULL").if_then_else( f""" {get_buffer_code()} {loop} """ ) else: return "" def default_case(): dealloc_code = (Expr(type.transfer).equals("NW_HANDLE") & type.deallocates).if_then_else( f""" ava_coupled_free(&__ava_endpoint, {local_value}); """.strip() ) return dealloc_code.then((Expr(arg.output) | arg.ret).if_then_else(f"{param_value} = {local_value};")) if type.fields: return for_all_elements(values, type, depth=depth, original_type=original_type, **other) return type.is_simple_buffer(allow_handle=False).if_then_else( simple_buffer_case, Expr(type.transfer).equals("NW_BUFFER").if_then_else( buffer_case, Expr(type.transfer).one_of({"NW_OPAQUE", "NW_HANDLE"}).if_then_else( default_case ) ) ).scope() with location(f"at {term.yellow(str(arg.name))}", arg.location): conv = convert_result_value((f"{dest}->{arg.param_spelling}", f"{src}->{arg.name}"), arg.type, depth=0, name=arg.name, kernel=convert_result_value, self_index=0) return comment_block(f"Output: {arg}", conv)
def function_wrapper(f: Function) -> str: """ Generate a wrapper function for f which takes the arguments of the function, executes the "logues", and calls the function. :param f: A function. :return: A C static function definition. """ with location(f"at {term.yellow(str(f.name))}", f.location): if f.return_value.type.is_void: declare_ret = "" capture_ret = "" return_statement = "return;" else: declare_ret = f"{f.return_value.type.nonconst.attach_to(f.return_value.name)};" capture_ret = f"{f.return_value.name} = " return_statement = f"return {f.return_value.name};" if f.disable_native: # This works for both normal functions and callbacks because the # difference between the two is in the call, which is not emitted in # this case anyway. capture_ret = "" call_code = "" callback_unpack = "" elif not f.callback_decl: # Normal call call_code = f"""{f.name}({", ".join(a.name for a in f.real_arguments)})""" callback_unpack = "" else: # Indirect call (callback) try: userdata_arg, = [a for a in f.arguments if a.userdata] except ValueError: generate_requires( False, "ava_callback_decl function must have exactly one argument annotated with " "ava_userdata.") call_code = f"""__target_function({", ".join(a.name for a in f.real_arguments)})""" callback_unpack = f""" {f.type.attach_to("__target_function")}; __target_function = {f.type.cast_type(f"((struct ava_callback_user_data*){userdata_arg.name})->function_pointer")}; {userdata_arg.name} = ((struct ava_callback_user_data*){userdata_arg.name})->userdata; """ return f""" static {f.return_value.type.spelling} __wrapper_{f.name}({", ".join(a.declaration for a in f.arguments)}) {{ {callback_unpack}\ {lines(a.declaration + ";" for a in f.logue_declarations)}\ {lines(f.prologue)}\ {{ {declare_ret} {capture_ret}{call_code}; {lines(f.epilogue)} /* Report resources */ {lines(report_alloc_resources(arg) for arg in f.arguments)} {report_alloc_resources(f.return_value)} {report_consume_resources(f)} {return_statement} }} }} """.strip()
def attach_for_argument(arg: Argument, dest): """ Copy arg into dest attaching buffers as needed. :param arg: The argument to copy. :param dest: The destination CALL struct. :return: A series of C statements. """ alloc_list = AllocList(arg.function) def copy_for_value(values, type: Type, depth, argument, original_type=None, **other): if isinstance(type, ConditionalType): return Expr(type.predicate).if_then_else( copy_for_value(values, type.then_type, depth, argument, original_type=type.original_type, **other), copy_for_value(values, type.else_type, depth, argument, original_type=type.original_type, **other)) arg_value, cmd_value = values def attach_data(data): return attach_buffer(cmd_value, arg_value, data, type, arg.input, cmd=dest, original_type=original_type, expect_reply=True) def simple_buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" return (Expr(arg_value).not_equals("NULL") & (Expr(type.buffer) > 0)).if_then_else( attach_data(arg_value), f"{cmd_value} = NULL;") def buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" if not arg.input: return simple_buffer_case() tmp_name = f"__tmp_{arg.name}_{depth}" size_name = f"__size_{arg.name}_{depth}" loop = for_all_elements((arg_value, tmp_name), type, depth=depth, argument=argument, precomputed_size=size_name, original_type=original_type, **other) return (Expr(arg_value).not_equals("NULL") & (Expr(type.buffer) > 0)).if_then_else( f""" {allocate_tmp_buffer(tmp_name, size_name, type, alloc_list=alloc_list, original_type=original_type)} {loop} {attach_data(tmp_name)} """, f"{cmd_value} = NULL;" ) def default_case(): return Expr(not type.is_void).if_then_else( f"{cmd_value} = {arg_value};", """abort_with_reason("Reached code to handle void value.");""") if type.fields: return for_all_elements(values, type, depth=depth, argument=argument, original_type=original_type, **other) return type.is_simple_buffer(allow_handle=True).if_then_else( simple_buffer_case, Expr(type.transfer).equals("NW_BUFFER").if_then_else( buffer_case, default_case ) ).scope() with location(f"at {term.yellow(str(arg.name))}", arg.location): userdata_code = "" if arg.userdata and not arg.function.callback_decl: try: callback_arg, = [a for a in arg.function.arguments if a.type.transfer == "NW_CALLBACK"] except ValueError: generate_requires(False, "If ava_userdata is applied to an argument exactly one other argument " "must be annotated with ava_callback.") generate_requires([arg] == [a for a in arg.function.arguments if a.userdata], "Only one argument on a given function can be annotated with ava_userdata.") userdata_code = f""" if ({callback_arg.param_spelling} != NULL) {{ // TODO:MEMORYLEAK: This leaks 2*sizeof(void*) whenever a callback is transported. Should be fixable // with "coupled buffer" framework. struct ava_callback_user_data *__callback_data = malloc(sizeof(struct ava_callback_user_data)); __callback_data->userdata = {arg.param_spelling}; __callback_data->function_pointer = (void*){callback_arg.param_spelling}; {arg.param_spelling} = __callback_data; }} """ return comment_block( f"Input: {arg}", Expr(userdata_code).then( copy_for_value((arg.param_spelling, f"{dest}->{arg.param_spelling}"), arg.type, depth=0, argument=arg, name=arg.name, kernel=copy_for_value, only_complex_buffers=False, self_index=0)))
def command_print_implementation(f: Function): with location(f"at {term.yellow(str(f.name))}", f.location): def printf(format, *values): return f"""fprintf(file, "{format}", {",".join(values)});""" def print_value_deep(values, cast_type: Type, type: Type, depth, no_depends, argument, **other): (value, ) = values if type.is_void: return "" buffer_pred = Expr(type.transfer).equals("NW_BUFFER") & Expr( value).not_equals("NULL") def address(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" tmp_name = f"__tmp_{argument.name}_{depth}" inner_values = (tmp_name, ) data_code = buffer_pred.if_then_else(f""" fprintf(file, " = {{"); {type.nonconst.attach_to(tmp_name)}; {tmp_name} = ({cast_type})({get_transfer_buffer_expr(value, type)}); {for_all_elements(inner_values, cast_type, type, precomputed_size=Expr(1), depth=depth, argument=argument, no_depends=no_depends, **other)} fprintf(file, ",...}}"); """) return f""" {printf("ptr 0x%012lx", f"(long int){value}")} {data_code} """ def handle(): return printf("handle %#lx", f"(long int){value}") def opaque(): st = str(type) if "*" in st: return printf("%#lx", f"(long int){value}") elif "int" in st: return printf("%ld", f"(long int){value}") elif "float" in st or "double" in st: return printf("%Lf", f"(long double){value}") else: # Fall back on pointer representation return printf("%#lx", f"(long int){value}") return Expr(bool( type.fields or argument.depends_on and no_depends)).if_then_else( "", # Using only else branch Expr(type.transfer).equals("NW_BUFFER").if_then_else( address, Expr(type.transfer).equals( "NW_ZEROCOPY_BUFFER").if_then_else( address, Expr(type.transfer).equals( "NW_OPAQUE").if_then_else( opaque, Expr(type.transfer).equals( "NW_HANDLE").if_then_else(handle)), ), ), ) def print_value(argument: Argument, value, type: Type, no_depends): conv = print_value_deep( (value, ), argument.type.nonconst, argument.type, depth=0, name=argument.name, argument=argument, no_depends=no_depends, kernel=print_value_deep, self_index=0, ) return (printf("%s=", f'"{argument.name}"') if not argument.ret else "") + str(conv) print_comma = """ fprintf(file, ", ");\n""" return f""" case {f.call_id_spelling}: {{ \ ava_is_in = 1; ava_is_out = 0; struct {f.call_spelling}* __call = (struct {f.call_spelling}*)__cmd; assert(__call->base.api_id == {f.api.number_spelling}); assert(__call->base.command_size == sizeof(struct {f.call_spelling}) && "Command size does not match ID. (Can be caused by incorrectly computed buffer sizes, especially using `strlen(s)` instead of `strlen(s)+1`)"); {unpack_struct("__call", f.arguments, "->", get_transfer_buffer_expr)} {printf("<%03ld> <thread=%012lx> %s(", "(long int)__call->__call_id", "(unsigned long int)__call->base.thread_id", f'"{f.name}"')} {print_comma.join(str(print_value(a, f"__call->{a.name}", a.type, False)) for a in f.arguments if a.input or not a.type.contains_buffer)} fprintf(file, "){snl}"); break; }} case {f.ret_id_spelling}: {{ \ ava_is_in = 0; ava_is_out = 1; struct {f.ret_spelling}* __ret = (struct {f.ret_spelling}*)__cmd; assert(__ret->base.api_id == {f.api.number_spelling}); assert(__ret->base.command_size == sizeof(struct {f.ret_spelling}) && "Command size does not match ID. (Can be caused by incorrectly computed buffer sizes, especially using `strlen(s)` instead of `strlen(s)+1`)"); {unpack_struct("__ret", ([] if f.return_value.type.is_void else [f.return_value]) + [a for a in f.arguments if a.output and a.type.contains_buffer and not bool(a.depends_on)], "->", get_transfer_buffer_expr)} {printf("<%03ld> <thread=%012lx> %s(", "(long int)__ret->__call_id", "(unsigned long int)__ret->base.thread_id", f'"{f.name}"')} {print_comma.join(str(print_value(a, f"__ret->{a.name}", a.type, True)) for a in f.arguments if a.output and a.type.contains_buffer)} fprintf(file, ") -> "); {print_value(f.return_value, f"__ret->{f.return_value.name}", f.return_value.type, True)} fprintf(file, "{snl}"); break; }} """.strip()
def call_command_implementation(f: Function, enabled_opts: List[str] = None): with location(f"at {term.yellow(str(f.name))}", f.location): alloc_list = AllocList(f) # pylint: disable=possibly-unused-variable is_async = ~Expr(f.synchrony).equals("NW_SYNC") if enabled_opts: # Enable batching optimization: batch the reply command; but for now we do not send reply # at all (because all batched APIs are ava_async). if "batching" in enabled_opts: # TODO: improve the batching logic to get rid of worker_argument_process_code. worker_argument_process_code = ( """ __handle_command_cudart_opt_single(__chan, handle_pool, __log, NULL); """.strip() if f.name == "__do_batch_emit" else "" ) reply_code = is_async.if_then_else( "", """ command_channel_send_command(__chan, (struct command_base*)__ret); """.strip(), ) else: worker_argument_process_code = "" reply_code = """ command_channel_send_command(__chan, (struct command_base*)__ret); """.strip() return f""" case {f.call_id_spelling}: {{\ {timing_code_worker("before_unmarshal", str(f.name), f.generate_timing_code)} ava_is_in = 1; ava_is_out = 0; {alloc_list.alloc} struct {f.call_spelling}* __call = (struct {f.call_spelling}*)__cmd; assert(__call->base.api_id == {f.api.number_spelling}); assert(__call->base.command_size == sizeof(struct {f.call_spelling}) && "Command size does not match ID. (Can be caused by incorrectly computed buffer sizes, expecially using `strlen(s)` instead of `strlen(s)+1`)"); /* Unpack and translate arguments */ {lines(convert_input_for_argument(a, "__call") for a in f.arguments)} {timing_code_worker("after_unmarshal", str(f.name), f.generate_timing_code)} /* Perform Call */ {worker_argument_process_code} {call_function_wrapper(f)} {timing_code_worker("after_execution", str(f.name), f.generate_timing_code)} ava_is_in = 0; ava_is_out = 1; {compute_total_size(f.arguments + [f.return_value], lambda a: a.output)} struct {f.ret_spelling}* __ret = (struct {f.ret_spelling}*)command_channel_new_command( __chan, sizeof(struct {f.ret_spelling}), __total_buffer_size); __ret->base.api_id = {f.api.number_spelling}; __ret->base.command_id = {f.ret_id_spelling}; __ret->base.thread_id = __call->base.original_thread_id; __ret->__call_id = __call->__call_id; {convert_result_for_argument(f.return_value, "__ret") if not f.return_value.type.is_void else ""} {lines(convert_result_for_argument(a, "__ret") for a in f.arguments if a.type.contains_buffer)} #ifdef AVA_RECORD_REPLAY {log_call_declaration} {log_ret_declaration} {lines( record_argument_metadata(a) for a in f.arguments)} {record_argument_metadata(f.return_value) if not f.return_value.type.is_void else ""} {record_call_metadata("NULL", None) if f.object_record else ""} #endif {timing_code_worker("after_marshal", str(f.name), f.generate_timing_code)} /* Send reply message */ {reply_code} {alloc_list.dealloc} {lines(deallocate_managed_for_argument(a, "") for a in f.arguments)} break; }} """.strip()