def test_dotted_path_annotation_local_type_def(
    python_version, import_collision_policy, expected_import, comment, remove_type
):
    content = '''
class serializers:
    # why would you do this
    pass

def no_op(arg1):
    """
    Args:
        arg1 (serializers.Serializer): blah
    """
    pass
'''

    if comment:
        comment = f"{comment}\n    "

    docstring_type = "" if remove_type else " (serializers.Serializer)"

    expected = f'''{expected_import}
class serializers:
    # why would you do this
    pass

def no_op(arg1):
    {comment}"""
    Args:
        arg1{docstring_type}: blah
    """
    pass
'''

    with tempfile.NamedTemporaryFile(suffix=".py") as f:
        with open(f.name, "w") as fw:
            fw.write(content)

        test_settings = override_settings(
            PYTHON_VERSION=python_version,
            ALLOW_UNTYPED_ARGS=False,
            REQUIRE_RETURN_TYPE=False,
            IMPORT_COLLISION_POLICY=import_collision_policy,
            UNPATHED_TYPE_POLICY=UnpathedTypePolicy.FAIL,
        )
        inject.clear_and_configure(configuration_factory(test_settings))

        annotate(
            f.name, in_process=True, interactive=False, write=True, silent=True,
        )

        with open(f.name, "r") as fr:
            annotated = fr.read()

    assert annotated == expected
def test_decorated_function():
    content = '''
@whatever(param=val)
def decorated(arg1):
    """
    Args:
        arg1 (Tuple[str, ...]): blah

    Returns:
        Tuple[int, ...]: blah
    """
    return tuple(
        int(arg) for arg in arg1
    )
'''

    expected = '''from typing import Tuple


@whatever(param=val)
def decorated(arg1):
    # type: (Tuple[str, ...]) -> Tuple[int, ...]
    """
    Args:
        arg1: blah

    Returns:
        blah
    """
    return tuple(
        int(arg) for arg in arg1
    )
'''

    with tempfile.NamedTemporaryFile(suffix=".py") as f:
        with open(f.name, "w") as fw:
            fw.write(content)

        test_settings = override_settings(
            ALLOW_UNTYPED_ARGS=False,
            REQUIRE_RETURN_TYPE=False,
            IMPORT_COLLISION_POLICY=ImportCollisionPolicy.IMPORT,
            UNPATHED_TYPE_POLICY=UnpathedTypePolicy.FAIL,
        )
        inject.clear_and_configure(configuration_factory(test_settings))

        annotate(
            f.name, in_process=True, interactive=False, write=True, silent=True,
        )

        with open(f.name, "r") as fr:
            annotated = fr.read()

    assert annotated == expected
def test_require_return_type(require_return_type):
    """
    NOTE: here is an example of a function where omitting the "Returns"
    block from the docstring and setting `REQUIRE_RETURN_TYPE=False` will
    give the wrong result (...an argument for `REQUIRE_RETURN_TYPE=True`)
    TODO: I don't know if there is any check for return statements we can
    do via Bowler?
    """
    content = '''
def identity(arg1):
    """
    Args:
        arg1 (Tuple[str, ...]): blah
    """
    return arg1
'''

    if not require_return_type:
        expected = '''from typing import Tuple


def identity(arg1):
    # type: (Tuple[str, ...]) -> None
    """
    Args:
        arg1: blah
    """
    return arg1
'''
    else:
        expected = content

    with tempfile.NamedTemporaryFile(suffix=".py") as f:
        with open(f.name, "w") as fw:
            fw.write(content)

        test_settings = override_settings(
            ALLOW_UNTYPED_ARGS=False,
            REQUIRE_RETURN_TYPE=require_return_type,
            IMPORT_COLLISION_POLICY=ImportCollisionPolicy.IMPORT,
            UNPATHED_TYPE_POLICY=UnpathedTypePolicy.FAIL,
        )
        inject.clear_and_configure(configuration_factory(test_settings))

        annotate(
            f.name, in_process=True, interactive=False, write=True, silent=True,
        )

        with open(f.name, "r") as fr:
            annotated = fr.read()

    assert annotated == expected
