def compute_total_size(args: Iterable[Argument], copy_pred: Callable[[Argument], Expr]) -> str: """ Sum the sizes of all the buffers created by all arguments. :param args: All the arguments to sum. :param copy_pred: A function which returns true if the data associated with this argument will be copied. :return: A series of C statements. """ size = "__total_buffer_size" def compute_size(values, type: Type, depth, argument: Argument, original_type=None, **other): if isinstance(type, ConditionalType): return Expr(type.predicate).if_then_else( compute_size(values, type.then_type, depth, argument, original_type=type.original_type, **other), compute_size(values, type.else_type, depth, argument, original_type=type.original_type, **other)) value, = values pred = Expr(type.transfer).equals("NW_BUFFER") & Expr( value).not_equals("NULL") & (Expr(type.buffer) > 0) def add_buffer_size(): size_expr = size_to_bytes(compute_buffer_size(type, original_type), type) return Expr(copy_pred(argument)).if_then_else( type.lifetime.equals("AVA_CALL").if_then_else( f"{size} += command_channel_buffer_size(__chan, {size_expr});\n", f"{size} += ava_shadow_buffer_size(&__ava_endpoint, __chan, {size_expr});\n" ), type.lifetime.equals("AVA_CALL").if_then_else( "", f"{size} += ava_shadow_buffer_size_without_data(&__ava_endpoint, __chan, {size_expr});\n" )) def simple_buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" return pred.if_then_else(add_buffer_size) def buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" loop = for_all_elements(values, type, depth=depth, argument=argument, original_type=original_type, **other) outer_buffer = str(add_buffer_size()) return pred.if_then_else(loop + outer_buffer) if type.fields: return for_all_elements(values, type, depth=depth, argument=argument, original_type=original_type, **other) return type.is_simple_buffer(allow_handle=True).if_then_else( simple_buffer_case, lambda: Expr(type.transfer).equals("NW_BUFFER") .if_then_else(buffer_case)) size_code = lines( comment_block( f"Size: {a}", compute_size((a.name, ), a.type, depth=0, name=a.name, kernel=compute_size, only_complex_buffers=False, argument=a, self_index=0)) for a in args) return f"size_t {size} = 0;{nl}{{ {size_code} }}"
def deallocate_managed_for_argument(arg: Argument, src): def convert_result_value(values, type: Type, depth, original_type=None, **other): if isinstance(type, ConditionalType): return Expr(type.predicate).if_then_else( convert_result_value(values, type.then_type, depth, original_type=type.original_type, **other), convert_result_value(values, type.else_type, depth, original_type=type.original_type, **other)) local_value, = values buffer_pred = (Expr(type.transfer).equals("NW_BUFFER") & f"{local_value} != NULL") dealloc_shadows = Expr(type.deallocates).if_then_else( f"ava_shadow_buffer_free_coupled(&__ava_endpoint, (void *){local_value});" ) def simple_buffer_case(): return "" def buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" return buffer_pred.if_then_else( for_all_elements(values, type, depth=depth, original_type=original_type, **other)) def default_case(): dealloc_code = Expr(type.deallocates).if_then_else( Expr(type.transfer).equals("NW_HANDLE").if_then_else(f""" ava_coupled_free(&__ava_endpoint, {local_value}); """.strip())) return dealloc_code if type.fields: return for_all_elements(values, type, depth=depth, original_type=original_type, **other) return type.is_simple_buffer(allow_handle=False).if_then_else( simple_buffer_case, Expr(type.transfer).equals("NW_BUFFER").if_then_else( buffer_case, (Expr(type.transfer).one_of({ "NW_OPAQUE", "NW_HANDLE" })).if_then_else(default_case))).then(dealloc_shadows).scope() with location(f"at {term.yellow(str(arg.name))}", arg.location): conv = convert_result_value( (f"""{src + "->" if src else ""}{arg.name}""", ), arg.type, depth=0, name=arg.name, kernel=convert_result_value, self_index=0) return comment_block(f"Dealloc: {arg}", conv)
def convert_input_for_argument(arg: Argument, src): """ Generate code to extract the value for arg from the call structure in src. The value of arg is left in a variable named arg.name. The value is fully converted to local values. This code used in the command receiver to implement a CALL command. :param arg: The argument to extract. :param src: The CALL command structure. :return: A series of C statements to perform the extraction. """ alloc_list = AllocList(arg.function) def convert_input_value(values, type: Type, depth, original_type=None, **other): local_value, param_value = values preassignment = f"{local_value} = {get_transfer_buffer_expr(param_value, type)};" if isinstance(type, ConditionalType): return Expr(preassignment).then( Expr(type.predicate).if_then_else( convert_input_value(values, type.then_type, depth, original_type=type.original_type, **other), convert_input_value(values, type.else_type, depth, original_type=type.original_type, **other))) if type.is_void: return """abort_with_reason("Reached code to handle void value.");""" def maybe_alloc_local_temporary_buffer(): # TODO: Deduplicate with allocate_tmp_buffer allocator = type.buffer_allocator deallocator = type.buffer_deallocator return Expr(param_value).not_equals("NULL").if_then_else(f"""{{ const size_t __size = {compute_buffer_size(type, original_type)}; {local_value} = ({type.nonconst.spelling}){allocator}({size_to_bytes("__size", type)}); {alloc_list.insert(local_value, deallocator)} }}""") src_name = f"__src_{arg.name}_{depth}" def get_buffer_code(): return f""" {type.nonconst.attach_to(src_name)}; {src_name} = {local_value}; {get_buffer(local_value, param_value, type, original_type=original_type, not_null=True)} {(type.lifetime.equals("AVA_CALL") & (~type.is_simple_buffer() | type.buffer_allocator.not_equals("malloc"))).if_then_else( maybe_alloc_local_temporary_buffer)} """ def simple_buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" copy_code = ( Expr(arg.input) & Expr(local_value).not_equals(src_name) ).if_then_else( f"""memcpy({local_value}, {src_name}, {size_to_bytes("__buffer_size", type)});""" ) return ((type.lifetime.not_equals("AVA_CALL") | arg.input) & Expr(param_value).not_equals("NULL")).if_then_else( f""" {get_buffer_code()} {copy_code} """.strip(), (Expr(arg.input) | type.transfer.equals("NW_ZEROCOPY_BUFFER") ).if_then_else(preassignment, maybe_alloc_local_temporary_buffer)) def buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" if not arg.input: return simple_buffer_case() inner_values = (local_value, src_name) core = for_all_elements(inner_values, type, depth=depth, precomputed_size="__buffer_size", original_type=original_type, **other) return ((type.lifetime.not_equals("AVA_CALL") | arg.input) & Expr(param_value).not_equals("NULL")).if_then_else( f""" {get_buffer_code()} {core} """.strip(), maybe_alloc_local_temporary_buffer) def default_case(): def deref_code(handlepool_function: str) -> callable: return lambda: (Expr(type.transfer).one_of({ "NW_CALLBACK", "NW_CALLBACK_REGISTRATION" })).if_then_else( f"{local_value} = ({param_value} == NULL) ? NULL : {type.callback_stub_function};", (Expr(type.transfer).equals("NW_HANDLE")).if_then_else( f"{local_value} = ({type.nonconst.spelling}){handlepool_function}(handle_pool, (void*){param_value});", Expr(not type.is_void).if_then_else( f"{local_value} = {param_value};", """abort_with_reason("Reached code to handle void value.");""" ))) return Expr(type.deallocates).if_then_else( deref_code("nw_handle_pool_deref_and_remove"), deref_code("nw_handle_pool_deref")) if type.fields: return for_all_elements(values, type, depth=depth, **other) rest = type.is_simple_buffer().if_then_else( simple_buffer_case, Expr(type.transfer).equals("NW_BUFFER").if_then_else( buffer_case, default_case)) if rest: return Expr(preassignment).then(rest).scope() else: return "" with location(f"at {term.yellow(str(arg.name))}", arg.location): conv = convert_input_value((arg.name, f"{src}->{arg.param_spelling}"), arg.type, depth=0, name=arg.name, kernel=convert_input_value, original_type=arg.type, self_index=0) return comment_block( f"Input: {arg}", f"""\ {arg.type.nonconst.attach_to(arg.name)}; \ {conv} """)
def convert_result_for_argument(arg: Argument, dest) -> ExprOrStr: """ Take the value of arg in the local scope and write it into dest. :param arg: The argument to place in the output. :param dest: A RET command struct pointer. :return: A series of C statements. """ alloc_list = AllocList(arg.function) def convert_result_value(values, type: Type, depth, original_type=None, **other) -> str: if isinstance(type, ConditionalType): return Expr(type.predicate).if_then_else( convert_result_value(values, type.then_type, depth, original_type=type.original_type, **other), convert_result_value(values, type.else_type, depth, original_type=type.original_type, **other)) if type.is_void: return """abort_with_reason("Reached code to handle void value.");""" param_value, local_value = values def attach_data(data): return attach_buffer(param_value, local_value, data, type, arg.output, cmd=dest, original_type=original_type, expect_reply=False) def simple_buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" return (Expr(local_value).not_equals("NULL") & (Expr(type.buffer) > 0)).if_then_else( attach_data(local_value), f"{param_value} = NULL;") def buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" if not arg.output: return simple_buffer_case() tmp_name = f"__tmp_{arg.name}_{depth}" size_name = f"__size_{arg.name}_{depth}" inner_values = (tmp_name, local_value) return Expr(local_value).not_equals("NULL").if_then_else( f"""{{ {allocate_tmp_buffer(tmp_name, size_name, type, alloc_list=alloc_list, original_type=original_type)} {for_all_elements(inner_values, type, precomputed_size=size_name, depth=depth, original_type=original_type, **other)} {attach_data(tmp_name)} }}""", f"{param_value} = NULL;") def default_case(): handlepool_function = "nw_handle_pool_lookup_or_insert" return Expr(type.transfer).equals("NW_HANDLE").if_then_else( Expr(type.deallocates).if_then_else( f"{param_value} = NULL;", f"{param_value} = ({type.nonconst.spelling}){handlepool_function}(handle_pool, (void*){local_value});" ), Expr(not type.is_void).if_then_else( f"{param_value} = {local_value};")) if type.fields: return for_all_elements(values, type, depth=depth, original_type=original_type, **other) return type.is_simple_buffer().if_then_else( simple_buffer_case, Expr(type.transfer).equals("NW_BUFFER").if_then_else( buffer_case, default_case)).scope() with location(f"at {term.yellow(str(arg.name))}", arg.location): conv = convert_result_value( (f"{dest}->{arg.param_spelling}", f"{arg.name}"), arg.type, depth=0, name=arg.name, kernel=convert_result_value, self_index=1) return (Expr(arg.output) | arg.ret).if_then_else( comment_block(f"Output: {arg}", conv))
def attach_for_argument(arg: Argument, dest): """ Copy arg into dest attaching buffers as needed. :param arg: The argument to copy. :param dest: The destination CALL struct. :return: A series of C statements. """ alloc_list = AllocList(arg.function) def copy_for_value(values, type: Type, depth, argument, original_type=None, **other): if isinstance(type, ConditionalType): return Expr(type.predicate).if_then_else( copy_for_value(values, type.then_type, depth, argument, original_type=type.original_type, **other), copy_for_value(values, type.else_type, depth, argument, original_type=type.original_type, **other)) arg_value, cmd_value = values def attach_data(data): return attach_buffer(cmd_value, arg_value, data, type, arg.input, cmd=dest, original_type=original_type, expect_reply=True) def simple_buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" return (Expr(arg_value).not_equals("NULL") & (Expr(type.buffer) > 0)).if_then_else( attach_data(arg_value), f"{cmd_value} = NULL;") def buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" if not arg.input: return simple_buffer_case() tmp_name = f"__tmp_{arg.name}_{depth}" size_name = f"__size_{arg.name}_{depth}" loop = for_all_elements((arg_value, tmp_name), type, depth=depth, argument=argument, precomputed_size=size_name, original_type=original_type, **other) return (Expr(arg_value).not_equals("NULL") & (Expr(type.buffer) > 0)).if_then_else( f""" {allocate_tmp_buffer(tmp_name, size_name, type, alloc_list=alloc_list, original_type=original_type)} {loop} {attach_data(tmp_name)} """, f"{cmd_value} = NULL;" ) def default_case(): return Expr(not type.is_void).if_then_else( f"{cmd_value} = {arg_value};", """abort_with_reason("Reached code to handle void value.");""") if type.fields: return for_all_elements(values, type, depth=depth, argument=argument, original_type=original_type, **other) return type.is_simple_buffer(allow_handle=True).if_then_else( simple_buffer_case, Expr(type.transfer).equals("NW_BUFFER").if_then_else( buffer_case, default_case ) ).scope() with location(f"at {term.yellow(str(arg.name))}", arg.location): userdata_code = "" if arg.userdata and not arg.function.callback_decl: try: callback_arg, = [a for a in arg.function.arguments if a.type.transfer == "NW_CALLBACK"] except ValueError: generate_requires(False, "If ava_userdata is applied to an argument exactly one other argument " "must be annotated with ava_callback.") generate_requires([arg] == [a for a in arg.function.arguments if a.userdata], "Only one argument on a given function can be annotated with ava_userdata.") userdata_code = f""" if ({callback_arg.param_spelling} != NULL) {{ // TODO:MEMORYLEAK: This leaks 2*sizeof(void*) whenever a callback is transported. Should be fixable // with "coupled buffer" framework. struct ava_callback_user_data *__callback_data = malloc(sizeof(struct ava_callback_user_data)); __callback_data->userdata = {arg.param_spelling}; __callback_data->function_pointer = (void*){callback_arg.param_spelling}; {arg.param_spelling} = __callback_data; }} """ return comment_block( f"Input: {arg}", Expr(userdata_code).then( copy_for_value((arg.param_spelling, f"{dest}->{arg.param_spelling}"), arg.type, depth=0, argument=arg, name=arg.name, kernel=copy_for_value, only_complex_buffers=False, self_index=0)))
def copy_result_for_argument(arg: Argument, dest, src) -> ExprOrStr: """ Copy arg from the src struct into dest struct. :param arg: The argument to copy. :param dest: The destination call record struct for the call. :param src: The source RET struct from the call. :return: A series C statements. """ reported_missing_lifetime = False def convert_result_value(values, type: Type, depth, original_type=None, **other): if isinstance(type, ConditionalType): return Expr(type.predicate).if_then_else( convert_result_value(values, type.then_type, depth, original_type=type.original_type, **other), convert_result_value(values, type.else_type, depth, original_type=type.original_type, **other)) param_value, local_value = values src_name = f"__src_{arg.name}_{depth}" def get_buffer_code(): nonlocal reported_missing_lifetime if not reported_missing_lifetime and \ ((arg.ret or arg.output and depth > 0) and type.buffer) and \ type.lifetime == "AVA_CALL": reported_missing_lifetime = True generate_expects( False, "Returned buffers with call lifetime are almost always incorrect. (You may want to set a lifetime.)") return Expr(f""" {DECLARE_BUFFER_SIZE_EXPR} {type.attach_to(src_name)}; {src_name} = {get_transfer_buffer_expr(local_value, type, not_null=True)}; """).then( Expr(type.lifetime).not_equals("AVA_CALL").if_then_else( f"""{get_buffer(param_value, local_value, type, original_type=original_type, not_null=True, declare_buffer_size=False)}""", f"""__buffer_size = {compute_buffer_size(type, original_type)};""" ).then( Expr(arg.output).if_then_else(f"AVA_DEBUG_ASSERT({param_value} != NULL);") )) def simple_buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" copy_code = Expr(arg.output).if_then_else( f"""memcpy({param_value}, {src_name}, {size_to_bytes("__buffer_size", type)});""") if copy_code: return Expr(local_value).not_equals("NULL").if_then_else( f""" {get_buffer_code()} {copy_code} """.strip() ) else: return "" def buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" if not arg.output: return simple_buffer_case() inner_values = (param_value, src_name) loop = for_all_elements(inner_values, type, depth=depth, precomputed_size="__buffer_size", original_type=original_type, **other) if loop: return Expr(local_value).not_equals("NULL").if_then_else( f""" {get_buffer_code()} {loop} """ ) else: return "" def default_case(): dealloc_code = (Expr(type.transfer).equals("NW_HANDLE") & type.deallocates).if_then_else( f""" ava_coupled_free(&__ava_endpoint, {local_value}); """.strip() ) return dealloc_code.then((Expr(arg.output) | arg.ret).if_then_else(f"{param_value} = {local_value};")) if type.fields: return for_all_elements(values, type, depth=depth, original_type=original_type, **other) return type.is_simple_buffer(allow_handle=False).if_then_else( simple_buffer_case, Expr(type.transfer).equals("NW_BUFFER").if_then_else( buffer_case, Expr(type.transfer).one_of({"NW_OPAQUE", "NW_HANDLE"}).if_then_else( default_case ) ) ).scope() with location(f"at {term.yellow(str(arg.name))}", arg.location): conv = convert_result_value((f"{dest}->{arg.param_spelling}", f"{src}->{arg.name}"), arg.type, depth=0, name=arg.name, kernel=convert_result_value, self_index=0) return comment_block(f"Output: {arg}", conv)
def assign_original_handle_for_argument(arg: Argument, original: str): def convert_result_value(values, cast_type: Type, type: Type, depth, original_type=None, **other) -> str: if isinstance(type, ConditionalType): return Expr(type.predicate).if_then_else( convert_result_value(values, type.then_type.nonconst, type.then_type, depth, original_type=type.original_type, **other), convert_result_value(values, type.then_type.nonconst, type.else_type, depth, original_type=type.original_type, **other), ) if type.is_void: return """abort_with_reason("Reached code to handle void value.");""" original_value, local_value = values buffer_pred = (Expr(type.transfer).equals("NW_BUFFER") & Expr(local_value).not_equals("NULL") & (Expr(type.buffer) > 0)) def simple_buffer_case(): return "" def buffer_case(): if not hasattr(type, "pointee"): return """abort_with_reason("Reached code to handle buffer in non-pointer type.");""" tmp_name = f"__tmp_{arg.name}_{depth}" inner_values = (tmp_name, local_value) return buffer_pred.if_then_else(f""" {type.nonconst.attach_to(tmp_name)}; {get_buffer(tmp_name, cast_type, original_value, type, original_type=original_type)} {for_all_elements(inner_values, cast_type, type, depth=depth, original_type=original_type, **other)} """) def default_case(): return (Expr(type.transfer).equals("NW_HANDLE").if_then_else( (~Expr(type.deallocates)).if_then_else( f"nw_handle_pool_assign_handle(handle_pool, (void*){original_value}, (void*){local_value});" ), ((Expr(arg.ret) | Expr(type.transfer).equals("NW_OPAQUE")) & Expr(not isinstance(type, FunctionPointer)) ).if_then_else(f"assert({original_value} == {local_value});"), )) if type.fields: return for_all_elements(values, cast_type, type, depth=depth, original_type=original_type, **other) return (type.is_simple_buffer().if_then_else( simple_buffer_case, Expr(type.transfer).equals("NW_BUFFER").if_then_else( buffer_case, default_case)).scope()) with location(f"at {term.yellow(str(arg.name))}", arg.location): conv = convert_result_value( (f"{original}->{arg.param_spelling}", f"{arg.name}"), arg.type.nonconst, arg.type, depth=0, name=arg.name, kernel=convert_result_value, self_index=1, ) return (Expr(arg.output) | arg.ret).if_then_else( comment_block(f"Assign or check: {arg}", conv))