Пример #1
0
def test_override_pyproject_toml(platform, monkeypatch, intercepted_build_args, fake_package_dir):

    fake_package_dir.joinpath("pyproject.toml").write_text(
        textwrap.dedent(
            """
            [project]
            requires-python = ">=3.8"
            """
        )
    )

    main()

    options = intercepted_build_args.args[0]
    intercepted_build_selector = options.globals.build_selector

    assert intercepted_build_selector.requires_python == SpecifierSet(">=3.8")

    assert intercepted_build_selector("cp39-win32")
    assert not intercepted_build_selector("cp36-win32")
Пример #2
0
def read_spec(lines):
    candidates = {}
    latest = None
    for line in lines:
        if not line or line.startswith("#"):
            continue
        if not line.startswith(" "):
            name, version = splitstrip(line, 2)
            version = Version(version)
            latest = Candidate(name, version)
            candidates[latest] = set()
        else:
            if latest is None:
                raise RuntimeError(
                    "Spec has dependencies before first candidate"
                )
            name, spec = splitstrip(line, 2)
            spec = SpecifierSet(spec)
            candidates[latest].add(Requirement(name, spec))
    return candidates
Пример #3
0
def check(packages, key, db_mirror, cached, ignore_ids, proxy):
    key = key if key else os.environ.get("SAFETY_API_KEY", False)
    db = fetch_database(key=key, db=db_mirror, cached=cached, proxy=proxy)
    db_full = None
    vulnerable_packages = frozenset(db.keys())
    vulnerable = []
    for pkg in packages:
        # Ignore recursive files not resolved
        if isinstance(pkg, RequirementFile):
            continue

        # normalize the package name, the safety-db is converting underscores to dashes and uses
        # lowercase
        name = pkg.key.replace("_", "-").lower()

        if name in vulnerable_packages:
            # we have a candidate here, build the spec set
            for specifier in db[name]:
                spec_set = SpecifierSet(specifiers=specifier)
                if spec_set.contains(pkg.version):
                    if not db_full:
                        db_full = fetch_database(full=True, key=key, db=db_mirror, cached=cached, proxy=proxy)
                    for data in get_vulnerabilities(pkg=name, spec=specifier, db=db_full):
                        vuln_id = data.get("id").replace("pyup.io-", "")
                        cve_id = data.get("cve")
                        if cve_id:
                            cve_id = cve_id.split(",")[0].strip()
                        if vuln_id and vuln_id not in ignore_ids:
                            cve_meta = db_full.get("$meta", {}).get("cve", {}).get(cve_id, {})
                            vulnerable.append(
                                Vulnerability(
                                    name=name,
                                    spec=specifier,
                                    version=pkg.version,
                                    advisory=data.get("advisory"),
                                    vuln_id=vuln_id,
                                    cvssv2=cve_meta.get("cvssv2", None),
                                    cvssv3=cve_meta.get("cvssv3", None),
                                )
                            )
    return vulnerable
Пример #4
0
def format_package(
    graph: DirectedGraph,
    package: Package,
    required: str = "",
    prefix: str = "",
    visited=None,
) -> str:
    """Format one package.

    :param graph: the dependency graph
    :param package: the package instance
    :param required: the version required by its parent
    :param prefix: prefix text for children
    :param visited: the visited package collection
    """
    if visited is None:
        visited = set()
    result = []
    version = (termui.red("[ not installed ]")
               if not package.version else termui.red(package.version)
               if required and required not in ("Any", "This project")
               and not SpecifierSet(required).contains(package.version) else
               termui.yellow(package.version))
    if package.name in visited:
        version = termui.red("[circular]")
    required = f"[ required: {required} ]" if required else "[ Not required ]"
    result.append(
        f"{termui.green(package.name, bold=True)} {version} {required}\n")
    if package.name in visited:
        return "".join(result)
    visited.add(package.name)
    children = sorted(graph.iter_children(package), key=lambda p: p.name)
    for i, child in enumerate(children):
        is_last = i == len(children) - 1
        head = LAST_CHILD if is_last else NON_LAST_CHILD
        cur_prefix = LAST_PREFIX if is_last else NON_LAST_PREFIX
        required = str(package.requirements[child.name].specifier or "Any")
        result.append(prefix + head +
                      format_package(graph, child, required, prefix +
                                     cur_prefix, visited.copy()))
    return "".join(result)
