예제 #1
0
def test_push_own_delete_own(local_engine_empty, unprivileged_pg_repo):
    destination = Repository.from_template(unprivileged_pg_repo,
                                           engine=local_engine_empty)
    clone(unprivileged_pg_repo, local_repository=destination)

    destination.images["latest"].checkout()
    destination.run_sql(
        """UPDATE fruits SET name = 'banana' WHERE fruit_id = 1""")
    destination.commit()

    # Test we can push to our namespace -- can't upload the object to the splitgraph_meta since we can't create
    # tables there
    remote_destination = Repository.from_template(
        destination,
        namespace=unprivileged_pg_repo.engine.conn_params["SG_NAMESPACE"],
        engine=unprivileged_pg_repo.engine,
    )
    destination.upstream = remote_destination

    destination.push(handler="S3")
    # Test we can delete a single image from our own repo
    assert len(remote_destination.images()) == 3
    remote_destination.images.delete([destination.images["latest"].image_hash])
    assert len(remote_destination.images()) == 2

    # Test we can delete our own repo once we've pushed it
    remote_destination.delete()
    assert len(remote_destination.images()) == 0
예제 #2
0
def lookup_repository(name: str, include_local: bool = False) -> "Repository":
    """
    Queries the SG engines on the lookup path to locate one hosting the given repository.

    :param name: Repository name
    :param include_local: If True, also queries the local engine

    :return: Local or remote Repository object
    """
    from splitgraph.core.repository import Repository

    template = Repository.from_schema(name)

    if name in _LOOKUP_PATH_OVERRIDE:
        return Repository(
            template.namespace, template.repository, get_engine(_LOOKUP_PATH_OVERRIDE[name])
        )

    # Currently just check if the schema with that name exists on the remote.
    if include_local and repository_exists(template):
        return template

    for engine in _LOOKUP_PATH:
        candidate = Repository(template.namespace, template.repository, get_engine(engine))
        if repository_exists(candidate):
            return candidate
        candidate.engine.close()

    raise RepositoryNotFoundError("Unknown repository %s!" % name)
    def exists(self, location: str, **kwargs: Any) -> bool:
        """
        Checks whether the target result exists in the file system.

        Does not validate whether the result is `valid`, only that it is present.

        Args:
            - location (str): Location of the result in the specific result target.
                Will check whether the provided location exists
            - **kwargs (Any): string format arguments for `location`

        Returns:
            - bool: whether or not the target result exists
        """

        try:
            repo_info = parse_repo(location)
            repo = Repository(namespace=repo_info.namespace,
                              repository=repo_info.repository)
            remote = Repository.from_template(repo,
                                              engine=get_engine(
                                                  repo_info.remote_name,
                                                  autocommit=True))

            table_exists_at(remote, repo_info.table)
            return self.client.get_object(Bucket=self.bucket,
                                          Key=location.format(**kwargs))

        except Exception as exc:
            self.logger.exception(
                "Unexpected error while reading from Splitgraph: {}".format(
                    repr(exc)))
            raise
예제 #4
0
def make_pg_repo(engine, repository=None):
    repository = repository or Repository("test", "pg_mount")
    repository = Repository.from_template(repository, engine=engine)
    repository.init()
    repository.run_sql(PG_DATA)
    repository.commit()
    return repository
예제 #5
0
def _get_local_image_for_import(hash_or_tag: str,
                                repository: Repository) -> Tuple[Image, bool]:
    """
    Converts a remote repository and tag into an Image object that exists on the engine,
    optionally pulling the repository or cloning it into a temporary location.

    :param hash_or_tag: Hash/tag
    :param repository: Name of the repository (doesn't need to be local)
    :return: Image object and a boolean flag showing whether the repository should be deleted
    when the image is no longer needed.
    """
    tmp_repo = Repository(repository.namespace,
                          repository.repository + "_tmp_clone")
    repo_is_temporary = False

    logging.info("Resolving repository %s", repository)
    source_repo = lookup_repository(repository.to_schema(), include_local=True)
    if source_repo.engine.name != "LOCAL":
        clone(source_repo, local_repository=tmp_repo, download_all=False)
        source_image = tmp_repo.images[hash_or_tag]
        repo_is_temporary = True
    else:
        # For local repositories, first try to pull them to see if they are clones of a remote.
        if source_repo.upstream:
            source_repo.pull()
        source_image = source_repo.images[hash_or_tag]

    return source_image, repo_is_temporary
