예제 #1
0
 def test_rename_error(self):
   path1 = os.path.join(self.tmpdir, 'f1')
   path2 = os.path.join(self.tmpdir, 'f2')
   with self.assertRaisesRegexp(BeamIOError,
                                r'^Rename operation failed') as error:
     FileSystems.rename([path1], [path2])
   self.assertEqual(error.exception.exception_details.keys(), [(path1, path2)])
예제 #2
0
 def test_delete_error(self):
   path1 = os.path.join(self.tmpdir, 'f1')
   with self.assertRaises(BeamIOError) as error:
     FileSystems.delete([path1])
   self.assertTrue(
       error.exception.message.startswith('Delete operation failed'))
   self.assertEqual(error.exception.exception_details.keys(), [path1])
예제 #3
0
 def _rename_batch(batch):
   """_rename_batch executes batch rename operations."""
   source_files, destination_files = batch
   exceptions = []
   try:
     FileSystems.rename(source_files, destination_files)
     return exceptions
   except BeamIOError as exp:
     if exp.exception_details is None:
       raise
     for (src, dest), exception in exp.exception_details.iteritems():
       if exception:
         logging.warning('Rename not successful: %s -> %s, %s', src, dest,
                         exception)
         should_report = True
         if isinstance(exception, IOError):
           # May have already been copied.
           try:
             if FileSystems.exists(dest):
               should_report = False
           except Exception as exists_e:  # pylint: disable=broad-except
             logging.warning('Exception when checking if file %s exists: '
                             '%s', dest, exists_e)
         if should_report:
           logging.warning(('Exception in _rename_batch. src: %s, '
                            'dest: %s, err: %s'), src, dest, exception)
           exceptions.append(exception)
       else:
         logging.debug('Rename successful: %s -> %s', src, dest)
     return exceptions
예제 #4
0
 def test_match_file_exception(self):
   # Match files with None so that it throws an exception
   with self.assertRaises(BeamIOError) as error:
     FileSystems.match([None])
   self.assertTrue(
       error.exception.message.startswith('Unable to get the Filesystem'))
   self.assertEqual(error.exception.exception_details.keys(), [None])
 def test_exists(self):
   path1 = os.path.join(self.tmpdir, 'f1')
   path2 = os.path.join(self.tmpdir, 'f2')
   with open(path1, 'a') as f:
     f.write('Hello')
   self.assertTrue(FileSystems.exists(path1))
   self.assertFalse(FileSystems.exists(path2))
예제 #6
0
  def stage_file(self, gcs_or_local_path, file_name, stream,
                 mime_type='application/octet-stream'):
    """Stages a file at a GCS or local path with stream-supplied contents."""
    if not gcs_or_local_path.startswith('gs://'):
      local_path = FileSystems.join(gcs_or_local_path, file_name)
      logging.info('Staging file locally to %s', local_path)
      with open(local_path, 'wb') as f:
        f.write(stream.read())
      return
    gcs_location = FileSystems.join(gcs_or_local_path, file_name)
    bucket, name = gcs_location[5:].split('/', 1)

    request = storage.StorageObjectsInsertRequest(
        bucket=bucket, name=name)
    logging.info('Starting GCS upload to %s...', gcs_location)
    upload = storage.Upload(stream, mime_type)
    try:
      response = self._storage_client.objects.Insert(request, upload=upload)
    except exceptions.HttpError as e:
      reportable_errors = {
          403: 'access denied',
          404: 'bucket not found',
      }
      if e.status_code in reportable_errors:
        raise IOError(('Could not upload to GCS path %s: %s. Please verify '
                       'that credentials are valid and that you have write '
                       'access to the specified path.') %
                      (gcs_or_local_path, reportable_errors[e.status_code]))
      raise
    logging.info('Completed GCS upload to %s', gcs_location)
    return response
예제 #7
0
 def _verify_copy(self, src, dst, dst_kms_key_name=None):
   self.assertTrue(FileSystems.exists(src), 'src does not exist: %s' % src)
   self.assertTrue(FileSystems.exists(dst), 'dst does not exist: %s' % dst)
   src_checksum = self.gcsio.checksum(src)
   dst_checksum = self.gcsio.checksum(dst)
   self.assertEqual(src_checksum, dst_checksum)
   self.assertEqual(self.gcsio.kms_key(dst), dst_kms_key_name)
예제 #8
0
 def initialize_write(self):
   file_path_prefix = self.file_path_prefix.get()
   file_name_suffix = self.file_name_suffix.get()
   tmp_dir = file_path_prefix + file_name_suffix + time.strftime(
       '-temp-%Y-%m-%d_%H-%M-%S')
   FileSystems.mkdirs(tmp_dir)
   return tmp_dir
 def test_get_filesystem(self):
   self.assertTrue(isinstance(FileSystems.get_filesystem('/tmp'),
                              localfilesystem.LocalFileSystem))
   self.assertTrue(isinstance(FileSystems.get_filesystem('c:\\abc\def'),  # pylint: disable=anomalous-backslash-in-string
                              localfilesystem.LocalFileSystem))
   with self.assertRaises(ValueError):
     FileSystems.get_filesystem('error://abc/def')
  def test_copy(self):
    path1 = os.path.join(self.tmpdir, 'f1')
    path2 = os.path.join(self.tmpdir, 'f2')
    with open(path1, 'a') as f:
      f.write('Hello')

    FileSystems.copy([path1], [path2])
    self.assertTrue(filecmp.cmp(path1, path2))
예제 #11
0
 def test_rename_error(self):
   path1 = os.path.join(self.tmpdir, 'f1')
   path2 = os.path.join(self.tmpdir, 'f2')
   with self.assertRaises(BeamIOError) as error:
     FileSystems.rename([path1], [path2])
   self.assertTrue(
       error.exception.message.startswith('Rename operation failed'))
   self.assertEqual(error.exception.exception_details.keys(), [(path1, path2)])
 def test_windows_path_join(self, *unused_mocks):
   # Test joining of Windows paths.
   localfilesystem.os.path.join.side_effect = _gen_fake_join('\\')
   self.assertEqual(r'C:\tmp\path\to\file',
                    FileSystems.join(r'C:\tmp\path', 'to', 'file'))
   self.assertEqual(r'C:\tmp\path\to\file',
                    FileSystems.join(r'C:\tmp\path', r'to\file'))
   self.assertEqual(r'C:\tmp\path\to\file',
                    FileSystems.join(r'C:\tmp\path\\', 'to', 'file'))
  def test_delete(self):
    path1 = os.path.join(self.tmpdir, 'f1')

    with open(path1, 'a') as f:
      f.write('Hello')

    self.assertTrue(FileSystems.exists(path1))
    FileSystems.delete([path1])
    self.assertFalse(FileSystems.exists(path1))
