Ejemplo n.º 1
0
def get_call_param_name(op_rec):
    api_method_rec = op_rec.column.api_method_rec  # type: APIMethodRecord
    pos_past_normals = None if op_rec.position is None else op_rec.position - len(api_method_rec.non_args_kwargs)

    # There are *args in the signature.
    if api_method_rec.args_name is not None:
        # op_rec has no name -> Can only be normal arg or one of *args.
        if op_rec.kwarg is None:
            # Position is higher than number of "normal" args -> must be one of *args.
            if pos_past_normals >= 0:
                param_name = api_method_rec.args_name + "[{}]".format(pos_past_normals)
            # Normal arg (not part of *args).
            else:
                param_name = api_method_rec.input_names[op_rec.position]
        # op_rec has name -> Can only be normal arg of one of **kwargs (if any).
        else:
            if op_rec.kwarg in api_method_rec.non_args_kwargs:
                param_name = op_rec.kwarg
            else:
                if api_method_rec.kwargs_name is None:
                    raise RLGraphAPICallParamError(
                        "ERROR: API-method '{}' has no **kwargs, but op-rec {} indicates that it has kwarg '{}'!".
                        format(api_method_rec.name, op_rec.id, op_rec.kwarg)
                    )
                param_name = api_method_rec.kwargs_name + "[{}]".format(op_rec.kwarg)
    # There are *kwargs in the signature.
    elif api_method_rec.kwargs_name is not None:
        # op_rec has no name -> Can only be a normal arg.
        if op_rec.kwarg is None:
            # Position is higher than number of "normal" args -> ERROR.
            if pos_past_normals >= 0:
                raise RLGraphAPICallParamError(
                    "Op-rec '{}' has no kwarg, but its position ({}) indicates that it's part "
                    "of {}'s **kwargs!".format(op_rec.id, op_rec.position, api_method_rec.name)
                )
            # Normal arg (by position).
            else:
                param_name = api_method_rec.input_names[op_rec.position]
        # op_rec has name -> Can only be normal arg of one of **kwargs.
        else:
            if op_rec.kwarg in api_method_rec.non_args_kwargs:
                param_name = op_rec.kwarg
            else:
                param_name = api_method_rec.kwargs_name + "[{}]".format(op_rec.kwarg)
    else:
        # op_rec has no name -> Can only be normal arg.
        if op_rec.kwarg is None:
            # Position is higher than number of "normal" args -> ERROR.
            if pos_past_normals >= 0:
                raise RLGraphAPICallParamError(
                    "Op-rec {}'s position ({}) is higher than {}'s number of args!".
                    format(op_rec.id, op_rec.position, api_method_rec.name)
                )
            # Normal arg (by position).
            else:
                param_name = api_method_rec.input_names[op_rec.position]
        # op_rec has name -> Can only be normal arg.
        else:
            if op_rec.kwarg in api_method_rec.non_args_kwargs:
                param_name = op_rec.kwarg
            else:
                raise RLGraphAPICallParamError(
                    "Op-rec's kwarg ({}) is not an parameter of API-method {}'s signature!".
                    format(op_rec.kwarg, api_method_rec.name)
                )

    return param_name
