示例#1
0
文件: common.py 项目: ballet/ballet
def get_accepted_features(features: Collection[Feature],
                          proposed_feature: Feature) -> List[Feature]:
    """Deselect candidate features from list of all features

    Args:
        features: collection of all features in the ballet project: both
            accepted features and candidate ones that have not been accepted
        proposed_feature: candidate feature that has not been accepted

    Returns:
        list of features with the proposed feature not in it.

    Raises:
        ballet.exc.BalletError: Could not deselect exactly the proposed
            feature.
    """
    def eq(feature):
        """Features are equal if they have the same source

        At least in this implementation...
        """
        return feature.source == proposed_feature.source

    # deselect features that match the proposed feature
    result = lfilter(complement(eq), features)

    if len(features) - len(result) == 1:
        return result
    elif len(result) == len(features):
        raise BalletError(
            'Did not find match for proposed feature within \'contrib\'')
    else:
        raise BalletError(f'Unexpected condition (n_features={len(features)}, '
                          f'n_result={len(result)})')
示例#2
0
    def name(self):
        if self._name is None:
            if self._scorer is not None:
                # try from scorer
                if isinstance(self._scorer,
                              sklearn.metrics.scorer._BaseScorer):
                    scorers = sklearn.metrics.scorer.SCORERS
                    matches = select_values(
                        lambda x: x == self._scorer, scorers)
                    matches = list(matches.keys())
                    if len(matches) == 1:
                        self._name = matches[0]
                    elif len(matches) > 1:
                        # unexpected
                        logger.debug(
                            'Unexpectedly found multiple matches for scorer '
                            'name {name}: {matches!r}'
                            .format(name=self._name, matches=matches))
                    else:
                        # must be a custom scorer, try to get name
                        if hasattr(self._scorer, '__name__'):
                            self._name = self._scorer.__name__
            elif self._description is not None:
                # try from description
                mapper = flip(SCORING_NAME_MAPPER)
                if self._description in mapper:
                    self._name = mapper[self._description]
                else:
                    # default formatting
                    self._name = '_'.join(self._description.lower().split(' '))

        if self._name is not None:
            return self._name
        else:
            raise BalletError('Could not get name from scorer')
示例#3
0
def code_to_module(code: str, modname='modname') -> ModuleType:
    # see https://stackoverflow.com/a/53080237
    spec = importlib.util.spec_from_loader(modname, loader=None)
    if spec is None:
        raise BalletError('Error compiling code into module')
    module = importlib.util.module_from_spec(spec)
    exec(code, module.__dict__)
    return module
示例#4
0
def push_changes(repo, user, feature, dry_run=False):
    branch_name = _make_branch_name(user, feature)
    origin = repo.remote('origin')
    refspec = '{branch}:{branch}'.format(branch=branch_name)
    if not dry_run:
        result = origin.push(refspec)
        push_info = one_or_raise(result)
        if not did_git_push_succeed(push_info):
            raise BalletError('Git push failed, '
                              'maybe you need to delete a branch on remote?')
    else:
        print('[dry run] would execute \'origin.push({refspec!r})\''.format(
            refspec=refspec))
示例#5
0
def needs_credentials(call):
    try:
        return call()
    except Exception as e:
        if 'NoCredentialsError' in e.__class__.__name__ \
                or 'Forbidden' in str(e):
            logger.debug(f'Caught {e} when trying to load data')
            msg = (
                'An error occurred trying to access fragile-families data. '
                'Access to this dataset is restricted. Make sure you follow '
                'the steps at '
                'https://github.com/HDI-Project/ballet-fragile-families#data-access'
                ' to get access, or ask in the project chat if you are having '
                'any trouble.')
            raise BalletError(msg) from None
        else:
            raise
示例#6
0
def push_branches_to_remote(repo: git.Repo, remote_name: str,
                            branches: Iterable[str]):
    """Push selected branches to origin

    Similar to::

        $ git push origin branch1:branch1 branch2:branch2

    Raises:
        ballet.exc.BalletError: Push failed in some way
    """
    remote = repo.remote(remote_name)
    result = remote.push([f'{b}:{b}' for b in branches])
    failures = lfilter(complement(did_git_push_succeed), result)
    if failures:
        for push_info in failures:
            logger.error(f'Failed to push ref {push_info.local_ref.name} to '
                         f'{push_info.remote_ref.name}')
        raise BalletError('Push failed')
示例#7
0
        def _transform(x_group, *args, **kwargs):
            # If the group is not a DataFrame, there are two problems
            # 1. We can't rely on group.name to lookup the right transformer
            # 2. We can't "reassemble" the transformed
            # However, the contract of ``pandas.core.groupby.GroupBy.apply`` is
            # that the input is a DataFrame, so this should never occur.
            if not isinstance(x_group, pd.DataFrame):
                raise NotImplementedError

            group_name = x_group.name

            if self.column_selection is not None:
                x_group = x_group[self.column_selection]

            if group_name in self.transformers_:
                transformer = self.transformers_[group_name]
                try:
                    data = transformer.transform(x_group, *args, **kwargs)

                    # This post-processing step is required because sklearn
                    # transform converts a DataFrame to an array. This is my
                    # best attempt so far to approximate the following:
                    # >>> result = x_group.copy()
                    # >>> result.values = data
                    # which is an error as `values` cannot be set.
                    index = x_group.index
                    columns = x_group.columns
                    return pd.DataFrame(
                        data=data, index=index, columns=columns)
                except Exception:
                    if self.handle_error == 'ignore':
                        return x_group
                    else:
                        raise
            else:
                if self.handle_unknown == 'error':
                    raise BalletError(f'Unknown group: {group_name}')
                elif self.handle_unknown == 'ignore':
                    return x_group
                else:
                    # Unreachable code
                    raise RuntimeError
