コード例 #1
0
    def test_find_datasets_client_api(self, client, created_entities):
        tags = [
            "test1a-{}".format(_utils.now()), "test1b-{}".format(_utils.now())
        ]
        dataset1 = client.set_dataset(type="big query", tags=tags)
        created_entities.append(dataset1)
        assert dataset1.id

        single_tag = ["test2-{}".format(_utils.now())]
        dataset2 = client.set_dataset(type="s3", tags=single_tag)
        created_entities.append(dataset2)
        assert dataset2.id

        # TODO: update once RAW is supported
        # dataset3 = client.set_dataset(type="raw")
        # created_entities.append(dataset3)
        # assert dataset3._dataset_type == _DatasetService.DatasetTypeEnum.RAW
        # assert dataset3.id

        # datasets = client.find_datasets()
        # assert len(datasets) == 3
        # assert datasets[0].id == dataset1.id
        # assert datasets[1].id == dataset2.id
        # assert datasets[2].id == dataset3.id

        datasets = client.find_datasets(tags=tags)
        assert len(datasets) == 1
        assert datasets[0].id == dataset1.id

        # str arg automatically wrapped into list by client
        datasets = client.find_datasets(tags=single_tag[0])
        assert len(datasets) == 1
        assert datasets[0].id == dataset2.id

        datasets = client.find_datasets(name=dataset1.name)
        assert len(datasets) == 1
        assert datasets[0].id == dataset1.id

        datasets = client.find_datasets(dataset_ids=[dataset1.id, dataset2.id],
                                        name=dataset1.name)
        assert len(datasets) == 1
        assert datasets[0].id == dataset1.id

        # test sorting ascending
        datasets = client.find_datasets(
            dataset_ids=[dataset1.id, dataset2.id],
            sort_key="time_created",
            ascending=True,
        )
        assert [dataset.id
                for dataset in datasets] == [dataset1.id, dataset2.id]
        # and descending
        datasets = client.find_datasets(
            dataset_ids=[dataset1.id, dataset2.id],
            sort_key="time_created",
            ascending=False,
        )
        assert [dataset.id
                for dataset in datasets] == [dataset2.id, dataset1.id]
コード例 #2
0
ファイル: test_datasets.py プロジェクト: stjordanis/modeldb
    def test_find_datasets_client_api(self, client, created_datasets):
        tags = [
            "test1-{}".format(_utils.now()), "test1-{}".format(_utils.now())
        ]
        dataset1 = client.set_dataset(type="big query", tags=tags)
        created_datasets.append(dataset1)
        assert dataset1._dataset_type == _DatasetService.DatasetTypeEnum.QUERY
        assert dataset1.id

        single_tag = ["test2-{}".format(_utils.now())]
        dataset2 = client.set_dataset(type="s3", tags=single_tag)
        created_datasets.append(dataset2)
        assert dataset2._dataset_type == _DatasetService.DatasetTypeEnum.PATH
        assert dataset2.id

        # TODO: update once RAW is supported
        # dataset3 = client.set_dataset(type="raw")
        # created_datasets.append(dataset3)
        # assert dataset3._dataset_type == _DatasetService.DatasetTypeEnum.RAW
        # assert dataset3.id

        # datasets = client.find_datasets()
        # assert len(datasets) == 3
        # assert datasets[0].id == dataset1.id
        # assert datasets[1].id == dataset2.id
        # assert datasets[2].id == dataset3.id

        datasets = client.find_datasets(tags=tags)
        assert len(datasets) == 1
        assert datasets[0].id == dataset1.id

        # str arg automatically wrapped into list by client
        datasets = client.find_datasets(tags=single_tag[0])
        assert len(datasets) == 1
        assert datasets[0].id == dataset2.id

        datasets = client.find_datasets(name=dataset1.name)
        assert len(datasets) == 1
        assert datasets[0].id == dataset1.id

        datasets = client.find_datasets(dataset_ids=[dataset1.id, dataset2.id])
        assert len(datasets) == 2

        datasets = client.find_datasets(dataset_ids=[dataset1.id, dataset2.id],
                                        name=dataset1.name)
        assert len(datasets) == 1
        assert datasets[0].id == dataset1.id
コード例 #3
0
ファイル: test_datasets.py プロジェクト: stjordanis/modeldb
    def test_find_datasets_by_fuzzy_name(self, client, created_datasets):
        now = str(_utils.now())
        created_datasets.append(client.set_dataset(now + " appl"))
        created_datasets.append(client.set_dataset(now + " Appl"))
        created_datasets.append(client.set_dataset(now + " Apple"))

        datasets = client.find_datasets(name=now + " Appl")
        assert len(datasets) == 3