예제 #14
0
  def test_delete_files_succeeds(self):
    path = os.path.join(self.tmpdir, 'f1')

    with open(path, 'a') as f:
      f.write('test')

    assert FileSystems.exists(path)
    utils.delete_files([path])
    assert not FileSystems.exists(path)
예제 #15
0
파일: stager.py 프로젝트: onderson/beam
  def _stage_beam_sdk(self, sdk_remote_location, staging_location, temp_dir):
    """Stages a Beam SDK file with the appropriate version.

      Args:
        sdk_remote_location: A URL from which thefile can be downloaded or a
          remote file location. The SDK file can be a tarball or a wheel. Set
          to 'pypi' to download and stage a wheel and source SDK from PyPi.
        staging_location: Location where the SDK file should be copied.
        temp_dir: path to temporary location where the file should be
          downloaded.

      Returns:
        A list of SDK files that were staged to the staging location.

      Raises:
        RuntimeError: if staging was not successful.
      """
    if sdk_remote_location == 'pypi':
      sdk_local_file = Stager._download_pypi_sdk_package(temp_dir)
      sdk_sources_staged_name = Stager.\
          _desired_sdk_filename_in_staging_location(sdk_local_file)
      staged_path = FileSystems.join(staging_location, sdk_sources_staged_name)
      logging.info('Staging SDK sources from PyPI to %s', staged_path)
      self.stage_artifact(sdk_local_file, staged_path)
      staged_sdk_files = [sdk_sources_staged_name]
      try:
        # Stage binary distribution of the SDK, for now on a best-effort basis.
        sdk_local_file = Stager._download_pypi_sdk_package(
            temp_dir, fetch_binary=True)
        sdk_binary_staged_name = Stager.\
            _desired_sdk_filename_in_staging_location(sdk_local_file)
        staged_path = FileSystems.join(staging_location, sdk_binary_staged_name)
        logging.info('Staging binary distribution of the SDK from PyPI to %s',
                     staged_path)
        self.stage_artifact(sdk_local_file, staged_path)
        staged_sdk_files.append(sdk_binary_staged_name)
      except RuntimeError as e:
        logging.warn(
            'Failed to download requested binary distribution '
            'of the SDK: %s', repr(e))

      return staged_sdk_files
    elif Stager._is_remote_path(sdk_remote_location):
      local_download_file = os.path.join(temp_dir, 'beam-sdk.tar.gz')
      Stager._download_file(sdk_remote_location, local_download_file)
      staged_name = Stager._desired_sdk_filename_in_staging_location(
          sdk_remote_location)
      staged_path = FileSystems.join(staging_location, staged_name)
      logging.info('Staging Beam SDK from %s to %s', sdk_remote_location,
                   staged_path)
      self.stage_artifact(local_download_file, staged_path)
      return [staged_name]
    else:
      raise RuntimeError(
          'The --sdk_location option was used with an unsupported '
          'type of location: %s' % sdk_remote_location)
예제 #16
0
  def __init__(self, options, proto_pipeline):
    self.options = options
    self.proto_pipeline = proto_pipeline
    self.google_cloud_options = options.view_as(GoogleCloudOptions)
    if not self.google_cloud_options.job_name:
      self.google_cloud_options.job_name = self.default_job_name(
          self.google_cloud_options.job_name)

    required_google_cloud_options = ['project', 'job_name', 'temp_location']
    missing = [
        option for option in required_google_cloud_options
        if not getattr(self.google_cloud_options, option)]
    if missing:
      raise ValueError(
          'Missing required configuration parameters: %s' % missing)

    if not self.google_cloud_options.staging_location:
      logging.info('Defaulting to the temp_location as staging_location: %s',
                   self.google_cloud_options.temp_location)
      (self.google_cloud_options
       .staging_location) = self.google_cloud_options.temp_location

    # Make the staging and temp locations job name and time specific. This is
    # needed to avoid clashes between job submissions using the same staging
    # area or team members using same job names. This method is not entirely
    # foolproof since two job submissions with same name can happen at exactly
    # the same time. However the window is extremely small given that
    # time.time() has at least microseconds granularity. We add the suffix only
    # for GCS staging locations where the potential for such clashes is high.
    if self.google_cloud_options.staging_location.startswith('gs://'):
      path_suffix = '%s.%f' % (self.google_cloud_options.job_name, time.time())
      self.google_cloud_options.staging_location = FileSystems.join(
          self.google_cloud_options.staging_location, path_suffix)
      self.google_cloud_options.temp_location = FileSystems.join(
          self.google_cloud_options.temp_location, path_suffix)

    self.proto = dataflow.Job(name=self.google_cloud_options.job_name)
    if self.options.view_as(StandardOptions).streaming:
      self.proto.type = dataflow.Job.TypeValueValuesEnum.JOB_TYPE_STREAMING
    else:
      self.proto.type = dataflow.Job.TypeValueValuesEnum.JOB_TYPE_BATCH
    if self.google_cloud_options.update:
      self.proto.replaceJobId = self.job_id_for_name(self.proto.name)

    # Labels.
    if self.google_cloud_options.labels:
      self.proto.labels = dataflow.Job.LabelsValue()
      for label in self.google_cloud_options.labels:
        parts = label.split('=', 1)
        key = parts[0]
        value = parts[1] if len(parts) > 1 else ''
        self.proto.labels.additionalProperties.append(
            dataflow.Job.LabelsValue.AdditionalProperty(key=key, value=value))

    self.base64_str_re = re.compile(r'^[A-Za-z0-9+/]*=*$')
    self.coder_str_re = re.compile(r'^([A-Za-z]+\$)([A-Za-z0-9+/]*=*)$')
 def test_unix_path_join(self, *unused_mocks):
   # Test joining of Unix paths.
   localfilesystem.os.path.join.side_effect = _gen_fake_join('/')
   self.assertEqual('/tmp/path/to/file',
                    FileSystems.join('/tmp/path', 'to', 'file'))
   self.assertEqual('/tmp/path/to/file',
                    FileSystems.join('/tmp/path', 'to/file'))
   self.assertEqual('/tmp/path/to/file',
                    FileSystems.join('/', 'tmp/path', 'to/file'))
   self.assertEqual('/tmp/path/to/file',
                    FileSystems.join('/tmp/', 'path', 'to/file'))
예제 #18
0
  def pre_finalize(self, init_result, writer_results):
    num_shards = len(list(writer_results))
    dst_glob = self._get_final_name_glob(num_shards)
    dst_glob_files = [file_metadata.path
                      for mr in FileSystems.match([dst_glob])
                      for file_metadata in mr.metadata_list]

    if dst_glob_files:
      logging.warn('Deleting %d existing files in target path matching: %s',
                   len(dst_glob_files), self.shard_name_glob_format)
      FileSystems.delete(dst_glob_files)
