def get_call_param_name(op_rec): api_method_rec = op_rec.column.api_method_rec # type: APIMethodRecord pos_past_normals = None if op_rec.position is None else op_rec.position - len(api_method_rec.non_args_kwargs) # There are *args in the signature. if api_method_rec.args_name is not None: # op_rec has no name -> Can only be normal arg or one of *args. if op_rec.kwarg is None: # Position is higher than number of "normal" args -> must be one of *args. if pos_past_normals >= 0: param_name = api_method_rec.args_name + "[{}]".format(pos_past_normals) # Normal arg (not part of *args). else: param_name = api_method_rec.input_names[op_rec.position] # op_rec has name -> Can only be normal arg of one of **kwargs (if any). else: if op_rec.kwarg in api_method_rec.non_args_kwargs: param_name = op_rec.kwarg else: if api_method_rec.kwargs_name is None: raise RLGraphAPICallParamError( "ERROR: API-method '{}' has no **kwargs, but op-rec {} indicates that it has kwarg '{}'!". format(api_method_rec.name, op_rec.id, op_rec.kwarg) ) param_name = api_method_rec.kwargs_name + "[{}]".format(op_rec.kwarg) # There are *kwargs in the signature. elif api_method_rec.kwargs_name is not None: # op_rec has no name -> Can only be a normal arg. if op_rec.kwarg is None: # Position is higher than number of "normal" args -> ERROR. if pos_past_normals >= 0: raise RLGraphAPICallParamError( "Op-rec '{}' has no kwarg, but its position ({}) indicates that it's part " "of {}'s **kwargs!".format(op_rec.id, op_rec.position, api_method_rec.name) ) # Normal arg (by position). else: param_name = api_method_rec.input_names[op_rec.position] # op_rec has name -> Can only be normal arg of one of **kwargs. else: if op_rec.kwarg in api_method_rec.non_args_kwargs: param_name = op_rec.kwarg else: param_name = api_method_rec.kwargs_name + "[{}]".format(op_rec.kwarg) else: # op_rec has no name -> Can only be normal arg. if op_rec.kwarg is None: # Position is higher than number of "normal" args -> ERROR. if pos_past_normals >= 0: raise RLGraphAPICallParamError( "Op-rec {}'s position ({}) is higher than {}'s number of args!". format(op_rec.id, op_rec.position, api_method_rec.name) ) # Normal arg (by position). else: param_name = api_method_rec.input_names[op_rec.position] # op_rec has name -> Can only be normal arg. else: if op_rec.kwarg in api_method_rec.non_args_kwargs: param_name = op_rec.kwarg else: raise RLGraphAPICallParamError( "Op-rec's kwarg ({}) is not an parameter of API-method {}'s signature!". format(op_rec.kwarg, api_method_rec.name) ) return param_name
def api_method_wrapper(self, *args, **kwargs): api_fn_name = name or re.sub(r'^_graph_fn_', "", wrapped_func.__name__) # Direct evaluation of function. if self.execution_mode == "define_by_run": type(self).call_count += 1 start = time.perf_counter() # Check with owner if extra args needed. if api_fn_name in self.api_methods and self.api_methods[ api_fn_name].add_auto_key_as_first_param: output = wrapped_func(self, "", *args, **kwargs) else: output = wrapped_func(self, *args, **kwargs) # Store runtime for this method. type(self).call_times.append( # Component.call_times (self.name, wrapped_func.__name__, time.perf_counter() - start)) return output api_method_rec = self.api_methods[api_fn_name] # Sanity check input args for accidential dict-return values being passed into the next API as # supposed DataOpRecord. dict_args = [ next(iter(a.values())) for a in args if isinstance(a, dict) ] if len(dict_args) > 0 and isinstance(dict_args[0], DataOpRecord): raise RLGraphError( "One of your input args to API-method '{}.{}()' is a dict of DataOpRecords! This is probably " "coming from a previous call to an API-method (returning a dict) and the DataOpRecord should be " "extracted by string-key and passed into '{}' " "directly.".format(api_method_rec.component.global_scope, api_fn_name, api_fn_name)) # Create op-record column to call API method with. Ignore None input params. These should not be sent # to the API-method. in_op_column = DataOpRecordColumnIntoAPIMethod( component=self, api_method_rec=api_method_rec, args=args, kwargs=kwargs) # Add the column to the API-method record. api_method_rec.in_op_columns.append(in_op_column) # Check minimum number of passed args. minimum_num_call_params = len(in_op_column.api_method_rec.non_args_kwargs) - \ len(in_op_column.api_method_rec.default_args) if len(in_op_column.op_records) < minimum_num_call_params: raise RLGraphAPICallParamError( "Number of call params ({}) for call to API-method '{}' is too low. Needs to be at least {} " "params!".format(len(in_op_column.op_records), api_method_rec.name, minimum_num_call_params)) # Link from incoming op_recs into the new column or populate new column with ops/Spaces (this happens # if this call was made from within a graph_fn such that ops and Spaces are already known). all_args = [(i, a) for i, a in enumerate(args) if a is not None] + \ [(k, v) for k, v in sorted(kwargs.items()) if v is not None] flex = None build_when_done = False for i, (key, value) in enumerate(all_args): # Named arg/kwarg -> get input_name from that and peel op_rec. if isinstance(key, str): param_name = key # Positional arg -> get input_name from input_names list. else: slot = key if flex is None else flex if slot >= len(api_method_rec.input_names): raise RLGraphAPICallParamError( "Too many input args given in call to API-method '{}'!" .format(api_method_rec.name)) param_name = api_method_rec.input_names[slot] # Var-positional arg, attach the actual position to input_name string. if self.api_method_inputs.get(param_name, "") == "*flex": if flex is None: flex = i param_name += "[{}]".format(i - flex) # Actual kwarg (not in list of api_method_inputs). elif api_method_rec.kwargs_name is not None and param_name not in self.api_method_inputs: param_name = api_method_rec.kwargs_name + "[{}]".format( param_name) # We are already in building phase (params may be coming from inside graph_fn). if self.graph_builder is not None and self.graph_builder.phase == "building": # If Space not stored yet, determine it from op. assert in_op_column.op_records[i].op is not None if in_op_column.op_records[i].space is None: in_op_column.op_records[i].space = get_space_from_op( in_op_column.op_records[i].op) self.api_method_inputs[ param_name] = in_op_column.op_records[i].space # Check input-completeness of Component (but not strict as we are only calling API, not a graph_fn). if self.input_complete is False: # Build right after this loop in case more Space information comes in through next args/kwargs. build_when_done = True # A DataOpRecord from the meta-graph. elif isinstance(value, DataOpRecord): # Create entry with unknown Space if it doesn't exist yet. if param_name not in self.api_method_inputs: self.api_method_inputs[param_name] = None # Fixed value (instead of op-record): Store the fixed value directly in the op. else: if self.api_method_inputs.get(param_name) is None: self.api_method_inputs[ param_name] = in_op_column.op_records[i].space if build_when_done: # Check Spaces and create variables. self.graph_builder.build_component_when_input_complete(self) # Regular API-method: Call it here. api_fn_args, api_fn_kwargs = in_op_column.get_args_and_kwargs() if api_method_rec.is_graph_fn_wrapper is False: return_values = wrapped_func(self, *api_fn_args, **api_fn_kwargs) # Wrapped graph_fn: Call it through yet another wrapper. else: return_values = graph_fn_wrapper( self, wrapped_func, returns, dict( flatten_ops=flatten_ops, split_ops=split_ops, add_auto_key_as_first_param=add_auto_key_as_first_param, requires_variable_completeness= requires_variable_completeness), *api_fn_args, **api_fn_kwargs) # Process the results (push into a column). out_op_column = DataOpRecordColumnFromAPIMethod( component=self, api_method_name=api_fn_name, args=util.force_tuple(return_values) if type(return_values) != dict else None, kwargs=return_values if type(return_values) == dict else None) # If we already have actual op(s) and Space(s), push them already into the # DataOpRecordColumnFromAPIMethod's records. if self.graph_builder is not None and self.graph_builder.phase == "building": # Link the returned ops to that new out-column. for i, rec in enumerate(out_op_column.op_records): out_op_column.op_records[i].op = rec.op out_op_column.op_records[i].space = rec.space # And append the new out-column to the api-method-rec. api_method_rec.out_op_columns.append(out_op_column) # Do we need to return the raw ops or the op-recs? # Only need to check if False, otherwise, we return ops directly anyway. return_ops = False stack = inspect.stack() f_locals = stack[1][0].f_locals # We may be in a list comprehension, try next frame. if f_locals.get(".0"): f_locals = stack[2][0].f_locals # Check whether the caller component is a parent of this one. caller_component = f_locals.get( "root", f_locals.get("self_", f_locals.get("self"))) # Potential call from a lambda. if caller_component is None and "fn" in stack[2][0].f_locals: # This is the component. prev_caller_component = TraceContext.PREV_CALLER lambda_obj = stack[2][0].f_locals["fn"] if "lambda" in inspect.getsource(lambda_obj): # Try to reconstruct caller by using parent of prior caller. caller_component = prev_caller_component.parent_component if caller_component is None: raise RLGraphError( "API-method '{}' must have as 1st parameter (the component) either `root` or `self`. Other names " "are not allowed!".format(api_method_rec.name)) # Not directly called by this method itself (auto-helper-component-API-call). # AND call is coming from some caller Component, but that component is not this component # OR a parent -> Error. elif caller_component is not None and \ type(caller_component).__name__ != "MetaGraphBuilder" and \ caller_component not in [self] + self.get_parents(): if not (stack[1][3] == "__init__" and re.search(r'op_records\.py$', stack[1][1])): raise RLGraphError( "The component '{}' is not a child (or grand-child) of the caller ({})! Maybe you forgot to " "add it as a sub-component via `add_components()`.". format(self.global_scope, caller_component.global_scope)) # Update trace context. TraceContext.PREV_CALLER = caller_component for stack_item in stack[1:]: # skip current frame # If we hit an API-method call -> return op-recs. if stack_item[3] == "api_method_wrapper" and re.search( r'decorators\.py$', stack_item[1]): break # If we hit a graph_fn call -> return ops. elif stack_item[3] == "run_through_graph_fn" and re.search( r'graph_builder\.py$', stack_item[1]): return_ops = True break if return_ops is True: if type(return_values) == dict: return { key: value.op for key, value in out_op_column.get_args_and_kwargs() [1].items() } else: tuple_returns = tuple( map(lambda x: x.op, out_op_column.get_args_and_kwargs()[0])) return tuple_returns[0] if len( tuple_returns) == 1 else tuple_returns # Parent caller is non-graph_fn: Return op-recs. else: if type(return_values) == dict: return return_values else: tuple_returns = out_op_column.get_args_and_kwargs()[0] return tuple_returns[0] if len( tuple_returns) == 1 else tuple_returns
def api_method_wrapper(self, *args, **kwargs): name_ = name or re.sub(r'^_graph_fn_', "", wrapped_func.__name__) return_ops = kwargs.pop("return_ops", False) # Direct evaluation of function. if self.execution_mode == "define_by_run": type(self).call_count += 1 start = time.perf_counter() # Check with owner if extra args needed. if name_ in self.api_methods and self.api_methods[name_].add_auto_key_as_first_param: output = wrapped_func(self, "", *args, **kwargs) else: output = wrapped_func(self, *args, **kwargs) # Store runtime for this method. type(self).call_times.append( # Component.call_times (self.name, wrapped_func.__name__, time.perf_counter() - start) ) return output api_method_rec = self.api_methods[name_] # Create op-record column to call API method with. Ignore None input params. These should not be sent # to the API-method. in_op_column = DataOpRecordColumnIntoAPIMethod( component=self, api_method_rec=api_method_rec, args=args, kwargs=kwargs ) # Add the column to the API-method record. api_method_rec.in_op_columns.append(in_op_column) # Check minimum number of passed args. minimum_num_call_params = len(in_op_column.api_method_rec.non_args_kwargs) - \ len(in_op_column.api_method_rec.default_args) if len(in_op_column.op_records) < minimum_num_call_params: raise RLGraphAPICallParamError( "Number of call params ({}) for call to API-method '{}' is too low. Needs to be at least {} " "params!".format(len(in_op_column.op_records), api_method_rec.name, minimum_num_call_params) ) # Link from incoming op_recs into the new column or populate new column with ops/Spaces (this happens # if this call was made from within a graph_fn such that ops and Spaces are already known). all_args = [(i, a) for i, a in enumerate(args) if a is not None] + \ [(k, v) for k, v in sorted(kwargs.items()) if v is not None] flex = None for i, (key, value) in enumerate(all_args): # Named arg/kwarg -> get input_name from that and peel op_rec. if isinstance(key, str): param_name = key # Positional arg -> get input_name from input_names list. else: slot = key if flex is None else flex if slot >= len(api_method_rec.input_names): raise RLGraphAPICallParamError( "Too many input args given in call to API-method '{}'!".format(api_method_rec.name) ) param_name = api_method_rec.input_names[slot] # Var-positional arg, attach the actual position to input_name string. if self.api_method_inputs[param_name] == "*flex": if flex is None: flex = i param_name += "[{}]".format(i - flex) # We are already in building phase (params may be coming from inside graph_fn). if self.graph_builder is not None and self.graph_builder.phase == "building": self.api_method_inputs[param_name] = in_op_column.op_records[i].space # Check input-completeness of Component (but not strict as we are only calling API, not a graph_fn). if self.input_complete is False: # Check Spaces and create variables. self.graph_builder.build_component_when_input_complete(self) # A DataOpRecord from the meta-graph. elif isinstance(value, DataOpRecord): if param_name not in self.api_method_inputs: self.api_method_inputs[param_name] = None # Fixed value (instead of op-record): Store the fixed value directly in the op. else: #in_op_column.op_records[i].space = get_space_from_op(value) if param_name not in self.api_method_inputs or self.api_method_inputs[param_name] is None: self.api_method_inputs[param_name] = in_op_column.op_records[i].space # Regular API-method: Call it here. args_, kwargs_ = in_op_column.get_args_and_kwargs() if api_method_rec.is_graph_fn_wrapper is False: return_values = wrapped_func(self, *args_, **kwargs_) # Wrapped graph_fn: Call it through yet another wrapper. else: return_values = graph_fn_wrapper( self, wrapped_func, returns, dict( flatten_ops=flatten_ops, split_ops=split_ops, add_auto_key_as_first_param=add_auto_key_as_first_param ), *args_, **kwargs_ ) # Process the results (push into a column). out_op_column = DataOpRecordColumnFromAPIMethod( component=self, api_method_name=name_, args=util.force_tuple(return_values) if type(return_values) != dict else None, kwargs=return_values if type(return_values) == dict else None ) # If we already have actual op(s) and Space(s), push them already into the # DataOpRecordColumnFromAPIMethod's records. if self.graph_builder is not None and self.graph_builder.phase == "building": # Link the returned ops to that new out-column. for i, rec in enumerate(out_op_column.op_records): out_op_column.op_records[i].op = rec.op out_op_column.op_records[i].space = rec.space # And append the new out-column to the api-method-rec. api_method_rec.out_op_columns.append(out_op_column) # Do we need to return the raw ops or the op-recs? # Direct parent caller is a `_graph_fn_...`: Return raw ops. stack = inspect.stack() if return_ops is True or re.match(r'^_graph_fn_.+$', stack[1][3]): if type(return_values) == dict: return {key: value.op for key, value in out_op_column.get_args_and_kwargs()[1].items()} else: tuple_ = tuple(map(lambda x: x.op, out_op_column.get_args_and_kwargs()[0])) return tuple_[0] if len(tuple_) == 1 else tuple_ # Parent caller is non-graph_fn: Return op-recs. else: if type(return_values) == dict: return return_values else: tuple_ = out_op_column.get_args_and_kwargs()[0] return tuple_[0] if len(tuple_) == 1 else tuple_