コード例 #4
0
ファイル: _entity.py プロジェクト: vishalbelsare/modeldb
    def log_code(self,
                 exec_path=None,
                 repo_url=None,
                 commit_hash=None,
                 overwrite=False,
                 is_dirty=None,
                 autocapture=True):
        """
        Logs the code version.

        A code version is either information about a Git snapshot or a bundle of Python source code files.

        `repo_url` and `commit_hash` can only be set if `use_git` was set to ``True`` in the Client.

        Parameters
        ----------
        exec_path : str, optional
            Filepath to the executable Python script or Jupyter notebook. If no filepath is provided,
            the Client will make its best effort to find the currently running script/notebook file.
        repo_url : str, optional
            URL for a remote Git repository containing `commit_hash`. If no URL is provided, the Client
            will make its best effort to find it.
        commit_hash : str, optional
            Git commit hash associated with this code version. If no hash is provided, the Client will
            make its best effort to find it.
        overwrite : bool, default False
            Whether to allow overwriting a code version.
        is_dirty : bool, optional
            Whether git status is dirty relative to `commit_hash`. If not provided, the Client will
            make its best effort to find it.
        autocapture : bool, default True
            Whether to enable the automatic capturing behavior of parameters above in git mode.

        Examples
        --------
        With ``Client(use_git=True)`` (default):

            Log Git snapshot information, plus the location of the currently executing notebook/script
            relative to the repository root:

            .. code-block:: python

                run.log_code()
                run.get_code()
                # {'exec_path': 'comparison/outcomes/classification.ipynb',
                #  'repo_url': '[email protected]:VertaAI/experiments.git',
                #  'commit_hash': 'f99abcfae6c3ce6d22597f95ad6ef260d31527a6',
                #  'is_dirty': False}

            Log Git snapshot information, plus the location of a specific source code file relative
            to the repository root:

            .. code-block:: python

                run.log_code("../trainer/training_pipeline.py")
                run.get_code()
                # {'exec_path': 'comparison/trainer/training_pipeline.py',
                #  'repo_url': '[email protected]:VertaAI/experiments.git',
                #  'commit_hash': 'f99abcfae6c3ce6d22597f95ad6ef260d31527a6',
                #  'is_dirty': False}

        With ``Client(use_git=False)``:

            Find and upload the currently executing notebook/script:

            .. code-block:: python

                run.log_code()
                zip_file = run.get_code()
                zip_file.printdir()
                # File Name                          Modified             Size
                # classification.ipynb        2019-07-10 17:18:24        10287

            Upload a specific source code file:

            .. code-block:: python

                run.log_code("../trainer/training_pipeline.py")
                zip_file = run.get_code()
                zip_file.printdir()
                # File Name                          Modified             Size
                # training_pipeline.py        2019-05-31 10:34:44          964

        """
        if self._conf.use_git and autocapture:
            # verify Git
            try:
                repo_root_dir = _git_utils.get_git_repo_root_dir()
            except OSError:
                # don't halt execution
                print(
                    "unable to locate git repository; you may be in an unsupported environment;\n"
                    "    consider using `autocapture=False` to manually enter values"
                )
                return
                # six.raise_from(OSError("failed to locate git repository; please check your working directory"),
                #                None)
            print("Git repository successfully located at {}".format(
                repo_root_dir))
        if not self._conf.use_git:
            if repo_url is not None or commit_hash is not None:
                raise ValueError(
                    "`repo_url` and `commit_hash` can only be set if `use_git` was set to True in the Client"
                )
            if not autocapture:  # user passed `False`
                raise ValueError(
                    "`autocapture` is only applicable if `use_git` was set to True in the Client"
                )

        if autocapture:
            if exec_path is None:
                # find dynamically
                try:
                    exec_path = _utils.get_notebook_filepath()
                except (ImportError, OSError):  # notebook not found
                    try:
                        exec_path = _utils.get_script_filepath()
                    except OSError:  # script not found
                        print("unable to find code file; skipping")
            else:
                exec_path = os.path.expanduser(exec_path)
                if not os.path.isfile(exec_path):
                    raise ValueError(
                        "`exec_path` \"{}\" must be a valid filepath".format(
                            exec_path))

        # TODO: remove this circular dependency
        from ._project import Project
        from ._experiment import Experiment
        from ._experimentrun import ExperimentRun
        if isinstance(self, Project):  # TODO: not this
            Message = self._service.LogProjectCodeVersion
            endpoint = "logProjectCodeVersion"
        elif isinstance(self, Experiment):
            Message = self._service.LogExperimentCodeVersion
            endpoint = "logExperimentCodeVersion"
        elif isinstance(self, ExperimentRun):
            Message = self._service.LogExperimentRunCodeVersion
            endpoint = "logExperimentRunCodeVersion"
        msg = Message(id=self.id)

        if overwrite:
            if isinstance(self, ExperimentRun):
                msg.overwrite = True
            else:
                raise ValueError(
                    "`overwrite=True` is currently only supported for ExperimentRun"
                )

        if self._conf.use_git:
            if autocapture:
                try:
                    # adjust `exec_path` to be relative to repo root
                    exec_path = os.path.relpath(
                        exec_path, _git_utils.get_git_repo_root_dir())
                except OSError as e:
                    print(
                        "{}; logging absolute path to file instead".format(e))
                    exec_path = os.path.abspath(exec_path)
            if exec_path:
                msg.code_version.git_snapshot.filepaths.append(exec_path)

            from verta.code import _git  # avoid Python 2 top-level circular import
            code_ver = _git.Git(
                repo_url=repo_url,
                commit_hash=commit_hash,
                is_dirty=is_dirty,
                autocapture=autocapture,
            )
            msg.code_version.git_snapshot.repo = code_ver.repo_url or ""
            msg.code_version.git_snapshot.hash = code_ver.commit_hash or ""
            if not autocapture and is_dirty is None:
                msg.code_version.git_snapshot.is_dirty = _CommonCommonService.TernaryEnum.UNKNOWN
            elif code_ver.is_dirty:
                msg.code_version.git_snapshot.is_dirty = _CommonCommonService.TernaryEnum.TRUE
            else:
                msg.code_version.git_snapshot.is_dirty = _CommonCommonService.TernaryEnum.FALSE
        else:  # log code as Artifact
            if exec_path is None:
                # don't halt execution
                print(
                    "unable to find code file; you may be in an unsupported environment"
                )
                return
                # raise RuntimeError("unable to find code file; you may be in an unsupported environment")

            # write ZIP archive
            zipstream = six.BytesIO()
            with zipfile.ZipFile(zipstream, 'w') as zipf:
                filename = os.path.basename(exec_path)
                if exec_path.endswith(".ipynb"):
                    try:
                        saved_notebook = _utils.save_notebook(exec_path)
                    except:  # failed to save
                        print("unable to automatically save Notebook;"
                              " logging latest checkpoint from disk")
                        zipf.write(exec_path, filename)
                    else:
                        zipf.writestr(
                            _artifact_utils.global_read_zipinfo(filename),
                            six.ensure_binary(saved_notebook.read()),
                        )
                else:
                    zipf.write(exec_path, filename)
            zipstream.seek(0)

            key = 'code'
            extension = 'zip'

            artifact_hash = _artifact_utils.calc_sha256(zipstream)
            basename = key + os.extsep + extension
            artifact_path = os.path.join(artifact_hash, basename)

            msg.code_version.code_archive.path = artifact_path
            msg.code_version.code_archive.path_only = False
            msg.code_version.code_archive.artifact_type = _CommonCommonService.ArtifactTypeEnum.CODE
            msg.code_version.code_archive.filename_extension = extension
        # TODO: check if we actually have any loggable information
        msg.code_version.date_logged = _utils.now()

        data = _utils.proto_to_json(msg)
        response = _utils.make_request("POST",
                                       self._request_url.format(endpoint),
                                       self._conn,
                                       json=data)
        if not response.ok:
            if response.status_code == 409:
                raise ValueError("a code version has already been logged")
            else:
                _utils.raise_for_http_error(response)

        if msg.code_version.WhichOneof("code") == 'code_archive':
            # upload artifact to artifact store
            # pylint: disable=no-member
            # this method should only be called on ExperimentRun, which does have _get_url_for_artifact()
            url = self._get_url_for_artifact(
                "verta_code_archive", "PUT",
                msg.code_version.code_archive.artifact_type).url

            response = _utils.make_request("PUT",
                                           url,
                                           self._conn,
                                           data=zipstream)
            _utils.raise_for_http_error(response)