예제 #19
0
 def _create_temp_dir(self, file_path_prefix):
   base_path, last_component = FileSystems.split(file_path_prefix)
   if not last_component:
     # Trying to re-split the base_path to check if it's a root.
     new_base_path, _ = FileSystems.split(base_path)
     if base_path == new_base_path:
       raise ValueError('Cannot create a temporary directory for root path '
                        'prefix %s. Please specify a file path prefix with '
                        'at least two components.' % file_path_prefix)
   path_components = [base_path,
                      'beam-temp-' + last_component + '-' + uuid.uuid1().hex]
   return FileSystems.join(*path_components)
예제 #20
0
def delete_files(file_paths):
  """A function to clean up files or directories using ``FileSystems``.

  Glob is supported in file path and directories will be deleted recursively.

  Args:
    file_paths: A list of strings contains file paths or directories.
  """
  if len(file_paths) == 0:
    raise RuntimeError('Clean up failed. Invalid file path: %s.' %
                       file_paths)
  FileSystems.delete(file_paths)
예제 #21
0
 def pre_finalize(self, init_result, writer_results):
   writer_results = sorted(writer_results)
   num_shards = len(writer_results)
   existing_files = []
   for shard_num in range(len(writer_results)):
     final_name = self._get_final_name(shard_num, num_shards)
     if FileSystems.exists(final_name):
       existing_files.append(final_name)
   if existing_files:
     logging.info('Deleting existing files in target path: %d',
                  len(existing_files))
     FileSystems.delete(existing_files)
예제 #22
0
 def _verify_copy(self, src, dst, dst_kms_key_name=None):
   self.assertTrue(FileSystems.exists(src), 'src does not exist: %s' % src)
   self.assertTrue(FileSystems.exists(dst), 'dst does not exist: %s' % dst)
   src_checksum = self.gcsio.checksum(src)
   dst_checksum = self.gcsio.checksum(dst)
   self.assertEqual(src_checksum, dst_checksum)
   actual_dst_kms_key = self.gcsio.kms_key(dst)
   if actual_dst_kms_key is None:
     self.assertEqual(actual_dst_kms_key, dst_kms_key_name)
   else:
     self.assertTrue(actual_dst_kms_key.startswith(dst_kms_key_name),
                     "got: %s, wanted startswith: %s" % (actual_dst_kms_key,
                                                         dst_kms_key_name))
예제 #23
0
  def _read_with_retry(self):
    """Read path with retry if I/O failed"""
    read_lines = []
    match_result = FileSystems.match([self.file_path])[0]
    matched_path = [f.path for f in match_result.metadata_list]
    if not matched_path:
      raise IOError('No such file or directory: %s' % self.file_path)

    logging.info('Find %d files in %s: \n%s',
                 len(matched_path), self.file_path, '\n'.join(matched_path))
    for path in matched_path:
      with FileSystems.open(path, 'r') as f:
        for line in f:
          read_lines.append(line)
    return read_lines
예제 #24
0
  def open(self, temp_path):
    """Opens ``temp_path``, returning an opaque file handle object.

    The returned file handle is passed to ``write_[encoded_]record`` and
    ``close``.
    """
    return FileSystems.create(temp_path, self.mime_type, self.compression_type)
  def test_match_file_empty(self):
    path = os.path.join(self.tmpdir, 'f2')  # Does not exist

    # Match files in the temp directory
    result = FileSystems.match([path])[0]
    files = [f.path for f in result.metadata_list]
    self.assertEqual(files, [])
예제 #26
0
  def _get_concat_source(self):
    if self._concat_source is None:
      pattern = self._pattern.get()

      single_file_sources = []
      match_result = FileSystems.match([pattern])[0]
      files_metadata = match_result.metadata_list

      # We create a reference for FileBasedSource that will be serialized along
      # with each _SingleFileSource. To prevent this FileBasedSource from having
      # a reference to ConcatSource (resulting in quadratic space complexity)
      # we clone it here.
      file_based_source_ref = pickler.loads(pickler.dumps(self))

      for file_metadata in files_metadata:
        file_name = file_metadata.path
        file_size = file_metadata.size_in_bytes
        if file_size == 0:
          continue  # Ignoring empty file.

        # We determine splittability of this specific file.
        splittable = (
            self.splittable and
            _determine_splittability_from_compression_type(
                file_name, self._compression_type))

        single_file_source = _SingleFileSource(
            file_based_source_ref, file_name,
            0,
            file_size,
            min_bundle_size=self._min_bundle_size,
            splittable=splittable)
        single_file_sources.append(single_file_source)
      self._concat_source = concat_source.ConcatSource(single_file_sources)
    return self._concat_source
  def test_match_file(self):
    path = os.path.join(self.tmpdir, 'f1')
    open(path, 'a').close()

    # Match files in the temp directory
    result = FileSystems.match([path])[0]
    files = [f.path for f in result.metadata_list]
    self.assertEqual(files, [path])
예제 #28
0
  def test_delete_files_fails_with_io_error(self, mocked_delete):
    f = tempfile.NamedTemporaryFile(dir=self.tmpdir, delete=False)
    assert FileSystems.exists(f.name)

    with self.assertRaises(BeamIOError):
      utils.delete_files([f.name])
    self.assertTrue(mocked_delete.called)
    self.assertEqual(mocked_delete.call_count, 4)
예제 #29
0
 def _rename_batch(batch):
   """_rename_batch executes batch rename operations."""
   source_files, destination_files = batch
   exceptions = []
   try:
     FileSystems.rename(source_files, destination_files)
     return exceptions
   except BeamIOError as exp:
     if exp.exception_details is None:
       raise
     for (src, dst), exception in exp.exception_details.iteritems():
       if exception:
         logging.error(('Exception in _rename_batch. src: %s, '
                        'dst: %s, err: %s'), src, dst, exception)
         exceptions.append(exception)
       else:
         logging.debug('Rename successful: %s -> %s', src, dst)
     return exceptions
예제 #30
0
 def file_copy(from_path, to_path):
   if not from_path.endswith(names.PICKLED_MAIN_SESSION_FILE):
     self.assertEqual(expected_from_path, from_path)
     self.assertEqual(FileSystems.join(expected_to_dir,
                                       names.DATAFLOW_SDK_TARBALL_FILE),
                      to_path)
   if from_path.startswith('gs://') or to_path.startswith('gs://'):
     logging.info('Faking file_copy(%s, %s)', from_path, to_path)
   else:
     shutil.copyfile(from_path, to_path)
예제 #31
0
    def test_match_all_two_directories(self):
        files = []
        directories = []

        for _ in range(2):
            # TODO: What about this having to append the ending slash?
            d = '%s%s' % (self._new_tempdir(), os.sep)
            directories.append(d)

            files.append(self._create_temp_file(dir=d))
            files.append(self._create_temp_file(dir=d))

        with TestPipeline() as p:
            files_pc = (p
                        | beam.Create(
                            [FileSystems.join(d, '*') for d in directories])
                        | fileio.MatchAll()
                        | beam.Map(lambda x: x.path))

            assert_that(files_pc, equal_to(files))