예제 #6
0
def _execute_custom(node: Node, output: Repository) -> ProvenanceLine:
    assert output.head is not None
    command, args = parse_custom_command(node)

    # Locate the command in the config file and instantiate it.
    cmd_fq_class: str = cast(
        str,
        get_all_in_section(CONFIG, "commands").get(command))
    if not cmd_fq_class:
        raise SplitfileError(
            "Custom command {0} not found in the config! Make sure you add an entry to your"
            " config like so:\n  [commands]  \n{0}=path.to.command.Class".
            format(command))

    assert isinstance(cmd_fq_class, str)
    index = cmd_fq_class.rindex(".")
    try:
        cmd_class = getattr(import_module(cmd_fq_class[:index]),
                            cmd_fq_class[index + 1:])
    except AttributeError as e:
        raise SplitfileError(
            "Error loading custom command {0}".format(command)) from e
    except ImportError as e:
        raise SplitfileError(
            "Error loading custom command {0}".format(command)) from e

    get_engine().run_sql("SET search_path TO %s", (output.to_schema(), ))
    command = cmd_class()

    # Pre-flight check: get the new command hash and see if we can short-circuit and just check the image out.
    command_hash = command.calc_hash(repository=output, args=args)
    output_head = output.head.image_hash

    if command_hash is not None:
        image_hash = _combine_hashes([output_head, command_hash])
        try:
            output.images.by_hash(image_hash).checkout()
            logging.info(" ---> Using cache")
            return {"type": "CUSTOM"}
        except ImageNotFoundError:
            pass

    logging.info(" Executing custom command...")
    exec_hash = command.execute(repository=output, args=args)
    command_hash = command_hash or exec_hash or "{:064x}".format(
        getrandbits(256))

    image_hash = _combine_hashes([output_head, command_hash])
    logging.info(" ---> %s" % image_hash[:12])

    # Check just in case if the new hash produced by the command already exists.
    try:
        output.images.by_hash(image_hash).checkout()
    except ImageNotFoundError:
        # Full command as a commit comment
        output.commit(image_hash, comment=node.text)
    return {"type": "CUSTOM"}
예제 #7
0
def test_engine_autocommit(local_engine_empty):
    conn_params = _prepare_engine_config(CONFIG)
    engine = PostgresEngine(conn_params=conn_params, name="test_engine", autocommit=True)

    repo = Repository("test", "repo", engine=engine)
    repo.init()

    repo.engine.rollback()
    assert repository_exists(Repository.from_template(repo, engine=local_engine_empty))
예제 #8
0
def readonly_pg_repo(unprivileged_remote_engine, pg_repo_remote_registry):
    target = Repository.from_template(pg_repo_remote_registry, namespace=READONLY_NAMESPACE)
    clone(pg_repo_remote_registry, target)
    pg_repo_remote_registry.delete(uncheckout=False)
    pg_repo_remote_registry.engine.run_sql(
        "UPDATE splitgraph_meta.objects SET namespace=%s WHERE namespace=%s",
        (READONLY_NAMESPACE, REMOTE_NAMESPACE),
    )
    pg_repo_remote_registry.engine.commit()
    yield Repository.from_template(target, engine=unprivileged_remote_engine)
예제 #9
0
def _prov_command_to_splitfile(
    prov_data: ProvenanceLine,
    source_replacement: Dict["Repository", str],
) -> str:
    """
    Converts the image's provenance data stored by the Splitfile executor back to a Splitfile used to
    reconstruct it.

    :param prov_data: Provenance line for one command
    :param source_replacement: Replace repository imports with different versions
    :return: String with the Splitfile command.
    """
    from splitgraph.core.repository import Repository

    prov_type = prov_data["type"]
    assert isinstance(prov_type, str)

    if prov_type == "IMPORT":
        repo, image = (
            Repository(cast(str, prov_data["source_namespace"]),
                       cast(str, prov_data["source"])),
            cast(str, prov_data["source_hash"]),
        )
        result = "FROM %s:%s IMPORT " % (str(repo),
                                         source_replacement.get(repo, image))
        result += ", ".join(
            "%s AS %s" %
            (tn if not q else "{" + tn.replace("}", "\\}") + "}", ta)
            for tn, ta, q in zip(
                cast(List[str], prov_data["tables"]),
                cast(List[str], prov_data["table_aliases"]),
                cast(List[bool], prov_data["table_queries"]),
            ))
        return result
    if prov_type == "FROM":
        repo = Repository(cast(str, prov_data["source_namespace"]),
                          cast(str, prov_data["source"]))
        return "FROM %s:%s" % (
            str(repo), source_replacement.get(repo, prov_data["source_hash"]))
    if prov_type == "SQL":
        # Use the SQL validator/replacer to rewrite old image hashes into new hashes/tags.

        def image_mapper(repository: Repository, image_hash: str):
            new_image = (repository.to_schema() + ":" +
                         source_replacement.get(repository, image_hash))
            return new_image, new_image

        if source_replacement:
            _, replaced_sql = prepare_splitfile_sql(str(prov_data["sql"]),
                                                    image_mapper)
        else:
            replaced_sql = str(prov_data["sql"])
        return "SQL " + "{" + replaced_sql.replace("}", "\\}") + "}"
    raise SplitGraphError("Cannot reconstruct provenance %s!" % prov_type)
