def test_arg_value_absent_keyword(): code = "MXNet(entry_point='run')" with pytest.raises(KeyError) as e: parsing.arg_value(ast_call(code), "framework_version") assert "arg 'framework_version' not found in call: {}".format(code) in str(e.value)
def test_arg_value(): call = ast_call("MXNet(framework_version='1.6.0')") assert "1.6.0" == parsing.arg_value(call, "framework_version") call = ast_call("MXNet(framework_version=mxnet_version)") assert "mxnet_version" == parsing.arg_value(call, "framework_version") call = ast_call("MXNet(instance_count=1)") assert 1 == parsing.arg_value(call, "instance_count") call = ast_call("MXNet(enable_network_isolation=True)") assert parsing.arg_value(call, "enable_network_isolation") is True call = ast_call("MXNet(source_dir=None)") assert parsing.arg_value(call, "source_dir") is None
def has_arg(node, arg): """Checks if the call has the given argument. Args: node (ast.Call): a node that represents a function call. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. arg (str): the name of the argument. Returns: bool: if the node has the given argument. """ try: return parsing.arg_value(node, arg) is not None except KeyError: return False
def modify_node(self, node): """Modifies the ``ast.Call`` node's keywords to include ``framework_version``. The ``framework_version`` value is determined by the framework: - Chainer: "4.1.0" - MXNet: "1.2.0" - PyTorch: "0.4.0" - SKLearn: "0.20.0" - TensorFlow: "1.11.0" The ``py_version`` value is determined by the framework, framework_version, and if it is a model, whether the model accepts a py_version Args: node (ast.Call): a node that represents the constructor of a framework class. Returns: ast.AST: the original node, which has been potentially modified. """ framework, is_model = _framework_from_node(node) # if framework_version is not supplied, get default and append keyword if matching.has_arg(node, FRAMEWORK_ARG): framework_version = parsing.arg_value(node, FRAMEWORK_ARG) else: framework_version = FRAMEWORK_DEFAULTS[framework] node.keywords.append( ast.keyword(arg=FRAMEWORK_ARG, value=ast.Str(s=framework_version))) # if py_version is not supplied, get a conditional default, and if not None, append keyword if not matching.has_arg(node, PY_ARG): py_version = _py_version_defaults(framework, framework_version, is_model) if py_version: node.keywords.append( ast.keyword(arg=PY_ARG, value=ast.Str(s=py_version))) return node
def _version_args_needed(node): """Determines if image_arg or version_arg was supplied Applies similar logic as ``validate_version_or_image_args`` """ # if image_arg is present, no need to supply version arguments if matching.has_arg(node, IMAGE_ARG): return False # if framework_version is None, need args if matching.has_arg(node, FRAMEWORK_ARG): framework_version = parsing.arg_value(node, FRAMEWORK_ARG) else: return True # check if we expect py_version and we don't get it -- framework and model dependent framework, is_model = _framework_from_node(node) expecting_py_version = _py_version_defaults(framework, framework_version, is_model) if expecting_py_version: return not matching.has_arg(node, PY_ARG) return False