Exemplo n.º 1
0
    def _load_conda_full_env(self, conda_env_dict, requirements):
        # noinspection PyBroadException
        try:
            cuda_version = int(self.session.config.get('agent.cuda_version',
                                                       0))
        except Exception:
            cuda_version = 0

        conda_env_dict['channels'] = self.extra_channels
        if 'dependencies' not in conda_env_dict:
            conda_env_dict['dependencies'] = []
        new_dependencies = OrderedDict()
        pip_requirements = None
        for line in conda_env_dict['dependencies']:
            if isinstance(line, dict):
                pip_requirements = line.pop('pip', None)
                continue
            name = line.strip().split('=', 1)[0].lower()
            if name == 'pip':
                continue
            elif name == 'python':
                line = 'python={}'.format('.'.join(
                    line.split('=')[1].split('.')[:2]))
            elif name == 'tensorflow-gpu' and cuda_version == 0:
                line = 'tensorflow={}'.format(line.split('=')[1])
            elif name == 'tensorflow' and cuda_version > 0:
                line = 'tensorflow-gpu={}'.format(line.split('=')[1])
            elif name in ('cupti', 'cudnn'):
                # cudatoolkit should pull them based on the cudatoolkit version
                continue
            elif name.startswith('_'):
                continue
            new_dependencies[line.split('=', 1)[0].strip()] = line

        # fix packages:
        conda_env_dict['dependencies'] = list(new_dependencies.values())

        with self.temp_file("conda_env",
                            yaml.dump(conda_env_dict),
                            suffix=".yml") as name:
            print('Conda: Trying to install requirements:\n{}'.format(
                conda_env_dict['dependencies']))
            result = self._run_command(
                ("env", "update", "-p", self.path, "--file", name))

        # check if we need to remove specific packages
        bad_req = self._parse_conda_result_bad_packges(result)
        if bad_req:
            print('failed installing the following conda packages: {}'.format(
                bad_req))
            return False

        if pip_requirements:
            # create a list of vcs packages that we need to replace in the pip section
            vcs_reqs = {}
            if 'pip' in requirements:
                pip_lines = requirements['pip'].splitlines() \
                    if isinstance(requirements['pip'], six.string_types) else requirements['pip']
                for line in pip_lines:
                    try:
                        marker = list(parse(line))
                    except Exception:
                        marker = None
                    if not marker:
                        continue

                    m = MarkerRequirement(marker[0])
                    if m.vcs:
                        vcs_reqs[m.name] = m
            try:
                pip_req_str = [
                    str(vcs_reqs.get(r.split('=', 1)[0], r))
                    for r in pip_requirements if not r.startswith('pip=')
                    and not r.startswith('virtualenv=')
                ]
                print(
                    'Conda: Installing requirements: step 2 - using pip:\n{}'.
                    format(pip_req_str))
                PackageManager._selected_manager = self.pip
                self.pip.load_requirements({'pip': '\n'.join(pip_req_str)})
            except Exception as e:
                print(e)
                raise e
            finally:
                PackageManager._selected_manager = self

        self.requirements_manager.post_install(self.session)