Пример #5
0
def test_smdataparallel_mnist_script_mode_multigpu(ecr_image, instance_type,
                                                   py_version,
                                                   sagemaker_session, tmpdir):
    """
    Tests SM Distributed DataParallel single-node via script mode
    """
    _, image_framework_version = get_framework_and_version_from_tag(ecr_image)
    if (Version(image_framework_version) in SpecifierSet("<1.6")):
        pytest.skip("Data Parallelism is supported on PyTorch v1.6 and above")

    instance_type = "ml.p3.16xlarge"
    with timeout(minutes=DEFAULT_TIMEOUT):
        pytorch = PyTorch(entry_point='smdataparallel_mnist_script_mode.sh',
                          role='SageMakerRole',
                          image_uri=ecr_image,
                          source_dir=mnist_path,
                          instance_count=1,
                          instance_type=instance_type,
                          sagemaker_session=sagemaker_session)

        pytorch.fit()
Пример #6
0
def get_all_candidates(requirement):
    session = HTMLSession()
    url = f"https://pypi.org/simple/{requirement.key}"
    resp = session.get(url)
    for a in resp.html.find('a'):
        link = a.attrs['href']
        python_requires = a.attrs.get('data-requires-python')
        filename = a.text

        if python_requires:
            spec = SpecifierSet(python_requires)
            if not spec.contains(PYTHON_VERSION):
                # Discard candidates that don't match the Python version.
                continue

        if not filename.endswith(".whl"):
            # Only parse wheels for this demo
            continue
        name, version = filename.split("-")[:2]
        if requirement.specifier.contains(version):
            yield Candidate(name, version, link)
Пример #7
0
def fix_requires_python_marker(requires_python):
    from packaging.requirements import Requirement as PackagingRequirement
    marker_str = ''
    if any(requires_python.startswith(op) for op in Specifier._operators.keys()):
        spec_dict = defaultdict(set)
        # We are checking first if we have  leading specifier operator
        # if not, we can assume we should be doing a == comparison
        specifierset = list(SpecifierSet(requires_python))
        # for multiple specifiers, the correct way to represent that in
        # a specifierset is `Requirement('fakepkg; python_version<"3.0,>=2.6"')`
        marker_key = Variable('python_version')
        for spec in specifierset:
            operator, val = spec._spec
            cleaned_val = Value(val).serialize().replace('"', "")
            spec_dict[Op(operator).serialize()].add(cleaned_val)
        marker_str = ' and '.join([
            "{0}{1}'{2}'".format(marker_key.serialize(), op, ','.join(vals))
            for op, vals in spec_dict.items()
        ])
    marker_to_add = PackagingRequirement('fakepkg; {0}'.format(marker_str)).marker
    return marker_to_add
Пример #8
0
    def _select_version(
        self,
        semantic_version_str: str,
        available_versions: List[Version],
    ) -> Optional[str]:
        """Perform semantic version search on available versions.

        Args:
            semantic_version_str (str): the semantic version for which to filter
                available versions.
            available_versions (List[Version]): list of available versions.
        """
        if semantic_version_str == "*":
            if len(available_versions) == 0:
                return None
            return str(max(available_versions))

        spec = SpecifierSet(f"=={semantic_version_str}")
        available_versions_filtered = list(spec.filter(available_versions))
        return (str(max(available_versions_filtered))
                if available_versions_filtered != [] else None)
Пример #9
0
    def build_extensions(self):
        libcommute_version = self.find_libcommute()
        if libcommute_version:
            print("Found libcommute version " + libcommute_version)
        else:
            raise RuntimeError(
                "Could not find libcommute headers. "
                "Use the LIBCOMMUTE_INCLUDEDIR environment variable "
                "to specify location of libcommute include directory.")

        if Version(libcommute_version) not in \
           SpecifierSet(comp_libcommute_versions):
            raise RuntimeError(
                "Incompatible libcommute version %s (required %s)." %
                (libcommute_version, comp_libcommute_versions))

        for ext in self.extensions:
            ext.include_dirs.append(self.libcommute_includedir)
            ext.cxx_std = 17

        build_ext.build_extensions(self)