示例#8
0
def _push(project):
    """Push default branch and project template branch to remote

    With default config (i.e. remote and branch names), equivalent to::

        $ git push origin master:master project-template:project-template

    Raises:
        ballet.exc.BalletError: Push failed in some way
    """
    repo = project.repo
    remote_name = project.config.get('github.remote')
    remote = repo.remote(remote_name)
    result = _call_remote_push(remote)
    failures = lfilter(complement(did_git_push_succeed), result)
    if failures:
        for push_info in failures:
            logger.error('Failed to push ref {from_ref} to {to_ref}'.format(
                from_ref=push_info.local_ref.name,
                to_ref=push_info.remote_ref.name))
        raise BalletError('Push failed')
示例#9
0
文件: update.py 项目: ballet/ballet
def update_project_template(push: bool = False,
                            project_template_path: Optional[Pathy] = None):
    """Update project with updates to upstream project template

    The update is fairly complicated and proceeds as follows:

    1. Load project: user must run command from master branch and ballet
       must be able to detect the project-template branch
    2. Load the saved cookiecutter context from disk
    3. Render the project template into a temporary directory using the
       saved context, *prompting the user if new keys are required*. Note
       that the project template is simply loaded from the data files of the
       installed version of ballet. Note further that by the project
       template's post_gen_hook, a new git repo is initialized [in the
       temporary directory] and files are committed.
    4. Add the temporary directory as a remote and merge it into the
       project-template branch, favoring changes made to the upstream template.
       Any failure to merge results in an unrecoverable error.
    5. Merge the project-template branch into the master branch. The user is
       responsible for merging conflicts and they are given instructions to
       do so and recover.
    6. If applicable, push to master.

    Args:
        push: whether to push updates to remote, defaults to False
        project_template_path: an override for the path to the
            project template
    """
    cwd = pathlib.Path.cwd().resolve()

    # get ballet project info -- must be at project root directory with a
    # ballet.yml file.
    try:
        project = Project.from_path(cwd)
    except ConfigurationError:
        raise ConfigurationError('Must run command from project root.')

    repo = project.repo
    original_head = repo.head.commit.hexsha[:7]

    if repo.is_dirty():
        raise BalletError(
            'Can\'t update project template with uncommitted changes. '
            'Please commit your changes and try again.')

    if repo.head.ref.name != DEFAULT_BRANCH:
        raise ConfigurationError(
            f'Must run command from branch {DEFAULT_BRANCH}')

    if TEMPLATE_BRANCH not in repo.branches:
        raise ConfigurationError(
            f'Could not find \'{TEMPLATE_BRANCH}\' branch.')

    # check for upstream updates to ballet
    new_version = _check_for_updated_ballet()
    if new_version:
        _warn_of_updated_ballet(new_version)

    with tempfile.TemporaryDirectory() as _tempdir:
        tempdir = pathlib.Path(_tempdir)

        # cookiecutter returns path to the resulting project dir
        logger.debug(f'Re-rendering project template at {tempdir}')
        updated_template = _render_project_template(
            cwd, tempdir, project_template_path=project_template_path)
        updated_repo = git.Repo(updated_template)

        # tempdir is a randomly-named dir suitable for a random remote name
        # to avoid conflicts
        remote_name = tempdir.name

        remote = repo.create_remote(
            remote_name, updated_repo.working_tree_dir)
        remote.fetch()

        repo.heads[TEMPLATE_BRANCH].checkout()
        try:
            logger.debug('Merging re-rendered template to project-template '
                         'branch')
            repo.git.merge(
                remote_name + '/' + DEFAULT_BRANCH,
                allow_unrelated_histories=True,
                strategy_option='theirs',
                squash=True,
            )
            if not repo.is_dirty():
                logger.info('No updates to template -- done.')
                return
            commit_message = _make_template_branch_merge_commit_message()
            logger.debug(f'Committing updates: {commit_message}')
            repo.git.commit(m=commit_message)
        except GitCommandError:
            logger.critical(
                f'Could not merge changes into {TEMPLATE_BRANCH} branch, '
                f'update failed')
            raise
        finally:
            _safe_delete_remote(repo, remote_name)
            logger.debug('Checking out master branch')
            repo.heads[DEFAULT_BRANCH].checkout()

    try:
        logger.debug('Merging project-template branch into master')
        repo.git.merge(TEMPLATE_BRANCH, no_ff=True)
    except GitCommandError as e:
        if 'merge conflict' in str(e).lower():
            logger.critical(dedent(
                f'''
                Update failed due to a merge conflict.
                Fix conflicts, and then complete merge manually:
                    $ git add .
                    $ git commit --no-edit
                Otherwise, abandon the update:
                    $ git reset --merge {original_head}
                '''
            ).strip())
        raise

    if push:
        repo = project.repo
        remote_name = project.config.get('github.remote')
        branches = [DEFAULT_BRANCH, TEMPLATE_BRANCH]
        push_branches_to_remote(repo, remote_name, branches)

    _log_recommended_reinstall()