예제 #1
0
    def visit_for(self, node: astroid.For):
        """
        When a For node is visited, check for dataframe-iteration-modification violations.

        :param node: Node which is visited.
        """
        try:
            if not (
                isinstance(node.iter, astroid.Call)
                and node.iter in self._call_types
                and (
                    self._call_types[node.iter] == "pandas.core.frame.DataFrame"
                    or self._call_types[node.iter] == "pyspark.sql.dataframe.DataFrame"
                )
            ):
                return

            for_targets = DataFrameChecker._get_for_targets(node)
            assigned = AssignUtil.get_assigned_target_names(node)
            modified_iterated_targets = any(target in for_targets for target in assigned)

            if modified_iterated_targets:
                self.add_message("dataframe-iteration-modification", node=node)
        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, node)
예제 #2
0
    def visit_call(self, node: astroid.Call):
        """
        When a Call node is visited, check whether hyperparameters are set.

        In strict mode, all hyperparameters should be set.
        In non-strict mode, function calls to learning functions should either contain all
        hyperparameters defined in HYPERPARAMETERS_MAIN or have at least one hyperparameter defined.

        :param node: Node which is visited.
        """
        try:
            try:
                function_name = node.func.name
            except AttributeError:
                return

            hyperparams_all = Resources.get_hyperparameters()

            if function_name in hyperparams_all:  # pylint: disable=unsupported-membership-test
                if self.config.strict_hyperparameters:
                    if not HyperparameterChecker._has_required_hyperparameters(node, hyperparams_all):
                        self.add_message("hyperparameters", node=node)
                else:  # non-strict
                    if (
                        function_name in self.HYPERPARAMETERS_MAIN
                        and not HyperparameterChecker._has_required_hyperparameters(node, self.HYPERPARAMETERS_MAIN)
                    ):
                        self.add_message("hyperparameters", node=node)
                    elif len(node.args) == 0 and node.keywords is None:
                        self.add_message("hyperparameters", node=node)
        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, node)
예제 #3
0
    def visit_module(self, node: astroid.Module):
        """
        When an Module node is visited, scan for Call nodes and get type the function is called on.

        :param node: Node which is visited.
        """
        try:
            # noinspection PyTypeChecker
            self._call_types = TypeInference.infer_types(node, astroid.Call, lambda x: x.func.expr.name)
        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, node)
예제 #4
0
파일: nan.py 프로젝트: MarkHaakman/dslinter
    def visit_compare(self, node: astroid.Compare):
        """
        When a compare node is visited, check whether a comparison is done with np.nan.

        :param node: Node which is visited.
        """
        try:
            for side in (node.left, node.ops[0][1]):
                if isinstance(side, astroid.Attribute) and side.attrname == "nan" and side.expr.name == "np":
                    self.add_message("nan-equality", node=node)
                    return
        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, node)
예제 #5
0
    def visit_import_from(self, node: astroid.ImportFrom):
        """
        When an ImportFrom node is visited, check if it follows the conventions.

        :param node: Node which is visited.
        """
        try:
            if node.modname[:7] == "sklearn":
                for _, alias in node.names:
                    if alias is not None:
                        self.add_message("import-sklearn", node=node)
        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, node)
예제 #6
0
    def visit_call(self, node: astroid.Call):
        """
        When a Call node is visited, check whether it violated the rules in this checker.

        :param node: The node which is visited.
        """
        try:
            # If the learning function is called on an estimator, rule is violated.
            if (node.func is not None and hasattr(node.func, "attrname")
                    and node.func.attrname in self.LEARNING_FUNCTIONS
                    and self._expr_is_estimator(node.func.expr)):
                self.add_message("sk-pipeline", node=node)
        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, node)
예제 #7
0
    def visit_import(self, node: astroid.Import):
        """
        When an Import node is visited, check if it follows the conventions.

        :param node: Node which is visited.
        """
        try:
            for name, alias in node.names:
                if name == "pandas" and alias != "pd":
                    self.add_message("import-pandas", node=node)
                elif name == "numpy" and alias != "np":
                    self.add_message("import-numpy", node=node)
                elif name == "matplotlib.pyplot" and alias != "plt":
                    self.add_message("import-pyplot", node=node)
        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, node)
예제 #8
0
    def visit_call(self, node: astroid.Call):
        """
        When a Call node is visited, add messages if it violated the defined rules.

        :param node: Node which is visited.
        """
        try:
            if (
                self._is_simple_call_node(node)
                and not self._function_whitelisted(node)
                and self._dataframe_is_lost(node)
            ):
                self.add_message("unassigned-dataframe", node=node)
            if self._iterating_through_dataframe(node):
                self.add_message("dataframe-iteration", node=node)
        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, node)