def test_file_retrieval(self): file_manager = InMemoryFileManager({ 'path/to/a': b'a', 'path/to/b': b'b' * 37 }) retrieval_service = artifact_service.ArtifactRetrievalService( file_manager.file_reader, chunk_size=10) dep_a = self.file_artifact('path/to/a') self.assertEqual( retrieval_service.ResolveArtifacts( beam_artifact_api_pb2.ResolveArtifactsRequest( artifacts=[dep_a])), beam_artifact_api_pb2.ResolveArtifactsResponse( replacements=[dep_a])) self.assertEqual( list( retrieval_service.GetArtifact( beam_artifact_api_pb2.GetArtifactRequest(artifact=dep_a))), [beam_artifact_api_pb2.GetArtifactResponse(data=b'a')]) self.assertEqual( list( retrieval_service.GetArtifact( beam_artifact_api_pb2.GetArtifactRequest( artifact=self.file_artifact('path/to/b')))), [ beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 10), beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 10), beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 10), beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 7) ])
def offer_artifacts( artifact_staging_service, artifact_retrieval_service, staging_token): """Offers a set of artifacts to an artifact staging service, via the ReverseArtifactRetrievalService API. The given artifact_retrieval_service should be able to resolve/get all artifacts relevant to this job. """ responses = _QueueIter() responses.put( beam_artifact_api_pb2.ArtifactResponseWrapper( staging_token=staging_token)) requests = artifact_staging_service.ReverseArtifactRetrievalService(responses) try: for request in requests: if request.HasField('resolve_artifact'): responses.put( beam_artifact_api_pb2.ArtifactResponseWrapper( resolve_artifact_response=artifact_retrieval_service. ResolveArtifacts(request.resolve_artifact))) elif request.HasField('get_artifact'): for chunk in artifact_retrieval_service.GetArtifact( request.get_artifact): responses.put( beam_artifact_api_pb2.ArtifactResponseWrapper( get_artifact_response=chunk)) responses.put( beam_artifact_api_pb2.ArtifactResponseWrapper( get_artifact_response=beam_artifact_api_pb2.GetArtifactResponse( data=b''), is_last=True)) responses.done() except: # pylint: disable=bare-except responses.abort() raise
def GetArtifact(self, request, context=None): if request.artifact.type_urn == common_urns.artifact_types.FILE.urn: payload = proto_utils.parse_Bytes( request.artifact.type_payload, beam_runner_api_pb2.ArtifactFilePayload) read_handle = self._file_reader(payload.path) elif request.artifact.type_urn == common_urns.artifact_types.URL.urn: payload = proto_utils.parse_Bytes( request.artifact.type_payload, beam_runner_api_pb2.ArtifactUrlPayload) # TODO(Py3): Remove the unneeded contextlib wrapper. read_handle = contextlib.closing(urlopen(payload.url)) elif request.artifact.type_urn == common_urns.artifact_types.EMBEDDED.urn: payload = proto_utils.parse_Bytes( request.artifact.type_payload, beam_runner_api_pb2.EmbeddedFilePayload) read_handle = BytesIO(payload.data) else: raise NotImplementedError(request.artifact.type_urn) with read_handle as fin: while True: chunk = fin.read(self._chunk_size) if not chunk: break yield beam_artifact_api_pb2.GetArtifactResponse(data=chunk)
def test_embedded_retrieval(self): retrieval_service = artifact_service.ArtifactRetrievalService(None) embedded_dep = self.embedded_artifact(b'some_data') self.assertEqual( list( retrieval_service.GetArtifact( beam_artifact_api_pb2.GetArtifactRequest( artifact=embedded_dep))), [beam_artifact_api_pb2.GetArtifactResponse(data=b'some_data')])
def GetArtifact(self, request): if request.artifact.type_urn == common_urns.artifact_types.EMBEDDED.urn: content = proto_utils.parse_Bytes( request.artifact.type_payload, beam_runner_api_pb2.EmbeddedFilePayload).data for k in range(0, len(content), 13): yield beam_artifact_api_pb2.GetArtifactResponse( data=content[k:k + 13]) else: raise NotImplementedError