def test_allow_untyped_args(allow_untyped_args):
    content = '''
def identity(arg1):
    """
    Args:
        arg1: blah

    Returns:
        Tuple[str, ...]: blah
    """
    return arg1
'''

    if allow_untyped_args:
        expected = '''from typing import Tuple


def identity(arg1):
    # type: (...) -> Tuple[str, ...]
    """
    Args:
        arg1: blah

    Returns:
        blah
    """
    return arg1
'''
    else:
        expected = content

    with tempfile.NamedTemporaryFile(suffix=".py") as f:
        with open(f.name, "w") as fw:
            fw.write(content)

        test_settings = override_settings(
            ALLOW_UNTYPED_ARGS=allow_untyped_args,
            REQUIRE_RETURN_TYPE=False,
            IMPORT_COLLISION_POLICY=ImportCollisionPolicy.IMPORT,
            UNPATHED_TYPE_POLICY=UnpathedTypePolicy.FAIL,
        )
        inject.clear_and_configure(configuration_factory(test_settings))

        annotate(
            f.name, in_process=True, interactive=False, write=True, silent=True,
        )

        with open(f.name, "r") as fr:
            annotated = fr.read()

    assert annotated == expected
def test_staticmethod():
    """
    First arg *is* annotatable
    """
    content = '''
class SomeClass:
    @staticmethod
    def method(obj, whatever):
        """
        Args:
            obj (object)
            whatever (Any)
        """
        pass
'''

    expected = '''from typing import Any


class SomeClass:
    @staticmethod
    def method(obj, whatever):
        # type: (object, Any) -> None
        """
        Args:
            obj
            whatever
        """
        pass
'''

    with tempfile.NamedTemporaryFile(suffix=".py") as f:
        with open(f.name, "w") as fw:
            fw.write(content)

        test_settings = override_settings(
            ALLOW_UNTYPED_ARGS=False,
            REQUIRE_RETURN_TYPE=False,
            IMPORT_COLLISION_POLICY=ImportCollisionPolicy.IMPORT,
            UNPATHED_TYPE_POLICY=UnpathedTypePolicy.FAIL,
        )
        inject.clear_and_configure(configuration_factory(test_settings))

        annotate(
            f.name, in_process=True, interactive=False, write=True, silent=True,
        )

        with open(f.name, "r") as fr:
            annotated = fr.read()

    assert annotated == expected
def test_yields_generator():
    content = '''
def generator(arg1):
    """
    Args:
        arg1 (Iterable[int]): blah

    Yields:
        int: blah
    """
    for val in arg1:
        yield val
'''

    expected = '''from typing import Generator, Iterable


def generator(arg1):
    # type: (Iterable[int]) -> Generator[int, None, None]
    """
    Args:
        arg1: blah

    Yields:
        blah
    """
    for val in arg1:
        yield val
'''

    with tempfile.NamedTemporaryFile(suffix=".py") as f:
        with open(f.name, "w") as fw:
            fw.write(content)

        test_settings = override_settings(
            ALLOW_UNTYPED_ARGS=False,
            REQUIRE_RETURN_TYPE=False,
            IMPORT_COLLISION_POLICY=ImportCollisionPolicy.IMPORT,
            UNPATHED_TYPE_POLICY=UnpathedTypePolicy.FAIL,
        )
        inject.clear_and_configure(configuration_factory(test_settings))

        annotate(
            f.name, in_process=True, interactive=False, write=True, silent=True,
        )

        with open(f.name, "r") as fr:
            annotated = fr.read()

    assert annotated == expected
def test_dotted_path_decorator():
    # NOTE: this example is an instance method, so first arg is not annotatable
    content = '''
class SomeClass:
    @some_package.decorator
    def method(cls, obj):
        """
        Args:
            obj (object)

        Returns:
            int
        """
        return 1
'''

    expected = '''
class SomeClass:
    @some_package.decorator
    def method(cls, obj):
        # type: (object) -> int
        """
        Args:
            obj
        """
        return 1
'''

    with tempfile.NamedTemporaryFile(suffix=".py") as f:
        with open(f.name, "w") as fw:
            fw.write(content)

        test_settings = override_settings(
            ALLOW_UNTYPED_ARGS=False,
            REQUIRE_RETURN_TYPE=False,
            IMPORT_COLLISION_POLICY=ImportCollisionPolicy.IMPORT,
            UNPATHED_TYPE_POLICY=UnpathedTypePolicy.FAIL,
        )
        inject.clear_and_configure(configuration_factory(test_settings))

        annotate(
            f.name, in_process=True, interactive=False, write=True, silent=True,
        )

        with open(f.name, "r") as fr:
            annotated = fr.read()

    assert annotated == expected
