예제 #1
0
 def visit_for(self, for_node: astroid.For):
     """
     When a For node is visited, check whether it violated the rule in this checker.
     :param for_node: The node which is visited.
     """
     try:
         _has_zero_grad = False
         _has_backward = False
         _has_step = False
         for n in for_node.body:
             if(
                 isinstance(n, astroid.Expr)
                 and hasattr(n.value, "func")
                 and hasattr(n.value.func, "attrname")
             ):
                 if n.value.func.attrname == "zero_grad":
                     _has_zero_grad = True
                 if n.value.func.attrname == "backward":
                     _has_backward = True
                 if n.value.func.attrname == "step":
                     _has_step = True
         if _has_backward is True and _has_step is True and _has_zero_grad is False:
             self.add_message("gradient-clear-pytorch", node=for_node)
     except: # pylint: disable = bare-except
         ExceptionHandler.handle(self, for_node)
예제 #2
0
    def visit_call(self, call_node: astroid.Call):
        """
        Visit call node to see whether there are rules violations.
        :param call_node:
        :return:
        """
        # if log is call but no mask outside of it, it violate the rule
        try:
            _has_log = False
            _has_mask = False
            if (hasattr(call_node.func, "attrname")
                    and call_node.func.attrname == "log"
                    and hasattr(call_node.func, "expr")
                    and hasattr(call_node.func.expr, "name")
                    and call_node.func.expr.name == "torch"):
                _has_log = True
            if (hasattr(call_node, "args") and len(call_node.args) > 0
                    and hasattr(call_node.args[0], "func")
                    and hasattr(call_node.args[0].func, "attrname")
                    and call_node.args[0].func.attrname in ["clip", "clamp"]):
                _has_mask = True
            if (hasattr(call_node, "args") and len(call_node.args) > 0
                    and hasattr(call_node.args[0], "name")
                    and call_node.args[0].name
                    in self._variables_with_processing_operation and
                ("torch.clip" in self._variables_with_processing_operation[
                    call_node.args[0].name]
                 or "torch.clamp" in self._variables_with_processing_operation[
                     call_node.args[0].name])):
                _has_mask = True

            if _has_log is True and _has_mask is False:
                self.add_message(msgid="missing-mask-pytorch", node=call_node)
        except:  # pylint: disable = bare-except
            ExceptionHandler.handle(self, call_node)
예제 #3
0
    def visit_module(self, module: astroid.Module):
        """
        Check whether there is a rule violation.
        :param module:
        """
        try:
            _is_main_module = check_main_module(module)
            if self.config.no_main_module_check_randomness_control_numpy is False and _is_main_module is False:
                return

            for node in module.body:
                if isinstance(node, astroid.nodes.Expr) and hasattr(
                        node, "value"):
                    call_node = node.value
                    if (hasattr(call_node, "func")
                            and hasattr(call_node.func, "attrname")
                            and call_node.func.attrname == "seed"
                            and hasattr(call_node.func.expr, "attrname")
                            and call_node.func.expr.attrname == "random"
                            and hasattr(call_node.func.expr, "expr")
                            and hasattr(call_node.func.expr.expr, "name") and
                            call_node.func.expr.expr.name in ["np", "numpy"]):
                        self._has_manual_seed = True

            if (self._import_numpy is True
                    and self._import_ml_libraries is True
                    and self._has_manual_seed is False):
                self.add_message("randomness-control-numpy", node=module)
        except:  # pylint: disable = bare-except
            ExceptionHandler.handle(self, module)
예제 #4
0
 def visit_call(self, call_node: astroid.Call):
     """
     When a Call node is visited, check whether it violated the rule in this checker.
     :param call_node: The node which is visited.
     """
     try:
         _has_forward = False
         _call_from_self = False
         _call_from_super = False
         if hasattr(call_node.func, "attrname") and call_node.func.attrname == "forward":
             _has_forward = True
         if(
             hasattr(call_node.func, "expr")
             and hasattr(call_node.func.expr, "name")
             and call_node.func.expr.name == "self"
         ):
             _call_from_self = True
         if(
             hasattr(call_node.func, "expr")
             and hasattr(call_node.func.expr, "func")
             and hasattr(call_node.func.expr.func, "name")
             and call_node.func.expr.func.name == "super"
         ):
             _call_from_super = True
         if _has_forward is True and (_call_from_self is False and _call_from_super is False):
             self.add_message("forward-pytorch", node=call_node)
     except: # pylint: disable = bare-except
         ExceptionHandler.handle(self, call_node)