Exemplo n.º 2
0
    def load_requirements(self, requirements):
        # if we are in read only mode, do not uninstall anything
        if self.env_read_only:
            print(
                'Conda environment in read-only mode, skipping requirements installation.'
            )
            return None

        # if we have a full conda environment, use it and pass the pip to pip
        if requirements.get('conda_env_json'):
            # noinspection PyBroadException
            try:
                conda_env_json = json.loads(requirements.get('conda_env_json'))
                print('Conda restoring full yaml environment')
                return self._load_conda_full_env(conda_env_json, requirements)
            except Exception:
                print(
                    'Could not load fully stored conda environment, falling back to requirements'
                )

        # create new environment file
        conda_env = dict()
        conda_env['channels'] = self.extra_channels
        reqs = []
        if isinstance(requirements['pip'], six.string_types):
            requirements['pip'] = requirements['pip'].split('\n')
        if isinstance(requirements.get('conda'), six.string_types):
            requirements['conda'] = requirements['conda'].split('\n')
        has_torch = False
        has_matplotlib = False
        try:
            cuda_version = int(self.session.config.get('agent.cuda_version',
                                                       0))
        except:
            cuda_version = 0

        # notice 'conda' entry with empty string is a valid conda requirements list, it means pip only
        # this should happen if experiment was executed on non-conda machine or old trains client
        conda_supported_req = requirements['pip'] if requirements.get(
            'conda', None) is None else requirements['conda']
        conda_supported_req_names = []
        pip_requirements = []
        for r in conda_supported_req:
            try:
                marker = list(parse(r))
            except:
                marker = None
            if not marker:
                continue

            m = MarkerRequirement(marker[0])
            # conda does not support version control links
            if m.vcs:
                pip_requirements.append(m)
                continue
            # Skip over pip
            if m.name in (
                    'pip',
                    'virtualenv',
            ):
                continue
            # python version, only major.minor
            if m.name == 'python' and m.specs:
                m.specs = [
                    (m.specs[0][0], '.'.join(m.specs[0][1].split('.')[:2])),
                ]
                if '.' not in m.specs[0][1]:
                    continue

            conda_supported_req_names.append(m.name.lower())
            if m.req.name.lower() == 'matplotlib':
                has_matplotlib = True
            elif m.req.name.lower().startswith('torch'):
                has_torch = True

            if m.req.name.lower() in ('torch', 'pytorch'):
                has_torch = True
                m.req.name = 'pytorch'

            if m.req.name.lower() in ('tensorflow_gpu', 'tensorflow-gpu',
                                      'tensorflow'):
                has_torch = True
                m.req.name = 'tensorflow-gpu' if cuda_version > 0 else 'tensorflow'

            reqs.append(m)

        # if we have a conda list, the rest should be installed with pip,
        # this means  any experiment that was executed with pip environment,
        # will be installed using pip
        if requirements.get('conda', None) is not None:
            for r in requirements['pip']:
                try:
                    marker = list(parse(r))
                except:
                    marker = None
                if not marker:
                    continue

                m = MarkerRequirement(marker[0])
                # skip over local files (we cannot change the version to a local file)
                if m.local_file:
                    continue
                m_name = (m.name or '').lower()
                if m_name in conda_supported_req_names:
                    # this package is in the conda list,
                    # make sure that if we changed version and we match it in conda
                    ## conda_supported_req_names.remove(m_name)
                    for cr in reqs:
                        if m_name.lower().replace(
                                '_', '-') == cr.name.lower().replace('_', '-'):
                            # match versions
                            cr.specs = m.specs
                            # # conda always likes "-" not "_" but only on pypi packages
                            # cr.name = cr.name.lower().replace('_', '-')
                            break
                else:
                    # not in conda, it is a pip package
                    pip_requirements.append(m)
                    if m_name == 'matplotlib':
                        has_matplotlib = True

        # Conda requirements Hacks:
        if has_matplotlib:
            reqs.append(MarkerRequirement(Requirement.parse('graphviz')))
            reqs.append(MarkerRequirement(
                Requirement.parse('python-graphviz')))
            reqs.append(MarkerRequirement(Requirement.parse('kiwisolver')))

        # remove specific cudatoolkit, it should have being preinstalled.
        # allow to override default cudatoolkit, but not the derivative packages, cudatoolkit should pull them
        reqs = [r for r in reqs if r.name not in ('cudnn', 'cupti')]

        if has_torch and cuda_version == 0:
            reqs.append(MarkerRequirement(Requirement.parse('cpuonly')))

        # make sure we have no double entries
        reqs = list(OrderedDict((r.name, r) for r in reqs).values())

        # conform conda packages (version/name)
        for r in reqs:
            # change _ to - in name but not the prefix _ (as this is conda prefix)
            if r.name and not r.name.startswith('_') and not requirements.get(
                    'conda', None):
                r.name = r.name.replace('_', '-')
            # remove .post from version numbers, it fails ~= version, and change == to ~=
            if r.specs and r.specs[0]:
                r.specs = [(r.specs[0][0].replace('==', '~='),
                            r.specs[0][1].split('.post')[0])]

        while reqs:
            # notice, we give conda more freedom in version selection, to help it choose best combination
            def clean_ver(ar):
                if not ar.specs:
                    return ar.tostr()
                ar.specs = [
                    (ar.specs[0][0], ar.specs[0][1] +
                     '.0' if '.' not in ar.specs[0][1] else ar.specs[0][1])
                ]
                return ar.tostr()

            conda_env['dependencies'] = [clean_ver(r) for r in reqs]
            with self.temp_file("conda_env",
                                yaml.dump(conda_env),
                                suffix=".yml") as name:
                print('Conda: Trying to install requirements:\n{}'.format(
                    conda_env['dependencies']))
                result = self._run_command(
                    ("env", "update", "-p", self.path, "--file", name))
            # check if we need to remove specific packages
            bad_req = self._parse_conda_result_bad_packges(result)
            if not bad_req:
                break

            solved = False
            for bad_r in bad_req:
                name = bad_r.split('[')[0].split('=')[0].split('~')[0].split(
                    '<')[0].split('>')[0]
                # look for name in requirements
                for r in reqs:
                    if r.name.lower() == name.lower():
                        pip_requirements.append(r)
                        reqs.remove(r)
                        solved = True
                        break

            # we couldn't remove even one package,
            # nothing we can do but try pip
            if not solved:
                pip_requirements.extend(reqs)
                break

        if pip_requirements:
            try:
                pip_req_str = [
                    r.tostr() for r in pip_requirements if r.name not in (
                        'pip',
                        'virtualenv',
                    )
                ]
                print(
                    'Conda: Installing requirements: step 2 - using pip:\n{}'.
                    format(pip_req_str))
                PackageManager._selected_manager = self.pip
                self.pip.load_requirements({'pip': '\n'.join(pip_req_str)})
            except Exception as e:
                print(e)
                raise e
            finally:
                PackageManager._selected_manager = self

        self.requirements_manager.post_install(self.session)
        return True
Exemplo n.º 3
0
 def safe_parse(req_str):
     try:
         return next(parse(req_str))
     except Exception as ex:
         return Requirement(req_str)