Пример #10
0
def test_extension_proxy_legacy():
    extension = LegacyExtension()
    proxy = ExtensionProxy(extension,
                           package_name="foo",
                           package_version="1.2.3")

    assert proxy.extension_uri is None
    assert proxy.legacy_class_names == {
        "asdf.tests.test_extension.LegacyExtension"
    }
    assert proxy.asdf_standard_requirement == SpecifierSet()
    assert proxy.converters == []
    assert proxy.tags == []
    assert proxy.types == [LegacyType]
    assert proxy.tag_mapping == LegacyExtension.tag_mapping
    assert proxy.url_mapping == LegacyExtension.url_mapping
    assert proxy.delegate is extension
    assert proxy.legacy is True
    assert proxy.package_name == "foo"
    assert proxy.package_version == "1.2.3"
    assert proxy.class_name == "asdf.tests.test_extension.LegacyExtension"
Пример #11
0
def test_override_setup_py_simple(platform, monkeypatch,
                                  intercepted_build_args, fake_package_dir):

    fake_package_dir.joinpath("setup.py").write_text(
        textwrap.dedent("""
            from setuptools import setup

            setup(
                name = "other",
                python_requires = ">=3.7",
            )
            """))

    main()

    intercepted_build_selector = intercepted_build_args.args[0].build_selector

    assert intercepted_build_selector.requires_python == SpecifierSet(">=3.7")

    assert intercepted_build_selector("cp39-win32")
    assert not intercepted_build_selector("cp36-win32")
Пример #12
0
def test_ecs_pytorch_training_dgl_cpu(
    cpu_only, py3_only, ecs_container_instance, pytorch_training, training_cmd, ecs_cluster_name
):
    """
    CPU DGL test for PyTorch Training

    Instance Type - c5.12xlarge

    DGL is only supported in py3, hence we have used the "py3_only" fixture to ensure py2 images don't run
    on this function.

    Given above parameters, registers a task with family named after this test, runs the task, and waits for
    the task to be stopped before doing teardown operations of instance and cluster.
    """
    _, image_framework_version = get_framework_and_version_from_tag(pytorch_training)
    # TODO: Remove when DGL gpu test on ecs get fixed
    if Version(image_framework_version) in SpecifierSet("==1.10.*"):
        pytest.skip("ecs test for DGL gpu fails for pt 1.10")
    instance_id, cluster_arn = ecs_container_instance

    ecs_utils.ecs_training_test_executor(ecs_cluster_name, cluster_arn, training_cmd, pytorch_training, instance_id)
Пример #13
0
    def from_pipfile(cls, name, pipfile):
        from .markers import PipenvMarkers

        _pipfile = {}
        if hasattr(pipfile, "keys"):
            _pipfile = dict(pipfile).copy()
        _pipfile["version"] = get_version(pipfile)
        vcs = first([vcs for vcs in VCS_LIST if vcs in _pipfile])
        if vcs:
            _pipfile["vcs"] = vcs
            r = VCSRequirement.from_pipfile(name, pipfile)
        elif any(key in _pipfile for key in ["path", "file", "uri"]):
            r = FileRequirement.from_pipfile(name, pipfile)
        else:
            r = NamedRequirement.from_pipfile(name, pipfile)
        markers = PipenvMarkers.from_pipfile(name, _pipfile)
        req_markers = None
        if markers:
            markers = str(markers)
            req_markers = PackagingRequirement("fakepkg; {0}".format(markers))
        r.req.marker = getattr(req_markers, "marker", None)
        r.req.specifier = SpecifierSet(_pipfile["version"])
        extras = _pipfile.get("extras")
        r.req.extras = (sorted(dedup([extra.lower()
                                      for extra in extras])) if extras else [])
        args = {
            "name": r.name,
            "vcs": vcs,
            "req": r,
            "markers": markers,
            "extras": _pipfile.get("extras"),
            "editable": _pipfile.get("editable", False),
            "index": _pipfile.get("index"),
        }
        if any(key in _pipfile for key in ["hash", "hashes"]):
            args["hashes"] = _pipfile.get("hashes", [pipfile.get("hash")])
        cls_inst = cls(**args)
        if cls_inst.is_named:
            cls_inst.req.req.line = cls_inst.as_line()
        return cls_inst