예제 #5
0
    def visit_call(self, node: astroid.Call):
        """
        When a Call node is visited, check whether it violated the rules in this checker.

        :param node: The node which is visited.
        """

        try:
            if (
                    # pylint: disable = R0916
                    hasattr(node, "func") and hasattr(node, "keywords")
                    and hasattr(node.func, "name")
                    and (node.func.name in self.SPLITTER_FUNCTIONS
                         or node.func.name in self.SPLITTER_CLASSES)
                    #or node.func.name in self.estimators_all)
            ):
                if node.keywords is not None:
                    _has_random_state_keyword = False
                    for keyword in node.keywords:
                        if keyword.arg == "random_state":
                            _has_random_state_keyword = True
                            if keyword.value.as_string() == "None":
                                self.add_message(
                                    "randomness-control-scikitlearn",
                                    node=node)
                    if _has_random_state_keyword is False:
                        self.add_message("randomness-control-scikitlearn",
                                         node=node)

        # pylint: disable = W0702
        except:
            ExceptionHandler.handle(self, node)
            traceback.print_exc()
예제 #6
0
    def visit_for(self, node: astroid.For):
        """
        When a For node is visited, check for dataframe-iteration-modification violations.

        :param node: Node which is visited.
        """
        try:
            if not (
                    isinstance(node.iter, astroid.Call)
                    and node.iter in self._call_types and
                (self._call_types[node.iter] == '"pandas.core.frame.DataFrame"'
                 or self._call_types[node.iter]
                 == '"pyspark.sql.dataframe.DataFrame"')):
                return

            for_targets = self._get_for_targets(node)
            assigned = AssignUtil.get_assigned_target_names(node)
            modified_iterated_targets = any(target in for_targets
                                            for target in assigned)

            if modified_iterated_targets:
                self.add_message("dataframe-iteration-modification-pandas",
                                 node=node)
        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, node)
예제 #7
0
    def visit_for(self, node: astroid.For):
        """Evaluate whether memory is freed in a loop with model creation."""
        try:
            has_clear_session = False
            has_model_creation = False

            for nod in node.body:
                if (hasattr(nod, "value") and hasattr(nod.value, "func")
                        and hasattr(nod.value.func, "attrname")
                        and nod.value.func.attrname == "clear_session"):
                    has_clear_session = True

                if (hasattr(nod, "value") and hasattr(nod.value, "func")
                        and hasattr(nod.value.func, "attrname")
                        and nod.value.func.attrname in self.MODELS):
                    if (hasattr(nod, "targets") and len(nod.targets) > 0
                            and hasattr(nod.targets[0], "name")
                            and self._is_tf_variable(nod.targets[0].name)):
                        has_model_creation = True

            if (has_clear_session is False and has_model_creation is True):
                # if there is no clear_session call in the loop
                # while there is a model creation, the rule is violated.
                self.add_message("memory-release-tensorflow", node=node)
        except:  # pylint: disable = bare-except
            ExceptionHandler.handle(self, node)
    def visit_module(self, module: astroid.Module):
        """
        When a module node is visited, check whether there is f1 score function called.
        If true, check whether there is auc function.
        If f1 score is called but auc is not called, the rule is violated.
        :param module:
        :return:
        """
        try:
            _has_auc = False
            _has_f1_score = False
            _f1_score_node = None

            for nod in module.body:
                if hasattr(nod, "value") and isinstance(
                        nod.value, astroid.Call):
                    call_node = nod.value
                    if (hasattr(call_node, "func")
                            and hasattr(call_node.func, "name")
                            and call_node.func.name == "AUROC"):
                        _has_auc = True
                    if (hasattr(call_node, "func")
                            and hasattr(call_node.func, "name")
                            and call_node.func.name == "F1Score"):
                        _has_f1_score = True
                        _f1_score_node = call_node

            # if f1 score is used but auc is not used
            if _has_f1_score is True and _has_auc is False:
                self.add_message("dependent-threshold-pytorch",
                                 node=_f1_score_node)
        except:  # pylint: disable = bare-except
            ExceptionHandler.handle(self, module)
