コード例 #1
0
    def create_experiment(self, project, name, description, params, properties,
                          tags, abortable, monitored):
        ExperimentCreationParams = self.backend_swagger_client.get_model(
            'ExperimentCreationParams')

        try:
            params = ExperimentCreationParams(
                projectId=project.internal_id,
                name=name,
                description=description,
                parameters=self._convert_to_api_parameters(params),
                properties=self._convert_to_api_properties(properties),
                tags=tags,
                enqueueCommand="command",  # FIXME
                entrypoint="",  # FIXME
                execArgsTemplate="",  # FIXME,
                abortable=abortable,
                monitored=monitored)

            api_experiment = self.backend_swagger_client.api.createExperiment(
                experimentCreationParams=params).response().result

            return self._convert_to_experiment(api_experiment)
        except HTTPNotFound:
            raise ProjectNotFound(project_identifier=project.full_id)
        except HTTPBadRequest as e:
            error_type = extract_response_field(e.response, 'type')
            if error_type == 'DUPLICATE_PARAMETER':
                raise ExperimentValidationError(
                    'Parameter list contains duplicates.')
            elif error_type == 'INVALID_TAG':
                raise ExperimentValidationError(
                    extract_response_field(e.response, 'message'))
            else:
                raise
        except HTTPUnprocessableEntity as e:
            if extract_response_field(
                    e.response,
                    'type') == 'LIMIT_OF_EXPERIMENTS_IN_PROJECT_REACHED':
                raise ExperimentLimitReached()
            else:
                raise
コード例 #2
0
 def update_tags(self, experiment, tags_to_add, tags_to_delete):
     UpdateTagsParams = self.backend_swagger_client.get_model(
         'UpdateTagsParams')
     try:
         self.backend_swagger_client.api.updateTags(
             updateTagsParams=UpdateTagsParams(
                 experimentIds=[experiment.internal_id],
                 groupsIds=[],
                 tagsToAdd=tags_to_add,
                 tagsToDelete=tags_to_delete)).response().result
     except HTTPNotFound:
         # pylint: disable=protected-access
         raise ExperimentNotFound(
             experiment_short_id=experiment.id,
             project_qualified_name=experiment._project_full_id)
     except HTTPBadRequest as e:
         error_type = extract_response_field(e.response, 'type')
         if error_type == 'INVALID_TAG':
             raise ExperimentValidationError(
                 extract_response_field(e.response, 'message'))
         else:
             raise
コード例 #3
0
 def extract_experiment_output(self, experiment, data):
     try:
         return self._upload_tar_data(
             experiment=experiment,
             api_method=self.backend_swagger_client.api.uploadExperimentOutputAsTarstream,
             data=data
         )
     except HTTPError as e:
         if e.response.status_code == NOT_FOUND:
             # pylint: disable=protected-access
             raise ExperimentNotFound(
                 experiment_short_id=experiment.id, project_qualified_name=experiment._project.full_id)
         if e.response.status_code == UNPROCESSABLE_ENTITY and (
                 extract_response_field(e.response, 'type') == 'LIMIT_OF_STORAGE_IN_PROJECT_REACHED'):
             raise StorageLimitReached()
         raise
コード例 #4
0
 def upload_experiment_output(self, experiment, data):
     try:
         # Api exception handling is done in _upload_loop
         self._upload_loop(partial(self._upload_raw_data,
                                   api_method=self.backend_swagger_client.api.uploadExperimentOutput),
                           data=data,
                           path_params={'experimentId': experiment.internal_id},
                           query_params={})
     except HTTPError as e:
         if e.response.status_code == NOT_FOUND:
             # pylint: disable=protected-access
             raise ExperimentNotFound(
                 experiment_short_id=experiment.id, project_qualified_name=experiment._project.full_id)
         if e.response.status_code == UNPROCESSABLE_ENTITY and (
                 extract_response_field(e.response, 'type') == 'LIMIT_OF_STORAGE_IN_PROJECT_REACHED'):
             raise StorageLimitReached()
         raise
