def modify_node(self, node): """Modifies the ``ast.Call`` node's keywords to include ``framework_version``. The ``framework_version`` value is determined by the framework: - Chainer: "4.1.0" - MXNet: "1.2.0" - PyTorch: "0.4.0" - SKLearn: "0.20.0" - TensorFlow: "1.11.0" The ``py_version`` value is determined by the framework, framework_version, and if it is a model, whether the model accepts a py_version Args: node (ast.Call): a node that represents the constructor of a framework class. Returns: ast.AST: the original node, which has been potentially modified. """ framework, is_model = _framework_from_node(node) # if framework_version is not supplied, get default and append keyword if matching.has_arg(node, FRAMEWORK_ARG): framework_version = parsing.arg_value(node, FRAMEWORK_ARG) else: framework_version = FRAMEWORK_DEFAULTS[framework] node.keywords.append( ast.keyword(arg=FRAMEWORK_ARG, value=ast.Str(s=framework_version))) # if py_version is not supplied, get a conditional default, and if not None, append keyword if not matching.has_arg(node, PY_ARG): py_version = _py_version_defaults(framework, framework_version, is_model) if py_version: node.keywords.append( ast.keyword(arg=PY_ARG, value=ast.Str(s=py_version))) return node
def _version_args_needed(node): """Determines if image_arg or version_arg was supplied Applies similar logic as ``validate_version_or_image_args`` """ # if image_arg is present, no need to supply version arguments if matching.has_arg(node, IMAGE_ARG): return False # if framework_version is None, need args if matching.has_arg(node, FRAMEWORK_ARG): framework_version = parsing.arg_value(node, FRAMEWORK_ARG) else: return True # check if we expect py_version and we don't get it -- framework and model dependent framework, is_model = _framework_from_node(node) expecting_py_version = _py_version_defaults(framework, framework_version, is_model) if expecting_py_version: return not matching.has_arg(node, PY_ARG) return False
def node_should_be_modified(self, node): """Checks if the node matches any of the relevant functions and contains the parameter to be renamed. Args: node (ast.Call): a node that represents a function call. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. Returns: bool: If the ``ast.Call`` matches the relevant function calls and contains the parameter to be renamed. """ return matching.matches_any(node, self.calls_to_modify) and matching.has_arg( node, self.old_param_name )
def test_has_arg(): assert matching.has_arg(ast_call("MXNet(framework_version=mxnet_version)"), "framework_version") assert not matching.has_arg(ast_call("MXNet()"), "framework_version")