def visit_for(self, for_node: astroid.For): """ When a For node is visited, check whether it violated the rule in this checker. :param for_node: The node which is visited. """ try: _has_zero_grad = False _has_backward = False _has_step = False for n in for_node.body: if( isinstance(n, astroid.Expr) and hasattr(n.value, "func") and hasattr(n.value.func, "attrname") ): if n.value.func.attrname == "zero_grad": _has_zero_grad = True if n.value.func.attrname == "backward": _has_backward = True if n.value.func.attrname == "step": _has_step = True if _has_backward is True and _has_step is True and _has_zero_grad is False: self.add_message("gradient-clear-pytorch", node=for_node) except: # pylint: disable = bare-except ExceptionHandler.handle(self, for_node)
def visit_call(self, call_node: astroid.Call): """ Visit call node to see whether there are rules violations. :param call_node: :return: """ # if log is call but no mask outside of it, it violate the rule try: _has_log = False _has_mask = False if (hasattr(call_node.func, "attrname") and call_node.func.attrname == "log" and hasattr(call_node.func, "expr") and hasattr(call_node.func.expr, "name") and call_node.func.expr.name == "torch"): _has_log = True if (hasattr(call_node, "args") and len(call_node.args) > 0 and hasattr(call_node.args[0], "func") and hasattr(call_node.args[0].func, "attrname") and call_node.args[0].func.attrname in ["clip", "clamp"]): _has_mask = True if (hasattr(call_node, "args") and len(call_node.args) > 0 and hasattr(call_node.args[0], "name") and call_node.args[0].name in self._variables_with_processing_operation and ("torch.clip" in self._variables_with_processing_operation[ call_node.args[0].name] or "torch.clamp" in self._variables_with_processing_operation[ call_node.args[0].name])): _has_mask = True if _has_log is True and _has_mask is False: self.add_message(msgid="missing-mask-pytorch", node=call_node) except: # pylint: disable = bare-except ExceptionHandler.handle(self, call_node)
def visit_module(self, module: astroid.Module): """ Check whether there is a rule violation. :param module: """ try: _is_main_module = check_main_module(module) if self.config.no_main_module_check_randomness_control_numpy is False and _is_main_module is False: return for node in module.body: if isinstance(node, astroid.nodes.Expr) and hasattr( node, "value"): call_node = node.value if (hasattr(call_node, "func") and hasattr(call_node.func, "attrname") and call_node.func.attrname == "seed" and hasattr(call_node.func.expr, "attrname") and call_node.func.expr.attrname == "random" and hasattr(call_node.func.expr, "expr") and hasattr(call_node.func.expr.expr, "name") and call_node.func.expr.expr.name in ["np", "numpy"]): self._has_manual_seed = True if (self._import_numpy is True and self._import_ml_libraries is True and self._has_manual_seed is False): self.add_message("randomness-control-numpy", node=module) except: # pylint: disable = bare-except ExceptionHandler.handle(self, module)
def visit_call(self, call_node: astroid.Call): """ When a Call node is visited, check whether it violated the rule in this checker. :param call_node: The node which is visited. """ try: _has_forward = False _call_from_self = False _call_from_super = False if hasattr(call_node.func, "attrname") and call_node.func.attrname == "forward": _has_forward = True if( hasattr(call_node.func, "expr") and hasattr(call_node.func.expr, "name") and call_node.func.expr.name == "self" ): _call_from_self = True if( hasattr(call_node.func, "expr") and hasattr(call_node.func.expr, "func") and hasattr(call_node.func.expr.func, "name") and call_node.func.expr.func.name == "super" ): _call_from_super = True if _has_forward is True and (_call_from_self is False and _call_from_super is False): self.add_message("forward-pytorch", node=call_node) except: # pylint: disable = bare-except ExceptionHandler.handle(self, call_node)
def visit_call(self, node: astroid.Call): """ When a Call node is visited, check whether it violated the rules in this checker. :param node: The node which is visited. """ try: if ( # pylint: disable = R0916 hasattr(node, "func") and hasattr(node, "keywords") and hasattr(node.func, "name") and (node.func.name in self.SPLITTER_FUNCTIONS or node.func.name in self.SPLITTER_CLASSES) #or node.func.name in self.estimators_all) ): if node.keywords is not None: _has_random_state_keyword = False for keyword in node.keywords: if keyword.arg == "random_state": _has_random_state_keyword = True if keyword.value.as_string() == "None": self.add_message( "randomness-control-scikitlearn", node=node) if _has_random_state_keyword is False: self.add_message("randomness-control-scikitlearn", node=node) # pylint: disable = W0702 except: ExceptionHandler.handle(self, node) traceback.print_exc()
def visit_for(self, node: astroid.For): """ When a For node is visited, check for dataframe-iteration-modification violations. :param node: Node which is visited. """ try: if not ( isinstance(node.iter, astroid.Call) and node.iter in self._call_types and (self._call_types[node.iter] == '"pandas.core.frame.DataFrame"' or self._call_types[node.iter] == '"pyspark.sql.dataframe.DataFrame"')): return for_targets = self._get_for_targets(node) assigned = AssignUtil.get_assigned_target_names(node) modified_iterated_targets = any(target in for_targets for target in assigned) if modified_iterated_targets: self.add_message("dataframe-iteration-modification-pandas", node=node) except: # pylint: disable=bare-except ExceptionHandler.handle(self, node)
def visit_for(self, node: astroid.For): """Evaluate whether memory is freed in a loop with model creation.""" try: has_clear_session = False has_model_creation = False for nod in node.body: if (hasattr(nod, "value") and hasattr(nod.value, "func") and hasattr(nod.value.func, "attrname") and nod.value.func.attrname == "clear_session"): has_clear_session = True if (hasattr(nod, "value") and hasattr(nod.value, "func") and hasattr(nod.value.func, "attrname") and nod.value.func.attrname in self.MODELS): if (hasattr(nod, "targets") and len(nod.targets) > 0 and hasattr(nod.targets[0], "name") and self._is_tf_variable(nod.targets[0].name)): has_model_creation = True if (has_clear_session is False and has_model_creation is True): # if there is no clear_session call in the loop # while there is a model creation, the rule is violated. self.add_message("memory-release-tensorflow", node=node) except: # pylint: disable = bare-except ExceptionHandler.handle(self, node)
def visit_module(self, module: astroid.Module): """ When a module node is visited, check whether there is f1 score function called. If true, check whether there is auc function. If f1 score is called but auc is not called, the rule is violated. :param module: :return: """ try: _has_auc = False _has_f1_score = False _f1_score_node = None for nod in module.body: if hasattr(nod, "value") and isinstance( nod.value, astroid.Call): call_node = nod.value if (hasattr(call_node, "func") and hasattr(call_node.func, "name") and call_node.func.name == "AUROC"): _has_auc = True if (hasattr(call_node, "func") and hasattr(call_node.func, "name") and call_node.func.name == "F1Score"): _has_f1_score = True _f1_score_node = call_node # if f1 score is used but auc is not used if _has_f1_score is True and _has_auc is False: self.add_message("dependent-threshold-pytorch", node=_f1_score_node) except: # pylint: disable = bare-except ExceptionHandler.handle(self, module)
def visit_module(self, module: astroid.Module): """ Check whether use_deterministic_algorithms option is used. :param module: call node """ try: _is_main_module = check_main_module(module) if self.config.no_main_module_check_deterministic_pytorch is False and _is_main_module is False: return # if torch.use_deterministic_algorithm() is call and the argument is True, # set _has_deterministic_algorithm_option to True for node in module.body: if isinstance(node, astroid.nodes.Expr) and hasattr(node, "value"): call_node = node.value if( hasattr(call_node, "func") and hasattr(call_node.func, "attrname") and call_node.func.attrname == "use_deterministic_algorithms" and hasattr(call_node, "args") and len(call_node.args) > 0 and hasattr(call_node.args[0], "value") and call_node.args[0].value is True ): self._has_deterministic_algorithm_option = True if( self._import_pytorch is True and self._has_deterministic_algorithm_option is False ): self.add_message("deterministic-pytorch", node=module) except: # pylint: disable = bare-except ExceptionHandler.handle(self, module)
def visit_call(self, call_node: astroid.Call): """ When a Call node is visited, check whether it violated the rules in this checker. :param call_node: The node which is visited. """ try: # If the learning function is called on an estimator, rule is violated. if (call_node.func is not None and hasattr(call_node.func, "attrname") and call_node.func.attrname in self.LEARNING_FUNCTIONS and hasattr(call_node.func, "expr") and self._expr_is_estimator(call_node.func.expr) and hasattr(call_node, "args")): has_learning_function = True has_preprocessing_function = False for arg in call_node.args: if isinstance(arg, astroid.Name): values = AssignUtil.assignment_values(arg) for value in values: if (isinstance(value, astroid.Call) and hasattr(value, "func") and hasattr(value.func, "expr")): if self._expr_is_preprocessor(value.func.expr): has_preprocessing_function = True if has_learning_function is True and has_preprocessing_function is True: self.add_message("data-leakage-scikitlearn", node=call_node) except: # pylint: disable=bare-except ExceptionHandler.handle(self, call_node)
def visit_module(self, module: astroid.Module): """Visit module and infer which libraries the variables are from.""" try: self._variable_types = TypeInference.infer_library_variable_first_types( module) except: # pylint: disable = bare-except ExceptionHandler.handle(self, module)
def visit_for(self, for_node: astroid.For): """ When a For node is visited, check whether it violated the rule in this checker. :param for_node: The node which is visited. """ try: _has_train = False _has_eval = False for node in for_node.body: if isinstance(node, astroid.Expr): if isinstance(node.value, astroid.Call) and hasattr( node.value.func, "attrname"): if node.value.func.attrname == "train": _has_train = True if node.value.func.attrname == "eval": _has_eval = True if isinstance(node, astroid.If): for if_node in node.body: if (isinstance(if_node, astroid.Expr) and isinstance(if_node.value, astroid.Call) and hasattr(if_node.value.func, "attrname")): if if_node.value.func.attrname == "train": _has_train = True if if_node.value.func.attrname == "eval": _has_eval = True if _has_eval is True and _has_train is False: self.add_message("mode-toggling-pytorch", node=for_node) except: # pylint: disable = bare-except ExceptionHandler.handle(self, for_node)
def visit_import(self, import_node: astroid.Import): """Visit import node to see whether pandas is imported.""" try: for name, _ in import_node.names: if name == "pandas": self._imported_pandas = True except: # pylint: disable = bare-except ExceptionHandler.handle(self, import_node)
def visit_importfrom(self, node: astroid.ImportFrom): """ Check whether there is a scikit-learn import. :param node: import from node """ try: if self._import_ml_libraries is False: self._import_ml_libraries = has_importfrom_sklearn(node) except: # pylint: disable = bare-except ExceptionHandler.handle(self, node)
def visit_for(self, node: astroid.For): """Evaluate whether there is an augmented assign in the loop, it can be replaced by a reduction operation, which is faster.""" try: if (hasattr(node, "body") # there is augmented assign with tf variable in the body and self._augmented_assign_with_tf_variable(node.body)): self.add_message("iteration-tensorflow", node=node) except: # pylint: disable=W0702 ExceptionHandler.handle(self, node)
def visit_import(self, node: astroid.Import): """ Check whether there is a tensorflow import. :param node: import node """ try: if self._import_tensorflow is False: self._import_tensorflow = has_import(node, "tensorflow") except: # pylint: disable = bare-except ExceptionHandler.handle(self, node)
def visit_import(self, node: astroid.Import): """ Check whether there is a pytorch import :param node: import node """ try: if self._import_pytorch is False: self._import_pytorch = has_import(node, "torch") except: # pylint: disable = bare-except ExceptionHandler.handle(self, node)
def visit_call(self, node: astroid.Call): """ When a node is visited, add a message if the rule is violated. :param node: :return: """ try: # If there is no scaler before a scaling-sensitive operarion, the rule is violated. # If pipeline is used if ( hasattr(node, "func") and hasattr(node.func, "name") and node.func.name in self.PIPELINE and hasattr(node, "args") ): has_scaling_sensitive_operation = False has_scaler = False for arg in node.args: if isinstance(arg, astroid.Call): if self._call_initiates_scaler(arg): has_scaler = True if self._call_initiates_scaling_sensitive_operations(arg): has_scaling_sensitive_operation = True break if has_scaling_sensitive_operation is True and has_scaler is False: self.add_message("scaler-missing-scikitlearn", node=node) # If pipeline is not used and a scaling-sensitive operation is called if ( hasattr(node, "func") and hasattr(node.func, "attrname") and node.func.attrname in self.LEARNING_FUNCTIONS and hasattr(node.func, "expr") and self._expr_is_scaling_sensitive_operation(node.func.expr) and hasattr(node, "args") ): has_scaling_sensitive_operation = True has_scaler = False for arg in node.args: if isinstance(arg, astroid.Name): values = AssignUtil.assignment_values(arg) for value in values: if ( isinstance(value, astroid.Call) and hasattr(value, "func") and hasattr(value.func, "expr") ): if self._expr_is_scaler(value.func.expr): has_scaler = True if has_scaling_sensitive_operation is True and has_scaler is False: self.add_message("scaler-missing-scikitlearn", node=node) except: # pylint: disable=bare-except ExceptionHandler.handle(self, node) traceback.print_exc()
def visit_call(self, node: astroid.Call): """ When a Call node is visited, add messages if it violated the defined rules. :param node: call node """ try: if self._result_is_lost(node): self.add_message("inplace-numpy", node=node) except: # pylint: disable=bare-except ExceptionHandler.handle(self, node)
def visit_module(self, node: astroid.Module): """ When an Module node is visited, scan for Call nodes and get type the function is called on. :param node: Node which is visited. """ try: # noinspection PyTypeChecker self._call_types = TypeInference.infer_types( node, astroid.Call, lambda x: x.func.expr.name) except: # pylint: disable=bare-except ExceptionHandler.handle(self, node)
def visit_call(self, node: astroid.Call): """ When a Call node is visited, add messages if it violated the defined rules. :param node: Node which is visited. """ try: if self._iterating_through_dataframe(node): self.add_message("dataframe-iteration-modification-pandas", node=node) except: # pylint: disable=bare-except ExceptionHandler.handle(self, node)
def visit_importfrom(self, importfrom_node: astroid.ImportFrom): """ Check whether there is DataLoader imported. """ try: if (hasattr(importfrom_node, "modname") and importfrom_node.modname == "torch.utils.data" and hasattr(importfrom_node, "names")): for name, _ in importfrom_node.names: if name == "DataLoader": self._import_DataLoader = True except: # pylint: disable = bare-except ExceptionHandler.handle(self, importfrom_node)
def visit_importfrom(self, node: astroid.ImportFrom): """ When an ImportFrom node is visited, check if it follows the conventions. :param node: Node which is visited. """ try: if node.modname[:7] == "sklearn": for _, alias in node.names: if alias is not None: self.add_message("import-sklearn", node=node) except: # pylint: disable=bare-except ExceptionHandler.handle(self, node)
def visit_call(self, node: astroid.Call): """ When a Call node is visited, add messages if it violated the defined rules. :param node: Node which is visited. """ try: if (self._is_simple_call_node(node) and not self._function_whitelisted(node) and self._dataframe_is_lost(node)): self.add_message("inplace-pandas", node=node) except: # pylint: disable=bare-except ExceptionHandler.handle(self, node)
def visit_for(self, for_node: astroid.For): """Visit for node and see whether the rule is violated.""" try: for node in for_node.body: if (isinstance(node, astroid.Assign) and len(node.targets) > 0 and hasattr(node.targets[0], "name") and node.targets[0].name in self._variable_types and self._variable_types[node.targets[0].name] in ["tf.constant", "tensorflow.constant"] and self.infer_call_expression( node.value) in ["tf.concat"]): self.add_message("tensor-array-tensorflow", node=for_node) except: # pylint: disable = bare-except ExceptionHandler.handle(self, for_node)
def visit_call(self, call_node: astroid.Call): """Visit call node to see whether there is rule violation.""" try: if hasattr(call_node, "attrname") and call_node.attrname == "values": self.add_message("dataframe-conversion-pandas", node=call_node) return if (hasattr(call_node, "func") and hasattr(call_node.func, "expr") and hasattr(call_node.func.expr, "attrname") and call_node.func.expr.attrname == "values"): self.add_message("dataframe-conversion-pandas", node=call_node) return except: # pylint: disable = bare-except ExceptionHandler.handle(self, call_node)
def visit_import(self, node: astroid.Import): """ Check whether there is a numpy import and ml library import. :param node: import node """ try: if self._import_numpy is False: self._import_numpy = has_import(node, "numpy") if self._import_ml_libraries is False: self._import_ml_libraries = has_import(node, "sklearn") \ or has_import(node, "torch") \ or has_import(node, "tensorflow") except: # pylint: disable = bare-except ExceptionHandler.handle(self, node)
def visit_call(self, node: astroid.Call): """ When a Call node is visited, check whether hyperparameters are set. In strict mode, all hyperparameters should be set. In non-strict mode, function calls to learning functions should either contain all hyperparameters defined in HYPERPARAMETERS_MAIN or have at least one hyperparameter defined. :param node: Node which is visited. """ try: if (hasattr(node, "func") and hasattr(node.func, "name")): self.hyperparameter_in_class(node, node.func.name) except: # pylint: disable=bare-except ExceptionHandler.handle(self, node)
def visit_compare(self, node: astroid.Compare): """ When a compare node is visited, check whether a comparison is done with np.nan. :param node: Node which is visited. """ try: for side in (node.left, node.ops[0][1]): if isinstance(side, astroid.Attribute) \ and side.attrname == "nan" \ and side.expr.name == "np": self.add_message("nan-numpy", node=node) return except: # pylint: disable=bare-except ExceptionHandler.handle(self, node)
def visit_subscript(self, subscript_node: astroid.Subscript): """Visit subscript node and check whether there is chain indexing.""" try: indexing_num = 0 node = subscript_node # count indexing number while hasattr(node, "value"): indexing_num += 1 node = node.value # if chain indexing is used in the code, and the indexing number is no less than two, the rule is violated. if (hasattr(node, "name") and node.name in self._subscript_types and self._subscript_types[node.name] == "pd.DataFrame" and indexing_num >= 2): self.add_message("chain-indexing-pandas", node=subscript_node) except: # pylint: disable = bare-except ExceptionHandler.handle(self, subscript_node)