コード例 #5
0
 def extract_experiment_output(self, experiment, data):
     try:
         return self._upload_tar_data(
             experiment=experiment,
             api_method=self.backend_swagger_client.api.
             uploadExperimentOutputAsTarstream,
             data=data)
     except HTTPNotFound:
         # pylint: disable=protected-access
         raise ExperimentNotFound(
             experiment_short_id=experiment.id,
             project_qualified_name=experiment._project_full_id)
     except HTTPUnprocessableEntity as e:
         if extract_response_field(
                 e.response,
                 'type') == 'LIMIT_OF_STORAGE_IN_PROJECT_REACHED':
             raise StorageLimitReached()
         else:
             raise
コード例 #6
0
    def create_experiment(self,
                          project,
                          name,
                          description,
                          params,
                          properties,
                          tags,
                          abortable,
                          monitored,
                          git_info,
                          hostname,
                          entrypoint,
                          notebook_id,
                          checkpoint_id):
        if not isinstance(name, six.string_types):
            raise ValueError("Invalid name {}, should be a string.".format(name))
        if not isinstance(description, six.string_types):
            raise ValueError("Invalid description {}, should be a string.".format(description))
        if not isinstance(params, dict):
            raise ValueError("Invalid params {}, should be a dict.".format(params))
        if not isinstance(properties, dict):
            raise ValueError("Invalid properties {}, should be a dict.".format(properties))
        if not isinstance(hostname, six.string_types):
            raise ValueError("Invalid hostname {}, should be a string.".format(hostname))
        if entrypoint is not None and not isinstance(entrypoint, six.string_types):
            raise ValueError("Invalid entrypoint {}, should be a string.".format(entrypoint))

        ExperimentCreationParams = self.backend_swagger_client.get_model('ExperimentCreationParams')
        GitInfoDTO = self.backend_swagger_client.get_model('GitInfoDTO')
        GitCommitDTO = self.backend_swagger_client.get_model('GitCommitDTO')

        git_info_data = None
        if git_info is not None:
            git_info_data = GitInfoDTO(
                commit=GitCommitDTO(
                    commitId=git_info.commit_id,
                    message=git_info.message,
                    authorName=git_info.author_name,
                    authorEmail=git_info.author_email,
                    commitDate=git_info.commit_date
                ),
                remotes=git_info.remote_urls,
                currentBranch=git_info.active_branch,
                repositoryDirty=git_info.repository_dirty
            )

        try:
            params = ExperimentCreationParams(
                projectId=project.internal_id,
                name=name,
                description=description,
                parameters=self._convert_to_api_parameters(params),
                properties=self._convert_to_api_properties(properties),
                tags=tags,
                gitInfo=git_info_data,
                enqueueCommand="command",  # legacy (it's ignored but any non-empty string is required)
                entrypoint=entrypoint,
                execArgsTemplate="",  # legacy,
                abortable=abortable,
                monitored=monitored,
                hostname=hostname,
                notebookId=notebook_id,
                checkpointId=checkpoint_id
            )

            kwargs = {
                'experimentCreationParams': params,
                'X-Neptune-CliVersion': self.client_lib_version
            }
            api_experiment = self.backend_swagger_client.api.createExperiment(**kwargs).response().result

            return self._convert_to_experiment(api_experiment, project)
        except HTTPNotFound:
            raise ProjectNotFound(project_identifier=project.full_id)
        except HTTPBadRequest as e:
            error_type = extract_response_field(e.response, 'type')
            if error_type == 'DUPLICATE_PARAMETER':
                raise ExperimentValidationError('Parameter list contains duplicates.')
            elif error_type == 'INVALID_TAG':
                raise ExperimentValidationError(extract_response_field(e.response, 'message'))
            else:
                raise
        except HTTPUnprocessableEntity as e:
            if extract_response_field(e.response, 'type') == 'LIMIT_OF_EXPERIMENTS_IN_PROJECT_REACHED':
                raise ExperimentLimitReached()
            else:
                raise