def test_file_retrieval(self):
        file_manager = InMemoryFileManager({
            'path/to/a': b'a',
            'path/to/b': b'b' * 37
        })
        retrieval_service = artifact_service.ArtifactRetrievalService(
            file_manager.file_reader, chunk_size=10)
        dep_a = self.file_artifact('path/to/a')
        self.assertEqual(
            retrieval_service.ResolveArtifacts(
                beam_artifact_api_pb2.ResolveArtifactsRequest(
                    artifacts=[dep_a])),
            beam_artifact_api_pb2.ResolveArtifactsResponse(
                replacements=[dep_a]))

        self.assertEqual(
            list(
                retrieval_service.GetArtifact(
                    beam_artifact_api_pb2.GetArtifactRequest(artifact=dep_a))),
            [beam_artifact_api_pb2.GetArtifactResponse(data=b'a')])
        self.assertEqual(
            list(
                retrieval_service.GetArtifact(
                    beam_artifact_api_pb2.GetArtifactRequest(
                        artifact=self.file_artifact('path/to/b')))),
            [
                beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 10),
                beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 10),
                beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 10),
                beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 7)
            ])
Exemplo n.º 2
0
def offer_artifacts(
    artifact_staging_service, artifact_retrieval_service, staging_token):
  """Offers a set of artifacts to an artifact staging service, via the
  ReverseArtifactRetrievalService API.

  The given artifact_retrieval_service should be able to resolve/get all
  artifacts relevant to this job.
  """
  responses = _QueueIter()
  responses.put(
      beam_artifact_api_pb2.ArtifactResponseWrapper(
          staging_token=staging_token))
  requests = artifact_staging_service.ReverseArtifactRetrievalService(responses)
  try:
    for request in requests:
      if request.HasField('resolve_artifact'):
        responses.put(
            beam_artifact_api_pb2.ArtifactResponseWrapper(
                resolve_artifact_response=artifact_retrieval_service.
                ResolveArtifacts(request.resolve_artifact)))
      elif request.HasField('get_artifact'):
        for chunk in artifact_retrieval_service.GetArtifact(
            request.get_artifact):
          responses.put(
              beam_artifact_api_pb2.ArtifactResponseWrapper(
                  get_artifact_response=chunk))
        responses.put(
            beam_artifact_api_pb2.ArtifactResponseWrapper(
                get_artifact_response=beam_artifact_api_pb2.GetArtifactResponse(
                    data=b''),
                is_last=True))
    responses.done()
  except:  # pylint: disable=bare-except
    responses.abort()
    raise
Exemplo n.º 3
0
  def GetArtifact(self, request, context=None):
    if request.artifact.type_urn == common_urns.artifact_types.FILE.urn:
      payload = proto_utils.parse_Bytes(
          request.artifact.type_payload,
          beam_runner_api_pb2.ArtifactFilePayload)
      read_handle = self._file_reader(payload.path)
    elif request.artifact.type_urn == common_urns.artifact_types.URL.urn:
      payload = proto_utils.parse_Bytes(
          request.artifact.type_payload, beam_runner_api_pb2.ArtifactUrlPayload)
      # TODO(Py3): Remove the unneeded contextlib wrapper.
      read_handle = contextlib.closing(urlopen(payload.url))
    elif request.artifact.type_urn == common_urns.artifact_types.EMBEDDED.urn:
      payload = proto_utils.parse_Bytes(
          request.artifact.type_payload,
          beam_runner_api_pb2.EmbeddedFilePayload)
      read_handle = BytesIO(payload.data)
    else:
      raise NotImplementedError(request.artifact.type_urn)

    with read_handle as fin:
      while True:
        chunk = fin.read(self._chunk_size)
        if not chunk:
          break
        yield beam_artifact_api_pb2.GetArtifactResponse(data=chunk)
 def test_embedded_retrieval(self):
     retrieval_service = artifact_service.ArtifactRetrievalService(None)
     embedded_dep = self.embedded_artifact(b'some_data')
     self.assertEqual(
         list(
             retrieval_service.GetArtifact(
                 beam_artifact_api_pb2.GetArtifactRequest(
                     artifact=embedded_dep))),
         [beam_artifact_api_pb2.GetArtifactResponse(data=b'some_data')])
 def GetArtifact(self, request):
     if request.artifact.type_urn == common_urns.artifact_types.EMBEDDED.urn:
         content = proto_utils.parse_Bytes(
             request.artifact.type_payload,
             beam_runner_api_pb2.EmbeddedFilePayload).data
         for k in range(0, len(content), 13):
             yield beam_artifact_api_pb2.GetArtifactResponse(
                 data=content[k:k + 13])
     else:
         raise NotImplementedError