Пример #14
0
 def test_convert_runway_version(self) -> None:
     """Test _convert_runway_version."""
     assert RunwayConfigDefinitionModel(  # handle string
         runway_version=">1.11.0").runway_version == SpecifierSet(
             ">1.11.0", prereleases=True)
     assert RunwayConfigDefinitionModel(  # handle exact version
         runway_version="1.11.0").runway_version == SpecifierSet(
             "==1.11.0", prereleases=True)
     assert RunwayConfigDefinitionModel(  # handle SpecifierSet
         runway_version=SpecifierSet(
             ">1.11.0")).runway_version == SpecifierSet(">1.11.0",
                                                        prereleases=True)
     assert RunwayConfigDefinitionModel(  # handle SpecifierSet
         runway_version=SpecifierSet(
             ">1.11.0", prereleases=True)).runway_version == SpecifierSet(
                 ">1.11.0", prereleases=True)
Пример #15
0
    def get_requirements(spec, version=None):
        spec = SpecifierSet(spec)
        if version is None:
            matching_versions = [
                version for version in sorted(versions) if version in spec
            ]
            matching_requirements = {
                requirements[version]
                for version in matching_versions
            }
            if len(matching_requirements) == 0:
                raise RuntimeError(
                    f"Unable to determine requirements for specifier '{spec}'."
                )
            elif len(matching_requirements) > 1:
                raise RuntimeError(
                    f"Requirements for specifier '{spec}' are not uniform.")
            reqs = matching_requirements.pop()
        else:
            reqs = requirements[parse(version)]

        return json.dumps({str(spec): reqs})[1:-1]
Пример #16
0
	def process_package(self):
		assert self._attrs is not None
		assert self._data is not None

		base, ext = self.splitext(self._data)

		if ext not in SDIST_EXTS:
			return

		attrs = dict(self._attrs)
		version = self.get_version(base)
		url = self.get_url(attrs['href'])
		hash_type, hash = self.get_hash(attrs['href'])
		requires_python = SpecifierSet(attrs.get('data-requires-python', ""))

		# Prefer extensions in the order given in SDIST_EXT
		if version in self.candidates:
			old_ext = self.splitext(self.candidates[version].url)[1]
			if SDIST_EXTS.index(old_ext) < SDIST_EXTS.index(ext):
				return

		self.candidates[version] = Candidate(self.base_name, version, url, hash_type, hash, requires_python)
Пример #17
0
def eval_specifier(spec, tag, bad_patterns=None):
    # Determine if specifiers are present in spec string
    have_specifier = False
    for ch in spec:
        if ch in SPECIFIERS:
            have_specifier = True

    # When no specifier is present we need to prepend one
    # to satisfy SpecifierSet's basic input requirements
    if not have_specifier:
        spec = '==' + spec

    spec = SpecifierSet(spec)
    tag = normalize_tag(tag)

    try:
        tag = Version(tag)
    except InvalidVersion as e:
        print("{}".format(e), file=sys.stderr)
        tag = ''

    return tag in spec
Пример #18
0
def check(packages):
    db = fetch_database()
    db_full = None
    vulnerable_packages = frozenset(db.keys())
    vulnerable = []
    for pkg in packages:
        # normalize the package name, the safety-db is converting underscores to dashes and uses
        # lowercase
        name = pkg.key.replace("_", "-").lower()

        if name in vulnerable_packages:
            # we have a candidate here, build the spec set
            for specifier in db[name]:
                spec_set = SpecifierSet(specifiers=specifier)
                if spec_set.contains(pkg.version):
                    if not db_full:
                        db_full = fetch_database(full=True)
                    for data in get_vulnerabilities(pkg=name, spec=specifier, db=db_full):
                        vulnerable.append(
                            Vulnerability(name=name, spec=specifier, version=pkg.version, data=data)
                        )
    return vulnerable
Пример #19
0
def format_reverse_package(
        graph: DirectedGraph,
        package: Package,
        child: Package | None = None,
        requires: str = "",
        prefix: str = "",
        visited: frozenset[str] = frozenset(),
) -> str:
    """Format one package for output reverse dependency graph."""
    version = (termui.red("[ not installed ]")
               if not package.version else termui.yellow(package.version))
    if package.name in visited:
        version = termui.red("[circular]")
    requires = (f"[ requires: {termui.red(requires)} ]"
                if requires not in ("Any", "") and child and child.version
                and not SpecifierSet(requires).contains(child.version) else
                "" if not requires else f"[ requires: {requires} ]")
    result = [
        f"{termui.green(package.name, bold=True)} {version} {requires}\n"
    ]
    if package.name in visited:
        return "".join(result)
    parents: list[Package] = sorted(filter(None, graph.iter_parents(package)),
                                    key=lambda p: p.name)
    for i, parent in enumerate(parents):
        is_last = i == len(parents) - 1
        head = LAST_CHILD if is_last else NON_LAST_CHILD
        cur_prefix = LAST_PREFIX if is_last else NON_LAST_PREFIX
        requires = specifier_from_requirement(
            parent.requirements[package.name])
        result.append(prefix + head + format_reverse_package(
            graph,
            parent,
            package,
            requires,
            prefix + cur_prefix,
            visited | {package.name},
        ))
    return "".join(result)