예제 #32
0
    def test_match_files_one_directory_failure(self):
        directories = [
            '%s%s' % (self._new_tempdir(), os.sep),
            '%s%s' % (self._new_tempdir(), os.sep)
        ]

        files = list()
        files.append(self._create_temp_file(dir=directories[0]))
        files.append(self._create_temp_file(dir=directories[0]))

        with self.assertRaises(beam.io.filesystem.BeamIOError):
            with TestPipeline() as p:
                files_pc = (
                    p
                    | beam.Create(
                        [FileSystems.join(d, '*') for d in directories])
                    | fileio.MatchAll(fileio.EmptyMatchTreatment.DISALLOW)
                    | beam.Map(lambda x: x.path))

                assert_that(files_pc, equal_to(files))
예제 #33
0
def get_metadata_header_lines(input_file):
    # type: (str) -> List[str]
    """Returns header lines from the given VCF file ``input_file``.

  Only returns lines starting with ## and not #.

  Args:
    input_file: A string specifying the path to a VCF file.
      It can be local or remote (e.g. on GCS).
  Returns:
    A list containing header lines of ``input_file``.
  Raises:
    ValueError: If ``input_file`` does not exist.
  """
    if not FileSystems.exists(input_file):
        raise ValueError('{} does not exist'.format(input_file))
    return [
        line for line in _header_line_generator(input_file)
        if line.startswith('##')
    ]
예제 #34
0
    def test_write_to_single_file_batch(self):

        dir = self._new_tempdir()

        with TestPipeline() as p:
            _ = (p
                 | beam.Create(WriteFilesTest.SIMPLE_COLLECTION)
                 | "Serialize" >> beam.Map(json.dumps)
                 | beam.io.fileio.WriteToFiles(path=dir))

        with TestPipeline() as p:
            result = (p
                      | fileio.MatchFiles(FileSystems.join(dir, '*'))
                      | fileio.ReadMatches()
                      |
                      beam.FlatMap(lambda f: f.read_utf8().strip().split('\n'))
                      | beam.Map(json.loads))

            assert_that(result,
                        equal_to([row for row in self.SIMPLE_COLLECTION]))
예제 #35
0
파일: apiclient.py 프로젝트: zhoubh/beam
    def create_job_description(self, job):
        """Creates a job described by the workflow proto."""

        # Stage the pipeline for the runner harness
        self.stage_file(job.google_cloud_options.staging_location,
                        names.STAGED_PIPELINE_FILENAME,
                        StringIO(job.proto_pipeline.SerializeToString()))

        # Stage other resources for the SDK harness
        resources = dependency.stage_job_resources(
            job.options, file_copy=self._gcs_file_copy)

        job.proto.environment = Environment(
            pipeline_url=FileSystems.join(
                job.google_cloud_options.staging_location,
                names.STAGED_PIPELINE_FILENAME),
            packages=resources,
            options=job.options,
            environment_version=self.environment_version).proto
        logging.debug('JOB: %s', job)
예제 #36
0
 def _verify_data(self, pcol, init_size, data_size):
     read = pcol | 'read' >> ReadAllFromParquet()
     v1 = (
         read
         | 'get_number' >> Map(lambda x: x['number'])
         | 'sum_globally' >> CombineGlobally(sum)
         |
         'validate_number' >> FlatMap(lambda x: TestParquetIT._sum_verifier(
             init_size, data_size, x)))
     v2 = (
         read
         | 'make_pair' >> Map(lambda x: (x['name'], x['number']))
         | 'count_per_key' >> Count.PerKey()
         |
         'validate_name' >> FlatMap(lambda x: TestParquetIT._count_verifier(
             init_size, data_size, x)))
     _ = ((v1, v2, pcol)
          | 'flatten' >> Flatten()
          | 'reshuffle' >> Reshuffle()
          | 'cleanup' >> Map(lambda x: FileSystems.delete([x])))
    def _validate_config(self, config_file_path):
        # type: (str) -> None
        with FileSystems.open(config_file_path, 'r') as f:
            try:
                partition_configs = yaml.load(f)
            except yaml.YAMLError as e:
                raise ValueError('Invalid yaml file: %s' % str(e))
        if len(partition_configs) > _MAX_NUM_PARTITIONS:
            raise ValueError(
                'There can be at most {} partitions but given config file '
                'contains {}'.format(_MAX_NUM_PARTITIONS,
                                     len(partition_configs)))
        if not partition_configs:
            raise ValueError(
                'There must be at least one partition in config file.')

        existing_partition_names = set()
        for partition_config in partition_configs:
            partition = partition_config.get('partition', None)
            if partition is None:
                raise ValueError(
                    'Wrong yaml file format, partition field missing.')
            regions = partition.get('regions', None)
            if regions is None:
                raise ValueError(
                    'Each partition must have at least one region.')
            if len(regions) > _MAX_NUM_REGIONS:
                raise ValueError(
                    'At most {} regions per partition, thie partition '
                    'contains {}'.format(_MAX_NUM_REGIONS, len(regions)))
            if not partition.get('partition_name', None):
                raise ValueError(
                    'Each partition must have partition_name field.')
            partition_name = partition.get('partition_name').strip()
            if not partition_name:
                raise ValueError('Partition name can not be empty string.')
            if partition_name in existing_partition_names:
                raise ValueError('Partition names must be unique, '
                                 '{} is duplicated'.format(partition_name))
            existing_partition_names.add(partition_name)
        return partition_configs
예제 #38
0
    def upload_to_bundle_store(self, bundle: Bundle, source: Source, git: bool,
                               unpack: bool):
        """Uploads the given source to the bundle store.
        Given arguments are the same as UploadManager.upload_to_bundle_store().
        Used when uploading from rest server."""
        try:
            # bundle_path = self._bundle_store.get_bundle_location(bundle.uuid)
            is_url, is_fileobj, filename = self._interpret_source(source)
            if is_url:
                assert isinstance(source, str)
                if git:
                    bundle_path = self._update_and_get_bundle_location(
                        bundle, is_directory=True)
                    self.write_git_repo(source, bundle_path)
                else:
                    # If downloading from a URL, convert the source to a file object.
                    is_fileobj = True
                    source = (filename, urlopen_with_retry(source))
            if is_fileobj:
                source_filename, source_fileobj = cast(Tuple[str, IO[bytes]],
                                                       source)
                source_ext = zip_util.get_archive_ext(source_filename)
                if unpack and zip_util.path_is_archive(filename):
                    bundle_path = self._update_and_get_bundle_location(
                        bundle, is_directory=source_ext in ARCHIVE_EXTS_DIR)
                    self.write_fileobj(source_ext,
                                       source_fileobj,
                                       bundle_path,
                                       unpack_archive=True)
                else:
                    bundle_path = self._update_and_get_bundle_location(
                        bundle, is_directory=False)
                    self.write_fileobj(source_ext,
                                       source_fileobj,
                                       bundle_path,
                                       unpack_archive=False)

        except UsageError:
            if FileSystems.exists(bundle_path):
                path_util.remove(bundle_path)
            raise
