Ejemplo n.º 1
0
  def __init__(self,
               file_pattern,
               min_bundle_size=0,
               compression_type=CompressionTypes.AUTO,
               splittable=True,
               validate=True):
    """Initializes ``FileBasedSource``.

    Args:
      file_pattern: the file glob to read a string or a ValueProvider
                    (placeholder to inject a runtime value).
      min_bundle_size: minimum size of bundles that should be generated when
                       performing initial splitting on this source.
      compression_type: compression type to use
      splittable: whether FileBasedSource should try to logically split a single
                  file into data ranges so that different parts of the same file
                  can be read in parallel. If set to False, FileBasedSource will
                  prevent both initial and dynamic splitting of sources for
                  single files. File patterns that represent multiple files may
                  still get split into sources for individual files. Even if set
                  to True by the user, FileBasedSource may choose to not split
                  the file, for example, for compressed files where currently
                  it is not possible to efficiently read a data range without
                  decompressing the whole file.
      validate: Boolean flag to verify that the files exist during the pipeline
                creation time.
    Raises:
      TypeError: when compression_type is not valid or if file_pattern is not a
                 string or a ValueProvider.
      ValueError: when compression and splittable files are specified.
      IOError: when the file pattern specified yields an empty result.
    """

    if not isinstance(file_pattern, (basestring, ValueProvider)):
      raise TypeError('%s: file_pattern must be of type string'
                      ' or ValueProvider; got %r instead'
                      % (self.__class__.__name__, file_pattern))

    if isinstance(file_pattern, basestring):
      file_pattern = StaticValueProvider(str, file_pattern)
    self._pattern = file_pattern

    self._concat_source = None
    self._min_bundle_size = min_bundle_size
    if not CompressionTypes.is_valid_compression_type(compression_type):
      raise TypeError('compression_type must be CompressionType object but '
                      'was %s' % type(compression_type))
    self._compression_type = compression_type
    if compression_type in (CompressionTypes.UNCOMPRESSED,
                            CompressionTypes.AUTO):
      self._splittable = splittable
    else:
      # We can't split compressed files efficiently so turn off splitting.
      self._splittable = False
    if validate and file_pattern.is_accessible():
      self._validate()
Ejemplo n.º 2
0
 def test_static_value_provider_empty_write(self):
     temp_path = StaticValueProvider(
         value_type=str, value=tempfile.NamedTemporaryFile().name)
     sink = MyFileSink(temp_path,
                       file_name_suffix=StaticValueProvider(
                           value_type=str, value='.output'),
                       coder=coders.ToStringCoder())
     p = TestPipeline()
     p | beam.Create([]) | beam.io.Write(sink)  # pylint: disable=expression-not-assigned
     p.run()
     self.assertEqual(
         open(temp_path.get() + '-00000-of-00001.output').read(),
         '[start][end]')
Ejemplo n.º 3
0
 def test_static_value_provider_empty_write(self):
   temp_path = StaticValueProvider(value_type=str,
                                   value=tempfile.NamedTemporaryFile().name)
   sink = MyFileSink(
       temp_path,
       file_name_suffix=StaticValueProvider(value_type=str, value='.output'),
       coder=coders.ToStringCoder()
   )
   p = TestPipeline()
   p | beam.Create([]) | beam.io.Write(sink)  # pylint: disable=expression-not-assigned
   p.run()
   self.assertEqual(
       open(temp_path.get() + '-00000-of-00001.output').read(), '[start][end]')
