Esempio n. 1
0
  def run_pipeline(self, pipeline):
    """Execute the entire pipeline and returns an DirectPipelineResult."""

    # TODO: Move imports to top. Pipeline <-> Runner dependency cause problems
    # with resolving imports when they are at top.
    # pylint: disable=wrong-import-position
    from apache_beam.pipeline import PipelineVisitor
    from apache_beam.runners.direct.consumer_tracking_pipeline_visitor import \
      ConsumerTrackingPipelineVisitor
    from apache_beam.runners.direct.evaluation_context import EvaluationContext
    from apache_beam.runners.direct.executor import Executor
    from apache_beam.runners.direct.transform_evaluator import \
      TransformEvaluatorRegistry
    from apache_beam.testing.test_stream import TestStream

    # Performing configured PTransform overrides.
    pipeline.replace_all(_get_transform_overrides(pipeline.options))

    # If the TestStream I/O is used, use a mock test clock.
    class _TestStreamUsageVisitor(PipelineVisitor):
      """Visitor determining whether a Pipeline uses a TestStream."""

      def __init__(self):
        self.uses_test_stream = False

      def visit_transform(self, applied_ptransform):
        if isinstance(applied_ptransform.transform, TestStream):
          self.uses_test_stream = True

    visitor = _TestStreamUsageVisitor()
    pipeline.visit(visitor)
    clock = TestClock() if visitor.uses_test_stream else RealClock()

    MetricsEnvironment.set_metrics_supported(True)
    logging.info('Running pipeline with DirectRunner.')
    self.consumer_tracking_visitor = ConsumerTrackingPipelineVisitor()
    pipeline.visit(self.consumer_tracking_visitor)

    evaluation_context = EvaluationContext(
        pipeline._options,
        BundleFactory(stacked=pipeline._options.view_as(DirectOptions)
                      .direct_runner_use_stacked_bundle),
        self.consumer_tracking_visitor.root_transforms,
        self.consumer_tracking_visitor.value_to_consumers,
        self.consumer_tracking_visitor.step_names,
        self.consumer_tracking_visitor.views,
        clock)

    executor = Executor(self.consumer_tracking_visitor.value_to_consumers,
                        TransformEvaluatorRegistry(evaluation_context),
                        evaluation_context)
    # DirectRunner does not support injecting
    # PipelineOptions values at runtime
    RuntimeValueProvider.set_runtime_options({})
    # Start the executor. This is a non-blocking call, it will start the
    # execution in background threads and return.
    executor.start(self.consumer_tracking_visitor.root_transforms)
    result = DirectPipelineResult(executor, evaluation_context)

    return result
Esempio n. 2
0
    def run_pipeline(self, pipeline, options):
        """Execute the entire pipeline and returns an DirectPipelineResult."""

        # TODO: Move imports to top. Pipeline <-> Runner dependency cause problems
        # with resolving imports when they are at top.
        # pylint: disable=wrong-import-position
        from apache_beam.pipeline import PipelineVisitor
        from apache_beam.runners.direct.consumer_tracking_pipeline_visitor import \
          ConsumerTrackingPipelineVisitor
        from apache_beam.runners.direct.evaluation_context import EvaluationContext
        from apache_beam.runners.direct.executor import Executor
        from apache_beam.runners.direct.transform_evaluator import \
          TransformEvaluatorRegistry
        from apache_beam.testing.test_stream import TestStream

        # Performing configured PTransform overrides.
        pipeline.replace_all(_get_transform_overrides(options))

        # If the TestStream I/O is used, use a mock test clock.
        class _TestStreamUsageVisitor(PipelineVisitor):
            """Visitor determining whether a Pipeline uses a TestStream."""
            def __init__(self):
                self.uses_test_stream = False

            def visit_transform(self, applied_ptransform):
                if isinstance(applied_ptransform.transform, TestStream):
                    self.uses_test_stream = True

        visitor = _TestStreamUsageVisitor()
        pipeline.visit(visitor)
        clock = TestClock() if visitor.uses_test_stream else RealClock()

        # TODO(BEAM-4274): Circular import runners-metrics. Requires refactoring.
        from apache_beam.metrics.execution import MetricsEnvironment
        MetricsEnvironment.set_metrics_supported(True)
        logging.info('Running pipeline with DirectRunner.')
        self.consumer_tracking_visitor = ConsumerTrackingPipelineVisitor()
        pipeline.visit(self.consumer_tracking_visitor)

        evaluation_context = EvaluationContext(
            options,
            BundleFactory(stacked=options.view_as(
                DirectOptions).direct_runner_use_stacked_bundle),
            self.consumer_tracking_visitor.root_transforms,
            self.consumer_tracking_visitor.value_to_consumers,
            self.consumer_tracking_visitor.step_names,
            self.consumer_tracking_visitor.views, clock)

        executor = Executor(self.consumer_tracking_visitor.value_to_consumers,
                            TransformEvaluatorRegistry(evaluation_context),
                            evaluation_context)
        # DirectRunner does not support injecting
        # PipelineOptions values at runtime
        RuntimeValueProvider.set_runtime_options({})
        # Start the executor. This is a non-blocking call, it will start the
        # execution in background threads and return.
        executor.start(self.consumer_tracking_visitor.root_transforms)
        result = DirectPipelineResult(executor, evaluation_context)

        return result
Esempio n. 3
0
    def _r(runner, options, seeds):
        bigquery.truncate(seeds)
        bigquery.seed(seeds)

        RuntimeValueProvider.set_runtime_options(None)

        runner._run(TestPipeline(options=options), options)