예제 #10
0
    def run(self,
            workspaces: Dict[str, Workspace] = None,
            sgr_tags: Dict[str, List[str]] = None,
            **kwargs: Any):
        """

        Args:

        Returns:

        """

        repo_infos = dict((name, parse_repo(workspace['repo_uri']))
                          for (name, workspace) in workspaces.items())
        repos = dict((name,
                      Repository(namespace=repo_info.namespace,
                                 repository=repo_info.repository))
                     for (name, repo_info) in repo_infos.items())
        repos_with_new_images = dict(
            (name, repo) for (name, repo) in repos.items() if repo.head
            and repo.head.image_hash != workspaces[name]['image_hash'])

        for name, repo in repos_with_new_images.items():
            repo_tags = sgr_tags[name] if sgr_tags and name in sgr_tags else []
            for tag in repo_tags:
                repo.head.tag(tag)

        # Push all repos. We don't know for sure that it shouldn't be pushed
        for name, repo in repos.items():
            remote_name = repo_infos[name].remote_name
            if not remote_name:
                self.logger.warn(
                    f'No remote_name specified. Not pushing {name}.')
                continue

            remote = Repository.from_template(repo,
                                              engine=get_engine(remote_name))

            repo.push(
                remote,
                handler="S3",
                handler_options={"threads": 8},
                overwrite_objects=True,
                overwrite_tags=True,
            )
            self.logger.info(f'Pushed {name} to {remote_name}')

        tagged_repo_uris = dict(
            (name, workspaces[name]['repo_uri'])
            for (name, repo) in repos_with_new_images.items())
        return tagged_repo_uris
예제 #11
0
def _execute_from(
        node: Node,
        output: Repository) -> Tuple[Repository, Optional[ProvenanceLine]]:
    interesting_nodes = extract_nodes(node, ["repo_source", "repository"])
    repo_source = get_first_or_none(interesting_nodes, "repo_source")
    output_node = get_first_or_none(interesting_nodes, "repository")
    provenance: Optional[ProvenanceLine] = None

    if output_node:
        # AS (output) detected, change the current output repository to it.
        output = Repository.from_schema(output_node.match.group(0))
        logging.info("Changed output repository to %s" % str(output))

        # NB this destroys all data in the case where we ran some commands in the Splitfile and then
        # did FROM (...) without AS repository
        if repository_exists(output):
            logging.info("Clearing all output from %s" % str(output))
            output.delete()
    if not repository_exists(output):
        output.init()
    if repo_source:
        repository, tag_or_hash = parse_image_spec(repo_source)
        source_repo = lookup_repository(repository.to_schema(),
                                        include_local=True)

        if source_repo.engine.name == "LOCAL":
            # For local repositories, make sure to update them if they've an upstream
            if source_repo.upstream:
                source_repo.pull()

        # Get the target image hash from the source repo: otherwise, if the tag is, say, 'latest' and
        # the output has just had the base commit (000...) created in it, that commit will be the latest.
        clone(source_repo, local_repository=output, download_all=False)
        source_hash = source_repo.images[tag_or_hash].image_hash
        output.images.by_hash(source_hash).checkout()
        provenance = {
            "type": "FROM",
            "source_namespace": source_repo.namespace,
            "source": source_repo.repository,
            "source_hash": source_hash,
        }
    else:
        # FROM EMPTY AS repository -- initializes an empty repository (say to create a table or import
        # the results of a previous stage in a multistage build.
        # In this case, if AS repository has been specified, it's already been initialized. If not, this command
        # literally does nothing
        if not output_node:
            raise SplitfileError(
                "FROM EMPTY without AS (repository) does nothing!")
    return output, provenance