示例#8
0
def test_find_local_types(python_version):
    expected = LocalTypes.factory(
        type_defs={
            "T",
            "TopLevel",
            "InnerClass",
            "SomeTuple",
            "SomeTypedTuple",
            "SomeTypedDict",
            "NewClass",
        },
        star_imports={"serious"},
        names_to_packages={
            "Irrelevant": "..sub",
            "Nonsense": "..sub",
            "ReallyUnused": "..sub",
            "Product": "other.module",
            "Imported": "some.module",
            "ConditionallyImported": "some.module",
            "InnerImported": "some.module",
            "namedtuple": "collections",
            "NamedTuple": "typing",
            "TypedDict": "typing",
            "TypeVar": "typing",
            "Union": "typing",
        },
        package_imports={"logging", "nott.so.serious"},
        signatures={
            34: ("static", ("cls", "self")),
            39: ("clsmethod", ("self", )),
            49: ("method", ("cls", )),
            53: ("conditionally_defined_method", ("cls", )),
            57: ("first", ("products", "getter")),
            77: ("second", ("products", "getter")),
            97: ("second_inner", ("product", "key", "default")),
            118: ("third", ("product_ids", "user_id")),
            133: ("fourth", ("product_id", "user_id")),
            141: ("fifth", ("product_ids", "user_id")),
            158: ("sixth", ("product_ids", )),
        },
    )

    test_settings = override_settings(PYTHON_VERSION=python_version)
    inject.clear_and_configure(configuration_factory(test_settings))

    result = find_local_types("tests/fixtures/napoleon.py")

    assert result == expected
def test_returns_none(python_version):
    content = '''
def no_op(arg1):
    """
    Args:
        arg1 (Tuple[str, ...]): blah

    Returns:
        None
    """
    pass
'''

    # "Returns" block omitted since there was no description
    expected = '''from typing import Tuple


def no_op(arg1):
    # type: (Tuple[str, ...]) -> None
    """
    Args:
        arg1: blah
    """
    pass
'''

    with tempfile.NamedTemporaryFile(suffix=".py") as f:
        with open(f.name, "w") as fw:
            fw.write(content)

        test_settings = override_settings(
            PYTHON_VERSION=python_version,
            ALLOW_UNTYPED_ARGS=False,
            REQUIRE_RETURN_TYPE=True,
            IMPORT_COLLISION_POLICY=ImportCollisionPolicy.IMPORT,
            UNPATHED_TYPE_POLICY=UnpathedTypePolicy.FAIL,
        )
        inject.clear_and_configure(configuration_factory(test_settings))

        annotate(
            f.name, in_process=True, interactive=False, write=True, silent=True,
        )

        with open(f.name, "r") as fr:
            annotated = fr.read()

    assert annotated == expected
def test_allow_missing_args_section_no_args_func(allow_untyped_args):
    """
    If func has no args but 'Returns' is given then we should be
    able to annotate it.
    """
    content = '''
def identity():
    """
    Returns:
        Tuple[str, ...]: blah
    """
    return arg1
'''

    expected = '''from typing import Tuple


def identity():
    # type: () -> Tuple[str, ...]
    """
    Returns:
        blah
    """
    return arg1
'''

    with tempfile.NamedTemporaryFile(suffix=".py") as f:
        with open(f.name, "w") as fw:
            fw.write(content)

        test_settings = override_settings(
            ALLOW_UNTYPED_ARGS=allow_untyped_args,
            REQUIRE_RETURN_TYPE=False,
            IMPORT_COLLISION_POLICY=ImportCollisionPolicy.IMPORT,
            UNPATHED_TYPE_POLICY=UnpathedTypePolicy.FAIL,
        )
        inject.clear_and_configure(configuration_factory(test_settings))

        annotate(
            f.name, in_process=True, interactive=False, write=True, silent=True,
        )

        with open(f.name, "r") as fr:
            annotated = fr.read()

    assert annotated == expected