예제 #9
0
    def visit_module(self, module: astroid.Module):
        """
        Check whether use_deterministic_algorithms option is used.
        :param module: call node
        """
        try:
            _is_main_module = check_main_module(module)
            if self.config.no_main_module_check_deterministic_pytorch is False and _is_main_module is False:
                return

            # if torch.use_deterministic_algorithm() is call and the argument is True,
            # set _has_deterministic_algorithm_option to True
            for node in module.body:
                if isinstance(node, astroid.nodes.Expr) and hasattr(node, "value"):
                    call_node = node.value
                    if(
                        hasattr(call_node, "func")
                        and hasattr(call_node.func, "attrname")
                        and call_node.func.attrname == "use_deterministic_algorithms"
                        and hasattr(call_node, "args")
                        and len(call_node.args) > 0
                        and hasattr(call_node.args[0], "value")
                        and call_node.args[0].value is True
                    ):
                        self._has_deterministic_algorithm_option = True

            if(
                self._import_pytorch is True
                and self._has_deterministic_algorithm_option is False
            ):
                self.add_message("deterministic-pytorch", node=module)
        except: # pylint: disable = bare-except
            ExceptionHandler.handle(self, module)
예제 #10
0
    def visit_call(self, call_node: astroid.Call):
        """
        When a Call node is visited, check whether it violated the rules in this checker.

        :param call_node: The node which is visited.
        """
        try:
            # If the learning function is called on an estimator, rule is violated.
            if (call_node.func is not None
                    and hasattr(call_node.func, "attrname")
                    and call_node.func.attrname in self.LEARNING_FUNCTIONS
                    and hasattr(call_node.func, "expr")
                    and self._expr_is_estimator(call_node.func.expr)
                    and hasattr(call_node, "args")):
                has_learning_function = True
                has_preprocessing_function = False
                for arg in call_node.args:
                    if isinstance(arg, astroid.Name):
                        values = AssignUtil.assignment_values(arg)
                        for value in values:
                            if (isinstance(value, astroid.Call)
                                    and hasattr(value, "func")
                                    and hasattr(value.func, "expr")):
                                if self._expr_is_preprocessor(value.func.expr):
                                    has_preprocessing_function = True
                if has_learning_function is True and has_preprocessing_function is True:
                    self.add_message("data-leakage-scikitlearn",
                                     node=call_node)

        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, call_node)
예제 #11
0
 def visit_module(self, module: astroid.Module):
     """Visit module and infer which libraries the variables are from."""
     try:
         self._variable_types = TypeInference.infer_library_variable_first_types(
             module)
     except:  # pylint: disable = bare-except
         ExceptionHandler.handle(self, module)
예제 #12
0
    def visit_for(self, for_node: astroid.For):
        """
        When a For node is visited, check whether it violated the rule in this checker.
        :param for_node: The node which is visited.
        """
        try:
            _has_train = False
            _has_eval = False
            for node in for_node.body:
                if isinstance(node, astroid.Expr):
                    if isinstance(node.value, astroid.Call) and hasattr(
                            node.value.func, "attrname"):
                        if node.value.func.attrname == "train":
                            _has_train = True
                        if node.value.func.attrname == "eval":
                            _has_eval = True
                if isinstance(node, astroid.If):
                    for if_node in node.body:
                        if (isinstance(if_node, astroid.Expr)
                                and isinstance(if_node.value, astroid.Call)
                                and hasattr(if_node.value.func, "attrname")):
                            if if_node.value.func.attrname == "train":
                                _has_train = True
                            if if_node.value.func.attrname == "eval":
                                _has_eval = True

            if _has_eval is True and _has_train is False:
                self.add_message("mode-toggling-pytorch", node=for_node)
        except:  # pylint: disable = bare-except
            ExceptionHandler.handle(self, for_node)
예제 #13
0
 def visit_import(self, import_node: astroid.Import):
     """Visit import node to see whether pandas is imported."""
     try:
         for name, _ in import_node.names:
             if name == "pandas":
                 self._imported_pandas = True
     except:  # pylint: disable = bare-except
         ExceptionHandler.handle(self, import_node)
