Exemple #1
0
def quickstart(tempdir):
    """
    $ cd tempdir
    $ ballet quickstart
    $ tree .
    """
    # cd tempdir
    with work_in(safepath(tempdir)):

        project_slug = 'foo-bar'
        package_slug = 'foo_bar'
        extra_context = {
            'project_name': project_slug.capitalize(),
            'project_slug': project_slug,
            'package_slug': package_slug,
        }

        # ballet quickstart
        render_project_template(no_input=True,
                                extra_context=extra_context,
                                output_dir=safepath(tempdir))

        # tree .
        tree(tempdir)

        project = Project.from_path(tempdir.joinpath(project_slug))
        repo = project.repo

        yield (namedtuple(
            'Quickstart',
            'project tempdir project_slug package_slug repo')._make(
                (project, tempdir, project_slug, package_slug, repo)))
Exemple #2
0
def project_template_copy(tempdir):
    old_path = ballet.templating.PROJECT_TEMPLATE_PATH
    new_path = tempdir.joinpath('templates', 'project_template')
    shutil.copytree(safepath(old_path), safepath(new_path))

    with patch('ballet.templating.PROJECT_TEMPLATE_PATH', new_path):
        tree(new_path)
        yield new_path
Exemple #3
0
def splitext2(filepath):
    """Split filepath into root, filename, ext

    Args:
        filepath (PathLike): file path

    Returns:
        str
    """
    root, filename = os.path.split(safepath(filepath))
    filename, ext = os.path.splitext(safepath(filename))
    return root, filename, ext
Exemple #4
0
def isemptyfile(filepath):
    """Determine if the file both exists and isempty

    Args:
        filepath (PathLike): file path

    Returns:
        bool
    """
    exists = os.path.exists(safepath(filepath))
    if exists:
        filesize = os.path.getsize(safepath(filepath))
        return filesize == 0
    else:
        return False
 def _create_queue(self):
     features = []
     for user_path in self.feature_path.iterdir():
         if not (user_path.is_dir()
                 and re.search(USER_REGEX, str(user_path))):
             continue
         user_num = int(user_path.parts[-1].split("_")[1])
         logger.debug("COLLECTING FEATURES FROM USER {}".format(user_num))
         user_features = []
         for feature_path in user_path.iterdir():
             if not re.search(FEATURE_REGEX, str(feature_path)):
                 logger.debug("INVALID FEATURE {}".format(
                     safepath(feature_path.parts[-1])))
                 continue
             feature_num = int(
                 feature_path.parts[-1].split("_")[1].split(".")[0])
             if self.start and self.start > feature_num:
                 continue
             elif self.end and self.end < feature_num:
                 continue
             else:
                 user_features.append((user_num, feature_num))
         logger.debug("FOUND {} FEATURES".format(len(user_features)))
         features.append(sorted(user_features, key=lambda f: f[1]))
     self.feature_queue = self.shuffle_feature_queue(features)
     logger.debug("USING QUEUE:")
     logger.debug("\n".join(list(map(str, self.feature_queue))))
Exemple #6
0
    def test_write_tabular_h5_ndarray(self):
        obj = self.array
        with tempfile.TemporaryDirectory() as d:
            filepath = pathlib.Path(d).joinpath('baz.h5')
            ballet.util.io._write_tabular_h5(obj, filepath)

            file_size = os.path.getsize(safepath(filepath))
            self.assertGreater(file_size, 0)
Exemple #7
0
def load_config_at_path(path):
    """Load config at exact path

    Args:
        path (path-like): path to config file

    Returns:
        dict: config dict
    """
    if path.exists() and path.is_file():
        options = DYNACONF_OPTIONS.copy()
        options.update({
            'ROOT_PATH_FOR_DYNACONF': safepath(path.parent),
            'SETTINGS_FILE_FOR_DYNACONF': safepath(path.name),
        })
        return LazySettings(**options)
    else:
        raise ConfigurationError("Couldn't find ballet.yml config file.")
Exemple #8
0
def pwalk(d, **kwargs):
    """Similar to os.walk but with pathlib.Path objects

    Returns:
        Iterable[Path]
    """
    for dirpath, dirnames, filenames in os.walk(safepath(d), **kwargs):
        dirpath = pathlib.Path(dirpath)
        for p in dirnames + filenames:
            yield dirpath.joinpath(p)