예제 #12
0
    def run(self,
            repo_uris: Dict[str, str] = None,
            retain: int = None,
            **kwargs: Any) -> Union[Version, None]:
        """

        Args:

        Returns:

        """

        repo_infos = dict(
            (name, parse_repo(uri)) for (name, uri) in repo_uris.items())
        repos = dict((name,
                      Repository(namespace=repo_info.namespace,
                                 repository=repo_info.repository))
                     for (name, repo_info) in repo_infos.items())

        repos_to_prune = dict(
            (name, (repos[name] if not repo_info.remote_name else Repository.
                    from_template(repos[name],
                                  engine=get_engine(repo_info.remote_name))))
            for (name, repo_info) in repo_infos.items())

        for name, repo_info in repo_infos.items():
            repo = repos_to_prune[name]
            prerelease = repo_info.prerelease
            image_tags = repo.get_all_hashes_tags()

            tag_dict = dict((tag, image_hash)
                            for (image_hash, tag) in image_tags
                            if image_hash)  #reverse keys

            version_list = [
                parse_tag(tag)
                for tag in sorted(list(tag_dict.keys()), key=len, reverse=True)
            ]

            valid_versions = [version for version in version_list if version]
            non_prerelease_versions = [
                version for version in valid_versions
                if len(version.prerelease) == 0
            ]
            prerelease_versions = [
                version for version in valid_versions
                if prerelease and len(version.prerelease) > 0
                and version.prerelease[0] == prerelease
            ]
            prune_candidates = prerelease_versions if prerelease else non_prerelease_versions

            total_candidates = len(prune_candidates)
            prune_count = total_candidates - retain
            prune_list = sorted(prune_candidates)[:prune_count]

            for version in prune_list:
                tag = str(version)
                image_hash = tag_dict[tag]
                image = repo.images[image_hash]
                image.delete_tag(tag)
예제 #13
0
def test_push_single_image(pg_repo_local, remote_engine):
    original_head = pg_repo_local.head
    _add_image_to_repo(pg_repo_local)
    remote_repo = Repository.from_template(pg_repo_local, engine=remote_engine)
    assert len(remote_repo.images()) == 0
    assert len(remote_repo.objects.get_all_objects()) == 0

    pg_repo_local.push(remote_repository=remote_repo,
                       single_image=original_head.image_hash)
    assert len(remote_repo.images()) == 1
    assert len(remote_repo.objects.get_all_objects()) == 2

    # Try pushing the same image again
    pg_repo_local.push(remote_repository=remote_repo,
                       single_image=original_head.image_hash)
    assert len(remote_repo.images()) == 1
    assert len(remote_repo.objects.get_all_objects()) == 2

    # Test we can check the repo out on the remote.
    remote_repo.images[original_head.image_hash].checkout()

    # Push the rest
    pg_repo_local.push(remote_repo)
    assert len(remote_repo.images()) == 3
    assert len(remote_repo.objects.get_all_objects()) == 3
예제 #14
0
def test_import_splitfile_reuses_hash(local_engine_empty):
    # Create two repositories and run the same Splitfile that loads some data from a mounted database.
    # Check that the same contents result in the same hash and no extra objects being created
    output_2 = Repository.from_schema("output_2")

    execute_commands(load_splitfile("import_from_mounted_db.splitfile"),
                     output=OUTPUT)
    execute_commands(load_splitfile("import_from_mounted_db.splitfile"),
                     output=output_2)

    head = OUTPUT.head
    assert head.get_table("my_fruits").objects == [
        "o71ba35a5bbf8ac7779d8fe32226aaacc298773e154a4f84e9aabf829238fb1"
    ]
    assert head.get_table("o_vegetables").objects == [
        "o70e726f4bf18547242722600c4723dceaaede27db8fa5e9e6d7ec39187dd86"
    ]
    assert head.get_table("vegetables").objects == [
        "ob474d04a80c611fc043e8303517ac168444dc7518af60e4ccc56b3b0986470"
    ]
    assert head.get_table("all_fruits").objects == [
        "o0e742bd2ea4927f5193a2c68f8d4c51ea018b1ef3e3005a50727147d2cf57b"
    ]

    head_2 = output_2.head
    assert head_2.get_table("my_fruits").objects == head.get_table(
        "my_fruits").objects
    assert head_2.get_table("o_vegetables").objects == head.get_table(
        "o_vegetables").objects
    assert head_2.get_table("vegetables").objects == head.get_table(
        "vegetables").objects
    assert head_2.get_table("all_fruits").objects == head.get_table(
        "all_fruits").objects
