예제 #1
0
    def fetch_call_return(self,
                          node: vy_ast.Call) -> Optional[BaseTypeDefinition]:
        if node.get(
                "func.value.id"
        ) == "self" and self.visibility == FunctionVisibility.EXTERNAL:
            raise CallViolation("Cannnot call external functions via 'self'",
                                node)

        # for external calls, include gas and value as optional kwargs
        kwarg_keys = self.kwarg_keys.copy()
        if node.get("func.value.id") != "self":
            kwarg_keys += ["gas", "value"]
        validate_call_args(node, self.arg_count, kwarg_keys)

        if self.mutability < StateMutability.PAYABLE:
            kwarg_node = next((k for k in node.keywords if k.arg == "value"),
                              None)
            if kwarg_node is not None:
                raise CallViolation(
                    "Cannnot send ether to nonpayable function", kwarg_node)

        for arg, expected in zip(node.args, self.arguments.values()):
            validate_expected_type(arg, expected)

        for kwarg in node.keywords:
            if kwarg.arg in ("gas", "value"):
                validate_expected_type(kwarg.value, Uint256Definition())
            else:
                validate_expected_type(kwarg.arg, kwarg.value)

        return self.return_type
예제 #2
0
    def fetch_call_return(self,
                          node: vy_ast.Call) -> Optional[BaseTypeDefinition]:
        if node.get(
                "func.value.id"
        ) == "self" and self.visibility == FunctionVisibility.EXTERNAL:
            raise CallViolation("Cannot call external functions via 'self'",
                                node)

        # for external calls, include gas and value as optional kwargs
        kwarg_keys = self.kwarg_keys.copy()
        if node.get("func.value.id") != "self":
            kwarg_keys += list(self.call_site_kwargs.keys())
        validate_call_args(node, (self.min_arg_count, self.max_arg_count),
                           kwarg_keys)

        if self.mutability < StateMutability.PAYABLE:
            kwarg_node = next((k for k in node.keywords if k.arg == "value"),
                              None)
            if kwarg_node is not None:
                raise CallViolation("Cannot send ether to nonpayable function",
                                    kwarg_node)

        for arg, expected in zip(node.args, self.arguments.values()):
            validate_expected_type(arg, expected)

        # TODO this should be moved to validate_call_args
        for kwarg in node.keywords:
            if kwarg.arg in self.call_site_kwargs:
                kwarg_settings = self.call_site_kwargs[kwarg.arg]
                validate_expected_type(kwarg.value, kwarg_settings.typ)
                if kwarg_settings.require_literal:
                    if not isinstance(kwarg.value, vy_ast.Constant):
                        raise InvalidType(
                            f"{kwarg.arg} must be literal {kwarg_settings.typ}",
                            kwarg.value)
            else:
                # Generate the modified source code string with the kwarg removed
                # as a suggestion to the user.
                kwarg_pattern = rf"{kwarg.arg}\s*=\s*{re.escape(kwarg.value.node_source_code)}"
                modified_line = re.sub(kwarg_pattern,
                                       kwarg.value.node_source_code,
                                       node.node_source_code)
                error_suggestion = (
                    f"\n(hint: Try removing the kwarg: `{modified_line}`)"
                    if modified_line != node.node_source_code else "")

                raise ArgumentException(
                    ("Usage of kwarg in Vyper is restricted to " +
                     ", ".join([f"{k}="
                                for k in self.call_site_kwargs.keys()]) +
                     f". {error_suggestion}"),
                    kwarg,
                )

        return self.return_type
예제 #3
0
    def fetch_call_return(self,
                          node: vy_ast.Call) -> Optional[BaseTypeDefinition]:
        validate_call_args(node, (self.min_arg_count, self.max_arg_count))

        if isinstance(self.underlying_type, DynamicArrayDefinition):
            if self.name == "append":
                return None

            elif self.name == "pop":
                value_type = self.underlying_type.value_type
                return value_type

        raise CallViolation("Function does not exist on given type", node)