Esempio n. 4
0
def test_inserting_the_dest_table_schema_into_pcollection_runtime():
    with TestPipeline() as p:
        lake_table = RuntimeValueProvider(
            option_name='dest',
            value_type=str,
            default_value=f'{project_id}:lake.wrench_metrics')
        expected = [{
            'schema': [
                gcp_bq.schema.SchemaField('entity_id', 'STRING', 'REQUIRED',
                                          None, ()),
                gcp_bq.schema.SchemaField('tree_user_id', 'INTEGER',
                                          'REQUIRED', None, ()),
                gcp_bq.schema.SchemaField('prediction', 'STRING', 'REQUIRED',
                                          None, ()),
                gcp_bq.schema.SchemaField('client_wrench_id', 'STRING',
                                          'REQUIRED', None, ()),
                gcp_bq.schema.SchemaField('expirement_name', 'STRING',
                                          'NULLABLE', None, ()),
                gcp_bq.schema.SchemaField('processing_datetime', 'DATETIME',
                                          'NULLABLE', None, ()),
                gcp_bq.schema.SchemaField('ingestion_timestamp', 'TIMESTAMP',
                                          'REQUIRED', None, ())
            ],
            'payload': {}
        }]
        pcoll = p | beam.Create([{}])
        schema_pcoll = pcoll | beam.ParDo(
            bq.IngectTableSchema(table=lake_table))
        assert_that(schema_pcoll, equal_to(expected))
        RuntimeValueProvider.set_runtime_options(None)
Esempio n. 5
0
    def test_experiments_setup(self):
        self.assertFalse('feature_1' in RuntimeValueProvider.experiments)

        RuntimeValueProvider.set_runtime_options(
            {'experiments': ['feature_1', 'feature_2']})
        self.assertTrue(isinstance(RuntimeValueProvider.experiments, set))
        self.assertTrue('feature_1' in RuntimeValueProvider.experiments)
        self.assertTrue('feature_2' in RuntimeValueProvider.experiments)
  def test_get_destination_uri_runtime_vp(self):
    # Provide values at job-execution time.
    RuntimeValueProvider.set_runtime_options({'gcs_location': 'gs://bucket'})
    options = self.UserDefinedOptions()
    unique_id = uuid.uuid4().hex

    uri = bigquery_export_destination_uri(options.gcs_location, None, unique_id)
    self.assertEqual(
        uri, 'gs://bucket/' + unique_id + '/bigquery-table-dump-*.json')
Esempio n. 7
0
 def test_runtime_values(self):
   test_runtime_provider = RuntimeValueProvider('test_param', int, None)
   sdk_worker_main.create_harness({
       'CONTROL_API_SERVICE_DESCRIPTOR': '',
       'PIPELINE_OPTIONS': '{"test_param": 37}',
   },
                                  dry_run=True)
   self.assertTrue(test_runtime_provider.is_accessible())
   self.assertEqual(test_runtime_provider.get(), 37)
Esempio n. 8
0
 def test_experiments_setup(self):
     RuntimeValueProvider.set_runtime_options(
         {'experiments': ['feature_1', 'feature_2']})
     self.assertTrue(isinstance(RuntimeValueProvider.experiments, set))
     self.assertTrue('feature_1' in RuntimeValueProvider.experiments)
     self.assertTrue('feature_2' in RuntimeValueProvider.experiments)
     # Clean up runtime_options after this test case finish, otherwise, it'll
     # affect other cases since runtime_options is static attr
     RuntimeValueProvider.set_runtime_options(None)
  def test_get_destination_uri_empty_runtime_vp(self):
    with self.assertRaisesRegex(ValueError,
                                '^ReadFromBigQuery requires a GCS '
                                'location to be provided'):
      # Don't provide any runtime values.
      RuntimeValueProvider.set_runtime_options({})
      options = self.UserDefinedOptions()

      bigquery_export_destination_uri(
          options.gcs_location, None, uuid.uuid4().hex)
Esempio n. 10
0
    def test_set_runtime_option(self):
        # define ValueProvider options, with and without default values
        class UserDefinedOptions1(PipelineOptions):
            @classmethod
            def _add_argparse_args(cls, parser):
                parser.add_value_provider_argument(
                    '--vpt_vp_arg6',
                    help='This keyword argument is a value provider'
                )  # set at runtime

                parser.add_value_provider_argument(  # not set, had default int
                    '-v',
                    '--vpt_vp_arg7',  # with short form
                    default=123,
                    type=int)

                parser.add_value_provider_argument(  # not set, had default str
                    '--vpt_vp-arg8',  # with dash in name
                    default='123',
                    type=str)

                parser.add_value_provider_argument(  # not set and no default
                    '--vpt_vp_arg9', type=float)

                parser.add_value_provider_argument(  # positional argument set
                    'vpt_vp_arg10',  # default & runtime ignored
                    help='This positional argument is a value provider',
                    type=float,
                    default=5.4)

        # provide values at graph-construction time
        # (options not provided here become of the type RuntimeValueProvider)
        options = UserDefinedOptions1(['1.2'])
        self.assertFalse(options.vpt_vp_arg6.is_accessible())
        self.assertFalse(options.vpt_vp_arg7.is_accessible())
        self.assertFalse(options.vpt_vp_arg8.is_accessible())
        self.assertFalse(options.vpt_vp_arg9.is_accessible())
        self.assertTrue(options.vpt_vp_arg10.is_accessible())

        # provide values at job-execution time
        # (options not provided here will use their default, if they have one)
        RuntimeValueProvider.set_runtime_options({
            'vpt_vp_arg6': 'abc',
            'vpt_vp_arg10': '3.2'
        })
        self.assertTrue(options.vpt_vp_arg6.is_accessible())
        self.assertEqual(options.vpt_vp_arg6.get(), 'abc')
        self.assertTrue(options.vpt_vp_arg7.is_accessible())
        self.assertEqual(options.vpt_vp_arg7.get(), 123)
        self.assertTrue(options.vpt_vp_arg8.is_accessible())
        self.assertEqual(options.vpt_vp_arg8.get(), '123')
        self.assertTrue(options.vpt_vp_arg9.is_accessible())
        self.assertIsNone(options.vpt_vp_arg9.get())
        self.assertTrue(options.vpt_vp_arg10.is_accessible())
        self.assertEqual(options.vpt_vp_arg10.get(), 1.2)