def get_vcf_headers(input_file):
    """Returns VCF headers (FORMAT and INFO) from ``input_file``.

  Args:
    input_file (str): A string specifying the path to the representative VCF
    file, i.e., the VCF file that contains a header representative of all VCF
    files matching the input_pattern of the job. It can be local or remote (e.g.
    on GCS).
  Returns:
    ``HeaderFields`` specifying header info.
  Raises:
    ValueError: If ``input_file`` is not a valid VCF file (e.g. bad format,
    empty, non-existent).
  """
    if not FileSystems.exists(input_file):
        raise ValueError('VCF header does not exist')
    try:
        vcf_reader = vcf.Reader(fsock=_line_generator(input_file))
    except (SyntaxError, StopIteration) as e:
        raise ValueError('Invalid VCF header: %s' % str(e))
    return HeaderFields(vcf_reader.infos, vcf_reader.formats)
예제 #40
0
    def process(self, element: Union[str, FileMetadata], *args,
                **kwargs) -> Tuple[FileMetadata, OffsetRange]:
        if isinstance(element, FileMetadata):
            metadata_list = [element]
        else:
            match_results = FileSystems.match([element])
            metadata_list = match_results[0].metadata_list
        for metadata in metadata_list:
            splittable = (self._splittable
                          and _determine_splittability_from_compression_type(
                              metadata.path, self._compression_type))

            if splittable:
                for split in OffsetRange(0, metadata.size_in_bytes).split(
                        self._desired_bundle_size, self._min_bundle_size):
                    yield (metadata, split)
            else:
                yield (metadata,
                       OffsetRange(
                           0,
                           range_trackers.OffsetRangeTracker.OFFSET_INFINITY))
예제 #41
0
파일: filebasedsink.py 프로젝트: nielm/beam
  def _check_state_for_finalize_write(self, writer_results, num_shards):
    """Checks writer output files' states.

    Returns:
      src_files, dst_files: Lists of files to rename. For each i, finalize_write
        should rename(src_files[i], dst_files[i]).
      delete_files: Src files to delete. These could be leftovers from an
        incomplete (non-atomic) rename operation.
      num_skipped: Tally of writer results files already renamed, such as from
        a previous run of finalize_write().
    """
    if not writer_results:
      return [], [], [], 0

    src_glob = FileSystems.join(FileSystems.split(writer_results[0])[0], '*')
    dst_glob = self._get_final_name_glob(num_shards)
    src_glob_files = set(
        file_metadata.path for mr in FileSystems.match([src_glob])
        for file_metadata in mr.metadata_list)
    dst_glob_files = set(
        file_metadata.path for mr in FileSystems.match([dst_glob])
        for file_metadata in mr.metadata_list)

    src_files = []
    dst_files = []
    delete_files = []
    num_skipped = 0
    for shard_num, src in enumerate(writer_results):
      final_name = self._get_final_name(shard_num, num_shards)
      dst = final_name
      src_exists = src in src_glob_files
      dst_exists = dst in dst_glob_files
      if not src_exists and not dst_exists:
        raise BeamIOError(
            'src and dst files do not exist. src: %s, dst: %s' % (src, dst))
      if not src_exists and dst_exists:
        _LOGGER.debug('src: %s -> dst: %s already renamed, skipping', src, dst)
        num_skipped += 1
        continue
      if (src_exists and dst_exists and
          FileSystems.checksum(src) == FileSystems.checksum(dst)):
        _LOGGER.debug('src: %s == dst: %s, deleting src', src, dst)
        delete_files.append(src)
        continue

      src_files.append(src)
      dst_files.append(dst)
    return src_files, dst_files, delete_files, num_skipped
예제 #42
0
    def _export_files(self, bq):
        """Runs a BigQuery export job.

        Returns:
          bigquery.TableSchema instance, a list of FileMetadata instances
        """
        job_id = uuid.uuid4().hex
        gcs_location = self.get_destination_uri()
        job_ref = bq.perform_extract_job([gcs_location],
                                         job_id,
                                         self.table_reference,
                                         bigquery_tools.FileFormat.JSON,
                                         include_header=False)
        bq.wait_for_bq_job(job_ref)
        metadata_list = FileSystems.match([gcs_location])[0].metadata_list

        table = bq.get_table(self.table_reference.projectId,
                             self.table_reference.datasetId,
                             self.table_reference.tableId)

        return table.schema, metadata_list
예제 #43
0
    def test_infer_compressed_file(self):
        dir = '%s%s' % (self._new_tempdir(), os.sep)

        file_contents = b'compressed_contents!'
        import gzip
        with gzip.GzipFile(os.path.join(dir, 'compressed.gz'), 'w') as f:
            f.write(file_contents)

        file_contents2 = b'compressed_contents_bz2!'
        import bz2
        with bz2.BZ2File(os.path.join(dir, 'compressed2.bz2'), 'w') as f:
            f.write(file_contents2)

        with TestPipeline() as p:
            content_pc = (p
                          | beam.Create([FileSystems.join(dir, '*')])
                          | fileio.MatchAll()
                          | fileio.ReadMatches()
                          | beam.Map(lambda rf: rf.open().readline()))

            assert_that(content_pc, equal_to([file_contents, file_contents2]))
예제 #44
0
def parse_gcp_path(path):
    """ parse a custom gcp path and determine which apache beam source/sink type that path refers to.

    returns a tuple of (service, path)
    where service indicates the apache beam source/sink type to use and path is the provided path with
    the scheme:// stripped off the front


    Examples

    Bigquery table          bq://project:dataset.table          (table, project:dataset.table)
    Bigquery query          "query://select * from [Table]"     (query, select * from [Table])
    GCS file                gs://bucket/path/file               (text, gs://bucket/path/file )
    Local file              file://path/file                    (text, file://path/file)
    Local file (absolute)   /path/file                          (text, /path/file )
    Local file (relative)   ./path/file                         (text, ./path/file)
    Bigquery query          "select * from [Table]"             (query, select * from [Table])
    """

    scheme = FileSystems.get_scheme(path)

    if scheme == 'query':
        # path contains a sql query
        # strip off the scheme and just return the rest
        return 'query', path[8:]
    elif scheme == 'bq':
        # path is a reference to a big query table
        # strip off the scheme and just return the table id in path
        return 'table', path[5:]
    elif scheme == 'gs':
        # path is a Google Cloud Storage reference
        return 'file', path
    elif scheme is None:
        # could be a local file or a sql query
        if path[0] in ('.', '/'):
            return 'file', path
        else:
            return 'query', path
    else:
        raise ValueError("Unknown scheme %s" % scheme)