예제 #15
0
    def provenance(self,
                   reverse=False,
                   engine=None) -> List[Tuple["Repository", str]]:
        """
        Inspects the image's parent chain to come up with a set of repositories and their hashes
        that it was created from.

        If `reverse` is True, returns a list of images that were created _from_ this image. If
        this image is on a remote repository, `engine` can be passed in to override the engine
        used for the lookup of dependents.

        :return: List of (repository, image_hash)
        """
        from splitgraph.core.repository import Repository

        api_call = "get_image_dependents" if reverse else "get_image_dependencies"

        engine = engine or self.engine

        result = set()
        for namespace, repository, image_hash in engine.run_sql(
                select(api_call,
                       table_args="(%s,%s,%s)",
                       schema=SPLITGRAPH_API_SCHEMA),
            (self.repository.namespace, self.repository.repository,
             self.image_hash),
        ):
            result.add((Repository(namespace, repository), image_hash))
        return list(result)
예제 #16
0
    def checkout_workspace(self, repo: Repository,
                           repo_info: RepoInfo) -> Workspace:
        image_tags = repo.get_all_hashes_tags()

        tag_dict = dict((tag, image_hash) for (image_hash, tag) in image_tags
                        if image_hash)  #reverse keys
        default_image = repo.images[
            tag_dict['latest']] if 'latest' in tag_dict else repo.head
        version_list = [
            parse_tag(tag)
            for tag in sorted(list(tag_dict.keys()), key=len, reverse=True)
        ]

        valid_versions = [version for version in version_list if version]

        spec_expr = f'<={repo_info.major}.{repo_info.minor}' if repo_info.minor else f'<={repo_info.major}'
        base_ref_spec = NpmSpec(spec_expr)
        base_ref = base_ref_spec.select(valid_versions)

        if repo_info.prerelease:
            assert base_ref, 'Cannot checkout using prerelease until a repo is initialized.'
            prerelease_base_version = base_ref.next_patch()
            base_ref = NpmSpec(
                f'>={str(prerelease_base_version)}-{repo_info.prerelease}'
            ).select(valid_versions)

        image_hash = tag_dict[str(
            base_ref)] if base_ref else default_image.image_hash

        image = repo.images[image_hash]
        image.checkout(force=True)
        return Workspace(repo_uri=repo_info.uri,
                         image_hash=image_hash,
                         version=base_ref)
예제 #17
0
def _execute_db_import(conn_string, fdw_name, fdw_params, table_names,
                       target_mountpoint, table_aliases,
                       table_queries) -> ProvenanceLine:
    mount_handler = get_mount_handler(fdw_name)
    tmp_mountpoint = Repository.from_schema(fdw_name + "_tmp_staging")
    tmp_mountpoint.delete()
    try:
        handler_kwargs = json.loads(fdw_params)
        handler_kwargs.update(
            conn_string_to_dict(conn_string.group() if conn_string else None))
        mount_handler(tmp_mountpoint.to_schema(), **handler_kwargs)
        # The foreign database is a moving target, so the new image hash is random.
        # Maybe in the future, when the object hash is a function of its contents, we can be smarter here...
        target_hash = "{:064x}".format(getrandbits(256))
        target_mountpoint.import_tables(
            table_aliases,
            tmp_mountpoint,
            table_names,
            target_hash=target_hash,
            foreign_tables=True,
            table_queries=table_queries,
        )
        return {"type": "MOUNT"}
    finally:
        tmp_mountpoint.delete()
예제 #18
0
def unprivileged_pg_repo(unprivileged_remote_engine, pg_repo_remote_registry):
    """Like pg_repo_remote_registry but accessed as an unprivileged user that can't
    access splitgraph_meta directly and has to use splitgraph_api. If access to
    splitgraph_meta is required, the test can use both fixtures and do e.g.
    pg_repo_remote_registry.objects.get_all_objects()"""
    yield Repository.from_template(pg_repo_remote_registry,
                                   engine=unprivileged_remote_engine)
예제 #19
0
def test_push_own_delete_own_different_namespaces(local_engine_empty,
                                                  readonly_pg_repo):
    # Same as previous but we clone the read-only repo and push to our own namespace
    # to check that the objects we push get their namespaces rewritten to be the unprivileged user, not test.
    destination = clone(readonly_pg_repo)

    destination.images["latest"].checkout()
    destination.run_sql(
        """UPDATE fruits SET name = 'banana' WHERE fruit_id = 1""")
    destination.commit()

    remote_destination = Repository.from_template(
        readonly_pg_repo,
        namespace=readonly_pg_repo.engine.conn_params["SG_NAMESPACE"],
        engine=readonly_pg_repo.engine,
    )
    destination.upstream = remote_destination

    destination.push(handler="S3")

    object_id = destination.head.get_table("fruits").objects[-1]
    assert (remote_destination.objects.get_object_meta([object_id
                                                        ])[object_id].namespace
            == readonly_pg_repo.engine.conn_params["SG_NAMESPACE"])

    # Test we can delete our own repo once we've pushed it
    remote_destination.delete(uncheckout=False)
    assert len(remote_destination.images()) == 0
