Пример #1
0
    def add_value_provider_argument(self, *args, **kwargs):
        """ValueProvider arguments can be either of type keyword or positional.
    At runtime, even positional arguments will need to be supplied in the
    key/value form.
    """
        # Extract the option name from positional argument ['pos_arg']
        assert args != () and len(args[0]) >= 1
        if args[0][0] != '-':
            option_name = args[0]
            if kwargs.get('nargs') is None:  # make them optionally templated
                kwargs['nargs'] = '?'
        else:
            # or keyword arguments like [--kw_arg, -k, -w] or [--kw-arg]
            option_name = [i.replace('--', '') for i in args
                           if i[:2] == '--'][0]

        # reassign the type to make room for using
        # StaticValueProvider as the type for add_argument
        value_type = kwargs.get('type') or str
        kwargs['type'] = _static_value_provider_of(value_type)

        # reassign default to default_value to make room for using
        # RuntimeValueProvider as the default for add_argument
        default_value = kwargs.get('default')
        kwargs['default'] = RuntimeValueProvider(option_name=option_name,
                                                 value_type=value_type,
                                                 default_value=default_value)

        # have add_argument do most of the work
        self.add_argument(*args, **kwargs)
Пример #2
0
def test_inserting_the_dest_table_schema_into_pcollection_runtime():
    with TestPipeline() as p:
        lake_table = RuntimeValueProvider(
            option_name='dest',
            value_type=str,
            default_value=f'{project_id}:lake.wrench_metrics')
        expected = [{
            'schema': [
                gcp_bq.schema.SchemaField('entity_id', 'STRING', 'REQUIRED',
                                          None, ()),
                gcp_bq.schema.SchemaField('tree_user_id', 'INTEGER',
                                          'REQUIRED', None, ()),
                gcp_bq.schema.SchemaField('prediction', 'STRING', 'REQUIRED',
                                          None, ()),
                gcp_bq.schema.SchemaField('client_wrench_id', 'STRING',
                                          'REQUIRED', None, ()),
                gcp_bq.schema.SchemaField('expirement_name', 'STRING',
                                          'NULLABLE', None, ()),
                gcp_bq.schema.SchemaField('processing_datetime', 'DATETIME',
                                          'NULLABLE', None, ()),
                gcp_bq.schema.SchemaField('ingestion_timestamp', 'TIMESTAMP',
                                          'REQUIRED', None, ())
            ],
            'payload': {}
        }]
        pcoll = p | beam.Create([{}])
        schema_pcoll = pcoll | beam.ParDo(
            bq.IngectTableSchema(table=lake_table))
        assert_that(schema_pcoll, equal_to(expected))
        RuntimeValueProvider.set_runtime_options(None)
    def test_value_provider_options(self):
        class UserOptions(PipelineOptions):
            @classmethod
            def _add_argparse_args(cls, parser):
                parser.add_value_provider_argument(
                    '--pot_vp_arg1', help='This flag is a value provider')

                parser.add_value_provider_argument('--pot_vp_arg2',
                                                   default=1,
                                                   type=int)

                parser.add_argument('--pot_non_vp_arg1', default=1, type=int)

        # Provide values: if not provided, the option becomes of the type runtime vp
        options = UserOptions(['--pot_vp_arg1', 'hello'])
        self.assertIsInstance(options.pot_vp_arg1, StaticValueProvider)
        self.assertIsInstance(options.pot_vp_arg2, RuntimeValueProvider)
        self.assertIsInstance(options.pot_non_vp_arg1, int)

        # Values can be overwritten
        options = UserOptions(pot_vp_arg1=5,
                              pot_vp_arg2=StaticValueProvider(value_type=str,
                                                              value='bye'),
                              pot_non_vp_arg1=RuntimeValueProvider(
                                  option_name='foo',
                                  value_type=int,
                                  default_value=10))
        self.assertEqual(options.pot_vp_arg1, 5)
        self.assertTrue(options.pot_vp_arg2.is_accessible(),
                        '%s is not accessible' % options.pot_vp_arg2)
        self.assertEqual(options.pot_vp_arg2.get(), 'bye')
        self.assertFalse(options.pot_non_vp_arg1.is_accessible())

        with self.assertRaises(RuntimeError):
            options.pot_non_vp_arg1.get()