Пример #20
0
def test_install_requires(pkg, requirement: str, version: str,
                          has_specifier: bool):
    """are python packages requirements consistent with other versions?"""
    config = ConfigParser()
    config.read(pkg / "setup.cfg")
    requirements: Dict[str, Requirement] = {
        requirement.name: requirement
        for line in config["options"]["install_requires"].splitlines()
        if line.strip() for requirement in [Requirement(line)]
    }
    assert requirement in requirements
    parsed_specifier = str(requirements[requirement].specifier)
    raw_specifier = version if has_specifier else f">={version}"
    expected_specifier = str(SpecifierSet(raw_specifier))

    if has_specifier:
        assert expected_specifier == parsed_specifier
    else:
        assert Version(version) in requirements[requirement].specifier
        if expected_specifier != parsed_specifier:
            warn(f"Version matches, but specifier might need updating:"
                 f" {requirement} {parsed_specifier}; version: {version}")
Пример #21
0
def test_sm_profiler_pt(pytorch_training):
    processor = get_processor_from_image_uri(pytorch_training)
    if processor not in ("cpu", "gpu"):
        pytest.skip(f"Processor {processor} not supported. Skipping test.")

    _, image_framework_version = get_framework_and_version_from_tag(pytorch_training)
    if Version(image_framework_version) in SpecifierSet(">=1.12"):
        pytest.skip("sm profiler ZCC test is not supported in PT 1.12 and above")

    ctx = Context()

    profiler_tests_dir = os.path.join(
        os.getenv("CODEBUILD_SRC_DIR"), get_container_name("smprof", pytorch_training), "smprofiler_tests"
    )
    ctx.run(f"mkdir -p {profiler_tests_dir}", hide=True)

    # Download sagemaker-tests zip
    sm_tests_zip = "sagemaker-tests.zip"
    ctx.run(
        f"aws s3 cp {os.getenv('SMPROFILER_TESTS_BUCKET')}/{sm_tests_zip} {profiler_tests_dir}/{sm_tests_zip}",
        hide=True,
    )

    # PT test setup requirements
    with ctx.prefix(f"cd {profiler_tests_dir}"):
        ctx.run(f"unzip {sm_tests_zip}", hide=True)
        with ctx.prefix("cd sagemaker-tests/tests/scripts/pytorch_scripts"):
            ctx.run("mkdir -p data", hide=True)
            ctx.run(
                "aws s3 cp s3://smdebug-testing/datasets/cifar-10-python.tar.gz data/cifar-10-batches-py.tar.gz",
                hide=True,
            )
            ctx.run("aws s3 cp s3://smdebug-testing/datasets/MNIST_pytorch.tar.gz data/MNIST_pytorch.tar.gz", hide=True)
            with ctx.prefix("cd data"):
                ctx.run("tar -zxf MNIST_pytorch.tar.gz", hide=True)
                ctx.run("tar -zxf cifar-10-batches-py.tar.gz", hide=True)

    run_sm_profiler_tests(pytorch_training, profiler_tests_dir, "test_profiler_pytorch.py", processor)
Пример #22
0
    def pip_1_4_format(user_agent):
        # We're only concerned about pip user agents.
        if not user_agent.startswith("pip/"):
            return

        # This format was brand new in pip 1.4, and went away in pip 6.0, so
        # we'll need to restrict it to only versions of pip between 1.4 and 6.0
        version_str = user_agent.split()[0].split("/", 1)[1]
        version = packaging.version.parse(version_str)
        if version not in SpecifierSet(">=1.4,<6", prereleases=True):
            return

        _, impl, system = user_agent.split(maxsplit=2)

        data = {
            "installer": {
                "name": "pip",
                "version": version_str,
            },
            "implementation": {
                "name": impl.split("/", 1)[0],
            },
        }

        if not impl.endswith("/Unknown"):
            data["implementation"]["version"] = impl.split("/", 1)[1]

        if not system.startswith("Unknown/"):
            data.setdefault("system", {})["name"] = system.split("/", 1)[0]

        if not system.endswith("/Unknown"):
            data.setdefault("system", {})["release"] = system.split("/", 1)[1]

        if (data["implementation"]["name"].lower() == "cpython" and
                data["implementation"].get("version")):
            data["python"] = data["implementation"]["version"]

        return data