예제 #45
0
    def _get_concat_source(self):
        if self._concat_source is None:
            pattern = self._pattern.get()

            single_file_sources = []
            match_result = FileSystems.match([pattern])[0]
            files_metadata = match_result.metadata_list

            # We create a reference for FileBasedSource that will be serialized along
            # with each _SingleFileSource. To prevent this FileBasedSource from having
            # a reference to ConcatSource (resulting in quadratic space complexity)
            # we clone it here.
            file_based_source_ref = pickler.loads(pickler.dumps(self))

            for file_metadata in files_metadata:
                file_name = file_metadata.path
                file_size = file_metadata.size_in_bytes
                if file_size == 0:
                    continue  # Ignoring empty file.

                # We determine splittability of this specific file.
                splittable = self.splittable
                if (splittable
                        and self._compression_type == CompressionTypes.AUTO):
                    compression_type = CompressionTypes.detect_compression_type(
                        file_name)
                    if compression_type != CompressionTypes.UNCOMPRESSED:
                        splittable = False

                single_file_source = _SingleFileSource(
                    file_based_source_ref,
                    file_name,
                    0,
                    file_size,
                    min_bundle_size=self._min_bundle_size,
                    splittable=splittable)
                single_file_sources.append(single_file_source)
            self._concat_source = concat_source.ConcatSource(
                single_file_sources)
        return self._concat_source
예제 #46
0
class TestPrepare(unittest.TestCase):

    test_data_dir = FileSystems.join(
        os.path.dirname(os.path.realpath(__file__)), 'testdata')

    def test_valid(self):
        file_pattern = FileSystems.join(self.test_data_dir, 'detail.json')
        expected_valid = [(1, {
            'error': [],
            'first_name': 'Bart',
            'last_name': 'Bruck',
            'email': '*****@*****.**',
            'id': 1
        }),
                          (3, {
                              'error':
                              [u"email 'wtuppeny2bandcamp.com' is invalid"],
                              'first_name':
                              'Winny',
                              'last_name':
                              'Tuppeny',
                              'email':
                              None,
                              'id':
                              3
                          })]
        expected_broken = [{
            'error':
            'id is missing',
            'element':
            '{"first_name":"Alfonso","last_name":"Koenen","email":"*****@*****.**"}'
        }]
        # Make use of the TestPipeline from the Beam testing util.
        with TestPipeline() as p:
            actual_valid, actual_broken = (p | Prepare(file_pattern))
            # The labels are required because otherwise the assert_that Transform does not have a stable unique label.
            assert_that(actual_valid, equal_to(expected_valid), label='valid')
            assert_that(actual_broken,
                        equal_to(expected_broken),
                        label='broken')
예제 #47
0
def run(argv=None):
    """Run the beam pipeline."""
    args, pipeline_args = _parse_args(argv)

    pipeline_options = PipelineOptions(pipeline_args)
    pipeline_options.view_as(SetupOptions).save_main_session = True
    p = beam.Pipeline(options=pipeline_options)

    sentence_files_match = FileSystems.match([args.sentence_files])[0]
    sentence_files = [
        file_metadata.path
        for file_metadata in sentence_files_match.metadata_list
    ]
    logging.info("Reading %i files from %s.", len(sentence_files),
                 args.sentence_files)
    assert len(sentence_files) > 0
    sentence_files = p | beam.Create(sentence_files)
    serialized_examples = sentence_files | "create examples" >> beam.FlatMap(
        partial(_create_examples_from_file,
                min_length=args.min_length,
                max_length=args.max_length,
                num_extra_contexts=args.num_extra_contexts))

    serialized_examples = _shuffle_examples(serialized_examples)

    serialized_examples |= "split train and test" >> beam.ParDo(
        _TrainTestSplitFn(args.train_split)).with_outputs(
            _TrainTestSplitFn.TEST_TAG, _TrainTestSplitFn.TRAIN_TAG)

    (serialized_examples[_TrainTestSplitFn.TRAIN_TAG]
     | "write train" >> WriteToTFRecord(os.path.join(args.output_dir, "train"),
                                        file_name_suffix=".tfrecords",
                                        num_shards=args.num_shards_train))
    (serialized_examples[_TrainTestSplitFn.TEST_TAG]
     | "write test" >> WriteToTFRecord(os.path.join(args.output_dir, "test"),
                                       file_name_suffix=".tfrecords",
                                       num_shards=args.num_shards_test))

    result = p.run()
    result.wait_until_finish()
 def check_file_equals_string(self, file_subpath: str,
                              expected_contents: str):
     with FileSystems.open(
             self.bundle_location,
             compression_type=CompressionTypes.UNCOMPRESSED) as f:
         if not file_subpath:
             # Should be a .gz file
             self.assertTrue(self.bundle_location.endswith("contents.gz"))
             self.assertEqual(
                 gzip.decompress(f.read()).decode(), expected_contents)
         else:
             # Should be a .tar.gz file
             self.assertTrue(
                 self.bundle_location.endswith("contents.tar.gz"))
             with tarfile.open(fileobj=f, mode='r:gz') as tf:
                 # Prepend "./" to the file subpath so that it corresponds with a file in the archive.
                 self.assertEqual(
                     cast(IO[bytes],
                          tf.extractfile("./" +
                                         file_subpath)).read().decode(),
                     expected_contents,
                 )
예제 #49
0
    def process(self,
                element,
                timestamp=beam.DoFn.TimestampParam,
                window=beam.DoFn.WindowParam,
                pane_info=beam.DoFn.PaneInfoParam):

        # Logging to audit triggering of side input refresh process. Statement will be logged only whenever the pubsub notification
        # triggers side input refresh process (i.e normally once in every x hours)
        if isinstance(window, beam.transforms.window.GlobalWindow):
            logging.info(
                f"(Re)loading side input data from basepath {element.decode()} for global window: {timestamp} - {window}"
            )
        else:
            logging.info(
                f"(Re)loading side input data from basepath {element.decode()} for window: {util.get_formatted_time(window.start)} - {util.get_formatted_time(window.end)}"
            )

        for sideinput_type in self.sideinput_types:
            yield beam.pvalue.TaggedOutput(
                sideinput_type,
                FileSystems.join(element.decode(), sideinput_type,
                                 self.file_prefix))