예제 #20
0
def dependents_c(image_spec, source_on, dependents_on):
    """
    List images that were created from an image.

    This is the inverse of the sgr provenance command. It will list all images that were
    created using a Splitfile that imported data from this image.

    By default, this will look at images on the local engine. The engine can be overridden
    with --source-on and --dependents-on. For example:

        sgr dependents --source-on data.splitgraph.com --dependents-on LOCAL noaa/climate:latest

    will show all images on the local engine that derived data from `noaa/climate:latest`
    on the Splitgraph registry.
    """
    from splitgraph.engine import get_engine
    from splitgraph.core.repository import Repository

    source_engine = get_engine(source_on) if source_on else get_engine()
    repository, image = image_spec
    repository = Repository.from_template(repository, engine=source_engine)
    image = repository.images[image]

    target_engine = get_engine(
        dependents_on) if dependents_on else get_engine()

    result = image.provenance(reverse=True, engine=target_engine)
    click.echo("%s:%s is depended on by:" %
               (str(repository), image.image_hash))
    click.echo("\n".join("%s:%s" % rs for rs in result))
예제 #21
0
def build_c(splitfile, args, output_repository):
    """
    Build Splitgraph images.

    This executes a Splitfile, building a new image or checking it out from cache if the same
    image had already been built.

    Examples:

    ``sgr build my.splitfile``

    Executes ``my.splitfile`` and writes its output into a new repository with the same name
    as the Splitfile (my) unless the name is specified in the Splitfile.

    ``sgr build my.splitfile -o mynew/repo``

    Executes ``my.splitfile`` and writes its output into ``mynew/repo``.

    ``sgr build my_other.splitfile -o mynew/otherrepo --args PARAM1 VAL1 --args PARAM2 VAL2``

    Executes ``my_other.splitfile`` with parameters ``PARAM1`` and ``PARAM2`` set to
    ``VAL1`` and  ``VAL2``, respectively.
    """
    from splitgraph.splitfile import execute_commands
    from splitgraph.core.repository import Repository

    args = {k: v for k, v in args}
    click.echo("Executing Splitfile %s with arguments %r" %
               (splitfile.name, args))

    if output_repository is None:
        file_name = os.path.splitext(os.path.basename(splitfile.name))[0]
        output_repository = Repository.from_schema(file_name)

    execute_commands(splitfile.read(), args, output=output_repository)
예제 #22
0
    def run(self, upstream_repos: Dict[str, str] = None, splitfile_commands: str = None, output: Workspace = None, **kwargs: Any):
        """

        Args:

        Returns:
            - No return
        """
        repo_infos = dict((name, parse_repo(uri)) for (name, uri) in upstream_repos.items())
        v1_sgr_repo_uris = dict((name, repo_info.v1_sgr_uri()) for (name, repo_info) in repo_infos.items())
 

        formatting_kwargs = {
            **v1_sgr_repo_uris,
            **kwargs,
            **prefect.context.get("parameters", {}).copy(),
            **prefect.context,
        }


        repo_info = parse_repo(output['repo_uri'])
        repo = Repository(namespace=repo_info.namespace, repository=repo_info.repository)
     
        execute_commands(
            splitfile_commands, 
            params=formatting_kwargs, 
            output=repo, 
            # output_base=output['image_hash'],
        )
예제 #23
0
def test_bloom_reindex_remote(local_engine_empty, unprivileged_pg_repo,
                              clean_minio):
    _prepare_fully_remote_repo(local_engine_empty, unprivileged_pg_repo)

    # Do a reindex using our local engine to query the object and the remote engine
    # to write metadata to.
    repo = Repository.from_template(unprivileged_pg_repo,
                                    object_engine=local_engine_empty)
    fruits = repo.images["latest"].get_table("fruits")

    # The repo used for LQ tests has 2 objects that overwrite data, so we ignore those.
    reindexed = fruits.reindex(
        extra_indexes={"bloom": {
            "name": {
                "probability": 0.01
            }
        }},
        raise_on_patch_objects=False)
    repo.commit_engines()

    assert len(reindexed) == 3
    assert set(repo.objects.get_downloaded_objects()) == set(reindexed)

    # Check the index was written to the remote metadata engine.
    assert ("bloom" in unprivileged_pg_repo.objects.get_object_meta(reindexed)[
        reindexed[0]].object_index)