예제 #14
0
 def visit_importfrom(self, node: astroid.ImportFrom):
     """
     Check whether there is a scikit-learn import.
     :param node: import from node
     """
     try:
         if self._import_ml_libraries is False:
             self._import_ml_libraries = has_importfrom_sklearn(node)
     except:  # pylint: disable = bare-except
         ExceptionHandler.handle(self, node)
 def visit_for(self, node: astroid.For):
     """Evaluate whether there is an augmented assign in the loop, it can be replaced
         by a reduction operation, which is faster."""
     try:
         if (hasattr(node, "body")
                 # there is augmented assign with tf variable in the body
                 and self._augmented_assign_with_tf_variable(node.body)):
             self.add_message("iteration-tensorflow", node=node)
     except:  # pylint: disable=W0702
         ExceptionHandler.handle(self, node)
 def visit_import(self, node: astroid.Import):
     """
     Check whether there is a tensorflow import.
     :param node: import node
     """
     try:
         if self._import_tensorflow is False:
             self._import_tensorflow = has_import(node, "tensorflow")
     except:  # pylint: disable = bare-except
         ExceptionHandler.handle(self, node)
예제 #17
0
 def visit_import(self, node: astroid.Import):
     """
     Check whether there is a pytorch import
     :param node: import node
     """
     try:
         if self._import_pytorch is False:
             self._import_pytorch = has_import(node, "torch")
     except: # pylint: disable = bare-except
         ExceptionHandler.handle(self, node)
예제 #18
0
    def visit_call(self, node: astroid.Call):
        """
        When a node is visited, add a message if the rule is violated.
        :param node:
        :return:
        """
        try:
            # If there is no scaler before a scaling-sensitive operarion, the rule is violated.
            # If pipeline is used
            if (
                hasattr(node, "func")
                and hasattr(node.func, "name")
                and node.func.name in self.PIPELINE
                and hasattr(node, "args")
            ):
                has_scaling_sensitive_operation = False
                has_scaler = False
                for arg in node.args:
                    if isinstance(arg, astroid.Call):
                        if self._call_initiates_scaler(arg):
                            has_scaler = True
                        if self._call_initiates_scaling_sensitive_operations(arg):
                            has_scaling_sensitive_operation = True
                            break
                if has_scaling_sensitive_operation is True and has_scaler is False:
                    self.add_message("scaler-missing-scikitlearn", node=node)

            # If pipeline is not used and a scaling-sensitive operation is called
            if (
                    hasattr(node, "func")
                    and hasattr(node.func, "attrname")
                    and node.func.attrname in self.LEARNING_FUNCTIONS
                    and hasattr(node.func, "expr")
                    and self._expr_is_scaling_sensitive_operation(node.func.expr)
                    and hasattr(node, "args")
            ):
                has_scaling_sensitive_operation = True
                has_scaler = False
                for arg in node.args:
                    if isinstance(arg, astroid.Name):
                        values = AssignUtil.assignment_values(arg)
                        for value in values:
                            if (
                                isinstance(value, astroid.Call)
                                and hasattr(value, "func")
                                and hasattr(value.func, "expr")
                            ):
                                if self._expr_is_scaler(value.func.expr):
                                    has_scaler = True
                if has_scaling_sensitive_operation is True and has_scaler is False:
                    self.add_message("scaler-missing-scikitlearn", node=node)

        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, node)
            traceback.print_exc()
예제 #19
0
    def visit_call(self, node: astroid.Call):
        """
        When a Call node is visited, add messages if it violated the defined rules.

        :param node: call node
        """
        try:
            if self._result_is_lost(node):
                self.add_message("inplace-numpy", node=node)
        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, node)
예제 #20
0
    def visit_module(self, node: astroid.Module):
        """
        When an Module node is visited, scan for Call nodes and get type the function is called on.

        :param node: Node which is visited.
        """
        try:
            # noinspection PyTypeChecker
            self._call_types = TypeInference.infer_types(
                node, astroid.Call, lambda x: x.func.expr.name)
        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, node)
예제 #21
0
    def visit_call(self, node: astroid.Call):
        """
        When a Call node is visited, add messages if it violated the defined rules.

        :param node: Node which is visited.
        """
        try:
            if self._iterating_through_dataframe(node):
                self.add_message("dataframe-iteration-modification-pandas",
                                 node=node)
        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, node)
예제 #22
0
 def visit_importfrom(self, importfrom_node: astroid.ImportFrom):
     """
     Check whether there is DataLoader imported.
     """
     try:
         if (hasattr(importfrom_node, "modname")
                 and importfrom_node.modname == "torch.utils.data"
                 and hasattr(importfrom_node, "names")):
             for name, _ in importfrom_node.names:
                 if name == "DataLoader":
                     self._import_DataLoader = True
     except:  # pylint: disable = bare-except
         ExceptionHandler.handle(self, importfrom_node)