Esempio n. 11
0
  def test_experiments_setup(self):
    self.assertFalse('feature_1' in RuntimeValueProvider.experiments)

    RuntimeValueProvider.set_runtime_options(
        {'experiments': ['feature_1', 'feature_2']}
    )
    self.assertTrue(isinstance(RuntimeValueProvider.experiments, set))
    self.assertTrue('feature_1' in RuntimeValueProvider.experiments)
    self.assertTrue('feature_2' in RuntimeValueProvider.experiments)
    # Clean up runtime_options after this test case finish, otherwise, it'll
    # affect other cases since runtime_options is static attr
    RuntimeValueProvider.set_runtime_options(None)
Esempio n. 12
0
    def run_pipeline(
            self,
            pipeline,  # type: Pipeline
            options  # type: pipeline_options.PipelineOptions
    ):
        # type: (...) -> RunnerResult
        RuntimeValueProvider.set_runtime_options({})

        # Setup "beam_fn_api" experiment options if lacked.
        experiments = (options.view_as(
            pipeline_options.DebugOptions).experiments or [])
        if not 'beam_fn_api' in experiments:
            experiments.append('beam_fn_api')
        options.view_as(
            pipeline_options.DebugOptions).experiments = experiments

        # This is sometimes needed if type checking is disabled
        # to enforce that the inputs (and outputs) of GroupByKey operations
        # are known to be KVs.
        from apache_beam.runners.dataflow.dataflow_runner import DataflowRunner
        # TODO: Move group_by_key_input_visitor() to a non-dataflow specific file.
        pipeline.visit(
            DataflowRunner.group_by_key_input_visitor(
                not options.view_as(pipeline_options.TypeOptions
                                    ).allow_non_deterministic_key_coders))
        self._bundle_repeat = self._bundle_repeat or options.view_as(
            pipeline_options.DirectOptions).direct_runner_bundle_repeat
        pipeline_direct_num_workers = options.view_as(
            pipeline_options.DirectOptions).direct_num_workers
        if pipeline_direct_num_workers == 0:
            self._num_workers = multiprocessing.cpu_count()
        else:
            self._num_workers = pipeline_direct_num_workers or self._num_workers

        # set direct workers running mode if it is defined with pipeline options.
        running_mode = \
          options.view_as(pipeline_options.DirectOptions).direct_running_mode
        if running_mode == 'multi_threading':
            self._default_environment = environments.EmbeddedPythonGrpcEnvironment(
            )
        elif running_mode == 'multi_processing':
            command_string = '%s -m apache_beam.runners.worker.sdk_worker_main' \
                          % sys.executable
            self._default_environment = environments.SubprocessSDKEnvironment(
                command_string=command_string)

        self._profiler_factory = Profile.factory_from_options(
            options.view_as(pipeline_options.ProfilingOptions))

        self._latest_run_result = self.run_via_runner_api(
            pipeline.to_runner_api(
                default_environment=self._default_environment))
        return self._latest_run_result
Esempio n. 13
0
  def run_pipeline(self, pipeline):
    """Execute the entire pipeline and returns an DirectPipelineResult."""

    # Performing configured PTransform overrides.
    pipeline.replace_all(self._ptransform_overrides)

    # TODO: Move imports to top. Pipeline <-> Runner dependency cause problems
    # with resolving imports when they are at top.
    # pylint: disable=wrong-import-position
    from apache_beam.runners.direct.consumer_tracking_pipeline_visitor import \
      ConsumerTrackingPipelineVisitor
    from apache_beam.runners.direct.evaluation_context import EvaluationContext
    from apache_beam.runners.direct.executor import Executor
    from apache_beam.runners.direct.transform_evaluator import \
      TransformEvaluatorRegistry

    MetricsEnvironment.set_metrics_supported(True)
    logging.info('Running pipeline with DirectRunner.')
    self.consumer_tracking_visitor = ConsumerTrackingPipelineVisitor()
    pipeline.visit(self.consumer_tracking_visitor)

    clock = TestClock() if self._use_test_clock else RealClock()
    evaluation_context = EvaluationContext(
        pipeline._options,
        BundleFactory(stacked=pipeline._options.view_as(DirectOptions)
                      .direct_runner_use_stacked_bundle),
        self.consumer_tracking_visitor.root_transforms,
        self.consumer_tracking_visitor.value_to_consumers,
        self.consumer_tracking_visitor.step_names,
        self.consumer_tracking_visitor.views,
        clock)

    evaluation_context.use_pvalue_cache(self._cache)

    executor = Executor(self.consumer_tracking_visitor.value_to_consumers,
                        TransformEvaluatorRegistry(evaluation_context),
                        evaluation_context)
    # DirectRunner does not support injecting
    # PipelineOptions values at runtime
    RuntimeValueProvider.set_runtime_options({})
    # Start the executor. This is a non-blocking call, it will start the
    # execution in background threads and return.
    executor.start(self.consumer_tracking_visitor.root_transforms)
    result = DirectPipelineResult(executor, evaluation_context)

    if self._cache:
      # We are running in eager mode, block until the pipeline execution
      # completes in order to have full results in the cache.
      result.wait_until_finish()
      self._cache.finalize()

    return result