def test_handle_splat_args():
    content = '''
def no_op(arg1, *args, **kwargs):
    """
    Args:
        arg1 (str): blah
        *args (int)
        **kwargs (Tuple[bool, ...])
    """
    return
'''

    expected = '''from typing import Tuple


def no_op(arg1, *args, **kwargs):
    # type: (str, *int, **Tuple[bool, ...]) -> None
    """
    Args:
        arg1: blah
        *args
        **kwargs
    """
    return
'''

    with tempfile.NamedTemporaryFile(suffix=".py") as f:
        with open(f.name, "w") as fw:
            fw.write(content)

        test_settings = override_settings(
            ALLOW_UNTYPED_ARGS=False,
            REQUIRE_RETURN_TYPE=False,
            IMPORT_COLLISION_POLICY=ImportCollisionPolicy.IMPORT,
            UNPATHED_TYPE_POLICY=UnpathedTypePolicy.FAIL,
        )
        inject.clear_and_configure(configuration_factory(test_settings))

        annotate(
            f.name, in_process=True, interactive=False, write=True, silent=True,
        )

        with open(f.name, "r") as fr:
            annotated = fr.read()

    assert annotated == expected
def test_property():
    """
    First arg is not annotatable
    """
    content = '''
class SomeClass:
    @property
    def method(obj):
        """
        Returns:
            int
        """
        return 1
'''

    expected = '''
class SomeClass:
    @property
    def method(obj):
        # type: () -> int
        """
        """
        return 1
'''

    with tempfile.NamedTemporaryFile(suffix=".py") as f:
        with open(f.name, "w") as fw:
            fw.write(content)

        test_settings = override_settings(
            ALLOW_UNTYPED_ARGS=False,
            REQUIRE_RETURN_TYPE=False,
            IMPORT_COLLISION_POLICY=ImportCollisionPolicy.IMPORT,
            UNPATHED_TYPE_POLICY=UnpathedTypePolicy.FAIL,
        )
        inject.clear_and_configure(configuration_factory(test_settings))

        annotate(
            f.name, in_process=True, interactive=False, write=True, silent=True,
        )

        with open(f.name, "r") as fr:
            annotated = fr.read()

    assert annotated == expected
def test_package_imports(python_version, import_line, arg_type):
    content = f'''{import_line}

def no_op(arg1):
    """
    Args:
        arg1 ({arg_type}): blah
    """
    pass
'''

    expected = f'''{import_line}

def no_op(arg1):
    # type: ({arg_type}) -> None
    """
    Args:
        arg1: blah
    """
    pass
'''

    with tempfile.NamedTemporaryFile(suffix=".py") as f:
        with open(f.name, "w") as fw:
            fw.write(content)

        test_settings = override_settings(
            PYTHON_VERSION=python_version,
            ALLOW_UNTYPED_ARGS=False,
            REQUIRE_RETURN_TYPE=False,
            IMPORT_COLLISION_POLICY=ImportCollisionPolicy.IMPORT,
            UNPATHED_TYPE_POLICY=UnpathedTypePolicy.FAIL,
        )
        inject.clear_and_configure(configuration_factory(test_settings))

        annotate(
            f.name, in_process=True, interactive=False, write=True, silent=True,
        )

        with open(f.name, "r") as fr:
            annotated = fr.read()

    assert annotated == expected