Пример #23
0
def clean_requires_python(candidates):
    """Get a cleaned list of all the candidates with valid specifiers in the `requires_python` attributes."""
    all_candidates = []
    sys_version = '.'.join(map(str, sys.version_info[:3]))
    from packaging.version import parse as parse_version
    py_version = parse_version(os.environ.get('PIP_PYTHON_VERSION', sys_version))
    for c in candidates:
        from_location = attrgetter("location.requires_python")
        requires_python = getattr(c, "requires_python", from_location(c))
        if requires_python:
            # Old specifications had people setting this to single digits
            # which is effectively the same as '>=digit,<digit+1'
            if requires_python.isdigit():
                requires_python = '>={0},<{1}'.format(requires_python, int(requires_python) + 1)
            try:
                specifierset = SpecifierSet(requires_python)
            except InvalidSpecifier:
                continue
            else:
                if not specifierset.contains(py_version):
                    continue
        all_candidates.append(c)
    return all_candidates
Пример #24
0
    def _match_node_at_path(self, key: str, metadata: Dict) -> bool:

        # Grab any tags prepended to key
        tags = key.split(":")

        # Take anything following the last semicolon as the path to the node
        path = tags.pop()

        # Set our default matching rules for each key
        nulls_match = self.nulls_match

        # Interpret matching rules in tags
        if tags:
            for tag in tags:
                if tag == "not-null":
                    nulls_match = False
                if tag == "match-null":
                    nulls_match = True

        # Get value (List) of node using dotted path given by key
        node = self._find_element_by_dotted_path(path, metadata)

        # Check for null matching
        if nulls_match and not node:
            return True

        # Check if SpeciferSet matches target versions
        # TODO: Figure out proper intersection of SpecifierSets
        ospecs: SpecifierSet = SpecifierSet(node)
        ispecs = self.specifiers[key]
        if any(ospecs.contains(ispec, prereleases=True) for ispec in ispecs):
            return True
        # Otherwise, fail
        logger.info(
            f"Failed check for {key}='{ospecs}' against '{ispecs}'"  # noqa: E501
        )
        return False
Пример #25
0
def format_reverse_package(
    graph: DirectedGraph,
    package: Package,
    child: Optional[Package] = None,
    requires: str = "",
    prefix: str = "",
    visited=None,
):
    """Format one package for output reverse dependency graph."""
    if visited is None:
        visited = set()
    result = []
    version = (stream.red("[ not installed ]")
               if not package.version else stream.yellow(package.version))
    if package.name in visited:
        version = stream.red("[circular]")
    requires = (f"[ requires: {stream.red(requires)} ]"
                if requires not in ("Any", "") and child and child.version
                and not SpecifierSet(requires).contains(child.version) else
                "" if not requires else f"[ requires: {requires} ]")
    result.append(
        f"{stream.green(package.name, bold=True)} {version} {requires}\n")
    if package.name in visited:
        return "".join(result)
    visited.add(package.name)
    parents = sorted(filter(None, graph.iter_parents(package)),
                     key=lambda p: p.name)
    for i, parent in enumerate(parents):
        is_last = i == len(parents) - 1
        head = LAST_CHILD if is_last else NON_LAST_CHILD
        cur_prefix = LAST_PREFIX if is_last else NON_LAST_PREFIX
        requires = str(parent.requirements[package.name].specifier or "Any")
        result.append(
            prefix + head +
            format_reverse_package(graph, parent, package, requires, prefix +
                                   cur_prefix, visited.copy()))
    return "".join(result)