예제 #23
0
    def visit_importfrom(self, node: astroid.ImportFrom):
        """
        When an ImportFrom node is visited, check if it follows the conventions.

        :param node: Node which is visited.
        """
        try:
            if node.modname[:7] == "sklearn":
                for _, alias in node.names:
                    if alias is not None:
                        self.add_message("import-sklearn", node=node)
        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, node)
예제 #24
0
    def visit_call(self, node: astroid.Call):
        """
        When a Call node is visited, add messages if it violated the defined rules.

        :param node: Node which is visited.
        """
        try:
            if (self._is_simple_call_node(node)
                    and not self._function_whitelisted(node)
                    and self._dataframe_is_lost(node)):
                self.add_message("inplace-pandas", node=node)
        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, node)
예제 #25
0
 def visit_for(self, for_node: astroid.For):
     """Visit for node and see whether the rule is violated."""
     try:
         for node in for_node.body:
             if (isinstance(node, astroid.Assign) and len(node.targets) > 0
                     and hasattr(node.targets[0], "name")
                     and node.targets[0].name in self._variable_types
                     and self._variable_types[node.targets[0].name]
                     in ["tf.constant", "tensorflow.constant"]
                     and self.infer_call_expression(
                         node.value) in ["tf.concat"]):
                 self.add_message("tensor-array-tensorflow", node=for_node)
     except:  # pylint: disable = bare-except
         ExceptionHandler.handle(self, for_node)
예제 #26
0
 def visit_call(self, call_node: astroid.Call):
     """Visit call node to see whether there is rule violation."""
     try:
         if hasattr(call_node,
                    "attrname") and call_node.attrname == "values":
             self.add_message("dataframe-conversion-pandas", node=call_node)
             return
         if (hasattr(call_node, "func") and hasattr(call_node.func, "expr")
                 and hasattr(call_node.func.expr, "attrname")
                 and call_node.func.expr.attrname == "values"):
             self.add_message("dataframe-conversion-pandas", node=call_node)
             return
     except:  # pylint: disable = bare-except
         ExceptionHandler.handle(self, call_node)
예제 #27
0
 def visit_import(self, node: astroid.Import):
     """
     Check whether there is a numpy import and ml library import.
     :param node: import node
     """
     try:
         if self._import_numpy is False:
             self._import_numpy = has_import(node, "numpy")
         if self._import_ml_libraries is False:
             self._import_ml_libraries = has_import(node, "sklearn") \
                                         or has_import(node, "torch") \
                                         or has_import(node, "tensorflow")
     except:  # pylint: disable = bare-except
         ExceptionHandler.handle(self, node)
예제 #28
0
    def visit_call(self, node: astroid.Call):
        """
        When a Call node is visited, check whether hyperparameters are set.

        In strict mode, all hyperparameters should be set.
        In non-strict mode, function calls to learning functions should either contain all
        hyperparameters defined in HYPERPARAMETERS_MAIN or have at least one hyperparameter defined.

        :param node: Node which is visited.
        """
        try:
            if (hasattr(node, "func") and hasattr(node.func, "name")):
                self.hyperparameter_in_class(node, node.func.name)
        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, node)
예제 #29
0
    def visit_compare(self, node: astroid.Compare):
        """
        When a compare node is visited, check whether a comparison is done with np.nan.

        :param node: Node which is visited.
        """
        try:
            for side in (node.left, node.ops[0][1]):
                if isinstance(side, astroid.Attribute) \
                        and side.attrname == "nan" \
                        and side.expr.name == "np":
                    self.add_message("nan-numpy", node=node)
                    return
        except:  # pylint: disable=bare-except
            ExceptionHandler.handle(self, node)
예제 #30
0
 def visit_subscript(self, subscript_node: astroid.Subscript):
     """Visit subscript node and check whether there is chain indexing."""
     try:
         indexing_num = 0
         node = subscript_node
         # count indexing number
         while hasattr(node, "value"):
             indexing_num += 1
             node = node.value
         # if chain indexing is used in the code, and the indexing number is no less than two, the rule is violated.
         if (hasattr(node, "name") and node.name in self._subscript_types
                 and self._subscript_types[node.name] == "pd.DataFrame"
                 and indexing_num >= 2):
             self.add_message("chain-indexing-pandas", node=subscript_node)
     except:  # pylint: disable = bare-except
         ExceptionHandler.handle(self, subscript_node)