def test_arg_annotation_signature_mismatch(signature, arg_annotations):
    annotations = "\n".join(
        f"        {name} ({type_}): {description}"
        for name, (type_, description) in arg_annotations.items()
    )
    content = f'''
def no_op({signature}):
    """
    Args:
{annotations}
    """
    pass
'''

    # in all cases we should not have annotated
    expected = content

    with tempfile.NamedTemporaryFile(suffix=".py") as f:
        with open(f.name, "w") as fw:
            fw.write(content)

        test_settings = override_settings(
            ALLOW_UNTYPED_ARGS=True,
            REQUIRE_RETURN_TYPE=False,
            IMPORT_COLLISION_POLICY=ImportCollisionPolicy.IMPORT,
            UNPATHED_TYPE_POLICY=UnpathedTypePolicy.FAIL,
        )
        inject.clear_and_configure(configuration_factory(test_settings))

        annotate(
            f.name, in_process=True, interactive=False, write=True, silent=True,
        )

        with open(f.name, "r") as fr:
            annotated = fr.read()

    assert annotated == expected
示例#15
0
def main(settings):
    parser = argparse.ArgumentParser(description=(
        "Convert the type annotations in 'Google-style' docstrings (as"
        " understood by e.g. Sphinx's Napoleon docs plugin) "
        " into PEP-484 type comments which can be checked statically"
        " using `mypy --py2`"), )

    subparsers = parser.add_subparsers(
        dest="subparser",
        title="commands",
    )

    subparsers.add_parser(
        "version",
        help="Echo current waterloo version.",
    )

    annotate_cmd = subparsers.add_parser(
        "annotate",
        help="Annotate a file or set of files.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    annotate_cmd.add_argument(
        "files",
        metavar="F",
        type=str,
        nargs="+",  # required
        help="List of file or directory paths to process.",
    )

    annotation_group = annotate_cmd.add_argument_group("annotation options")
    annotation_group.add_argument(
        "-p",
        "--python-version",
        type=str,
        default=settings.PYTHON_VERSION,
        help="We can refactor either Python 2 or Python 3 source files but "
        "the underlying bowler+fissix libraries need to know which grammar "
        "to use (to know if `print` is a statement or a function). In Py2 "
        "mode, `print` will be auto-detected based on whether a `from "
        "__future__ import print_function` is found. For Py3 files `print` "
        "can only be a function. We also use `parso` library which can "
        "benefit from knowing <major>.<minor> version of your sources.",
    )
    annotation_group.add_argument(
        "-aa",
        "--allow-untyped-args",
        action="store_true",
        default=settings.ALLOW_UNTYPED_ARGS,
        help="If any args or return types are found in the docstring we can "
        "attempt to output a type annotation. If arg types are missing or "
        "incomplete, default behaviour is to raise an error. If this flag "
        "is set we will instead output an annotation like `(...) -> returnT` "
        "which mypy will treat as if all args are `Any`.",
    )
    annotation_group.add_argument(
        "-rr",
        "--require-return-type",
        action="store_true",
        default=settings.REQUIRE_RETURN_TYPE,
        help="If any args or return types are found in the docstring we can "
        "attempt to output a type annotation. If the return type is missing "
        "our default behaviour is to assume function should be annotated as "
        "returning `-> None`. If this flag is set we will instead raise an "
        "error.",
    )
    annotation_group.add_argument(
        "-ic",
        "--import-collision-policy",
        default=settings.IMPORT_COLLISION_POLICY.name,
        choices=[m.name for m in ImportCollisionPolicy],
        help="There are some cases where it is ambiguous whether we need to "
        "add an import for your documented type. This can occur if you gave "
        "a dotted package path but there is already a matching `from package "
        "import *`, or a relative import of same type name. In both cases it "
        "is safest for us to add a new specific import for your type, but it "
        "may be redundant. The default option IMPORT will add imports. The "
        "NO_IMPORT option will annotate without adding imports, and will also "
        "show a warning message. FAIL will print an error and won't add any "
        "annotation.",
    )
    annotation_group.add_argument(
        "-ut",
        "--unpathed-type-policy",
        default=settings.UNPATHED_TYPE_POLICY.name,
        choices=[m.name for m in UnpathedTypePolicy],
        help="There are some cases where we cannot determine an appropriate "
        "import to add - when your types do not have a dotted path and we "
        "can't find a matching type in builtins, typing package or locals. "
        "When policy is IGNORE we will annotate as documented, you will need "
        "to resolve any errors raised by mypy manually. WARN option will "
        "annotate as documented but also display a warning. FAIL will print "
        "an error and won't add any annotation.",
    )

    apply_group = annotate_cmd.add_argument_group("apply options")
    apply_group.add_argument(
        "-w",
        "--write",
        action="store_true",
        default=False,
        help="Whether to apply the changes to target files. Without this "
        "flag set waterloo will just perform a 'dry run'.",
    )
    apply_group.add_argument(
        "-s",
        "--show-diff",
        action="store_true",
        default=False,
        help="Whether to print the hunk diffs to be applied.",
    )
    apply_group.add_argument(
        "-i",
        "--interactive",
        action="store_true",
        default=False,
        help="Whether to prompt about applying each diff hunk.",
    )

    logging_group = annotate_cmd.add_argument_group("logging options")
    logging_group.add_argument(
        "-l",
        "--enable-logging",
        action="store_true",
        default=False,
        help="Enable structured logging to stderr.",
    )
    logging_group.add_argument(
        "-ll",
        "--log-level",
        default=settings.LOG_LEVEL.name,
        choices=[m.name for m in LogLevel],
        help="Set the log level for stderr logging.",
    )
    echo_group = logging_group.add_mutually_exclusive_group()
    echo_group.add_argument(
        "-q",
        "--quiet",
        action="store_true",
        default=False,
        help=
        "'quiet' mode for minimal details on stdout (filenames, summary stats only).",
    )
    echo_group.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        default=True,  # if this defaulted to False we'd have 3 levels
        help=
        "'verbose' mode for informative details on stdout (inc. warnings with suggested remedies).",
    )

    args = parser.parse_args()

    if args.subparser == "version":
        print(__version__)
        return
    elif args.subparser == "annotate":
        settings.PYTHON_VERSION = args.python_version

        settings.ALLOW_UNTYPED_ARGS = args.allow_untyped_args
        settings.REQUIRE_RETURN_TYPE = args.require_return_type

        settings.IMPORT_COLLISION_POLICY = args.import_collision_policy
        settings.UNPATHED_TYPE_POLICY = args.unpathed_type_policy

        if args.enable_logging:
            settings.LOG_LEVEL = args.log_level
        else:
            settings.LOG_LEVEL = LogLevel.DISABLED
        settings.VERBOSE_ECHO = args.verbose and not args.quiet

        inject.clear_and_configure(configuration_factory(settings))

        annotate(
            *args.files,
            interactive=args.interactive,
            write=args.write,
            silent=not args.show_diff,
        )
    else:
        parser.print_usage()
