コード例 #1
0
    def run(self, upstream_repos: Dict[str, str] = None, splitfile_commands: str = None, output: Workspace = None, **kwargs: Any):
        """

        Args:

        Returns:
            - No return
        """
        repo_infos = dict((name, parse_repo(uri)) for (name, uri) in upstream_repos.items())
        v1_sgr_repo_uris = dict((name, repo_info.v1_sgr_uri()) for (name, repo_info) in repo_infos.items())
 

        formatting_kwargs = {
            **v1_sgr_repo_uris,
            **kwargs,
            **prefect.context.get("parameters", {}).copy(),
            **prefect.context,
        }


        repo_info = parse_repo(output['repo_uri'])
        repo = Repository(namespace=repo_info.namespace, repository=repo_info.repository)
     
        execute_commands(
            splitfile_commands, 
            params=formatting_kwargs, 
            output=repo, 
            # output_base=output['image_hash'],
        )
コード例 #2
0
    def exists(self, location: str, **kwargs: Any) -> bool:
        """
        Checks whether the target result exists in the file system.

        Does not validate whether the result is `valid`, only that it is present.

        Args:
            - location (str): Location of the result in the specific result target.
                Will check whether the provided location exists
            - **kwargs (Any): string format arguments for `location`

        Returns:
            - bool: whether or not the target result exists
        """

        try:
            repo_info = parse_repo(location)
            repo = Repository(namespace=repo_info.namespace,
                              repository=repo_info.repository)
            remote = Repository.from_template(repo,
                                              engine=get_engine(
                                                  repo_info.remote_name,
                                                  autocommit=True))

            table_exists_at(remote, repo_info.table)
            return self.client.get_object(Bucket=self.bucket,
                                          Key=location.format(**kwargs))

        except Exception as exc:
            self.logger.exception(
                "Unexpected error while reading from Splitgraph: {}".format(
                    repr(exc)))
            raise
コード例 #3
0
    def run(self,
            repo_uris: Dict[str, str] = None,
            retain: int = None,
            **kwargs: Any) -> Union[Version, None]:
        """

        Args:

        Returns:

        """

        repo_infos = dict(
            (name, parse_repo(uri)) for (name, uri) in repo_uris.items())
        repos = dict((name,
                      Repository(namespace=repo_info.namespace,
                                 repository=repo_info.repository))
                     for (name, repo_info) in repo_infos.items())

        repos_to_prune = dict(
            (name, (repos[name] if not repo_info.remote_name else Repository.
                    from_template(repos[name],
                                  engine=get_engine(repo_info.remote_name))))
            for (name, repo_info) in repo_infos.items())

        for name, repo_info in repo_infos.items():
            repo = repos_to_prune[name]
            prerelease = repo_info.prerelease
            image_tags = repo.get_all_hashes_tags()

            tag_dict = dict((tag, image_hash)
                            for (image_hash, tag) in image_tags
                            if image_hash)  #reverse keys

            version_list = [
                parse_tag(tag)
                for tag in sorted(list(tag_dict.keys()), key=len, reverse=True)
            ]

            valid_versions = [version for version in version_list if version]
            non_prerelease_versions = [
                version for version in valid_versions
                if len(version.prerelease) == 0
            ]
            prerelease_versions = [
                version for version in valid_versions
                if prerelease and len(version.prerelease) > 0
                and version.prerelease[0] == prerelease
            ]
            prune_candidates = prerelease_versions if prerelease else non_prerelease_versions

            total_candidates = len(prune_candidates)
            prune_count = total_candidates - retain
            prune_list = sorted(prune_candidates)[:prune_count]

            for version in prune_list:
                tag = str(version)
                image_hash = tag_dict[tag]
                image = repo.images[image_hash]
                image.delete_tag(tag)