예제 #50
0
    def test_torch_run_inference_imagenet_mobilenetv2(self):
        test_pipeline = TestPipeline(is_integration_test=True)
        # text files containing absolute path to the imagenet validation data on GCS
        file_of_image_names = 'gs://apache-beam-ml/testing/inputs/it_mobilenetv2_imagenet_validation_inputs.txt'  # disable: line-too-long
        output_file_dir = 'gs://apache-beam-ml/testing/predictions'
        output_file = '/'.join(
            [output_file_dir, str(uuid.uuid4()), 'result.txt'])

        model_state_dict_path = 'gs://apache-beam-ml/models/imagenet_classification_mobilenet_v2.pt'
        extra_opts = {
            'input': file_of_image_names,
            'output': output_file,
            'model_state_dict_path': model_state_dict_path,
        }
        pytorch_image_classification.run(
            test_pipeline.get_full_options_as_args(**extra_opts),
            save_main_session=False)

        self.assertEqual(FileSystems().exists(output_file), True)
        predictions = process_outputs(filepath=output_file)

        for prediction in predictions:
            filename, prediction = prediction.split(',')
            self.assertEqual(_EXPECTED_OUTPUTS[filename], prediction)
def _create_examples_from_file(file_name, min_length, max_length,
                               num_extra_contexts):
    _, file_id = path.split(file_name)
    previous_lines = []
    for line in FileSystems.open(file_name, "application/octet-stream"):
        line = _preprocess_line(line)
        if not line:
            continue

        should_skip = _should_skip(line,
                                   min_length=min_length,
                                   max_length=max_length)

        if previous_lines:
            should_skip |= _should_skip(previous_lines[-1],
                                        min_length=min_length,
                                        max_length=max_length)

            if not should_skip:
                yield create_example(previous_lines, line, file_id)

        previous_lines.append(line)
        if len(previous_lines) > num_extra_contexts + 1:
            del previous_lines[0]
예제 #52
0
 def test_mkdirs(self):
     path = os.path.join(self.tmpdir, 't1/t2')
     FileSystems.mkdirs(path)
     self.assertTrue(os.path.isdir(path))
예제 #53
0
    def initialize_write(self):
        file_path_prefix = self.file_path_prefix.get()

        tmp_dir = self._create_temp_dir(file_path_prefix)
        FileSystems.mkdirs(tmp_dir)
        return tmp_dir
예제 #54
0
    def test_streaming_complex_timing(self):
        # Use state on the TestCase class, since other references would be pickled
        # into a closure and not have the desired side effects.
        #
        # TODO(BEAM-5295): Use assert_that after it works for the cases here in
        # streaming mode.
        WriteFilesTest.all_records = []

        dir = '%s%s' % (self._new_tempdir(), os.sep)

        # Setting up the input (TestStream)
        ts = TestStream().advance_watermark_to(0)
        for elm in WriteFilesTest.LARGER_COLLECTION:
            timestamp = int(elm)

            ts.add_elements([('key', '%s' % elm)])
            if timestamp % 5 == 0 and timestamp != 0:
                # TODO(BEAM-3759): Add many firings per window after getting PaneInfo.
                ts.advance_processing_time(5)
                ts.advance_watermark_to(timestamp)
        ts.advance_watermark_to_infinity()

        def no_colon_file_naming(*args):
            file_name = fileio.destination_prefix_naming()(*args)
            return file_name.replace(':', '_')

        # The pipeline that we are testing
        options = PipelineOptions()
        options.view_as(StandardOptions).streaming = True
        with TestPipeline(options=options) as p:
            res = (p
                   | ts
                   | beam.WindowInto(
                       FixedWindows(10),
                       trigger=trigger.AfterWatermark(),
                       accumulation_mode=trigger.AccumulationMode.DISCARDING)
                   | beam.GroupByKey()
                   | beam.FlatMap(lambda x: x[1]))
            # Triggering after 5 processing-time seconds, and on the watermark. Also
            # discarding old elements.

            _ = (res
                 | beam.io.fileio.WriteToFiles(
                     path=dir,
                     file_naming=no_colon_file_naming,
                     max_writers_per_bundle=0)
                 | beam.Map(lambda fr: FileSystems.join(dir, fr.file_name))
                 | beam.ParDo(self.record_dofn()))

        # Verification pipeline
        with TestPipeline() as p:
            files = (p | beam.io.fileio.MatchFiles(FileSystems.join(dir, '*')))

            file_names = (files | beam.Map(lambda fm: fm.path))

            file_contents = (
                files
                | beam.io.fileio.ReadMatches()
                | beam.Map(lambda rf: (rf.metadata.path, rf.read_utf8().strip(
                ).split('\n'))))

            content = (file_contents
                       | beam.FlatMap(lambda fc: [ln.strip() for ln in fc[1]]))

            assert_that(file_names,
                        equal_to(WriteFilesTest.all_records),
                        label='AssertFilesMatch')
            assert_that(content,
                        matches_all(WriteFilesTest.LARGER_COLLECTION),
                        label='AssertContentsMatch')
예제 #55
0
 def _InferArrowSchema(self):
     match_result = FileSystems.match([self._file_pattern])[0]
     files_metadata = match_result.metadata_list[0]
     with FileSystems.open(files_metadata.path) as f:
         return pq.read_schema(f)
예제 #56
0
 def tearDown(self):
     FileSystems.delete([self.outdir + '/'])
예제 #57
0
 def read_csv(path):
     with FileSystems.open(path) as fp:
         return pd.read_csv(fp)
예제 #58
0
    def __init__(self, options, proto_pipeline):
        self.options = options
        self.proto_pipeline = proto_pipeline
        self.google_cloud_options = options.view_as(GoogleCloudOptions)
        if not self.google_cloud_options.job_name:
            self.google_cloud_options.job_name = self.default_job_name(
                self.google_cloud_options.job_name)

        required_google_cloud_options = [
            'project', 'job_name', 'temp_location'
        ]
        missing = [
            option for option in required_google_cloud_options
            if not getattr(self.google_cloud_options, option)
        ]
        if missing:
            raise ValueError('Missing required configuration parameters: %s' %
                             missing)

        if not self.google_cloud_options.staging_location:
            logging.info(
                'Defaulting to the temp_location as staging_location: %s',
                self.google_cloud_options.temp_location)
            (self.google_cloud_options.staging_location
             ) = self.google_cloud_options.temp_location

        # Make the staging and temp locations job name and time specific. This is
        # needed to avoid clashes between job submissions using the same staging
        # area or team members using same job names. This method is not entirely
        # foolproof since two job submissions with same name can happen at exactly
        # the same time. However the window is extremely small given that
        # time.time() has at least microseconds granularity. We add the suffix only
        # for GCS staging locations where the potential for such clashes is high.
        if self.google_cloud_options.staging_location.startswith('gs://'):
            path_suffix = '%s.%f' % (self.google_cloud_options.job_name,
                                     time.time())
            self.google_cloud_options.staging_location = FileSystems.join(
                self.google_cloud_options.staging_location, path_suffix)
            self.google_cloud_options.temp_location = FileSystems.join(
                self.google_cloud_options.temp_location, path_suffix)

        self.proto = dataflow.Job(name=self.google_cloud_options.job_name)
        if self.options.view_as(StandardOptions).streaming:
            self.proto.type = dataflow.Job.TypeValueValuesEnum.JOB_TYPE_STREAMING
        else:
            self.proto.type = dataflow.Job.TypeValueValuesEnum.JOB_TYPE_BATCH
        if self.google_cloud_options.update:
            self.proto.replaceJobId = self.job_id_for_name(self.proto.name)

        # Labels.
        if self.google_cloud_options.labels:
            self.proto.labels = dataflow.Job.LabelsValue()
            for label in self.google_cloud_options.labels:
                parts = label.split('=', 1)
                key = parts[0]
                value = parts[1] if len(parts) > 1 else ''
                self.proto.labels.additionalProperties.append(
                    dataflow.Job.LabelsValue.AdditionalProperty(key=key,
                                                                value=value))

        self.base64_str_re = re.compile(r'^[A-Za-z0-9+/]*=*$')
        self.coder_str_re = re.compile(r'^([A-Za-z]+\$)([A-Za-z0-9+/]*=*)$')
