def add_value_provider_argument(self, *args, **kwargs): """ValueProvider arguments can be either of type keyword or positional. At runtime, even positional arguments will need to be supplied in the key/value form. """ # Extract the option name from positional argument ['pos_arg'] assert args != () and len(args[0]) >= 1 if args[0][0] != '-': option_name = args[0] if kwargs.get('nargs') is None: # make them optionally templated kwargs['nargs'] = '?' else: # or keyword arguments like [--kw_arg, -k, -w] or [--kw-arg] option_name = [i.replace('--', '') for i in args if i[:2] == '--'][0] # reassign the type to make room for using # StaticValueProvider as the type for add_argument value_type = kwargs.get('type') or str kwargs['type'] = _static_value_provider_of(value_type) # reassign default to default_value to make room for using # RuntimeValueProvider as the default for add_argument default_value = kwargs.get('default') kwargs['default'] = RuntimeValueProvider(option_name=option_name, value_type=value_type, default_value=default_value) # have add_argument do most of the work self.add_argument(*args, **kwargs)
def test_inserting_the_dest_table_schema_into_pcollection_runtime(): with TestPipeline() as p: lake_table = RuntimeValueProvider( option_name='dest', value_type=str, default_value=f'{project_id}:lake.wrench_metrics') expected = [{ 'schema': [ gcp_bq.schema.SchemaField('entity_id', 'STRING', 'REQUIRED', None, ()), gcp_bq.schema.SchemaField('tree_user_id', 'INTEGER', 'REQUIRED', None, ()), gcp_bq.schema.SchemaField('prediction', 'STRING', 'REQUIRED', None, ()), gcp_bq.schema.SchemaField('client_wrench_id', 'STRING', 'REQUIRED', None, ()), gcp_bq.schema.SchemaField('expirement_name', 'STRING', 'NULLABLE', None, ()), gcp_bq.schema.SchemaField('processing_datetime', 'DATETIME', 'NULLABLE', None, ()), gcp_bq.schema.SchemaField('ingestion_timestamp', 'TIMESTAMP', 'REQUIRED', None, ()) ], 'payload': {} }] pcoll = p | beam.Create([{}]) schema_pcoll = pcoll | beam.ParDo( bq.IngectTableSchema(table=lake_table)) assert_that(schema_pcoll, equal_to(expected)) RuntimeValueProvider.set_runtime_options(None)
def test_value_provider_options(self): class UserOptions(PipelineOptions): @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( '--pot_vp_arg1', help='This flag is a value provider') parser.add_value_provider_argument('--pot_vp_arg2', default=1, type=int) parser.add_argument('--pot_non_vp_arg1', default=1, type=int) # Provide values: if not provided, the option becomes of the type runtime vp options = UserOptions(['--pot_vp_arg1', 'hello']) self.assertIsInstance(options.pot_vp_arg1, StaticValueProvider) self.assertIsInstance(options.pot_vp_arg2, RuntimeValueProvider) self.assertIsInstance(options.pot_non_vp_arg1, int) # Values can be overwritten options = UserOptions(pot_vp_arg1=5, pot_vp_arg2=StaticValueProvider(value_type=str, value='bye'), pot_non_vp_arg1=RuntimeValueProvider( option_name='foo', value_type=int, default_value=10)) self.assertEqual(options.pot_vp_arg1, 5) self.assertTrue(options.pot_vp_arg2.is_accessible(), '%s is not accessible' % options.pot_vp_arg2) self.assertEqual(options.pot_vp_arg2.get(), 'bye') self.assertFalse(options.pot_non_vp_arg1.is_accessible()) with self.assertRaises(RuntimeError): options.pot_non_vp_arg1.get()
def test_runtime_serialized_file_list_is_deserialized_and_processed_by_insertion_order( cloudstorage): with TestPipeline() as p: bucket = f'{project_id}-cdc-imports' # Update sort_key based on the filename format def _sort_key(f): delimeter = '-' ts = f[f.rfind(delimeter) + 1:] return int(ts) if ts.isdigit() else f _sort_key = bytes.hex(dill.dumps(_sort_key)) runtime_env = RuntimeValueProvider(option_name='env', value_type=str, default_value='local') runtime_bucket = RuntimeValueProvider(option_name='bucket', value_type=str, default_value=bucket) runtime_startswith = RuntimeValueProvider( option_name='files_startwith', value_type=str, default_value='vibe-tree-user-statuses-final') runtime_sort_key = RuntimeValueProvider(option_name='sort_key', value_type=str, default_value=_sort_key) [b.delete() for b in cloudstorage.client.list_blobs(bucket)] file_paths = [ 'vibe-tree-user-statuses-final-0083c-1987582612499', 'vibe-tree-user-statuses-final-003c-1587582612405', 'vibe-order-items-final-0030dd8697-1588231505823' ] expected_output = [ 'gs://icentris-ml-local-wbrito-cdc-imports/vibe-tree-user-statuses-final-003c-1587582612405', 'gs://icentris-ml-local-wbrito-cdc-imports/vibe-tree-user-statuses-final-0083c-1987582612499' ] for f in file_paths: cloudstorage.client.upload_blob_from_string(bucket, f, f) p_paths = p | FileListIteratorTransform( env=runtime_env, bucket=runtime_bucket, files_startwith=runtime_startswith, sort_key=runtime_sort_key) assert_that(p_paths, equal_to(expected_output)) RuntimeValueProvider.set_runtime_options(None)
def test_runtime_values(self): test_runtime_provider = RuntimeValueProvider('test_param', int, None) sdk_worker_main.create_harness({ 'CONTROL_API_SERVICE_DESCRIPTOR': '', 'PIPELINE_OPTIONS': '{"test_param": 37}', }, dry_run=True) self.assertTrue(test_runtime_provider.is_accessible()) self.assertEqual(test_runtime_provider.get(), 37)
def test_string_or_value_provider_only(self): str_file_pattern = tempfile.NamedTemporaryFile(delete=False).name self.assertEqual(str_file_pattern, FileBasedSource(str_file_pattern)._pattern.value) static_vp_file_pattern = StaticValueProvider(value_type=str, value=str_file_pattern) self.assertEqual(static_vp_file_pattern, FileBasedSource(static_vp_file_pattern)._pattern) runtime_vp_file_pattern = RuntimeValueProvider( option_name='arg', value_type=str, default_value=str_file_pattern) self.assertEqual(runtime_vp_file_pattern, FileBasedSource(runtime_vp_file_pattern)._pattern) invalid_file_pattern = 123 with self.assertRaises(TypeError): FileBasedSource(invalid_file_pattern)
def _reader_thread(self): # pylint: disable=too-many-nested-blocks experiments = set( RuntimeValueProvider('experiments', str, '').get().split(',')) try: while True: try: source = self.sources_queue.get_nowait() if isinstance(source, iobase.BoundedSource): for value in source.read(source.get_range_tracker(None, None)): if self.has_errored: # If any reader has errored, just return. return if isinstance(value, window.WindowedValue): self.element_queue.put(value) else: self.element_queue.put(_globally_windowed(value)) else: # Native dataflow source. with source.reader() as reader: # The tracking of time spend reading and bytes read from side # inputs is kept behind an experiment flag to test performance # impact. if 'sideinput_io_metrics' in experiments: self.add_byte_counter(reader) returns_windowed_values = reader.returns_windowed_values for value in reader: if self.has_errored: # If any reader has errored, just return. return if returns_windowed_values: self.element_queue.put(value) else: self.element_queue.put(_globally_windowed(value)) except Queue.Empty: return except Exception as e: # pylint: disable=broad-except logging.error('Encountered exception in PrefetchingSourceSetIterable ' 'reader thread: %s', traceback.format_exc()) self.reader_exceptions.put(e) self.has_errored = True finally: self.element_queue.put(READER_THREAD_IS_DONE_SENTINEL)
def test_runtime_value_provider_to(self): RuntimeValueProvider.set_runtime_options(None) rvp = RuntimeValueProvider('arg', 123, int) self.assertEqual(JsonValue(is_null=True), to_json_value(rvp)) # Reset runtime options to avoid side-effects in other tests. RuntimeValueProvider.set_runtime_options(None)
def test_runtime_value_provider_to(self): RuntimeValueProvider.runtime_options = None rvp = RuntimeValueProvider('arg', 123, int) self.assertEquals(JsonValue(is_null=True), to_json_value(rvp))
def _read_side_inputs(self, tags_and_types): """Generator reading side inputs in the order prescribed by tags_and_types. Args: tags_and_types: List of tuples (tag, type). Each side input has a string tag that is specified in the worker instruction. The type is actually a boolean which is True for singleton input (read just first value) and False for collection input (read all values). Yields: With each iteration it yields the result of reading an entire side source either in singleton or collection mode according to the tags_and_types argument. """ # Only call this on the old path where side_input_maps was not # provided directly. assert self.side_input_maps is None # Get experiments active in the worker to check for side input metrics exp. experiments = set( RuntimeValueProvider('experiments', str, '').get().split(',')) # We will read the side inputs in the order prescribed by the # tags_and_types argument because this is exactly the order needed to # replace the ArgumentPlaceholder objects in the args/kwargs of the DoFn # getting the side inputs. # # Note that for each tag there could be several read operations in the # specification. This can happen for instance if the source has been # sharded into several files. for i, (side_tag, view_class, view_options) in enumerate(tags_and_types): sources = [] # Using the side_tag in the lambda below will trigger a pylint warning. # However in this case it is fine because the lambda is used right away # while the variable has the value assigned by the current iteration of # the for loop. # pylint: disable=cell-var-from-loop for si in itertools.ifilter(lambda o: o.tag == side_tag, self.spec.side_inputs): if not isinstance(si, operation_specs.WorkerSideInputSource): raise NotImplementedError('Unknown side input type: %r' % si) sources.append(si.source) # The tracking of time spend reading and bytes read from side inputs is # behind an experiment flag to test its performance impact. if 'sideinput_io_metrics' in experiments: si_counter = opcounters.SideInputReadCounter( self.counter_factory, self.state_sampler, declaring_step=self.operation_name, # Inputs are 1-indexed, so we add 1 to i in the side input id input_index=i + 1) else: si_counter = opcounters.TransformIOCounter() iterator_fn = sideinputs.get_iterator_fn_for_sources( sources, read_counter=si_counter) # Backwards compatibility for pre BEAM-733 SDKs. if isinstance(view_options, tuple): if view_class == pvalue.AsSingleton: has_default, default = view_options view_options = {'default': default} if has_default else {} else: view_options = {} yield apache_sideinputs.SideInputMap( view_class, view_options, sideinputs.EmulatedIterable(iterator_fn))