예제 #24
0
def test_splitfile_end_to_end_with_uploading(local_engine_empty, remote_engine,
                                             pg_repo_remote_multitag,
                                             mg_repo_remote, clean_minio):
    # An end-to-end test:
    #   * Create a derived dataset from some tables imported from the remote engine
    #   * Push it back to the remote engine, uploading all objects to S3 (instead of the remote engine itself)
    #   * Delete everything from pgcache
    #   * Run another splitfile that depends on the just-pushed dataset (and does lazy checkouts to
    #     get the required tables).

    # Do the same setting up first and run the splitfile against the remote data.
    execute_commands(load_splitfile("import_remote_multiple.splitfile"),
                     params={"TAG": "v1"},
                     output=OUTPUT)

    remote_output = Repository(OUTPUT.namespace, OUTPUT.repository,
                               remote_engine)

    # Push with upload
    OUTPUT.push(remote_repository=remote_output,
                handler="S3",
                handler_options={})
    # Unmount everything locally and cleanup
    for mountpoint, _ in get_current_repositories(local_engine_empty):
        mountpoint.delete()
    OUTPUT.objects.cleanup()

    stage_2 = R("output_stage_2")
    execute_commands(
        load_splitfile("import_from_preuploaded_remote.splitfile"),
        output=stage_2)

    assert stage_2.run_sql("SELECT id, name, fruit, vegetable FROM diet") == [
        (2, "James", "orange", "carrot")
    ]
예제 #25
0
def test_mount_and_import(local_engine_empty):
    runner = CliRunner()
    try:
        # sgr mount
        result = runner.invoke(
            mount_c,
            [
                "mongo_fdw",
                "tmp",
                "-c",
                "originro:originpass@mongoorigin:27017",
                "-o",
                json.dumps(_MONGO_PARAMS),
            ],
        )
        assert result.exit_code == 0

        result = runner.invoke(import_c, ["tmp", "stuff", str(MG_MNT)])
        assert result.exit_code == 0
        assert MG_MNT.head.get_table("stuff")

        result = runner.invoke(
            import_c, ["tmp", "SELECT * FROM stuff WHERE duration > 10", str(MG_MNT), "stuff_query"]
        )
        assert result.exit_code == 0
        assert MG_MNT.head.get_table("stuff_query")
    finally:
        Repository("", "tmp").delete()
예제 #26
0
def clone_c(remote_repository_or_image, local_repository, remote, download_all,
            overwrite_object_meta, tags):
    """
    Clone a remote Splitgraph repository/image into a local one.

    The lookup path for the repository is governed by the ``SG_REPO_LOOKUP`` and ``SG_REPO_LOOKUP_OVERRIDE``
    config parameters and can be overridden by the command line ``--remote`` option.
    """
    from splitgraph.core.repository import Repository
    from splitgraph.engine import get_engine
    from splitgraph.core.repository import clone

    remote_repository, image = remote_repository_or_image

    # If the user passed in a remote, we can inject that into the repository spec.
    # Otherwise, we have to turn the repository into a string and let clone() look up the
    # actual engine the repository lives on.
    if remote:
        remote_repository = Repository.from_template(remote_repository,
                                                     engine=get_engine(remote))
    else:
        remote_repository = remote_repository.to_schema()

    clone(
        remote_repository,
        local_repository=local_repository,
        download_all=download_all,
        single_image=image,
        overwrite_objects=overwrite_object_meta,
        overwrite_tags=tags,
    )
예제 #27
0
def test_push_target(
    repository,
    remote_repository,
    remote,
    available_remotes,
    upstream,
    expected_target,
    expected_remote,
):

    repository = Repository.from_schema(repository)
    remote_repository = Repository.from_schema(
        remote_repository) if remote_repository else None

    fake_config = {
        "remotes": {s: {
            "SG_NAMESPACE": "user"
        }
                    for s in available_remotes}
    }

    with mock.patch.object(Repository, "upstream",
                           new_callable=PropertyMock) as up:
        up.return_value = upstream
        with mock.patch("splitgraph.commandline.push_pull.REMOTES",
                        available_remotes):
            with mock.patch("splitgraph.commandline.push_pull.CONFIG",
                            fake_config):
                with mock.patch("splitgraph.engine.get_engine") as ge:
                    ge.return_value = Mock()
                    ge.return_value.name = expected_remote

                    if isinstance(expected_target, type):
                        with pytest.raises(expected_target):
                            _determine_push_target(repository,
                                                   remote_repository, remote)
                    else:
                        result = _determine_push_target(
                            repository, remote_repository, remote)
                        if upstream:
                            assert result == upstream
                        else:
                            assert result.to_schema() == expected_target

                            ge_call = ge.mock_calls[0]
                            assert ge_call[1][0] == expected_remote
                            assert result.engine.name == expected_remote