예제 #59
0
    def test_streaming_different_file_types(self):
        dir = self._new_tempdir()
        input = iter(WriteFilesTest.SIMPLE_COLLECTION)
        ts = (TestStream().advance_watermark_to(0).add_elements(
            [next(input), next(input)]).advance_watermark_to(10).add_elements(
                [next(input),
                 next(input)]).advance_watermark_to(20).add_elements([
                     next(input), next(input)
                 ]).advance_watermark_to(30).add_elements([
                     next(input), next(input)
                 ]).advance_watermark_to(40).advance_watermark_to_infinity())

        def no_colon_file_naming(*args):
            file_name = fileio.destination_prefix_naming()(*args)
            return file_name.replace(':', '_')

        with TestPipeline() as p:
            _ = (p
                 | ts
                 | beam.WindowInto(FixedWindows(10))
                 | beam.io.fileio.WriteToFiles(
                     path=dir,
                     destination=lambda record: record['foundation'],
                     sink=lambda dest:
                     (WriteFilesTest.CsvSink(WriteFilesTest.CSV_HEADERS)
                      if dest == 'apache' else WriteFilesTest.JsonSink()),
                     file_naming=no_colon_file_naming,
                     max_writers_per_bundle=0,
                 ))

        with TestPipeline() as p:
            cncf_files = (p
                          | fileio.MatchFiles(FileSystems.join(dir, 'cncf*'))
                          | "CncfFileNames" >> beam.Map(lambda fm: fm.path))

            apache_files = (p
                            | "MatchApache" >> fileio.MatchFiles(
                                FileSystems.join(dir, 'apache*'))
                            |
                            "ApacheFileNames" >> beam.Map(lambda fm: fm.path))

            assert_that(
                cncf_files,
                matches_all([
                    stringmatches.matches_regexp(
                        '.*cncf-1970-01-01T00_00_00-1970-01-01T00_00_10--.*'),
                    stringmatches.matches_regexp(
                        '.*cncf-1970-01-01T00_00_10-1970-01-01T00_00_20--.*'),
                    stringmatches.matches_regexp(
                        '.*cncf-1970-01-01T00_00_20-1970-01-01T00_00_30--.*'),
                    stringmatches.matches_regexp(
                        '.*cncf-1970-01-01T00_00_30-1970-01-01T00_00_40--.*')
                ]),
                label='verifyCNCFFiles')

            assert_that(
                apache_files,
                matches_all([
                    stringmatches.matches_regexp(
                        '.*apache-1970-01-01T00_00_00-1970-01-01T00_00_10--.*'
                    ),
                    stringmatches.matches_regexp(
                        '.*apache-1970-01-01T00_00_10-1970-01-01T00_00_20--.*'
                    ),
                    stringmatches.matches_regexp(
                        '.*apache-1970-01-01T00_00_20-1970-01-01T00_00_30--.*'
                    ),
                    stringmatches.matches_regexp(
                        '.*apache-1970-01-01T00_00_30-1970-01-01T00_00_40--.*')
                ]),
                label='verifyApacheFiles')
예제 #60
0
    def finalize_write(self, init_result, writer_results,
                       unused_pre_finalize_results):
        writer_results = sorted(writer_results)
        num_shards = len(writer_results)

        src_files, dst_files, delete_files, num_skipped = (
            self._check_state_for_finalize_write(writer_results, num_shards))
        num_skipped += len(delete_files)
        FileSystems.delete(delete_files)
        num_shards_to_finalize = len(src_files)
        min_threads = min(num_shards_to_finalize,
                          FileBasedSink._MAX_RENAME_THREADS)
        num_threads = max(1, min_threads)

        chunk_size = FileSystems.get_chunk_size(self.file_path_prefix.get())
        source_file_batch = [
            src_files[i:i + chunk_size]
            for i in range(0, len(src_files), chunk_size)
        ]
        destination_file_batch = [
            dst_files[i:i + chunk_size]
            for i in range(0, len(dst_files), chunk_size)
        ]

        if num_shards_to_finalize:
            logging.info(
                'Starting finalize_write threads with num_shards: %d (skipped: %d), '
                'batches: %d, num_threads: %d', num_shards_to_finalize,
                num_skipped, len(source_file_batch), num_threads)
            start_time = time.time()

            # Use a thread pool for renaming operations.
            def _rename_batch(batch):
                """_rename_batch executes batch rename operations."""
                source_files, destination_files = batch
                exceptions = []
                try:
                    FileSystems.rename(source_files, destination_files)
                    return exceptions
                except BeamIOError as exp:
                    if exp.exception_details is None:
                        raise
                    for (src,
                         dst), exception in iteritems(exp.exception_details):
                        if exception:
                            logging.error(
                                ('Exception in _rename_batch. src: %s, '
                                 'dst: %s, err: %s'), src, dst, exception)
                            exceptions.append(exception)
                        else:
                            logging.debug('Rename successful: %s -> %s', src,
                                          dst)
                    return exceptions

            exception_batches = util.run_using_threadpool(
                _rename_batch,
                list(zip(source_file_batch, destination_file_batch)),
                num_threads)

            all_exceptions = [
                e for exception_batch in exception_batches
                for e in exception_batch
            ]
            if all_exceptions:
                raise Exception(
                    'Encountered exceptions in finalize_write: %s' %
                    all_exceptions)

            for final_name in dst_files:
                yield final_name

            logging.info('Renamed %d shards in %.2f seconds.',
                         num_shards_to_finalize,
                         time.time() - start_time)
        else:
            logging.warning(
                'No shards found to finalize. num_shards: %d, skipped: %d',
                num_shards, num_skipped)

        try:
            FileSystems.delete([init_result])
        except IOError:
            # May have already been removed.
            pass