コード例 #4
0
    def run(self,
            workspaces: Dict[str, Workspace] = None,
            sgr_tags: Dict[str, List[str]] = None,
            **kwargs: Any):
        """

        Args:

        Returns:

        """

        repo_infos = dict((name, parse_repo(workspace['repo_uri']))
                          for (name, workspace) in workspaces.items())
        repos = dict((name,
                      Repository(namespace=repo_info.namespace,
                                 repository=repo_info.repository))
                     for (name, repo_info) in repo_infos.items())
        repos_with_new_images = dict(
            (name, repo) for (name, repo) in repos.items() if repo.head
            and repo.head.image_hash != workspaces[name]['image_hash'])

        for name, repo in repos_with_new_images.items():
            repo_tags = sgr_tags[name] if sgr_tags and name in sgr_tags else []
            for tag in repo_tags:
                repo.head.tag(tag)

        # Push all repos. We don't know for sure that it shouldn't be pushed
        for name, repo in repos.items():
            remote_name = repo_infos[name].remote_name
            if not remote_name:
                self.logger.warn(
                    f'No remote_name specified. Not pushing {name}.')
                continue

            remote = Repository.from_template(repo,
                                              engine=get_engine(remote_name))

            repo.push(
                remote,
                handler="S3",
                handler_options={"threads": 8},
                overwrite_objects=True,
                overwrite_tags=True,
            )
            self.logger.info(f'Pushed {name} to {remote_name}')

        tagged_repo_uris = dict(
            (name, workspaces[name]['repo_uri'])
            for (name, repo) in repos_with_new_images.items())
        return tagged_repo_uris
コード例 #5
0
    def workspace_tags(self, workspace: Workspace, build: Tuple[str],
                       **kwargs: Any) -> List[str]:
        formatting_kwargs = {
            **kwargs,
            **prefect.context.get('parameters', {}).copy(),
            **prefect.context,
        }

        base_ref = workspace['version']
        repo_info = parse_repo(workspace['repo_uri'])

        next_version = base_ref.next_patch() if base_ref else Version(
            f'{repo_info.major}.{repo_info.minor}.0' if repo_info.
            minor else f'{repo_info.major}.0.0')
        is_prerelease = base_ref and len(base_ref.prerelease) >= 2
        if is_prerelease:
            prerelease, prerelease_count = base_ref.prerelease
            prerelease_count = int(prerelease_count)
            prerelease_count += 1
            next_version = Version(major=base_ref.major,
                                   minor=base_ref.minor,
                                   patch=base_ref.patch,
                                   prerelease=(prerelease,
                                               str(prerelease_count)))

        next_version.build = [
            meta.format(**formatting_kwargs) for meta in build
        ]

        if is_prerelease:
            prerelease, prerelease_count = next_version.prerelease
            return [
                f'{next_version.major}-{prerelease}',
                f'{next_version.major}.{next_version.minor}-{prerelease}',
                str(next_version),
            ]

        return [
            f'{next_version.major}',
            f'{next_version.major}.{next_version.minor}',
            str(next_version),
        ]
コード例 #6
0
    def run(self,
            workspaces: Dict[str, Workspace] = None,
            comment: str = None,
            **kwargs: Any):
        """

        Args:

        Returns:

        """

        self.logger.info(f'Commit will eval: {workspaces}')

        engine = get_engine()
        repo_infos = dict((name, parse_repo(workspace['repo_uri']))
                          for (name, workspace) in workspaces.items())
        repos = dict((name,
                      Repository(namespace=repo_info.namespace,
                                 repository=repo_info.repository))
                     for (name, repo_info) in repo_infos.items())

        repos_with_changes = dict()
        for name, repo in repos.items():
            old_image_hash = workspaces[name]['image_hash']
            new_image = repo.commit(comment=comment,
                                    chunk_size=self.chunk_size)

            unchanged = self.image_contents_equal(repo.images[old_image_hash],
                                                  new_image)
            if unchanged:
                repo.images.delete([new_image.image_hash])
            else:
                repos_with_changes[name] = repo
                self.logger.info(f'Commit complete: {name}')

        self.logger.info(f'Commit now done')
        committed_repo_uris = dict(
            (name, workspaces[name]['repo_uri'])
            for (name, repo) in repos_with_changes.items())

        return committed_repo_uris