Пример #26
0
    def _run_state(self, state: State) -> bool:
        """Check state match."""
        state_prescription = self.match_prescription.get("state")
        if state_prescription:
            for resolved_dependency in state_prescription.get(
                    "resolved_dependencies", []):
                resolved = state.resolved_dependencies.get(
                    resolved_dependency["name"])
                if not resolved:
                    return False

                index_url = resolved_dependency.get("index_url")
                if index_url is not None and resolved[
                        2] != resolved_dependency["index_url"]:
                    return False

                version = resolved_dependency.get("version")
                if version is not None:
                    specifier = SpecifierSet(
                        version)  # XXX: this could be optimized out
                    if resolved[1] not in specifier:
                        return False

        return True
Пример #27
0
 def add_pyversion(self, v: str) -> None:
     pyv = PyVersion.parse(v)
     if pyv in self.details.python_versions:
         log.info("Project already supports %s; not adding", pyv)
         return
     if str(pyv) not in SpecifierSet(self.details.python_requires):
         raise ValueError(f"Version {pyv} does not match python_requires ="
                          f" {self.details.python_requires!r}")
     log.info("Adding %s to supported Python versions", pyv)
     log.info("Updating setup.cfg ...")
     add_line_to_file(
         self.directory / "setup.cfg",
         f"    Programming Language :: Python :: {pyv}\n",
         inserter=AfterLast(
             r"^    Programming Language :: Python :: \d+\.\d+$"),
         encoding="utf-8",
     )
     if self.details.has_tests:
         log.info("Updating tox.ini ...")
         map_lines(
             self.directory / "tox.ini",
             partial(
                 replace_group,
                 re.compile(r"^envlist\s*=[ \t]*(.+)$", flags=re.M),
                 partial(add_py_env, pyv),
             ),
         )
     if self.details.has_ci:
         log.info("Updating .github/workflows/test.yml ...")
         add_line_to_file(
             self.directory / ".github" / "workflows" / "test.yml",
             f"{' ' * 10}- '{pyv}'\n",
             inserter=AfterLast(fr"^{' ' * 10}- ['\x22]?\d+\.\d+['\x22]?$"),
             encoding="utf-8",
         )
     insort(self.details.python_versions, pyv)
Пример #28
0
def test_smmodelparallel_mnist_multigpu(ecr_image, instance_type, py_version,
                                        sagemaker_session, tmpdir):
    """
    Tests pt mnist command via script mode
    """
    instance_type = "ml.p3.16xlarge"
    _, image_framework_version = get_framework_and_version_from_tag(ecr_image)
    image_cuda_version = get_cuda_version_from_tag(ecr_image)
    if not (Version(image_framework_version)
            in SpecifierSet(">=1.6,<1.8")) or image_cuda_version != "cu110":
        pytest.skip(
            "Model Parallelism only supports CUDA 11 on PyTorch 1.6 and PyTorch 1.7"
        )

    with timeout(minutes=DEFAULT_TIMEOUT):
        pytorch = PyTorch(entry_point='smmodelparallel_pt_mnist.sh',
                          role='SageMakerRole',
                          image_uri=ecr_image,
                          source_dir=mnist_path,
                          instance_count=1,
                          instance_type=instance_type,
                          sagemaker_session=sagemaker_session)

        pytorch.fit()
Пример #29
0
    def record_import(self, pkg, name, spec=None):
        """record_import(pkg, name, spec=None) -> None

        Note a bimport() call by a package being imported. This should
        only be called by bimport().
        """

        if type(spec) in [str]:
            spec = SpecifierSet(spec)

        if isinstance(spec, Version):
            isexact = True
        elif (spec is None) or isinstance(spec, SpecifierSet):
            # ok
            isexact = False
        else:
            return

        ls = self.import_recorder.get(pkg.key)
        if not ls:
            ls = []
            self.import_recorder[pkg.key] = ls

        ls.append((name, spec))
Пример #30
0
 def parse(self):
     """
     Parse a Pipfile (as seen in pipenv)
     :return:
     """
     try:
         data = toml.loads(self.obj.content, _dict=OrderedDict)
         if data:
             for package_type in ['packages', 'dev-packages']:
                 if package_type in data:
                     for name, specs in data[package_type].items():
                         # skip on VCS dependencies
                         if not isinstance(specs, str):
                             continue
                         if specs == '*':
                             specs = ''
                         self.obj.dependencies.append(
                             Dependency(name=name,
                                        specs=SpecifierSet(specs),
                                        dependency_type=filetypes.pipfile,
                                        line=''.join([name, specs]),
                                        section=package_type))
     except (toml.TomlDecodeError, IndexError) as e:
         pass