Esempio n. 14
0
  def test_set_runtime_option(self):
    # define ValueProvider ptions, with and without default values
    class UserDefinedOptions1(PipelineOptions):
      @classmethod
      def _add_argparse_args(cls, parser):
        parser.add_value_provider_argument(
            '--vpt_vp_arg6',
            help='This keyword argument is a value provider')   # set at runtime

        parser.add_value_provider_argument(         # not set, had default int
            '-v', '--vpt_vp_arg7',                      # with short form
            default=123,
            type=int)

        parser.add_value_provider_argument(         # not set, had default str
            '--vpt_vp-arg8',                            # with dash in name
            default='123',
            type=str)

        parser.add_value_provider_argument(         # not set and no default
            '--vpt_vp_arg9',
            type=float)

        parser.add_value_provider_argument(         # positional argument set
            'vpt_vp_arg10',                         # default & runtime ignored
            help='This positional argument is a value provider',
            type=float,
            default=5.4)

    # provide values at graph-construction time
    # (options not provided here become of the type RuntimeValueProvider)
    options = UserDefinedOptions1(['1.2'])
    self.assertFalse(options.vpt_vp_arg6.is_accessible())
    self.assertFalse(options.vpt_vp_arg7.is_accessible())
    self.assertFalse(options.vpt_vp_arg8.is_accessible())
    self.assertFalse(options.vpt_vp_arg9.is_accessible())
    self.assertTrue(options.vpt_vp_arg10.is_accessible())

    # provide values at job-execution time
    # (options not provided here will use their default, if they have one)
    RuntimeValueProvider.set_runtime_options({'vpt_vp_arg6': 'abc',
                                              'vpt_vp_arg10':'3.2'})
    self.assertTrue(options.vpt_vp_arg6.is_accessible())
    self.assertEqual(options.vpt_vp_arg6.get(), 'abc')
    self.assertTrue(options.vpt_vp_arg7.is_accessible())
    self.assertEqual(options.vpt_vp_arg7.get(), 123)
    self.assertTrue(options.vpt_vp_arg8.is_accessible())
    self.assertEqual(options.vpt_vp_arg8.get(), '123')
    self.assertTrue(options.vpt_vp_arg9.is_accessible())
    self.assertIsNone(options.vpt_vp_arg9.get())
    self.assertTrue(options.vpt_vp_arg10.is_accessible())
    self.assertEqual(options.vpt_vp_arg10.get(), 1.2)
Esempio n. 15
0
 def run_pipeline(self, pipeline, options):
   MetricsEnvironment.set_metrics_supported(False)
   RuntimeValueProvider.set_runtime_options({})
   # This is sometimes needed if type checking is disabled
   # to enforce that the inputs (and outputs) of GroupByKey operations
   # are known to be KVs.
   from apache_beam.runners.dataflow.dataflow_runner import DataflowRunner
   pipeline.visit(DataflowRunner.group_by_key_input_visitor())
   self._bundle_repeat = self._bundle_repeat or options.view_as(
       pipeline_options.DirectOptions).direct_runner_bundle_repeat
   self._profiler_factory = profiler.Profile.factory_from_options(
       options.view_as(pipeline_options.ProfilingOptions))
   return self.run_via_runner_api(pipeline.to_runner_api(
       default_environment=self._default_environment))
Esempio n. 16
0
 def run_pipeline(self, pipeline, options):
   MetricsEnvironment.set_metrics_supported(False)
   RuntimeValueProvider.set_runtime_options({})
   # This is sometimes needed if type checking is disabled
   # to enforce that the inputs (and outputs) of GroupByKey operations
   # are known to be KVs.
   from apache_beam.runners.dataflow.dataflow_runner import DataflowRunner
   pipeline.visit(DataflowRunner.group_by_key_input_visitor())
   self._bundle_repeat = self._bundle_repeat or options.view_as(
       pipeline_options.DirectOptions).direct_runner_bundle_repeat
   self._profiler_factory = profiler.Profile.factory_from_options(
       options.view_as(pipeline_options.ProfilingOptions))
   return self.run_via_runner_api(pipeline.to_runner_api(
       default_environment=self._default_environment))
Esempio n. 17
0
  def test_bytes_read_are_reported(self):
    RuntimeValueProvider.set_runtime_options(
        {'experiments': ['sideinput_io_metrics_v2', 'other']})
    mock_read_counter = mock.MagicMock()
    source_records = ['a', 'b', 'c', 'd']
    sources = [
        FakeSource(source_records, notify_observers=True),
    ]
    iterator_fn = sideinputs.get_iterator_fn_for_sources(
        sources, max_reader_threads=3, read_counter=mock_read_counter)
    assert list(strip_windows(iterator_fn())) == source_records
    mock_read_counter.add_bytes_read.assert_called_with(4)

    # Remove runtime options from the runtime value provider.
    RuntimeValueProvider.set_runtime_options({})
Esempio n. 18
0
  def test_get_destination_uri_fallback_temp_location(self):
    # Don't provide any runtime values.
    RuntimeValueProvider.set_runtime_options({})
    options = self.UserDefinedOptions()

    with self.assertLogs('apache_beam.io.gcp.bigquery_read_internal',
                         level='DEBUG') as context:
      bigquery_export_destination_uri(
          options.gcs_location, 'gs://bucket', uuid.uuid4().hex)
    self.assertEqual(
        context.output,
        [
            'DEBUG:apache_beam.io.gcp.bigquery_read_internal:gcs_location is '
            'empty, using temp_location instead'
        ])
Esempio n. 19
0
  def test_bytes_read_are_reported(self):
    RuntimeValueProvider.set_runtime_options(
        {'experiments': 'sideinput_io_metrics,other'})
    mock_read_counter = mock.MagicMock()
    source_records = ['a', 'b', 'c', 'd']
    sources = [
        FakeSource(source_records, notify_observers=True),
    ]
    iterator_fn = sideinputs.get_iterator_fn_for_sources(
        sources, max_reader_threads=3, read_counter=mock_read_counter)
    assert list(strip_windows(iterator_fn())) == source_records
    mock_read_counter.add_bytes_read.assert_called_with(4)

    # Remove runtime options from the runtime value provider.
    RuntimeValueProvider.set_runtime_options({})
  def test_nested_value_provider_wrap_runtime(self):
    class UserDefinedOptions(PipelineOptions):
      @classmethod
      def _add_argparse_args(cls, parser):
        parser.add_value_provider_argument(
            '--vpt_vp_arg15',
            help='This keyword argument is a value provider')  # set at runtime

    options = UserDefinedOptions([])
    vp = NestedValueProvider(options.vpt_vp_arg15, lambda x: x + x)
    self.assertFalse(vp.is_accessible())

    RuntimeValueProvider.set_runtime_options({'vpt_vp_arg15': 'abc'})

    self.assertTrue(vp.is_accessible())
    self.assertEqual(vp.get(), 'abcabc')