コード例 #7
0
    def run(self, repo_uri: str = None, query: str = None, schema: Schema = None, layer_query: bool = None, **kwargs: Any):
        """  

        Args:

        Returns:
            - No return
        """
        assert repo_uri, 'Must specify repo_uri.'
        repo_info = parse_repo(repo_uri)

        repo = Repository(namespace=repo_info.namespace, repository=repo_info.repository)
        data = sql_to_df(query, repository=repo, use_lq=layer_query)        

        if schema is not None:
            errors = schema.validate(data)
            if errors:
                raise SchemaValidationError(errors)
        
        return data
コード例 #8
0
    def run(self,
            upstream_repos: Dict[str, str] = None,
            **kwargs: Any) -> Dict[str, Workspace]:
        """

        Args:

        Returns:
            - 
        """
        repo_infos = dict(
            (name, parse_repo(uri)) for (name, uri) in upstream_repos.items())

        repos = dict((name, self.init_repo(repo_info))
                     for (name, repo_info) in repo_infos.items())
        workspaces = dict(
            (name, self.checkout_workspace(repo, repo_infos[name]))
            for (name, repo) in repos.items())

        return workspaces
コード例 #9
0
    def run(self,
            request: DataFrameToTableRequest,
            repo_uri: str = None,
            **kwargs: Any):
        """

        Args:

        Returns:

        """
        assert repo_uri, 'Must specify repo_uri.'
        repo_info = parse_repo(repo_uri)

        repo = Repository(namespace=repo_info.namespace,
                          repository=repo_info.repository)

        df_to_table(request.data_frame,
                    repository=repo,
                    table=request.table,
                    if_exists=request.if_exists,
                    schema_check=self.schema_check)
コード例 #10
0
    def run(self,
            data_frame: pd.DataFrame,
            table: str = None,
            if_exists: str = None,
            repo_uri: str = None,
            **kwargs: Any):
        """

        Args:

        Returns:

        """
        assert repo_uri, 'Must specify repo_uri.'
        repo_info = parse_repo(repo_uri)

        repo = Repository(namespace=repo_info.namespace,
                          repository=repo_info.repository)

        df_to_table(data_frame,
                    repository=repo,
                    table=table,
                    if_exists=if_exists)
コード例 #11
0
 def repo_info(self) -> RepoInfo:
     return parse_repo(self.location)
コード例 #12
0
    def write(self, value_: Any, **kwargs: Any) -> Result:
        """
        Writes the result to a repository on Splitgraph


        Args:
            - value_ (Any): the value to write; will then be stored as the `value` attribute
                of the returned `Result` instance
            - **kwargs (optional): if provided, will be used to format the `table`, `comment`, and `tag`

        Returns:
            - Result: returns a new `Result` with both `value`, `comment`, `table`, and `tag` attributes
        """

        if self.schema is not None:
            errors = self.schema.validate(value_)
            if errors:
                raise SchemaValidationError(errors)

        new = self.format(**kwargs)
        new.value = value_

        repo_info = parse_repo(new.location)

        repo = Repository(namespace=repo_info.namespace,
                          repository=repo_info.repository)
        remote = Repository.from_template(repo,
                                          engine=get_engine(
                                              repo_info.remote_name,
                                              autocommit=True))

        assert isinstance(value_, pd.DataFrame)

        if not repository_exists(repo) and self.auto_init_repo:
            self.logger.info("Creating repo {}/{}...".format(
                repo.namespace, repo.repository))
            repo.init()

        # TODO: Retrieve the repo from bedrock first

        self.logger.info("Starting to upload result to {}...".format(
            new.location))

        with self.atomic(repo.engine):
            self.logger.info("checkout")
            img = repo.head

            img.checkout(force=True)

            self.logger.info("df to table")
            df_to_table(new.value,
                        repository=repo,
                        table=repo_info.table,
                        if_exists='replace')

            self.logger.info("commit")
            new_img = repo.commit(comment=new.comment, chunk_size=10000)
            new_img.tag(repo_info.tag)

        # if (repo.diff(new.table, img, new_img)):
        if self.auto_push:
            self.logger.info("push")
            repo.push(
                remote,
                handler="S3",
                overwrite_objects=True,
                overwrite_tags=True,
                reupload_objects=True,
            )

        self.logger.info("Finished uploading result to {}...".format(
            new.location))

        return new