Ejemplo n.º 2
0
        def api_method_wrapper(self, *args, **kwargs):
            api_fn_name = name or re.sub(r'^_graph_fn_', "",
                                         wrapped_func.__name__)
            # Direct evaluation of function.
            if self.execution_mode == "define_by_run":
                type(self).call_count += 1

                start = time.perf_counter()
                # Check with owner if extra args needed.
                if api_fn_name in self.api_methods and self.api_methods[
                        api_fn_name].add_auto_key_as_first_param:
                    output = wrapped_func(self, "", *args, **kwargs)
                else:
                    output = wrapped_func(self, *args, **kwargs)

                # Store runtime for this method.
                type(self).call_times.append(  # Component.call_times
                    (self.name, wrapped_func.__name__,
                     time.perf_counter() - start))
                return output

            api_method_rec = self.api_methods[api_fn_name]

            # Sanity check input args for accidential dict-return values being passed into the next API as
            # supposed DataOpRecord.
            dict_args = [
                next(iter(a.values())) for a in args if isinstance(a, dict)
            ]
            if len(dict_args) > 0 and isinstance(dict_args[0], DataOpRecord):
                raise RLGraphError(
                    "One of your input args to API-method '{}.{}()' is a dict of DataOpRecords! This is probably "
                    "coming from a previous call to an API-method (returning a dict) and the DataOpRecord should be "
                    "extracted by string-key and passed into '{}' "
                    "directly.".format(api_method_rec.component.global_scope,
                                       api_fn_name, api_fn_name))
            # Create op-record column to call API method with. Ignore None input params. These should not be sent
            # to the API-method.
            in_op_column = DataOpRecordColumnIntoAPIMethod(
                component=self,
                api_method_rec=api_method_rec,
                args=args,
                kwargs=kwargs)
            # Add the column to the API-method record.
            api_method_rec.in_op_columns.append(in_op_column)

            # Check minimum number of passed args.
            minimum_num_call_params = len(in_op_column.api_method_rec.non_args_kwargs) - \
                len(in_op_column.api_method_rec.default_args)
            if len(in_op_column.op_records) < minimum_num_call_params:
                raise RLGraphAPICallParamError(
                    "Number of call params ({}) for call to API-method '{}' is too low. Needs to be at least {} "
                    "params!".format(len(in_op_column.op_records),
                                     api_method_rec.name,
                                     minimum_num_call_params))

            # Link from incoming op_recs into the new column or populate new column with ops/Spaces (this happens
            # if this call was made from within a graph_fn such that ops and Spaces are already known).
            all_args = [(i, a) for i, a in enumerate(args) if a is not None] + \
                       [(k, v) for k, v in sorted(kwargs.items()) if v is not None]
            flex = None
            build_when_done = False
            for i, (key, value) in enumerate(all_args):
                # Named arg/kwarg -> get input_name from that and peel op_rec.
                if isinstance(key, str):
                    param_name = key
                # Positional arg -> get input_name from input_names list.
                else:
                    slot = key if flex is None else flex
                    if slot >= len(api_method_rec.input_names):
                        raise RLGraphAPICallParamError(
                            "Too many input args given in call to API-method '{}'!"
                            .format(api_method_rec.name))
                    param_name = api_method_rec.input_names[slot]

                # Var-positional arg, attach the actual position to input_name string.
                if self.api_method_inputs.get(param_name, "") == "*flex":
                    if flex is None:
                        flex = i
                    param_name += "[{}]".format(i - flex)
                # Actual kwarg (not in list of api_method_inputs).
                elif api_method_rec.kwargs_name is not None and param_name not in self.api_method_inputs:
                    param_name = api_method_rec.kwargs_name + "[{}]".format(
                        param_name)

                # We are already in building phase (params may be coming from inside graph_fn).
                if self.graph_builder is not None and self.graph_builder.phase == "building":
                    # If Space not stored yet, determine it from op.
                    assert in_op_column.op_records[i].op is not None
                    if in_op_column.op_records[i].space is None:
                        in_op_column.op_records[i].space = get_space_from_op(
                            in_op_column.op_records[i].op)
                    self.api_method_inputs[
                        param_name] = in_op_column.op_records[i].space
                    # Check input-completeness of Component (but not strict as we are only calling API, not a graph_fn).
                    if self.input_complete is False:
                        # Build right after this loop in case more Space information comes in through next args/kwargs.
                        build_when_done = True

                # A DataOpRecord from the meta-graph.
                elif isinstance(value, DataOpRecord):
                    # Create entry with unknown Space if it doesn't exist yet.
                    if param_name not in self.api_method_inputs:
                        self.api_method_inputs[param_name] = None

                # Fixed value (instead of op-record): Store the fixed value directly in the op.
                else:
                    if self.api_method_inputs.get(param_name) is None:
                        self.api_method_inputs[
                            param_name] = in_op_column.op_records[i].space

            if build_when_done:
                # Check Spaces and create variables.
                self.graph_builder.build_component_when_input_complete(self)

            # Regular API-method: Call it here.
            api_fn_args, api_fn_kwargs = in_op_column.get_args_and_kwargs()

            if api_method_rec.is_graph_fn_wrapper is False:
                return_values = wrapped_func(self, *api_fn_args,
                                             **api_fn_kwargs)
            # Wrapped graph_fn: Call it through yet another wrapper.
            else:
                return_values = graph_fn_wrapper(
                    self, wrapped_func, returns,
                    dict(
                        flatten_ops=flatten_ops,
                        split_ops=split_ops,
                        add_auto_key_as_first_param=add_auto_key_as_first_param,
                        requires_variable_completeness=
                        requires_variable_completeness), *api_fn_args,
                    **api_fn_kwargs)

            # Process the results (push into a column).
            out_op_column = DataOpRecordColumnFromAPIMethod(
                component=self,
                api_method_name=api_fn_name,
                args=util.force_tuple(return_values)
                if type(return_values) != dict else None,
                kwargs=return_values if type(return_values) == dict else None)

            # If we already have actual op(s) and Space(s), push them already into the
            # DataOpRecordColumnFromAPIMethod's records.
            if self.graph_builder is not None and self.graph_builder.phase == "building":
                # Link the returned ops to that new out-column.
                for i, rec in enumerate(out_op_column.op_records):
                    out_op_column.op_records[i].op = rec.op
                    out_op_column.op_records[i].space = rec.space
            # And append the new out-column to the api-method-rec.
            api_method_rec.out_op_columns.append(out_op_column)

            # Do we need to return the raw ops or the op-recs?
            # Only need to check if False, otherwise, we return ops directly anyway.
            return_ops = False
            stack = inspect.stack()
            f_locals = stack[1][0].f_locals
            # We may be in a list comprehension, try next frame.
            if f_locals.get(".0"):
                f_locals = stack[2][0].f_locals
            # Check whether the caller component is a parent of this one.
            caller_component = f_locals.get(
                "root", f_locals.get("self_", f_locals.get("self")))

            # Potential call from a lambda.
            if caller_component is None and "fn" in stack[2][0].f_locals:
                # This is the component.
                prev_caller_component = TraceContext.PREV_CALLER
                lambda_obj = stack[2][0].f_locals["fn"]
                if "lambda" in inspect.getsource(lambda_obj):
                    # Try to reconstruct caller by using parent of prior caller.
                    caller_component = prev_caller_component.parent_component

            if caller_component is None:
                raise RLGraphError(
                    "API-method '{}' must have as 1st parameter (the component) either `root` or `self`. Other names "
                    "are not allowed!".format(api_method_rec.name))
            # Not directly called by this method itself (auto-helper-component-API-call).
            # AND call is coming from some caller Component, but that component is not this component
            # OR a parent -> Error.
            elif caller_component is not None and \
                    type(caller_component).__name__ != "MetaGraphBuilder" and \
                    caller_component not in [self] + self.get_parents():
                if not (stack[1][3] == "__init__"
                        and re.search(r'op_records\.py$', stack[1][1])):
                    raise RLGraphError(
                        "The component '{}' is not a child (or grand-child) of the caller ({})! Maybe you forgot to "
                        "add it as a sub-component via `add_components()`.".
                        format(self.global_scope,
                               caller_component.global_scope))

            # Update trace context.
            TraceContext.PREV_CALLER = caller_component

            for stack_item in stack[1:]:  # skip current frame
                # If we hit an API-method call -> return op-recs.
                if stack_item[3] == "api_method_wrapper" and re.search(
                        r'decorators\.py$', stack_item[1]):
                    break
                # If we hit a graph_fn call -> return ops.
                elif stack_item[3] == "run_through_graph_fn" and re.search(
                        r'graph_builder\.py$', stack_item[1]):
                    return_ops = True
                    break

            if return_ops is True:
                if type(return_values) == dict:
                    return {
                        key: value.op
                        for key, value in out_op_column.get_args_and_kwargs()
                        [1].items()
                    }
                else:
                    tuple_returns = tuple(
                        map(lambda x: x.op,
                            out_op_column.get_args_and_kwargs()[0]))
                    return tuple_returns[0] if len(
                        tuple_returns) == 1 else tuple_returns
            # Parent caller is non-graph_fn: Return op-recs.
            else:
                if type(return_values) == dict:
                    return return_values
                else:
                    tuple_returns = out_op_column.get_args_and_kwargs()[0]
                    return tuple_returns[0] if len(
                        tuple_returns) == 1 else tuple_returns