Esempio n. 21
0
    def add_value_provider_argument(self, *args, **kwargs):
        """ValueProvider arguments can be either of type keyword or positional.
    At runtime, even positional arguments will need to be supplied in the
    key/value form.
    """
        # Extract the option name from positional argument ['pos_arg']
        assert args != () and len(args[0]) >= 1
        if args[0][0] != '-':
            option_name = args[0]
            if kwargs.get('nargs') is None:  # make them optionally templated
                kwargs['nargs'] = '?'
        else:
            # or keyword arguments like [--kw_arg, -k, -w] or [--kw-arg]
            option_name = [i.replace('--', '') for i in args
                           if i[:2] == '--'][0]

        # reassign the type to make room for using
        # StaticValueProvider as the type for add_argument
        value_type = kwargs.get('type') or str
        kwargs['type'] = _static_value_provider_of(value_type)

        # reassign default to default_value to make room for using
        # RuntimeValueProvider as the default for add_argument
        default_value = kwargs.get('default')
        kwargs['default'] = RuntimeValueProvider(option_name=option_name,
                                                 value_type=value_type,
                                                 default_value=default_value)

        # have add_argument do most of the work
        self.add_argument(*args, **kwargs)
    def test_value_provider_options(self):
        class UserOptions(PipelineOptions):
            @classmethod
            def _add_argparse_args(cls, parser):
                parser.add_value_provider_argument(
                    '--pot_vp_arg1', help='This flag is a value provider')

                parser.add_value_provider_argument('--pot_vp_arg2',
                                                   default=1,
                                                   type=int)

                parser.add_argument('--pot_non_vp_arg1', default=1, type=int)

        # Provide values: if not provided, the option becomes of the type runtime vp
        options = UserOptions(['--pot_vp_arg1', 'hello'])
        self.assertIsInstance(options.pot_vp_arg1, StaticValueProvider)
        self.assertIsInstance(options.pot_vp_arg2, RuntimeValueProvider)
        self.assertIsInstance(options.pot_non_vp_arg1, int)

        # Values can be overwritten
        options = UserOptions(pot_vp_arg1=5,
                              pot_vp_arg2=StaticValueProvider(value_type=str,
                                                              value='bye'),
                              pot_non_vp_arg1=RuntimeValueProvider(
                                  option_name='foo',
                                  value_type=int,
                                  default_value=10))
        self.assertEqual(options.pot_vp_arg1, 5)
        self.assertTrue(options.pot_vp_arg2.is_accessible(),
                        '%s is not accessible' % options.pot_vp_arg2)
        self.assertEqual(options.pot_vp_arg2.get(), 'bye')
        self.assertFalse(options.pot_non_vp_arg1.is_accessible())

        with self.assertRaises(RuntimeError):
            options.pot_non_vp_arg1.get()
Esempio n. 23
0
    def test_string_or_value_provider_only(self):
        str_file_pattern = tempfile.NamedTemporaryFile(delete=False).name
        self.assertEqual(str_file_pattern,
                         FileBasedSource(str_file_pattern)._pattern.value)

        static_vp_file_pattern = StaticValueProvider(value_type=str,
                                                     value=str_file_pattern)
        self.assertEqual(static_vp_file_pattern,
                         FileBasedSource(static_vp_file_pattern)._pattern)

        runtime_vp_file_pattern = RuntimeValueProvider(
            option_name='arg', value_type=str, default_value=str_file_pattern)
        self.assertEqual(runtime_vp_file_pattern,
                         FileBasedSource(runtime_vp_file_pattern)._pattern)
        # Reset runtime options to avoid side-effects in other tests.
        RuntimeValueProvider.set_runtime_options(None)

        invalid_file_pattern = 123
        with self.assertRaises(TypeError):
            FileBasedSource(invalid_file_pattern)
Esempio n. 24
0
def ndjson(env, cloudstorage, record_testsuite_property):
    cloudstorage.client.delete_blob(bucket, dest_blob_name)
    assert cloudstorage.client.blob_exists(bucket, dest_blob_name) is False

    sql = BigQuery.querybuilder(union=('all', [
        BigQuery.querybuilder(
            select=[('NULL',
                     'none'), ('True', 'true_bool'), (
                         'False', 'false_bool'), (
                             '"2020-04-03"',
                             'date'), ('"2020-04-03 13:45:00"', 'datetime'),
                    ('"1966-06-06 06:06:06.666666 UTC"',
                     'timestamp'), ('"STRING"',
                                    'string'), ('234',
                                                'integer'), ('123.54',
                                                             'float')]),
        BigQuery.querybuilder(select=['NULL'] * 9),
        BigQuery.querybuilder(select=[
            '"String"', 'False', 'True', '"1993-09-03"',
            '"1993-09-03 03:44:00"', '"1993-09-03 03:44:00.777555 UTC"',
            '"Not String"', '567', '456'
        ])
    ]))

    RuntimeValueProvider.set_runtime_options(None)

    options = RuntimeOptions([
        '--env', env['env'], '--query',
        str(sql), '--destination', f'gs://{bucket}/{blob_name}'
    ])
    Runner._run(TestPipeline(options=options), options)

    assert cloudstorage.client.blob_exists(bucket, dest_blob_name) is True

    zbytes = cloudstorage.client.download_blob_as_string(
        bucket, dest_blob_name)
    bytes = gzip.decompress(zbytes)
    lns = bytes.decode('utf8').rstrip().split('\n')
    yield [json.loads(l) for l in lns]