Пример #4
0
def test_runtime_serialized_file_list_is_deserialized_and_processed_by_insertion_order(
        cloudstorage):
    with TestPipeline() as p:
        bucket = f'{project_id}-cdc-imports'

        # Update sort_key based on the filename format
        def _sort_key(f):
            delimeter = '-'
            ts = f[f.rfind(delimeter) + 1:]
            return int(ts) if ts.isdigit() else f

        _sort_key = bytes.hex(dill.dumps(_sort_key))

        runtime_env = RuntimeValueProvider(option_name='env',
                                           value_type=str,
                                           default_value='local')
        runtime_bucket = RuntimeValueProvider(option_name='bucket',
                                              value_type=str,
                                              default_value=bucket)
        runtime_startswith = RuntimeValueProvider(
            option_name='files_startwith',
            value_type=str,
            default_value='vibe-tree-user-statuses-final')
        runtime_sort_key = RuntimeValueProvider(option_name='sort_key',
                                                value_type=str,
                                                default_value=_sort_key)
        [b.delete() for b in cloudstorage.client.list_blobs(bucket)]
        file_paths = [
            'vibe-tree-user-statuses-final-0083c-1987582612499',
            'vibe-tree-user-statuses-final-003c-1587582612405',
            'vibe-order-items-final-0030dd8697-1588231505823'
        ]
        expected_output = [
            'gs://icentris-ml-local-wbrito-cdc-imports/vibe-tree-user-statuses-final-003c-1587582612405',
            'gs://icentris-ml-local-wbrito-cdc-imports/vibe-tree-user-statuses-final-0083c-1987582612499'
        ]
        for f in file_paths:
            cloudstorage.client.upload_blob_from_string(bucket, f, f)

        p_paths = p | FileListIteratorTransform(
            env=runtime_env,
            bucket=runtime_bucket,
            files_startwith=runtime_startswith,
            sort_key=runtime_sort_key)
        assert_that(p_paths, equal_to(expected_output))
        RuntimeValueProvider.set_runtime_options(None)
Пример #5
0
 def test_runtime_values(self):
   test_runtime_provider = RuntimeValueProvider('test_param', int, None)
   sdk_worker_main.create_harness({
       'CONTROL_API_SERVICE_DESCRIPTOR': '',
       'PIPELINE_OPTIONS': '{"test_param": 37}',
   },
                                  dry_run=True)
   self.assertTrue(test_runtime_provider.is_accessible())
   self.assertEqual(test_runtime_provider.get(), 37)
Пример #6
0
    def test_string_or_value_provider_only(self):
        str_file_pattern = tempfile.NamedTemporaryFile(delete=False).name
        self.assertEqual(str_file_pattern,
                         FileBasedSource(str_file_pattern)._pattern.value)

        static_vp_file_pattern = StaticValueProvider(value_type=str,
                                                     value=str_file_pattern)
        self.assertEqual(static_vp_file_pattern,
                         FileBasedSource(static_vp_file_pattern)._pattern)

        runtime_vp_file_pattern = RuntimeValueProvider(
            option_name='arg', value_type=str, default_value=str_file_pattern)
        self.assertEqual(runtime_vp_file_pattern,
                         FileBasedSource(runtime_vp_file_pattern)._pattern)

        invalid_file_pattern = 123
        with self.assertRaises(TypeError):
            FileBasedSource(invalid_file_pattern)