Ejemplo n.º 3
0
        def api_method_wrapper(self, *args, **kwargs):
            name_ = name or re.sub(r'^_graph_fn_', "", wrapped_func.__name__)

            return_ops = kwargs.pop("return_ops", False)

            # Direct evaluation of function.
            if self.execution_mode == "define_by_run":
                type(self).call_count += 1

                start = time.perf_counter()
                # Check with owner if extra args needed.
                if name_ in self.api_methods and self.api_methods[name_].add_auto_key_as_first_param:
                    output = wrapped_func(self, "", *args, **kwargs)
                else:
                    output = wrapped_func(self, *args, **kwargs)

                # Store runtime for this method.
                type(self).call_times.append(  # Component.call_times
                    (self.name, wrapped_func.__name__, time.perf_counter() - start)
                )
                return output

            api_method_rec = self.api_methods[name_]

            # Create op-record column to call API method with. Ignore None input params. These should not be sent
            # to the API-method.
            in_op_column = DataOpRecordColumnIntoAPIMethod(
                component=self, api_method_rec=api_method_rec, args=args, kwargs=kwargs
            )
            # Add the column to the API-method record.
            api_method_rec.in_op_columns.append(in_op_column)

            # Check minimum number of passed args.
            minimum_num_call_params = len(in_op_column.api_method_rec.non_args_kwargs) - \
                len(in_op_column.api_method_rec.default_args)
            if len(in_op_column.op_records) < minimum_num_call_params:
                raise RLGraphAPICallParamError(
                    "Number of call params ({}) for call to API-method '{}' is too low. Needs to be at least {} "
                    "params!".format(len(in_op_column.op_records), api_method_rec.name, minimum_num_call_params)
                )

            # Link from incoming op_recs into the new column or populate new column with ops/Spaces (this happens
            # if this call was made from within a graph_fn such that ops and Spaces are already known).
            all_args = [(i, a) for i, a in enumerate(args) if a is not None] + \
                       [(k, v) for k, v in sorted(kwargs.items()) if v is not None]
            flex = None
            for i, (key, value) in enumerate(all_args):
                # Named arg/kwarg -> get input_name from that and peel op_rec.
                if isinstance(key, str):
                    param_name = key
                # Positional arg -> get input_name from input_names list.
                else:
                    slot = key if flex is None else flex
                    if slot >= len(api_method_rec.input_names):
                        raise RLGraphAPICallParamError(
                            "Too many input args given in call to API-method '{}'!".format(api_method_rec.name)
                        )
                    param_name = api_method_rec.input_names[slot]

                # Var-positional arg, attach the actual position to input_name string.
                if self.api_method_inputs[param_name] == "*flex":
                    if flex is None:
                        flex = i
                    param_name += "[{}]".format(i - flex)

                # We are already in building phase (params may be coming from inside graph_fn).
                if self.graph_builder is not None and self.graph_builder.phase == "building":
                    self.api_method_inputs[param_name] = in_op_column.op_records[i].space
                    # Check input-completeness of Component (but not strict as we are only calling API, not a graph_fn).
                    if self.input_complete is False:
                        # Check Spaces and create variables.
                        self.graph_builder.build_component_when_input_complete(self)

                # A DataOpRecord from the meta-graph.
                elif isinstance(value, DataOpRecord):
                    if param_name not in self.api_method_inputs:
                        self.api_method_inputs[param_name] = None

                # Fixed value (instead of op-record): Store the fixed value directly in the op.
                else:
                    #in_op_column.op_records[i].space = get_space_from_op(value)
                    if param_name not in self.api_method_inputs or self.api_method_inputs[param_name] is None:
                        self.api_method_inputs[param_name] = in_op_column.op_records[i].space

            # Regular API-method: Call it here.
            args_, kwargs_ = in_op_column.get_args_and_kwargs()

            if api_method_rec.is_graph_fn_wrapper is False:
                return_values = wrapped_func(self, *args_, **kwargs_)
            # Wrapped graph_fn: Call it through yet another wrapper.
            else:
                return_values = graph_fn_wrapper(
                    self, wrapped_func, returns, dict(
                        flatten_ops=flatten_ops, split_ops=split_ops,
                        add_auto_key_as_first_param=add_auto_key_as_first_param
                    ), *args_, **kwargs_
                )

            # Process the results (push into a column).
            out_op_column = DataOpRecordColumnFromAPIMethod(
                component=self,
                api_method_name=name_,
                args=util.force_tuple(return_values) if type(return_values) != dict else None,
                kwargs=return_values if type(return_values) == dict else None
            )

            # If we already have actual op(s) and Space(s), push them already into the
            # DataOpRecordColumnFromAPIMethod's records.
            if self.graph_builder is not None and self.graph_builder.phase == "building":
                # Link the returned ops to that new out-column.
                for i, rec in enumerate(out_op_column.op_records):
                    out_op_column.op_records[i].op = rec.op
                    out_op_column.op_records[i].space = rec.space
            # And append the new out-column to the api-method-rec.
            api_method_rec.out_op_columns.append(out_op_column)

            # Do we need to return the raw ops or the op-recs?
            # Direct parent caller is a `_graph_fn_...`: Return raw ops.
            stack = inspect.stack()
            if return_ops is True or re.match(r'^_graph_fn_.+$', stack[1][3]):
                if type(return_values) == dict:
                    return {key: value.op for key, value in out_op_column.get_args_and_kwargs()[1].items()}
                else:
                    tuple_ = tuple(map(lambda x: x.op, out_op_column.get_args_and_kwargs()[0]))
                    return tuple_[0] if len(tuple_) == 1 else tuple_
            # Parent caller is non-graph_fn: Return op-recs.
            else:
                if type(return_values) == dict:
                    return return_values
                else:
                    tuple_ = out_op_column.get_args_and_kwargs()[0]
                    return tuple_[0] if len(tuple_) == 1 else tuple_