Esempio n. 25
0
 def _reader_thread(self):
     # pylint: disable=too-many-nested-blocks
     experiments = set(
         RuntimeValueProvider.get_value('experiments', str, '').split(','))
     try:
         while True:
             try:
                 source = self.sources_queue.get_nowait()
                 if isinstance(source, iobase.BoundedSource):
                     for value in source.read(
                             source.get_range_tracker(None, None)):
                         if self.has_errored:
                             # If any reader has errored, just return.
                             return
                         if isinstance(value, window.WindowedValue):
                             self.element_queue.put(value)
                         else:
                             self.element_queue.put(
                                 _globally_windowed(value))
                 else:
                     # Native dataflow source.
                     with source.reader() as reader:
                         # The tracking of time spend reading and bytes read from side
                         # inputs is kept behind an experiment flag to test performance
                         # impact.
                         if 'sideinput_io_metrics' in experiments:
                             self.add_byte_counter(reader)
                         returns_windowed_values = reader.returns_windowed_values
                         for value in reader:
                             if self.has_errored:
                                 # If any reader has errored, just return.
                                 return
                             if returns_windowed_values:
                                 self.element_queue.put(value)
                             else:
                                 self.element_queue.put(
                                     _globally_windowed(value))
             except Queue.Empty:
                 return
     except Exception as e:  # pylint: disable=broad-except
         logging.error(
             'Encountered exception in PrefetchingSourceSetIterable '
             'reader thread: %s', traceback.format_exc())
         self.reader_exceptions.put(e)
         self.has_errored = True
     finally:
         self.element_queue.put(READER_THREAD_IS_DONE_SENTINEL)
Esempio n. 26
0
 def _reader_thread(self):
   # pylint: disable=too-many-nested-blocks
   experiments = RuntimeValueProvider.get_value('experiments', list, [])
   try:
     while True:
       try:
         source = self.sources_queue.get_nowait()
         if isinstance(source, iobase.BoundedSource):
           for value in source.read(source.get_range_tracker(None, None)):
             if self.has_errored:
               # If any reader has errored, just return.
               return
             if isinstance(value, window.WindowedValue):
               self.element_queue.put(value)
             else:
               self.element_queue.put(_globally_windowed(value))
         else:
           # Native dataflow source.
           with source.reader() as reader:
             # The tracking of time spend reading and bytes read from side
             # inputs is kept behind an experiment flag to test performance
             # impact.
             if 'sideinput_io_metrics' in experiments:
               self.add_byte_counter(reader)
             returns_windowed_values = reader.returns_windowed_values
             for value in reader:
               if self.has_errored:
                 # If any reader has errored, just return.
                 return
               if returns_windowed_values:
                 self.element_queue.put(value)
               else:
                 self.element_queue.put(_globally_windowed(value))
       except Queue.Empty:
         return
   except Exception as e:  # pylint: disable=broad-except
     logging.error('Encountered exception in PrefetchingSourceSetIterable '
                   'reader thread: %s', traceback.format_exc())
     self.reader_exceptions.put(e)
     self.has_errored = True
   finally:
     self.element_queue.put(READER_THREAD_IS_DONE_SENTINEL)
Esempio n. 27
0
def test_runtime_serialized_file_list_is_deserialized_and_processed_by_insertion_order(
        cloudstorage):
    with TestPipeline() as p:
        bucket = f'{project_id}-cdc-imports'

        # Update sort_key based on the filename format
        def _sort_key(f):
            delimeter = '-'
            ts = f[f.rfind(delimeter) + 1:]
            return int(ts) if ts.isdigit() else f

        _sort_key = bytes.hex(dill.dumps(_sort_key))

        runtime_env = RuntimeValueProvider(option_name='env',
                                           value_type=str,
                                           default_value='local')
        runtime_bucket = RuntimeValueProvider(option_name='bucket',
                                              value_type=str,
                                              default_value=bucket)
        runtime_startswith = RuntimeValueProvider(
            option_name='files_startwith',
            value_type=str,
            default_value='vibe-tree-user-statuses-final')
        runtime_sort_key = RuntimeValueProvider(option_name='sort_key',
                                                value_type=str,
                                                default_value=_sort_key)
        [b.delete() for b in cloudstorage.client.list_blobs(bucket)]
        file_paths = [
            'vibe-tree-user-statuses-final-0083c-1987582612499',
            'vibe-tree-user-statuses-final-003c-1587582612405',
            'vibe-order-items-final-0030dd8697-1588231505823'
        ]
        expected_output = [
            'gs://icentris-ml-local-wbrito-cdc-imports/vibe-tree-user-statuses-final-003c-1587582612405',
            'gs://icentris-ml-local-wbrito-cdc-imports/vibe-tree-user-statuses-final-0083c-1987582612499'
        ]
        for f in file_paths:
            cloudstorage.client.upload_blob_from_string(bucket, f, f)

        p_paths = p | FileListIteratorTransform(
            env=runtime_env,
            bucket=runtime_bucket,
            files_startwith=runtime_startswith,
            sort_key=runtime_sort_key)
        assert_that(p_paths, equal_to(expected_output))
        RuntimeValueProvider.set_runtime_options(None)
 def test_runtime_value_provider_to(self):
     RuntimeValueProvider.runtime_options = None
     rvp = RuntimeValueProvider('arg', 123, int)
     self.assertEquals(JsonValue(is_null=True), to_json_value(rvp))