Exemple #9
0
def _write_tabular_pickle(obj, filepath):
    _, fn, ext = splitext2(filepath)
    _check_ext(ext, '.pkl')
    if isinstance(obj, np.ndarray):
        with open(safepath(filepath), 'wb') as f:
            pickle.dump(obj, f)
    elif isinstance(obj, pd.core.frame.NDFrame):
        obj.to_pickle(filepath)
    else:
        raise NotImplementedError
Exemple #10
0
def test_quickstart_install(quickstart, virtualenv):
    # This is an annoying test because the project template should specify
    # as installation dependency the most recent tagged version of ballet,
    # but this will not necessarily have been released on PyPI
    # TODO: figure out the right way to mange this
    d = quickstart.tempdir.joinpath(quickstart.project_slug).absolute()
    with work_in(safepath(d)):
        # cmd = 'cd "{d!s}" && make install'.format(d=d)
        cmd = 'echo okay'
        virtualenv.run(cmd, capture=True)
Exemple #11
0
def _render_project_template(cwd, tempdir, project_template_path=None):
    tempdir = pathlib.Path(tempdir)
    context = _get_full_context(cwd)

    # don't dump replay files to home directory.
    with patch('cookiecutter.main.dump'):
        return render_project_template(
            project_template_path=project_template_path,
            no_input=True,
            extra_context=context,
            output_dir=safepath(tempdir))
Exemple #12
0
def spliceext(filepath, s):
    """Add s into filepath before the extension

    Args:
        filepath (PathLike): file path
        s (str): string to splice

    Returns:
        str
    """
    root, ext = os.path.splitext(safepath(filepath))
    return root + s + ext
Exemple #13
0
def _synctree(src, dst, onexist):
    result = []
    cleanup = []
    try:
        for root, dirnames, filenames in os.walk(safepath(src)):
            root = pathlib.Path(root)
            relative_dir = root.relative_to(src)

            for dirname in dirnames:
                dstdir = dst.joinpath(relative_dir, dirname)
                if dstdir.exists():
                    if not dstdir.is_dir():
                        raise BalletError
                else:
                    logger.debug(
                        'Making directory: {dstdir!s}'.format(dstdir=dstdir))
                    dstdir.mkdir()
                    result.append((dstdir, 'dir'))
                    cleanup.append(partial(os.rmdir, safepath(dstdir)))

            for filename in filenames:
                srcfile = root.joinpath(filename)
                dstfile = dst.joinpath(relative_dir, filename)
                if dstfile.exists():
                    onexist(dstfile)
                else:
                    logger.debug(
                        'Copying file to destination: {dstfile!s}'
                        .format(dstfile=dstfile))
                    copyfile(srcfile, dstfile)
                    result.append((dstfile, 'file'))
                    cleanup.append(partial(os.unlink, safepath(dstfile)))

    except Exception:
        with suppress(Exception):
            for f in reversed(cleanup):
                f()
        raise

    return result
    def call_validate_all(pr=None):
        envvars = {
            'TRAVIS_BUILD_DIR': repo.working_tree_dir,
        }
        if pr is None:
            envvars['TRAVIS_PULL_REQUEST'] = 'false'
            envvars['TRAVIS_COMMIT_RANGE'] = make_commit_range(
                repo.commit('HEAD@{-1}').hexsha, repo.commit('HEAD').hexsha)
            envvars['TRAVIS_PULL_REQUEST_BRANCH'] = ''
            envvars['TRAVIS_BRANCH'] = repo.heads.master.name
        else:
            envvars['TRAVIS_PULL_REQUEST'] = str(pr)
            envvars['TRAVIS_COMMIT_RANGE'] = make_commit_range(
                repo.heads.master.name,
                repo.commit('pull/{pr}'.format(pr=pr)).hexsha)

        with patch.dict(os.environ, envvars):
            cmd = 'ballet validate -A'
            check_call(cmd, cwd=safepath(base), env=os.environ)
Exemple #15
0
def get_repo(repo=None):
    if repo is None:
        repo = git.Repo(safepath(pathlib.Path.cwd()),
                        search_parent_directories=True)
    return repo
Exemple #16
0
def _read_tabular_pickle(filepath):
    _, fn, ext = splitext2(filepath)
    _check_ext(ext, '.pkl')
    with open(safepath(filepath), 'rb') as f:
        return pickle.load(f)
Exemple #17
0
 def load(self, filepath):
     if not os.path.exists(safepath(filepath)):
         raise ValueError("Couldn't find model at {}".format(filepath))
     self.estimator = joblib.load(filepath)