def test_fully_typed(
    import_collision_policy: ImportCollisionPolicy,
    unpathed_type_policy: UnpathedTypePolicy,
    type_and_imports: Tuple[DocType, Dict[str, ImportT]],
    settings,
):
    (type_name, type_map, import_map), imports = type_and_imports
    note(f"{type_and_imports}")
    note(f"import_collision_policy: {import_collision_policy}")
    note(f"unpathed_type_policy: {unpathed_type_policy}")
    assert type_map.keys() >= imports.keys()

    expect_annotated = True
    expected_import_strategy = {}
    class_defs_needed = set()
    for name, import_t in imports.items():
        type_def, type_src = type_map[name]
        if isinstance(import_t, NonAmbiguousImport):
            if import_t.import_type is ImportType.FROM_PKG_IMPORT_NAME:
                expected_import_strategy[name] = ImportStrategy.USE_EXISTING
            elif import_t.import_type is ImportType.PACKAGE_ONLY:
                expected_import_strategy[
                    name] = ImportStrategy.USE_EXISTING_DOTTED
            else:
                raise TypeError(import_t.import_type)
        elif isinstance(import_t, AmbiguousImport):
            if import_t.ambiguity is Ambiguity.NAME_CLASH:
                expected_import_strategy[name] = ImportStrategy.ADD_DOTTED
            else:
                if import_collision_policy is ImportCollisionPolicy.FAIL:
                    expect_annotated = False
                elif import_collision_policy is ImportCollisionPolicy.IMPORT:
                    if import_t.ambiguity is Ambiguity.STAR_IMPORT:
                        expected_import_strategy[
                            name] = ImportStrategy.ADD_FROM
                    elif import_t.ambiguity is Ambiguity.RELATIVE_IMPORT:
                        expected_import_strategy[
                            name] = ImportStrategy.ADD_DOTTED
                    else:
                        raise TypeError(import_t.ambiguity)
                else:
                    expected_import_strategy[
                        name] = ImportStrategy.USE_EXISTING
        elif isinstance(import_t, NoImport):
            if import_t is NoImport.NOT_NEEDED:
                expected_import_strategy[name] = ImportStrategy.USE_EXISTING
            elif import_t is NoImport.LOCAL_CLS:
                if type_def is TypeDef.DOTTED_PATH:
                    if import_collision_policy is ImportCollisionPolicy.FAIL:
                        expect_annotated = False
                    expected_import_strategy[name] = (
                        ImportStrategy.ADD_DOTTED if
                        import_collision_policy is ImportCollisionPolicy.IMPORT
                        else ImportStrategy.USE_EXISTING)
                else:
                    expected_import_strategy[
                        name] = ImportStrategy.USE_EXISTING
                class_defs_needed.add(name.rsplit(".", 1)[-1])
            elif import_t is NoImport.MISSING:
                if type_def is TypeDef.DOTTED_PATH:
                    expected_import_strategy[name] = ImportStrategy.ADD_FROM
                elif type_src is TypeSrc.TYPING:
                    expected_import_strategy[name] = ImportStrategy.ADD_FROM
                elif type_src is TypeSrc.USER:
                    if unpathed_type_policy is UnpathedTypePolicy.FAIL:
                        expect_annotated = False
                    expected_import_strategy[
                        name] = ImportStrategy.USE_EXISTING
                else:
                    # builtin / non-type
                    expected_import_strategy[
                        name] = ImportStrategy.USE_EXISTING
            else:
                raise TypeError(import_t)
        else:
            raise TypeError(import_t)

    note(f"expect_annotated: {expect_annotated}")
    note(f"expected_import_strategy: {expected_import_strategy}")

    import_lines = "\n".join(i.import_statement for i in imports.values()
                             if not isinstance(i, NoImport))

    class_defs = "\n\n\n".join(f"""
class {cls_name}(object):
    pass""" for cls_name in class_defs_needed)

    funcdef = f'''
def identity(arg1):
    """
    Args:
        arg1 ({type_name}): blah

    Returns:
        {type_name}: blah
    """
    return arg1
'''

    content = "\n\n".join(var for var in (import_lines, class_defs, funcdef)
                          if var)
    note(content)

    NO_ADD_STRATEGIES = {
        ImportStrategy.USE_EXISTING,
        ImportStrategy.USE_EXISTING_DOTTED,
    }
    if expect_annotated:
        # add imports expected to have been added
        imports_to_add = []
        for name, strategy in expected_import_strategy.items():
            if strategy in NO_ADD_STRATEGIES:
                continue
            package, module = import_map[name]
            if strategy is ImportStrategy.ADD_FROM:
                imports_to_add.append(f"from {package} import {module}")
            elif strategy is ImportStrategy.ADD_DOTTED:
                imports_to_add.append(f"import {package}")

        atom = type_atom.parse(type_name)
        annotated_type_name = atom.to_annotation(expected_import_strategy)

        added_imports = "\n".join(imports_to_add)
    else:
        added_imports = ""

    if expect_annotated:
        funcdef = f'''
def identity(arg1):
    # type: ({annotated_type_name}) -> {annotated_type_name}
    """
    Args:
        arg1: blah

    Returns:
        blah
    """
    return arg1
'''

    if added_imports:
        if import_lines:
            import_lines = f"{import_lines}\n{added_imports}"
        else:
            import_lines = added_imports

    expected = "\n\n".join(var for var in (import_lines, class_defs, funcdef)
                           if var)
    note(expected)

    with tempfile.NamedTemporaryFile(suffix=".py") as f:
        with open(f.name, "w") as fw:
            fw.write(content)

        test_settings = settings.copy(deep=True)
        test_settings.ALLOW_UNTYPED_ARGS = False
        test_settings.REQUIRE_RETURN_TYPE = False
        test_settings.IMPORT_COLLISION_POLICY = import_collision_policy
        test_settings.UNPATHED_TYPE_POLICY = unpathed_type_policy

        inject.clear_and_configure(configuration_factory(test_settings))

        annotate(
            f.name,
            in_process=True,
            interactive=False,
            write=True,
            silent=True,
        )

        with open(f.name, "r") as fr:
            annotated = fr.read()

    assert annotated == expected