Esempio n. 29
0
  def _read_side_inputs(self, tags_and_types):
    """Generator reading side inputs in the order prescribed by tags_and_types.

    Args:
      tags_and_types: List of tuples (tag, type). Each side input has a string
        tag that is specified in the worker instruction. The type is actually
        a boolean which is True for singleton input (read just first value)
        and False for collection input (read all values).

    Yields:
      With each iteration it yields the result of reading an entire side source
      either in singleton or collection mode according to the tags_and_types
      argument.
    """
    # Only call this on the old path where side_input_maps was not
    # provided directly.
    assert self.side_input_maps is None

    # Get experiments active in the worker to check for side input metrics exp.
    experiments = set(
        RuntimeValueProvider.get_value('experiments', str, '').split(','))

    # We will read the side inputs in the order prescribed by the
    # tags_and_types argument because this is exactly the order needed to
    # replace the ArgumentPlaceholder objects in the args/kwargs of the DoFn
    # getting the side inputs.
    #
    # Note that for each tag there could be several read operations in the
    # specification. This can happen for instance if the source has been
    # sharded into several files.
    for i, (side_tag, view_class, view_options) in enumerate(tags_and_types):
      sources = []
      # Using the side_tag in the lambda below will trigger a pylint warning.
      # However in this case it is fine because the lambda is used right away
      # while the variable has the value assigned by the current iteration of
      # the for loop.
      # pylint: disable=cell-var-from-loop
      for si in itertools.ifilter(
          lambda o: o.tag == side_tag, self.spec.side_inputs):
        if not isinstance(si, operation_specs.WorkerSideInputSource):
          raise NotImplementedError('Unknown side input type: %r' % si)
        sources.append(si.source)
        # The tracking of time spend reading and bytes read from side inputs is
        # behind an experiment flag to test its performance impact.
        if 'sideinput_io_metrics' in experiments:
          si_counter = opcounters.SideInputReadCounter(
              self.counter_factory,
              self.state_sampler,
              declaring_step=self.operation_name,
              # Inputs are 1-indexed, so we add 1 to i in the side input id
              input_index=i + 1)
        else:
          si_counter = opcounters.TransformIOCounter()
      iterator_fn = sideinputs.get_iterator_fn_for_sources(
          sources, read_counter=si_counter)

      # Backwards compatibility for pre BEAM-733 SDKs.
      if isinstance(view_options, tuple):
        if view_class == pvalue.AsSingleton:
          has_default, default = view_options
          view_options = {'default': default} if has_default else {}
        else:
          view_options = {}

      yield apache_sideinputs.SideInputMap(
          view_class, view_options, sideinputs.EmulatedIterable(iterator_fn))
Esempio n. 30
0
 def tearDown(self):
     # Reset runtime options to avoid side-effects in other tests.
     RuntimeValueProvider.set_runtime_options(None)
Esempio n. 31
0
 def setUp(self):
     # Reset runtime options to avoid side-effects caused by other tests.
     # Note that is_accessible assertions require runtime_options to
     # be uninitialized.
     RuntimeValueProvider.set_runtime_options(None)
Esempio n. 32
0
def main(unused_argv):
    """Main entry point for SDK Fn Harness."""
    if 'LOGGING_API_SERVICE_DESCRIPTOR' in os.environ:
        try:
            logging_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
            text_format.Merge(os.environ['LOGGING_API_SERVICE_DESCRIPTOR'],
                              logging_service_descriptor)

            # Send all logs to the runner.
            fn_log_handler = FnApiLogRecordHandler(logging_service_descriptor)
            # TODO(BEAM-5468): This should be picked up from pipeline options.
            logging.getLogger().setLevel(logging.DEBUG)
            logging.getLogger().addHandler(fn_log_handler)
            logging.info('Logging handler created.')
        except Exception:
            logging.error(
                "Failed to set up logging handler, continuing without.",
                exc_info=True)
            fn_log_handler = None
    else:
        fn_log_handler = None

    # Start status HTTP server thread.
    thread = threading.Thread(name='status_http_server',
                              target=StatusServer().start)
    thread.daemon = True
    thread.setName('status-server-demon')
    thread.start()

    if 'PIPELINE_OPTIONS' in os.environ:
        sdk_pipeline_options = _parse_pipeline_options(
            os.environ['PIPELINE_OPTIONS'])
    else:
        sdk_pipeline_options = PipelineOptions.from_dictionary({})

    if 'SEMI_PERSISTENT_DIRECTORY' in os.environ:
        semi_persistent_directory = os.environ['SEMI_PERSISTENT_DIRECTORY']
    else:
        semi_persistent_directory = None

    logging.info('semi_persistent_directory: %s', semi_persistent_directory)
    _worker_id = os.environ.get('WORKER_ID', None)

    try:
        _load_main_session(semi_persistent_directory)
    except Exception:  # pylint: disable=broad-except
        exception_details = traceback.format_exc()
        logging.error('Could not load main session: %s',
                      exception_details,
                      exc_info=True)

    try:
        logging.info('Python sdk harness started with pipeline_options: %s',
                     sdk_pipeline_options.get_all_options(drop_default=True))
        RuntimeValueProvider.set_runtime_options(
            sdk_pipeline_options.view_as(
                pipeline_options.HadoopFileSystemOptions).get_all_options())
        service_descriptor = endpoints_pb2.ApiServiceDescriptor()
        text_format.Merge(os.environ['CONTROL_API_SERVICE_DESCRIPTOR'],
                          service_descriptor)
        # TODO(robertwb): Support credentials.
        assert not service_descriptor.oauth2_client_credentials_grant.url
        SdkHarness(control_address=service_descriptor.url,
                   worker_count=_get_worker_count(sdk_pipeline_options),
                   worker_id=_worker_id,
                   profiler_factory=profiler.Profile.factory_from_options(
                       sdk_pipeline_options.view_as(
                           pipeline_options.ProfilingOptions))).run()
        logging.info('Python sdk harness exiting.')
    except:  # pylint: disable=broad-except
        logging.exception('Python sdk harness failed: ')
        raise
    finally:
        if fn_log_handler:
            fn_log_handler.close()
Esempio n. 33
0
 def setUp(self):
     # Reset runtime options, since the is_accessible assertions require them to
     # be uninitialized.
     RuntimeValueProvider.set_runtime_options(None)
