示例#1
0
def test_arg_value_absent_keyword():
    code = "MXNet(entry_point='run')"

    with pytest.raises(KeyError) as e:
        parsing.arg_value(ast_call(code), "framework_version")

    assert "arg 'framework_version' not found in call: {}".format(code) in str(e.value)
示例#2
0
def test_arg_value():
    call = ast_call("MXNet(framework_version='1.6.0')")
    assert "1.6.0" == parsing.arg_value(call, "framework_version")

    call = ast_call("MXNet(framework_version=mxnet_version)")
    assert "mxnet_version" == parsing.arg_value(call, "framework_version")

    call = ast_call("MXNet(instance_count=1)")
    assert 1 == parsing.arg_value(call, "instance_count")

    call = ast_call("MXNet(enable_network_isolation=True)")
    assert parsing.arg_value(call, "enable_network_isolation") is True

    call = ast_call("MXNet(source_dir=None)")
    assert parsing.arg_value(call, "source_dir") is None
def has_arg(node, arg):
    """Checks if the call has the given argument.

    Args:
        node (ast.Call): a node that represents a function call. For more,
            see https://docs.python.org/3/library/ast.html#abstract-grammar.
        arg (str): the name of the argument.

    Returns:
        bool: if the node has the given argument.
    """
    try:
        return parsing.arg_value(node, arg) is not None
    except KeyError:
        return False
    def modify_node(self, node):
        """Modifies the ``ast.Call`` node's keywords to include ``framework_version``.

        The ``framework_version`` value is determined by the framework:

        - Chainer: "4.1.0"
        - MXNet: "1.2.0"
        - PyTorch: "0.4.0"
        - SKLearn: "0.20.0"
        - TensorFlow: "1.11.0"

        The ``py_version`` value is determined by the framework, framework_version, and if it is a
        model, whether the model accepts a py_version

        Args:
            node (ast.Call): a node that represents the constructor of a framework class.

        Returns:
            ast.AST: the original node, which has been potentially modified.
        """
        framework, is_model = _framework_from_node(node)

        # if framework_version is not supplied, get default and append keyword
        if matching.has_arg(node, FRAMEWORK_ARG):
            framework_version = parsing.arg_value(node, FRAMEWORK_ARG)
        else:
            framework_version = FRAMEWORK_DEFAULTS[framework]
            node.keywords.append(
                ast.keyword(arg=FRAMEWORK_ARG,
                            value=ast.Str(s=framework_version)))

        # if py_version is not supplied, get a conditional default, and if not None, append keyword
        if not matching.has_arg(node, PY_ARG):
            py_version = _py_version_defaults(framework, framework_version,
                                              is_model)
            if py_version:
                node.keywords.append(
                    ast.keyword(arg=PY_ARG, value=ast.Str(s=py_version)))
        return node
def _version_args_needed(node):
    """Determines if image_arg or version_arg was supplied

    Applies similar logic as ``validate_version_or_image_args``
    """
    # if image_arg is present, no need to supply version arguments
    if matching.has_arg(node, IMAGE_ARG):
        return False

    # if framework_version is None, need args
    if matching.has_arg(node, FRAMEWORK_ARG):
        framework_version = parsing.arg_value(node, FRAMEWORK_ARG)
    else:
        return True

    # check if we expect py_version and we don't get it -- framework and model dependent
    framework, is_model = _framework_from_node(node)
    expecting_py_version = _py_version_defaults(framework, framework_version,
                                                is_model)
    if expecting_py_version:
        return not matching.has_arg(node, PY_ARG)

    return False