예제 #4
0
파일: module.py 프로젝트: fubuloubu/vyper
    def __init__(
        self, module_node: vy_ast.Module, interface_codes: InterfaceDict, namespace: dict
    ) -> None:
        self.ast = module_node
        self.interface_codes = interface_codes or {}
        self.namespace = namespace

        module_nodes = module_node.body.copy()
        while module_nodes:
            count = len(module_nodes)
            err_list = ExceptionList()
            for node in list(module_nodes):
                try:
                    self.visit(node)
                    module_nodes.remove(node)
                except (InvalidLiteral, InvalidType, VariableDeclarationException):
                    # these exceptions cannot be caused by another statement not yet being
                    # parsed, so we raise them immediately
                    raise
                except VyperException as e:
                    err_list.append(e)

            # Only raise if no nodes were successfully processed. This allows module
            # level logic to parse regardless of the ordering of code elements.
            if count == len(module_nodes):
                err_list.raise_if_not_empty()

        # check for collisions between 4byte function selectors
        # internal functions are intentionally included in this check, to prevent breaking
        # changes in in case of a future change to their calling convention
        self_members = namespace["self"].members
        functions = [i for i in self_members.values() if isinstance(i, ContractFunction)]
        validate_unique_method_ids(functions)

        # generate an `InterfacePrimitive` from the top-level node - used for building the ABI
        interface = namespace["interface"].build_primitive_from_node(module_node)
        module_node._metadata["type"] = interface

        # get list of internal function calls made by each function
        function_defs = self.ast.get_children(vy_ast.FunctionDef)
        function_names = set(node.name for node in function_defs)
        for node in function_defs:
            calls_to_self = set(
                i.func.attr for i in node.get_descendants(vy_ast.Call, {"func.value.id": "self"})
            )
            # anything that is not a function call will get semantically checked later
            calls_to_self = calls_to_self.intersection(function_names)
            self_members[node.name].internal_calls = calls_to_self
            if node.name in self_members[node.name].internal_calls:
                self_node = node.get_descendants(
                    vy_ast.Attribute, {"value.id": "self", "attr": node.name}
                )[0]
                raise CallViolation(f"Function '{node.name}' calls into itself", self_node)

        for fn_name in sorted(function_names):

            if fn_name not in self_members:
                # the referenced function does not exist - this is an issue, but we'll report
                # it later when parsing the function so we can give more meaningful output
                continue

            # check for circular function calls
            sequence = _find_cyclic_call([fn_name], self_members)
            if sequence is not None:
                nodes = []
                for i in range(len(sequence) - 1):
                    fn_node = self.ast.get_children(vy_ast.FunctionDef, {"name": sequence[i]})[0]
                    call_node = fn_node.get_descendants(
                        vy_ast.Attribute, {"value.id": "self", "attr": sequence[i + 1]}
                    )[0]
                    nodes.append(call_node)

                raise CallViolation("Contract contains cyclic function call", *nodes)

            # get complete list of functions that are reachable from this function
            function_set = set(i for i in self_members[fn_name].internal_calls if i in self_members)
            while True:
                expanded = set(x for i in function_set for x in self_members[i].internal_calls)
                expanded |= function_set
                if expanded == function_set:
                    break
                function_set = expanded

            self_members[fn_name].recursive_calls = function_set
예제 #5
0
파일: module.py 프로젝트: Rish001/vyper
    def __init__(
        self,
        module_node: vy_ast.Module,
        interface_codes: InterfaceDict,
        namespace: dict,
    ) -> None:
        self.ast = module_node
        self.interface_codes = interface_codes or {}
        self.namespace = namespace

        module_nodes = module_node.body.copy()
        while module_nodes:
            count = len(module_nodes)
            err_list = ExceptionList()
            for node in list(module_nodes):
                try:
                    self.visit(node)
                    module_nodes.remove(node)
                except VyperException as e:
                    err_list.append(e)

            # Only raise if no nodes were successfully processed. This allows module
            # level logic to parse regardless of the ordering of code elements.
            if count == len(module_nodes):
                err_list.raise_if_not_empty()

        # get list of internal function calls made by each function
        call_function_names = set()
        self_members = namespace["self"].members
        for node in self.ast.get_children(vy_ast.FunctionDef):
            call_function_names.add(node.name)
            self_members[node.name].internal_calls = set(
                i.func.attr for i in node.get_descendants(
                    vy_ast.Call, {"func.value.id": "self"}))
            if node.name in self_members[node.name].internal_calls:
                self_node = node.get_descendants(vy_ast.Attribute, {
                    "value.id": "self",
                    "attr": node.name
                })[0]
                raise CallViolation(
                    f"Function '{node.name}' calls into itself", self_node)

        for fn_name in sorted(call_function_names):

            if fn_name not in self_members:
                # the referenced function does not exist - this is an issue, but we'll report
                # it later when parsing the function so we can give more meaningful output
                continue

            # check for circular function calls
            sequence = _find_cyclic_call([fn_name], self_members)
            if sequence is not None:
                nodes = []
                for i in range(len(sequence) - 1):
                    fn_node = self.ast.get_children(vy_ast.FunctionDef,
                                                    {"name": sequence[i]})[0]
                    call_node = fn_node.get_descendants(
                        vy_ast.Attribute, {
                            "value.id": "self",
                            "attr": sequence[i + 1]
                        })[0]
                    nodes.append(call_node)

                raise CallViolation("Contract contains cyclic function call",
                                    *nodes)

            # get complete list of functions that are reachable from this function
            function_set = set(i for i in self_members[fn_name].internal_calls
                               if i in self_members)
            while True:
                expanded = set(x for i in function_set
                               for x in self_members[i].internal_calls)
                expanded |= function_set
                if expanded == function_set:
                    break
                function_set = expanded

            self_members[fn_name].recursive_calls = function_set