Esempio n. 34
0
  def _read_side_inputs(self, tags_and_types):
    """Generator reading side inputs in the order prescribed by tags_and_types.

    Args:
      tags_and_types: List of tuples (tag, type). Each side input has a string
        tag that is specified in the worker instruction. The type is actually
        a boolean which is True for singleton input (read just first value)
        and False for collection input (read all values).

    Yields:
      With each iteration it yields the result of reading an entire side source
      either in singleton or collection mode according to the tags_and_types
      argument.
    """
    # Only call this on the old path where side_input_maps was not
    # provided directly.
    assert self.side_input_maps is None

    # Get experiments active in the worker to check for side input metrics exp.
    experiments = RuntimeValueProvider.get_value('experiments', list, [])

    # We will read the side inputs in the order prescribed by the
    # tags_and_types argument because this is exactly the order needed to
    # replace the ArgumentPlaceholder objects in the args/kwargs of the DoFn
    # getting the side inputs.
    #
    # Note that for each tag there could be several read operations in the
    # specification. This can happen for instance if the source has been
    # sharded into several files.
    for i, (side_tag, view_class, view_options) in enumerate(tags_and_types):
      sources = []
      # Using the side_tag in the lambda below will trigger a pylint warning.
      # However in this case it is fine because the lambda is used right away
      # while the variable has the value assigned by the current iteration of
      # the for loop.
      # pylint: disable=cell-var-from-loop
      for si in itertools.ifilter(
          lambda o: o.tag == side_tag, self.spec.side_inputs):
        if not isinstance(si, operation_specs.WorkerSideInputSource):
          raise NotImplementedError('Unknown side input type: %r' % si)
        sources.append(si.source)
        # The tracking of time spend reading and bytes read from side inputs is
        # behind an experiment flag to test its performance impact.
        if 'sideinput_io_metrics' in experiments:
          si_counter = opcounters.SideInputReadCounter(
              self.counter_factory,
              self.state_sampler,
              declaring_step=self.name_context.step_name,
              # Inputs are 1-indexed, so we add 1 to i in the side input id
              input_index=i + 1)
        else:
          si_counter = opcounters.TransformIOCounter()
      iterator_fn = sideinputs.get_iterator_fn_for_sources(
          sources, read_counter=si_counter)

      # Backwards compatibility for pre BEAM-733 SDKs.
      if isinstance(view_options, tuple):
        if view_class == pvalue.AsSingleton:
          has_default, default = view_options
          view_options = {'default': default} if has_default else {}
        else:
          view_options = {}

      yield apache_sideinputs.SideInputMap(
          view_class, view_options, sideinputs.EmulatedIterable(iterator_fn))
Esempio n. 35
0
 def test_runtime_value_provider_to(self):
     RuntimeValueProvider.set_runtime_options(None)
     rvp = RuntimeValueProvider('arg', 123, int)
     self.assertEqual(JsonValue(is_null=True), to_json_value(rvp))
     # Reset runtime options to avoid side-effects in other tests.
     RuntimeValueProvider.set_runtime_options(None)
Esempio n. 36
0
def create_harness(environment, dry_run=False):
    """Creates SDK Fn Harness."""
    if 'LOGGING_API_SERVICE_DESCRIPTOR' in environment:
        try:
            logging_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
            text_format.Merge(environment['LOGGING_API_SERVICE_DESCRIPTOR'],
                              logging_service_descriptor)

            # Send all logs to the runner.
            fn_log_handler = FnApiLogRecordHandler(logging_service_descriptor)
            # TODO(BEAM-5468): This should be picked up from pipeline options.
            logging.getLogger().setLevel(logging.INFO)
            logging.getLogger().addHandler(fn_log_handler)
            _LOGGER.info('Logging handler created.')
        except Exception:
            _LOGGER.error(
                "Failed to set up logging handler, continuing without.",
                exc_info=True)
            fn_log_handler = None
    else:
        fn_log_handler = None

    pipeline_options_dict = _load_pipeline_options(
        environment.get('PIPELINE_OPTIONS'))
    # These are used for dataflow templates.
    RuntimeValueProvider.set_runtime_options(pipeline_options_dict)
    sdk_pipeline_options = PipelineOptions.from_dictionary(
        pipeline_options_dict)
    filesystems.FileSystems.set_options(sdk_pipeline_options)

    if 'SEMI_PERSISTENT_DIRECTORY' in environment:
        semi_persistent_directory = environment['SEMI_PERSISTENT_DIRECTORY']
    else:
        semi_persistent_directory = None

    _LOGGER.info('semi_persistent_directory: %s', semi_persistent_directory)
    _worker_id = environment.get('WORKER_ID', None)

    try:
        _load_main_session(semi_persistent_directory)
    except CorruptMainSessionException:
        exception_details = traceback.format_exc()
        _LOGGER.error('Could not load main session: %s',
                      exception_details,
                      exc_info=True)
        raise
    except Exception:  # pylint: disable=broad-except
        exception_details = traceback.format_exc()
        _LOGGER.error('Could not load main session: %s',
                      exception_details,
                      exc_info=True)

    _LOGGER.info('Pipeline_options: %s',
                 sdk_pipeline_options.get_all_options(drop_default=True))
    control_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
    status_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
    text_format.Merge(environment['CONTROL_API_SERVICE_DESCRIPTOR'],
                      control_service_descriptor)
    if 'STATUS_API_SERVICE_DESCRIPTOR' in environment:
        text_format.Merge(environment['STATUS_API_SERVICE_DESCRIPTOR'],
                          status_service_descriptor)
    # TODO(robertwb): Support authentication.
    assert not control_service_descriptor.HasField('authentication')

    experiments = sdk_pipeline_options.view_as(DebugOptions).experiments or []
    enable_heap_dump = 'enable_heap_dump' in experiments
    if dry_run:
        return
    sdk_harness = SdkHarness(
        control_address=control_service_descriptor.url,
        status_address=status_service_descriptor.url,
        worker_id=_worker_id,
        state_cache_size=_get_state_cache_size(experiments),
        data_buffer_time_limit_ms=_get_data_buffer_time_limit_ms(experiments),
        profiler_factory=profiler.Profile.factory_from_options(
            sdk_pipeline_options.view_as(ProfilingOptions)),
        enable_heap_dump=enable_heap_dump)
    return fn_log_handler, sdk_harness