def test_py3_syntax():
    content = '''
class OtherClass(object):
    raise ValueError("WTF python 2")


def one(arg2):
    """
    Args:
        arg2 (str): blah

    Returns:
        my.module.SomeClass[str]: blah
    """
    try:
        SomeClass([])
    except TypeError as e:
        pass
    return SomeClass(arg2)


def two(arg1):
    """
    Args:
        arg1 (Tuple[str, ...]): blah

    Returns:
        Tuple[my.module.SomeClass[str], ...]: blah
    """
    print("print function with kwarg", end='end')
    return tuple(
        one(arg) for arg in arg1
    )
'''

    expected = '''from my.module import SomeClass
from typing import Tuple


class OtherClass(object):
    raise ValueError("WTF python 2")


def one(arg2):
    # type: (str) -> SomeClass[str]
    """
    Args:
        arg2: blah

    Returns:
        blah
    """
    try:
        SomeClass([])
    except TypeError as e:
        pass
    return SomeClass(arg2)


def two(arg1):
    # type: (Tuple[str, ...]) -> Tuple[SomeClass[str], ...]
    """
    Args:
        arg1: blah

    Returns:
        blah
    """
    print("print function with kwarg", end='end')
    return tuple(
        one(arg) for arg in arg1
    )
'''

    with tempfile.NamedTemporaryFile(suffix=".py") as f:
        with open(f.name, "w") as fw:
            fw.write(content)

        test_settings = override_settings(
            PYTHON_VERSION="3.6",
            ALLOW_UNTYPED_ARGS=False,
            REQUIRE_RETURN_TYPE=False,
            IMPORT_COLLISION_POLICY=ImportCollisionPolicy.IMPORT,
            UNPATHED_TYPE_POLICY=UnpathedTypePolicy.FAIL,
        )
        inject.clear_and_configure(configuration_factory(test_settings))

        annotate(
            f.name, in_process=True, interactive=False, write=True, silent=True,
        )

        with open(f.name, "r") as fr:
            annotated = fr.read()

    assert annotated == expected
