def ping_experiment(self, experiment): try: self.backend_swagger_client.api.pingExperiment(experimentId=experiment.internal_id).response() except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project.full_id)
def send_channels_values(self, experiment, channels_with_values): InputChannelValues = self.backend_swagger_client.get_model( 'InputChannelValues') Point = self.backend_swagger_client.get_model('Point') Y = self.backend_swagger_client.get_model('Y') input_channels_values = [] for channel_with_values in channels_with_values: points = [ Point(timestampMillis=int(value.ts * 1000.0), x=value.x, y=Y(numericValue=value.y.get('numeric_value'), textValue=value.y.get('text_value'), inputImageValue=value.y.get('image_value'))) for value in channel_with_values.channel_values ] input_channels_values.append( InputChannelValues(channelId=channel_with_values.channel_id, values=points)) try: batch_errors = self.backend_swagger_client.api.postChannelValues( experimentId=experiment.internal_id, channelsValues=input_channels_values).response().result if batch_errors: raise ChannelsValuesSendBatchError(experiment.id, batch_errors) except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id)
def send_hardware_metric_reports(self, experiment, metrics, metric_reports): SystemMetricValues = self.backend_swagger_client.get_model('SystemMetricValues') SystemMetricPoint = self.backend_swagger_client.get_model('SystemMetricPoint') try: metrics_by_name = {metric.name: metric for metric in metrics} system_metric_values = [ SystemMetricValues( metricId=metrics_by_name.get(report.metric.name).internal_id, seriesName=gauge_name, values=[ SystemMetricPoint( timestampMillis=int(metric_value.timestamp * 1000.0), x=int(metric_value.running_time * 1000.0), y=metric_value.value ) for metric_value in metric_values ] ) for report in metric_reports for gauge_name, metric_values in groupby(report.values, lambda value: value.gauge_name) ] response = self.backend_swagger_client.api.postSystemMetricValues( experimentId=experiment.internal_id, metricValues=system_metric_values).response() return response except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project.full_id)
def create_hardware_metric(self, experiment, metric): SystemMetricParams = self.backend_swagger_client.get_model( 'SystemMetricParams') try: series = [gauge.name() for gauge in metric.gauges] system_metric_params = SystemMetricParams( name=metric.name, description=metric.description, resourceType=metric.resource_type, unit=metric.unit, min=metric.min_value, max=metric.max_value, series=series) metric_dto = self.backend_swagger_client.api.createSystemMetric( experimentId=experiment.internal_id, metricToCreate=system_metric_params).response().result return metric_dto.id except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id)
def put_tensorflow_graph(self, experiment, graph_id, graph): TensorflowGraph = self.backend_swagger_client.get_model( 'TensorflowGraph') def gzip_compress(data): output_buffer = io.BytesIO() gzip_stream = gzip.GzipFile(fileobj=output_buffer, mode='w') gzip_stream.write(data) gzip_stream.close() return output_buffer.getvalue() bingraph = graph.encode('UTF-8') compressed_graph_data = base64.b64encode(gzip_compress(bingraph)) data = compressed_graph_data.decode('UTF-8') value = TensorflowGraph(id=graph_id, value=data) try: r = self.backend_swagger_client.api.putTensorflowGraph( experimentId=experiment.internal_id, tensorflowGraph=value).response() return r.result except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id)
def get_system_channels(self, experiment): try: channels = self.backend_swagger_client.api.getSystemChannels( experimentId=experiment.internal_id, ).response().result return channels except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project.full_id)
def _get_channels(self, experiment) -> List[AlphaChannelDTO]: try: return [ AlphaChannelDTO(attr) for attr in self._get_attributes(experiment.internal_id) if AlphaChannelDTO.is_valid_attribute(attr) ] except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project.full_id, )
def get_metrics_csv(self, experiment): try: csv = StringIO() csv.write( self.backend_swagger_client.api.getSystemMetricsCSV( experimentId=experiment.internal_id).response(). incoming_response.text) csv.seek(0) return csv except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id)
def _get_channel_tuples_from_csv(self, experiment, channel_attribute_path): try: csv = (self.leaderboard_swagger_client.api.getFloatSeriesValuesCSV( experimentId=experiment.internal_id, attribute=channel_attribute_path, ).response().incoming_response.text) lines = csv.split("\n")[:-1] return [line.split(",") for line in lines] except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project.full_id, )
def extract_experiment_output(self, experiment, data): try: return self._upload_tar_data( experiment=experiment, api_method=self.backend_swagger_client.api.uploadExperimentOutputAsTarstream, data=data ) except HTTPError as e: if e.response.status_code == NOT_FOUND: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project.full_id) if e.response.status_code == UNPROCESSABLE_ENTITY and ( extract_response_field(e.response, 'type') == 'LIMIT_OF_STORAGE_IN_PROJECT_REACHED'): raise StorageLimitReached() raise
def upload_experiment_output(self, experiment, data): try: # Api exception handling is done in _upload_loop self._upload_loop(partial(self._upload_raw_data, api_method=self.backend_swagger_client.api.uploadExperimentOutput), data=data, path_params={'experimentId': experiment.internal_id}, query_params={}) except HTTPError as e: if e.response.status_code == NOT_FOUND: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project.full_id) if e.response.status_code == UNPROCESSABLE_ENTITY and ( extract_response_field(e.response, 'type') == 'LIMIT_OF_STORAGE_IN_PROJECT_REACHED'): raise StorageLimitReached() raise
def mark_failed(self, experiment, traceback): CompletedExperimentParams = self.backend_swagger_client.get_model( 'CompletedExperimentParams') try: self.backend_swagger_client.api.markExperimentCompleted( experimentId=experiment.internal_id, completedExperimentParams=CompletedExperimentParams( state='failed', traceback=traceback)).response() return experiment except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) except HTTPUnprocessableEntity: raise ExperimentAlreadyFinished(experiment.id)
def update_experiment(self, experiment, properties): EditExperimentParams = self.backend_swagger_client.get_model( 'EditExperimentParams') KeyValueProperty = self.backend_swagger_client.get_model( 'KeyValueProperty') try: self.backend_swagger_client.api.updateExperiment( experimentId=experiment.internal_id, editExperimentParams=EditExperimentParams(properties=[ KeyValueProperty(key=key, value=properties[key]) for key in properties ])).response() return experiment except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id)
def create_channel(self, experiment, name, channel_type): ChannelParams = self.backend_swagger_client.get_model('ChannelParams') try: params = ChannelParams(name=name, channelType=channel_type) channel = self.backend_swagger_client.api.createChannel( experimentId=experiment.internal_id, channelToCreate=params).response().result return self._convert_channel_to_channel_with_last_value(channel) except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) except HTTPConflict: raise ChannelAlreadyExists(channel_name=name, experiment_short_id=experiment.id)
def get_channel_points_csv(self, experiment, channel_internal_id, channel_name): try: channel_attr_path = self._get_channel_attribute_path( channel_name, ChannelNamespace.USER) values = self._get_channel_tuples_from_csv(experiment, channel_attr_path) step_and_value = [val[0] + "," + val[2] for val in values] csv = StringIO() for line in step_and_value: csv.write(line + "\n") csv.seek(0) return csv except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project.full_id, )
def extract_experiment_output(self, experiment, data): try: return self._upload_tar_data( experiment=experiment, api_method=self.backend_swagger_client.api. uploadExperimentOutputAsTarstream, data=data) except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) except HTTPUnprocessableEntity as e: if extract_response_field( e.response, 'type') == 'LIMIT_OF_STORAGE_IN_PROJECT_REACHED': raise StorageLimitReached() else: raise
def _execute_operations(self, experiment: Experiment, operations: List[alpha_operation.Operation]): experiment_id = experiment.internal_id file_operations = ( alpha_operation.UploadFile, alpha_operation.UploadFileContent, alpha_operation.UploadFileSet, ) if any(isinstance(op, file_operations) for op in operations): raise NeptuneException( "File operations must be handled directly by `_execute_upload_operation`," " not by `_execute_operations` function call.") kwargs = { "experimentId": experiment_id, "operations": [{ "path": alpha_path_utils.path_to_str(op.path), AlphaOperationApiNameVisitor().visit(op): AlphaOperationApiObjectConverter().convert(op), } for op in operations], } try: result = (self.leaderboard_swagger_client.api.executeOperations( **kwargs).response().result) errors = [ alpha_exceptions.MetadataInconsistency(err.errorDescription) for err in result ] if errors: raise ExperimentOperationErrors(errors=errors) return None except HTTPNotFound as e: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project.full_id, ) from e
def update_tags(self, experiment, tags_to_add, tags_to_delete): UpdateTagsParams = self.backend_swagger_client.get_model( 'UpdateTagsParams') try: self.backend_swagger_client.api.updateTags( updateTagsParams=UpdateTagsParams( experimentIds=[experiment.internal_id], groupsIds=[], tagsToAdd=tags_to_add, tagsToDelete=tags_to_delete)).response().result except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) except HTTPBadRequest as e: error_type = extract_response_field(e.response, 'type') if error_type == 'INVALID_TAG': raise ExperimentValidationError( extract_response_field(e.response, 'message')) else: raise
def handler(*args, **kwargs): experiment = kwargs.get("experiment") if experiment is None: raise NeptuneException( "This function must be called with experiment passed by name," " like this fun(..., experiment=<experiment>, ...)" ) try: return f(*args, **kwargs) except HTTPError as e: if e.response.status_code == NOT_FOUND: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project.full_id, ) if ( e.response.status_code == UNPROCESSABLE_ENTITY and extract_response_field(e.response, "title").startswith( "Storage limit reached in organization: " ) ): raise StorageLimitReached() raise