예제 #28
0
def test_singer_ingestion_errors(local_engine_empty):
    runner = CliRunner(mix_stderr=False)

    with open(os.path.join(INGESTION_RESOURCES, "singer/initial.json"),
              "r") as f:
        result = runner.invoke(singer_target, [TEST_REPO + ":latest"],
                               input=f,
                               catch_exceptions=False)

    assert result.exit_code == 0

    # Default strategy: delete image on failure
    with open(os.path.join(INGESTION_RESOURCES, "singer/wrong_schema.json"),
              "r") as f:
        result = runner.invoke(singer_target, [TEST_REPO + ":latest"],
                               input=f,
                               catch_exceptions=True)

    assert result.exit_code == 1
    assert isinstance(result.exception, psycopg2.errors.InvalidDatetimeFormat)
    repo = Repository.from_schema(TEST_REPO)
    assert len(repo.images()) == 1

    # Keep new image
    with open(os.path.join(INGESTION_RESOURCES, "singer/wrong_schema.json"),
              "r") as f:
        result = runner.invoke(
            singer_target,
            [TEST_REPO + ":latest", "--failure=keep-both"],
            input=f,
            catch_exceptions=True,
        )

    assert result.exit_code == 1
    assert isinstance(result.exception, psycopg2.errors.InvalidDatetimeFormat)
    repo = Repository.from_schema(TEST_REPO)
    assert len(repo.images()) == 2

    # The "stargazers" table is still the same but the "releases" table managed to get updated.
    image = repo.images["latest"]
    assert sorted(image.get_tables()) == ["releases", "stargazers"]
    image.checkout()

    assert repo.run_sql("SELECT COUNT(1) FROM releases",
                        return_shape=ResultShape.ONE_ONE) == 7
    assert repo.run_sql("SELECT COUNT(1) FROM stargazers",
                        return_shape=ResultShape.ONE_ONE) == 5
예제 #29
0
def test_pull_push(local_engine_empty, pg_repo_remote):
    runner = CliRunner()
    pg_repo_local = Repository.from_template(pg_repo_remote,
                                             engine=local_engine_empty)

    # Clone the base 0000.. image first to check single-image clones
    assert len(pg_repo_local.images()) == 0
    result = runner.invoke(clone_c, [str(pg_repo_local) + ":" + "00000000"])
    assert result.exit_code == 0
    assert len(pg_repo_local.images()) == 1
    assert repository_exists(pg_repo_local)

    # Clone the rest of the repo
    result = runner.invoke(clone_c, [str(pg_repo_local)])
    assert result.exit_code == 0
    assert len(pg_repo_local.images()) == 2

    pg_repo_remote.run_sql("INSERT INTO fruits VALUES (3, 'mayonnaise')")
    remote_engine_head = pg_repo_remote.commit()

    # Pull the new image
    result = runner.invoke(
        pull_c,
        [str(pg_repo_local) + ":" + remote_engine_head.image_hash[:10]])
    assert result.exit_code == 0
    assert len(pg_repo_local.objects.get_downloaded_objects()) == 0
    assert len(pg_repo_local.images()) == 3

    # Pull the whole repo (should be no changes)
    result = runner.invoke(pull_c, [str(pg_repo_local)])
    assert result.exit_code == 0
    assert len(pg_repo_local.objects.get_downloaded_objects()) == 0
    assert len(pg_repo_local.images()) == 3

    # Pull repo downloading everything
    result = runner.invoke(pull_c, [str(pg_repo_local), "--download-all"])
    assert result.exit_code == 0
    assert len(pg_repo_local.objects.get_downloaded_objects()) == 3

    pg_repo_local.images.by_hash(remote_engine_head.image_hash).checkout()

    pg_repo_local.run_sql("INSERT INTO fruits VALUES (4, 'mustard')")
    local_head = pg_repo_local.commit()

    assert local_head.image_hash not in list(pg_repo_remote.images)

    # Push out the single new image first
    result = runner.invoke(
        push_c,
        [str(pg_repo_local) + ":" + local_head.image_hash[:10], "-h", "DB"])
    assert result.exit_code == 0
    assert len(pg_repo_remote.images()) == 4

    # Push out the whole repo
    result = runner.invoke(push_c, [str(pg_repo_local), "-h", "DB"])
    assert result.exit_code == 0
    assert pg_repo_local.head.get_table("fruits")
예제 #30
0
def ingestion_test_repo():
    repo = Repository.from_schema("test/ingestion")
    try:
        repo.delete()
        repo.objects.cleanup()
        repo.init()
        yield repo
    finally:
        repo.rollback_engines()
        repo.delete()