def test_arg_annotation_signature_validate(signature, arg_annotations):
    annotations = "\n".join(
        f"        {name} ({type_}): {description}"
        for name, (type_, description) in arg_annotations.items()
    )
    content = f'''
def no_op({signature}):
    """
    Args:
{annotations}
    """
    pass
'''

    def splatify(name, type_):
        if name.startswith("**"):
            return f"**{type_}"
        elif name.startswith("*"):
            return f"*{type_}"
        else:
            return type_

    # I guess this is an 'oracle' i.e. an alternate implementation (meh)
    stripped_annotations = "\n".join(
        f"        {name}: {description}"
        for name, (_, description) in arg_annotations.items()
    )
    str_types = ", ".join(
        splatify(name, type_) for name, (type_, _) in arg_annotations.items()
    )
    type_comment = f"# type: ({str_types}) -> None"

    # only builtin types in examples, no imports needed
    expected = f'''
def no_op({signature}):
    {type_comment}
    """
    Args:
{stripped_annotations}
    """
    pass
'''

    with tempfile.NamedTemporaryFile(suffix=".py") as f:
        with open(f.name, "w") as fw:
            fw.write(content)

        test_settings = override_settings(
            PYTHON_VERSION="3.8",  # keyword-only args
            ALLOW_UNTYPED_ARGS=True,
            REQUIRE_RETURN_TYPE=False,
            IMPORT_COLLISION_POLICY=ImportCollisionPolicy.IMPORT,
            UNPATHED_TYPE_POLICY=UnpathedTypePolicy.FAIL,
        )
        inject.clear_and_configure(configuration_factory(test_settings))

        annotate(
            f.name, in_process=True, interactive=False, write=True, silent=True,
        )

        with open(f.name, "r") as fr:
            annotated = fr.read()

    assert annotated == expected