def __init__(self, api_address, api_token): self.api_address = api_address self.api_token = api_token self._http_client = RequestsClient() self.backend_swagger_client = SwaggerClient.from_url( '{}/api/backend/swagger.json'.format(self.api_address), config=dict(validate_swagger_spec=False, validate_requests=False, validate_responses=False, formats=[uuid_format]), http_client=self._http_client) self.leaderboard_swagger_client = SwaggerClient.from_url( '{}/api/leaderboard/swagger.json'.format(self.api_address), config=dict(validate_swagger_spec=False, validate_requests=False, validate_responses=False, formats=[uuid_format]), http_client=self._http_client) self.authenticator = NeptuneAuthenticator( self.backend_swagger_client.api.exchangeApiToken( X_Neptune_Api_Token=api_token).response().result) self._http_client.authenticator = self.authenticator
def test_apply_oauth2_session_to_request(self, time_mock, session_mock): # given api_token = MagicMock() backend_client = MagicMock() auth_tokens = MagicMock() auth_tokens.accessToken = an_access_token() auth_tokens.refreshToken = a_refresh_token() decoded_access_token = jwt.decode(auth_tokens.accessToken, SECRET, options=_decoding_options) backend_client.api.exchangeApiToken( X_Neptune_Api_Token=api_token).response().result = auth_tokens # and now = time.time() time_mock.time.return_value = now # and session = MagicMock() session_mock.return_value = session session.token = dict() # and neptune_authenticator = NeptuneAuthenticator(api_token, backend_client, False, None) request = a_request() # when updated_request = neptune_authenticator.apply(request) # then expected_token = { "access_token": auth_tokens.accessToken, "refresh_token": auth_tokens.refreshToken, "expires_in": decoded_access_token["exp"] - now, } expected_auto_refresh_url = "{realm_url}/protocol/openid-connect/token".format( realm_url=decoded_access_token["iss"]) session_mock.assert_called_once_with( client_id=decoded_access_token["azp"], token=expected_token, auto_refresh_url=expected_auto_refresh_url, auto_refresh_kwargs={"client_id": decoded_access_token["azp"]}, token_updater=_no_token_updater, ) # and self.assertEqual(session, updated_request.auth.session)
def create_http_client_with_auth( credentials: Credentials, ssl_verify: bool, proxies: Dict[str, str]) -> Tuple[RequestsClient, ClientConfig]: client_config = get_client_config(credentials=credentials, ssl_verify=ssl_verify, proxies=proxies) config_api_url = credentials.api_url_opt or credentials.token_origin_address verify_client_version(client_config, neptune_client_version) endpoint_url = None if config_api_url != client_config.api_url: endpoint_url = build_operation_url(client_config.api_url, BACKEND_SWAGGER_PATH) http_client = create_http_client(ssl_verify=ssl_verify, proxies=proxies) http_client.authenticator = NeptuneAuthenticator( credentials.api_token, _get_token_client( credentials=credentials, ssl_verify=ssl_verify, proxies=proxies, endpoint_url=endpoint_url, ), ssl_verify, proxies, ) return http_client, client_config
def test_apply_oauth2_session_to_request(self, time_mock, session_mock): # given auth_tokens = MagicMock() auth_tokens.accessToken = an_access_token() auth_tokens.refreshToken = a_refresh_token() decoded_access_token = jwt.decode(auth_tokens.accessToken, SECRET) # and now = time.time() time_mock.time.return_value = now # and session = MagicMock() session_mock.return_value = session session.token = dict() # and neptune_authenticator = NeptuneAuthenticator(auth_tokens, False, None) request = a_request() # when updated_request = neptune_authenticator.apply(request) # then expected_token = { 'access_token': auth_tokens.accessToken, 'refresh_token': auth_tokens.refreshToken, 'expires_in': decoded_access_token['exp'] - now } expected_auto_refresh_url = '{realm_url}/protocol/openid-connect/token'.format( realm_url=decoded_access_token['iss']) session_mock.assert_called_once_with( client_id=decoded_access_token['azp'], token=expected_token, auto_refresh_url=expected_auto_refresh_url, auto_refresh_kwargs={'client_id': decoded_access_token['azp']}, token_updater=_no_token_updater) # and self.assertEqual(session, updated_request.auth.session)
class Client(object): @with_api_exceptions_handler def __init__(self, api_address, api_token): self.api_address = api_address self.api_token = api_token self._http_client = RequestsClient() self.backend_swagger_client = SwaggerClient.from_url( '{}/api/backend/swagger.json'.format(self.api_address), config=dict(validate_swagger_spec=False, validate_requests=False, validate_responses=False, formats=[uuid_format]), http_client=self._http_client) self.leaderboard_swagger_client = SwaggerClient.from_url( '{}/api/leaderboard/swagger.json'.format(self.api_address), config=dict(validate_swagger_spec=False, validate_requests=False, validate_responses=False, formats=[uuid_format]), http_client=self._http_client) self.authenticator = NeptuneAuthenticator( self.backend_swagger_client.api.exchangeApiToken( X_Neptune_Api_Token=api_token).response().result) self._http_client.authenticator = self.authenticator @with_api_exceptions_handler def get_project(self, project_qualified_name): try: return self.backend_swagger_client.api.getProject( projectIdentifier=project_qualified_name).response().result except HTTPNotFound: raise ProjectNotFound(project_qualified_name) @with_api_exceptions_handler def get_projects(self, namespace): try: r = self.backend_swagger_client.api.listProjectsInOrganization( organizationName=namespace).response() return r.result.entries except HTTPNotFound: raise NamespaceNotFound(namespace_name=namespace) @with_api_exceptions_handler def get_project_members(self, project_identifier): try: r = self.backend_swagger_client.api.listProjectMembers( projectIdentifier=project_identifier).response() return r.result except HTTPNotFound: raise ProjectNotFound(project_identifier) @with_api_exceptions_handler def get_leaderboard_entries(self, project, entry_types=None, ids=None, states=None, owners=None, tags=None, min_running_time=None): try: if entry_types is None: entry_types = ['experiment', 'notebook'] def get_portion(limit, offset): return self.leaderboard_swagger_client.api.getLeaderboard( projectIdentifier=project.full_id, entryType=entry_types, shortId=ids, groupShortId=None, state=states, owner=owners, tags=tags, minRunningTimeSeconds=min_running_time, sortBy=['shortId'], sortFieldType=['native'], sortDirection=['ascending'], limit=limit, offset=offset).response().result.entries return [ LeaderboardEntry(e) for e in self._get_all_items(get_portion, step=100) ] except HTTPNotFound: raise ProjectNotFound(project_identifier=project.full_id) @with_api_exceptions_handler def get_channel_points_csv(self, experiment, channel_internal_id): try: csv = StringIO() csv.write( self.backend_swagger_client.api.getChannelValuesCSV( experimentId=experiment.internal_id, channelId=channel_internal_id).response(). incoming_response.text) csv.seek(0) return csv except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) @with_api_exceptions_handler def get_metrics_csv(self, experiment): try: csv = StringIO() csv.write( self.backend_swagger_client.api.getSystemMetricsCSV( experimentId=experiment.internal_id).response(). incoming_response.text) csv.seek(0) return csv except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) @with_api_exceptions_handler def create_experiment(self, project, name, description, params, properties, tags, abortable, monitored): ExperimentCreationParams = self.backend_swagger_client.get_model( 'ExperimentCreationParams') try: params = ExperimentCreationParams( projectId=project.internal_id, name=name, description=description, parameters=self._convert_to_api_parameters(params), properties=self._convert_to_api_properties(properties), tags=tags, enqueueCommand="command", # FIXME entrypoint="", # FIXME execArgsTemplate="", # FIXME, abortable=abortable, monitored=monitored) api_experiment = self.backend_swagger_client.api.createExperiment( experimentCreationParams=params).response().result return self._convert_to_experiment(api_experiment) except HTTPNotFound: raise ProjectNotFound(project_identifier=project.full_id) except HTTPBadRequest as e: error_type = extract_response_field(e.response, 'type') if error_type == 'DUPLICATE_PARAMETER': raise ExperimentValidationError( 'Parameter list contains duplicates.') elif error_type == 'INVALID_TAG': raise ExperimentValidationError( extract_response_field(e.response, 'message')) else: raise except HTTPUnprocessableEntity as e: if extract_response_field( e.response, 'type') == 'LIMIT_OF_EXPERIMENTS_IN_PROJECT_REACHED': raise ExperimentLimitReached() else: raise @with_api_exceptions_handler def get_experiment(self, experiment_id): return self.backend_swagger_client.api.getExperiment( experimentId=experiment_id).response().result @with_api_exceptions_handler def update_experiment(self, experiment, properties): EditExperimentParams = self.backend_swagger_client.get_model( 'EditExperimentParams') KeyValueProperty = self.backend_swagger_client.get_model( 'KeyValueProperty') try: self.backend_swagger_client.api.updateExperiment( experimentId=experiment.internal_id, editExperimentParams=EditExperimentParams(properties=[ KeyValueProperty(key=key, value=properties[key]) for key in properties ])).response() return experiment except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) @with_api_exceptions_handler def update_tags(self, experiment, tags_to_add, tags_to_delete): UpdateTagsParams = self.backend_swagger_client.get_model( 'UpdateTagsParams') try: self.backend_swagger_client.api.updateTags( updateTagsParams=UpdateTagsParams( experimentIds=[experiment.internal_id], groupsIds=[], tagsToAdd=tags_to_add, tagsToDelete=tags_to_delete)).response().result except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) except HTTPBadRequest as e: error_type = extract_response_field(e.response, 'type') if error_type == 'INVALID_TAG': raise ExperimentValidationError( extract_response_field(e.response, 'message')) else: raise @with_api_exceptions_handler def upload_experiment_source(self, experiment, data): try: return self._upload_loop( partial(self._upload_raw_data, api_method=self.backend_swagger_client.api. uploadExperimentSource), data=data, experiment=experiment) except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) except HTTPUnprocessableEntity as e: if extract_response_field( e.response, 'type') == 'LIMIT_OF_STORAGE_IN_PROJECT_REACHED': raise StorageLimitReached() else: raise @with_api_exceptions_handler def extract_experiment_source(self, experiment, data): try: return self._upload_tar_data( experiment=experiment, api_method=self.backend_swagger_client.api. uploadExperimentSourceAsTarstream, data=data) except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) except HTTPUnprocessableEntity as e: if extract_response_field( e.response, 'type') == 'LIMIT_OF_STORAGE_IN_PROJECT_REACHED': raise StorageLimitReached() else: raise @with_api_exceptions_handler def create_channel(self, experiment, name, channel_type): ChannelParams = self.backend_swagger_client.get_model('ChannelParams') try: params = ChannelParams(name=name, channelType=channel_type) channel = self.backend_swagger_client.api.createChannel( experimentId=experiment.internal_id, channelToCreate=params).response().result return self._convert_channel_to_channel_with_last_value(channel) except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) except HTTPConflict: raise ChannelAlreadyExists(channel_name=name, experiment_short_id=experiment.id) @with_api_exceptions_handler def send_channels_values(self, experiment, channels_with_values): InputChannelValues = self.backend_swagger_client.get_model( 'InputChannelValues') Point = self.backend_swagger_client.get_model('Point') Y = self.backend_swagger_client.get_model('Y') input_channels_values = [] for channel_with_values in channels_with_values: points = [ Point(timestampMillis=int(value.ts * 1000.0), x=value.x, y=Y(numericValue=value.y.get('numeric_value'), textValue=value.y.get('text_value'), inputImageValue=value.y.get('image_value'))) for value in channel_with_values.channel_values ] input_channels_values.append( InputChannelValues(channelId=channel_with_values.channel_id, values=points)) try: batch_errors = self.backend_swagger_client.api.postChannelValues( experimentId=experiment.internal_id, channelsValues=input_channels_values).response().result if batch_errors: raise ChannelsValuesSendBatchError(experiment.id, batch_errors) except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) @with_api_exceptions_handler def put_tensorflow_graph(self, experiment, graph_id, graph): TensorflowGraph = self.backend_swagger_client.get_model( 'TensorflowGraph') def gzip_compress(data): output_buffer = io.BytesIO() gzip_stream = gzip.GzipFile(fileobj=output_buffer, mode='w') gzip_stream.write(data) gzip_stream.close() return output_buffer.getvalue() bingraph = graph.encode('UTF-8') compressed_graph_data = base64.b64encode(gzip_compress(bingraph)) data = compressed_graph_data.decode('UTF-8') value = TensorflowGraph(id=graph_id, value=data) try: r = self.backend_swagger_client.api.putTensorflowGraph( experimentId=experiment.internal_id, tensorflowGraph=value).response() return r.result except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) @with_api_exceptions_handler def mark_succeeded(self, experiment): CompletedExperimentParams = self.backend_swagger_client.get_model( 'CompletedExperimentParams') try: self.backend_swagger_client.api.markExperimentCompleted( experimentId=experiment.internal_id, completedExperimentParams=CompletedExperimentParams( state='succeeded', traceback='' # FIXME )).response() return experiment except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) except HTTPUnprocessableEntity: raise ExperimentAlreadyFinished(experiment.id) @with_api_exceptions_handler def mark_failed(self, experiment, traceback): CompletedExperimentParams = self.backend_swagger_client.get_model( 'CompletedExperimentParams') try: self.backend_swagger_client.api.markExperimentCompleted( experimentId=experiment.internal_id, completedExperimentParams=CompletedExperimentParams( state='failed', traceback=traceback)).response() return experiment except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) except HTTPUnprocessableEntity: raise ExperimentAlreadyFinished(experiment.id) @with_api_exceptions_handler def ping_experiment(self, experiment): try: self.backend_swagger_client.api.pingExperiment( experimentId=experiment.internal_id).response() except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) @with_api_exceptions_handler def create_hardware_metric(self, experiment, metric): SystemMetricParams = self.backend_swagger_client.get_model( 'SystemMetricParams') try: series = [gauge.name() for gauge in metric.gauges] system_metric_params = SystemMetricParams( name=metric.name, description=metric.description, resourceType=metric.resource_type, unit=metric.unit, min=metric.min_value, max=metric.max_value, series=series) metric_dto = self.backend_swagger_client.api.createSystemMetric( experimentId=experiment.internal_id, metricToCreate=system_metric_params).response().result return metric_dto.id except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) @with_api_exceptions_handler def send_hardware_metric_reports(self, experiment, metrics, metric_reports): SystemMetricValues = self.backend_swagger_client.get_model( 'SystemMetricValues') SystemMetricPoint = self.backend_swagger_client.get_model( 'SystemMetricPoint') try: metrics_by_name = {metric.name: metric for metric in metrics} system_metric_values = [ SystemMetricValues( metricId=metrics_by_name.get( report.metric.name).internal_id, seriesName=gauge_name, values=[ SystemMetricPoint(timestampMillis=int( metric_value.timestamp * 1000.0), x=int(metric_value.running_time * 1000.0), y=metric_value.value) for metric_value in metric_values ]) for report in metric_reports for gauge_name, metric_values in groupby( report.values, lambda value: value.gauge_name) ] response = self.backend_swagger_client.api.postSystemMetricValues( experimentId=experiment.internal_id, metricValues=system_metric_values).response() return response except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) @with_api_exceptions_handler def upload_experiment_output(self, experiment, data): try: return self._upload_loop( partial(self._upload_raw_data, api_method=self.backend_swagger_client.api. uploadExperimentOutput), data=data, experiment=experiment) except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) except HTTPUnprocessableEntity as e: if extract_response_field( e.response, 'type') == 'LIMIT_OF_STORAGE_IN_PROJECT_REACHED': raise StorageLimitReached() else: raise @with_api_exceptions_handler def extract_experiment_output(self, experiment, data): try: return self._upload_tar_data( experiment=experiment, api_method=self.backend_swagger_client.api. uploadExperimentOutputAsTarstream, data=data) except HTTPNotFound: # pylint: disable=protected-access raise ExperimentNotFound( experiment_short_id=experiment.id, project_qualified_name=experiment._project_full_id) except HTTPUnprocessableEntity as e: if extract_response_field( e.response, 'type') == 'LIMIT_OF_STORAGE_IN_PROJECT_REACHED': raise StorageLimitReached() else: raise @staticmethod def _get_all_items(get_portion, step): items = [] previous_items = None while previous_items is None or len(previous_items) >= step: previous_items = get_portion(limit=step, offset=len(items)) items += previous_items return items def _convert_to_api_parameters(self, raw_params): Parameter = self.backend_swagger_client.get_model('Parameter') params = [] for name, value in raw_params.items(): parameter_type = 'double' if is_float(value) else 'string' params.append( Parameter(id=str(uuid.uuid4()), name=name, parameterType=parameter_type, value=str(value))) return params def _convert_to_api_properties(self, raw_properties): KeyValueProperty = self.backend_swagger_client.get_model( 'KeyValueProperty') return [ KeyValueProperty(key=key, value=value) for key, value in raw_properties.items() ] def _convert_to_experiment(self, api_experiment): return Experiment(client=self, _id=api_experiment.shortId, internal_id=api_experiment.id, project_full_id='{}/{}'.format( api_experiment.organizationName, api_experiment.projectName)) def _convert_channel_to_channel_with_last_value(self, channel): ChannelWithValueDTO = self.leaderboard_swagger_client.get_model( 'ChannelWithValueDTO') return ChannelWithLastValue( ChannelWithValueDTO(channelId=channel.id, channelName=channel.name, channelType=channel.channelType, x=None, y=None)) def _upload_loop(self, fun, data, checksums=None, **kwargs): ret = None for part in data.generate(): skip = False if checksums and part.start in checksums: skip = checksums[part.start].checksum == part.md5() if not skip: ret = self._upload_loop_chunk(fun, part, data, **kwargs) else: part.skip() data.close() return ret def _upload_loop_chunk(self, fun, part, data, **kwargs): part_to_send = part.get_data() if part.end: binary_range = "bytes=%d-%d/%d" % (part.start, part.end - 1, data.length) else: binary_range = "bytes=%d-/%d" % (part.start, data.length) return fun(data=part_to_send, headers={ "Content-Type": "application/octet-stream", "Content-Filename": data.filename, "Range": binary_range, "X-File-Permissions": data.permissions }, **kwargs) def _upload_raw_data(self, experiment, api_method, data, headers): url = self.api_address + api_method.operation.path_name url = url.replace("{experimentId}", experiment.internal_id) session = self._http_client.session request = self.authenticator.apply( requests.Request(method='POST', url=url, data=data, headers=headers)) return session.send(session.prepare_request(request)) def _upload_tar_data(self, experiment, api_method, data): url = self.api_address + api_method.operation.path_name url = url.replace("{experimentId}", experiment.internal_id) session = self._http_client.session request = self.authenticator.apply( requests.Request( method='POST', url=url, data=io.BytesIO(data), headers={"Content-Type": "application/octet-stream"})) return session.send(session.prepare_request(request))
def _create_authenticator(self, api_token, ssl_verify, proxies, backend_client): return NeptuneAuthenticator(api_token, backend_client, ssl_verify, proxies)
def _create_authenticator(self, api_token, ssl_verify, proxies): return NeptuneAuthenticator( self.backend_swagger_client.api.exchangeApiToken(X_Neptune_Api_Token=api_token).response().result, ssl_verify, proxies )