Exemple #18
0
def update_project_template(push=False, project_template_path=None):
    """Update project with updates to upstream project template

    The update is fairly complicated and proceeds as follows:

    1. Load project: user must run command from master branch and ballet
       must be able to detect the project-template branch
    2. Load the saved cookiecutter context from disk
    3. Render the project template into a temporary directory using the
       saved context, *prompting the user if new keys are required*. Note
       that the project template is simply loaded from the data files of the
       installed version of ballet. Note further that by the project
       template's post_gen_hook, a new git repo is initialized [in the
       temporary directory] and files are committed.
    4. Add the temporary directory as a remote and merge it into the
       project-template branch, favoring changes made to the upstream template.
       Any failure to merge results in an unrecoverable error.
    5. Merge the project-template branch into the master branch. The user is
       responsible for merging conflicts and they are given instructions to
       do so and recover.
    6. If applicable, push to master.

    Args:
        push (bool): whether to push updates to remote, defaults to False
        project_template_path (PathLike): an override for the path to the
            project template
    """
    cwd = pathlib.Path.cwd().resolve()

    # get ballet project info -- must be at project root directory with a
    # ballet.yml file.
    try:
        project = Project.from_path(cwd)
    except ConfigurationError:
        raise ConfigurationError('Must run command from project root.')

    repo = project.repo
    original_head = repo.head.commit.hexsha[:7]

    if repo.is_dirty():
        raise BalletError(
            'Can\'t update project template with uncommitted changes. '
            'Please commit your changes and try again.')

    if repo.head.ref.name != DEFAULT_BRANCH:
        raise ConfigurationError(
            'Must run command from branch {master}'.format(
                master=DEFAULT_BRANCH))

    if TEMPLATE_BRANCH not in repo.branches:
        raise ConfigurationError(
            'Could not find \'{}\' branch.'.format(TEMPLATE_BRANCH))

    # check for upstream updates to ballet
    new_version = _check_for_updated_ballet()
    if new_version:
        _warn_of_updated_ballet(new_version)

    with tempfile.TemporaryDirectory() as tempdir:
        tempdir = pathlib.Path(tempdir)

        # cookiecutter returns path to the resulting project dir
        logger.debug('Re-rendering project template at {}'.format(tempdir))
        updated_template = _render_project_template(
            cwd, tempdir, project_template_path=project_template_path)
        updated_repo = git.Repo(safepath(updated_template))

        # tempdir is a randomly-named dir suitable for a random remote name
        # to avoid conflicts
        remote_name = tempdir.name

        remote = repo.create_remote(remote_name, updated_repo.working_tree_dir)
        remote.fetch()

        repo.heads[TEMPLATE_BRANCH].checkout()
        try:
            logger.debug('Merging re-rendered template to project-template '
                         'branch')
            repo.git.merge(
                remote_name + '/' + DEFAULT_BRANCH,
                allow_unrelated_histories=True,
                strategy_option='theirs',
                squash=True,
            )
            if not repo.is_dirty():
                logger.info('No updates to template -- done.')
                return
            commit_message = _make_template_branch_merge_commit_message()
            logger.debug('Committing updates: {}'.format(commit_message))
            repo.git.commit(m=commit_message)
        except GitCommandError:
            logger.critical(
                'Could not merge changes into {template_branch} branch, '
                'update failed'.format(template_branch=TEMPLATE_BRANCH))
            raise
        finally:
            _safe_delete_remote(repo, remote_name)
            logger.debug('Checking out master branch')
            repo.heads[DEFAULT_BRANCH].checkout()

    try:
        logger.debug('Merging project-template branch into master')
        repo.git.merge(TEMPLATE_BRANCH, no_ff=True)
    except GitCommandError as e:
        if 'merge conflict' in str(e).lower():
            logger.critical('\n'.join([
                'Update failed due to a merge conflict.',
                'Fix conflicts, and then complete merge manually:',
                '    $ git add .', '    $ git commit --no-edit',
                'Otherwise, abandon the update:',
                '    $ git reset --merge {original_head}'
            ]).format(original_head=original_head))
        raise

    if push:
        _push(project)

    _log_recommended_reinstall()
def _run_ballet_update_template(d, project_slug, **kwargs):
    with work_in(safepath(d.joinpath(project_slug))):
        ballet.update.update_project_template(**kwargs)