Ejemplo n.º 4
0
    def __init__(self,
                 file_path_prefix,
                 coder,
                 file_name_suffix='',
                 num_shards=0,
                 shard_name_template=None,
                 mime_type='application/octet-stream',
                 compression_type=CompressionTypes.AUTO):
        """
     Raises:
      TypeError: if file path parameters are not a string or ValueProvider,
                 or if compression_type is not member of CompressionTypes.
      ValueError: if shard_name_template is not of expected format.
    """
        if not (isinstance(file_path_prefix, basestring)
                or isinstance(file_path_prefix, ValueProvider)):
            raise TypeError(
                'file_path_prefix must be a string or ValueProvider;'
                'got %r instead' % file_path_prefix)
        if not (isinstance(file_name_suffix, basestring)
                or isinstance(file_name_suffix, ValueProvider)):
            raise TypeError(
                'file_name_suffix must be a string or ValueProvider;'
                'got %r instead' % file_name_suffix)

        if not CompressionTypes.is_valid_compression_type(compression_type):
            raise TypeError(
                'compression_type must be CompressionType object but '
                'was %s' % type(compression_type))
        if shard_name_template is None:
            shard_name_template = DEFAULT_SHARD_NAME_TEMPLATE
        elif shard_name_template is '':
            num_shards = 1
        if isinstance(file_path_prefix, basestring):
            file_path_prefix = StaticValueProvider(str, file_path_prefix)
        if isinstance(file_name_suffix, basestring):
            file_name_suffix = StaticValueProvider(str, file_name_suffix)
        self.file_path_prefix = file_path_prefix
        self.file_name_suffix = file_name_suffix
        self.num_shards = num_shards
        self.coder = coder
        self.shard_name_format = self._template_to_format(shard_name_template)
        self.compression_type = compression_type
        self.mime_type = mime_type
        if file_path_prefix.is_accessible():
            self._file_system = get_filesystem(file_path_prefix.get())
        else:
            self._file_system = None
    def test_value_provider_options(self):
        class UserOptions(PipelineOptions):
            @classmethod
            def _add_argparse_args(cls, parser):
                parser.add_value_provider_argument(
                    '--vp_arg', help='This flag is a value provider')

                parser.add_value_provider_argument('--vp_arg2',
                                                   default=1,
                                                   type=int)

                parser.add_argument('--non_vp_arg', default=1, type=int)

        # Provide values: if not provided, the option becomes of the type runtime vp
        options = UserOptions(['--vp_arg', 'hello'])
        self.assertIsInstance(options.vp_arg, StaticValueProvider)
        self.assertIsInstance(options.vp_arg2, RuntimeValueProvider)
        self.assertIsInstance(options.non_vp_arg, int)

        # Values can be overwritten
        options = UserOptions(vp_arg=5,
                              vp_arg2=StaticValueProvider(value_type=str,
                                                          value='bye'),
                              non_vp_arg=RuntimeValueProvider(
                                  option_name='foo',
                                  value_type=int,
                                  default_value=10))
        self.assertEqual(options.vp_arg, 5)
        self.assertTrue(options.vp_arg2.is_accessible(),
                        '%s is not accessible' % options.vp_arg2)
        self.assertEqual(options.vp_arg2.get(), 'bye')
        self.assertFalse(options.non_vp_arg.is_accessible())

        with self.assertRaises(RuntimeError):
            options.non_vp_arg.get()
Ejemplo n.º 6
0
    def test_string_or_value_provider_only(self):
        str_file_pattern = tempfile.NamedTemporaryFile(delete=False).name
        self.assertEqual(str_file_pattern,
                         FileBasedSource(str_file_pattern)._pattern.value)

        static_vp_file_pattern = StaticValueProvider(value_type=str,
                                                     value=str_file_pattern)
        self.assertEqual(static_vp_file_pattern,
                         FileBasedSource(static_vp_file_pattern)._pattern)

        runtime_vp_file_pattern = RuntimeValueProvider(
            option_name='arg', value_type=str, default_value=str_file_pattern)
        self.assertEqual(runtime_vp_file_pattern,
                         FileBasedSource(runtime_vp_file_pattern)._pattern)

        invalid_file_pattern = 123
        with self.assertRaises(TypeError):
            FileBasedSource(invalid_file_pattern)
Ejemplo n.º 7
0
 def _f(value):
     _f.func_name = value_type.__name__
     return StaticValueProvider(value_type, value)
Ejemplo n.º 8
0
 def test_static_value_provider_to(self):
     svp = StaticValueProvider(str, 'abc')
     self.assertEquals(JsonValue(string_value=svp.value),
                       to_json_value(svp))