Пример #7
0
 def _reader_thread(self):
   # pylint: disable=too-many-nested-blocks
   experiments = set(
       RuntimeValueProvider('experiments', str, '').get().split(','))
   try:
     while True:
       try:
         source = self.sources_queue.get_nowait()
         if isinstance(source, iobase.BoundedSource):
           for value in source.read(source.get_range_tracker(None, None)):
             if self.has_errored:
               # If any reader has errored, just return.
               return
             if isinstance(value, window.WindowedValue):
               self.element_queue.put(value)
             else:
               self.element_queue.put(_globally_windowed(value))
         else:
           # Native dataflow source.
           with source.reader() as reader:
             # The tracking of time spend reading and bytes read from side
             # inputs is kept behind an experiment flag to test performance
             # impact.
             if 'sideinput_io_metrics' in experiments:
               self.add_byte_counter(reader)
             returns_windowed_values = reader.returns_windowed_values
             for value in reader:
               if self.has_errored:
                 # If any reader has errored, just return.
                 return
               if returns_windowed_values:
                 self.element_queue.put(value)
               else:
                 self.element_queue.put(_globally_windowed(value))
       except Queue.Empty:
         return
   except Exception as e:  # pylint: disable=broad-except
     logging.error('Encountered exception in PrefetchingSourceSetIterable '
                   'reader thread: %s', traceback.format_exc())
     self.reader_exceptions.put(e)
     self.has_errored = True
   finally:
     self.element_queue.put(READER_THREAD_IS_DONE_SENTINEL)
Пример #8
0
 def test_runtime_value_provider_to(self):
     RuntimeValueProvider.set_runtime_options(None)
     rvp = RuntimeValueProvider('arg', 123, int)
     self.assertEqual(JsonValue(is_null=True), to_json_value(rvp))
     # Reset runtime options to avoid side-effects in other tests.
     RuntimeValueProvider.set_runtime_options(None)
 def test_runtime_value_provider_to(self):
     RuntimeValueProvider.runtime_options = None
     rvp = RuntimeValueProvider('arg', 123, int)
     self.assertEquals(JsonValue(is_null=True), to_json_value(rvp))
Пример #10
0
    def _read_side_inputs(self, tags_and_types):
        """Generator reading side inputs in the order prescribed by tags_and_types.

    Args:
      tags_and_types: List of tuples (tag, type). Each side input has a string
        tag that is specified in the worker instruction. The type is actually
        a boolean which is True for singleton input (read just first value)
        and False for collection input (read all values).

    Yields:
      With each iteration it yields the result of reading an entire side source
      either in singleton or collection mode according to the tags_and_types
      argument.
    """
        # Only call this on the old path where side_input_maps was not
        # provided directly.
        assert self.side_input_maps is None

        # Get experiments active in the worker to check for side input metrics exp.
        experiments = set(
            RuntimeValueProvider('experiments', str, '').get().split(','))

        # We will read the side inputs in the order prescribed by the
        # tags_and_types argument because this is exactly the order needed to
        # replace the ArgumentPlaceholder objects in the args/kwargs of the DoFn
        # getting the side inputs.
        #
        # Note that for each tag there could be several read operations in the
        # specification. This can happen for instance if the source has been
        # sharded into several files.
        for i, (side_tag, view_class,
                view_options) in enumerate(tags_and_types):
            sources = []
            # Using the side_tag in the lambda below will trigger a pylint warning.
            # However in this case it is fine because the lambda is used right away
            # while the variable has the value assigned by the current iteration of
            # the for loop.
            # pylint: disable=cell-var-from-loop
            for si in itertools.ifilter(lambda o: o.tag == side_tag,
                                        self.spec.side_inputs):
                if not isinstance(si, operation_specs.WorkerSideInputSource):
                    raise NotImplementedError('Unknown side input type: %r' %
                                              si)
                sources.append(si.source)
                # The tracking of time spend reading and bytes read from side inputs is
                # behind an experiment flag to test its performance impact.
                if 'sideinput_io_metrics' in experiments:
                    si_counter = opcounters.SideInputReadCounter(
                        self.counter_factory,
                        self.state_sampler,
                        declaring_step=self.operation_name,
                        # Inputs are 1-indexed, so we add 1 to i in the side input id
                        input_index=i + 1)
                else:
                    si_counter = opcounters.TransformIOCounter()
            iterator_fn = sideinputs.get_iterator_fn_for_sources(
                sources, read_counter=si_counter)

            # Backwards compatibility for pre BEAM-733 SDKs.
            if isinstance(view_options, tuple):
                if view_class == pvalue.AsSingleton:
                    has_default, default = view_options
                    view_options = {'default': default} if has_default else {}
                else:
                    view_options = {}

            yield apache_sideinputs.SideInputMap(
                view_class, view_options,
                sideinputs.EmulatedIterable(iterator_fn))