Beispiel #1
0
    def test_construct_with_custom_config(self):
        input_base = types.TfxArtifact(type_name='ExternalPath')
        custom_config = example_gen_pb2.CustomConfig(
            custom_config=any_pb2.Any())
        example_gen = component.FileBasedExampleGen(
            input_base=channel.as_channel([input_base]),
            custom_config=custom_config,
            executor_class=TestExampleGenExecutor)

        stored_custom_config = example_gen_pb2.CustomConfig()
        json_format.Parse(example_gen.exec_properties['custom_config'],
                          stored_custom_config)
        self.assertEqual(custom_config, stored_custom_config)
Beispiel #2
0
    def testConstructWithCustomConfig(self):
        custom_config = example_gen_pb2.CustomConfig(
            custom_config=any_pb2.Any())
        example_gen = component.FileBasedExampleGen(
            input_base='path',
            custom_config=custom_config,
            custom_executor_spec=executor_spec.ExecutorClassSpec(
                TestExampleGenExecutor))

        stored_custom_config = example_gen_pb2.CustomConfig()
        json_format.Parse(example_gen.exec_properties['custom_config'],
                          stored_custom_config)
        self.assertEqual(custom_config, stored_custom_config)
Beispiel #3
0
  def testConstructWithCustomConfig(self):
    input_base = standard_artifacts.ExternalArtifact()
    custom_config = example_gen_pb2.CustomConfig(custom_config=any_pb2.Any())
    example_gen = component.FileBasedExampleGen(
        input_base=channel_utils.as_channel([input_base]),
        custom_config=custom_config,
        custom_executor_spec=executor_spec.ExecutorClassSpec(
            TestExampleGenExecutor))

    stored_custom_config = example_gen_pb2.CustomConfig()
    json_format.Parse(example_gen.exec_properties['custom_config'],
                      stored_custom_config)
    self.assertEqual(custom_config, stored_custom_config)
Beispiel #4
0
  def testConstructWithCustomConfig(self):
    custom_config = example_gen_pb2.CustomConfig(custom_config=any_pb2.Any())
    example_gen = component.FileBasedExampleGen(
        input_base='path',
        custom_config=custom_config,
        custom_executor_spec=executor_spec.BeamExecutorSpec(
            TestExampleGenExecutor))

    stored_custom_config = example_gen_pb2.CustomConfig()
    proto_utils.json_to_proto(
        example_gen.exec_properties[standard_component_specs.CUSTOM_CONFIG_KEY],
        stored_custom_config)
    self.assertEqual(custom_config, stored_custom_config)
Beispiel #5
0
def _PrestoToExample(  # pylint: disable=invalid-name
        pipeline: beam.Pipeline,
        input_dict: Dict[Text, List[types.Artifact]],  # pylint: disable=unused-argument
        exec_properties: Dict[Text, Any],
        split_pattern: Text) -> beam.pvalue.PCollection:
    """Read from Presto and transform to TF examples.

  Args:
    pipeline: beam pipeline.
    input_dict: Input dict from input key to a list of Artifacts.
    exec_properties: A dict of execution properties.
    split_pattern: Split.pattern in Input config, a Presto sql string.

  Returns:
    PCollection of TF examples.
  """
    conn_config = example_gen_pb2.CustomConfig()
    json_format.Parse(exec_properties['custom_config'], conn_config)
    presto_config = presto_config_pb2.PrestoConnConfig()
    conn_config.custom_config.Unpack(presto_config)

    client = _deserialize_conn_config(presto_config)
    return (pipeline
            | 'Query' >> beam.Create([split_pattern])
            | 'QueryTable' >> beam.ParDo(_ReadPrestoDoFn(client))
            | 'ToTFExample' >> beam.Map(_row_to_example))
Beispiel #6
0
    def _extract_conn_config(self, custom_config):
        unpacked_custom_config = example_gen_pb2.CustomConfig()
        proto_utils.json_to_proto(custom_config, unpacked_custom_config)

        conn_config = presto_config_pb2.PrestoConnConfig()
        unpacked_custom_config.custom_config.Unpack(conn_config)
        return conn_config
Beispiel #7
0
    def testPrestoToExample(self):
        with beam.Pipeline() as pipeline:
            examples = (pipeline
                        | 'ToTFExample' >> executor._PrestoToExample(
                            exec_properties={
                                'input_config':
                                json_format.MessageToJson(
                                    example_gen_pb2.Input(),
                                    preserving_proto_field_name=True),
                                'custom_config':
                                json_format.MessageToJson(
                                    example_gen_pb2.CustomConfig(),
                                    preserving_proto_field_name=True)
                            },
                            split_pattern='SELECT i, f, s FROM `fake`'))

            feature = {}
            feature['i'] = tf.train.Feature(int64_list=tf.train.Int64List(
                value=[1]))
            feature['f'] = tf.train.Feature(float_list=tf.train.FloatList(
                value=[2.0]))
            feature['s'] = tf.train.Feature(bytes_list=tf.train.BytesList(
                value=[tf.compat.as_bytes('abc')]))
            example_proto = tf.train.Example(features=tf.train.Features(
                feature=feature))
            util.assert_that(examples, util.equal_to([example_proto]))
Beispiel #8
0
  def _extract_conn_config(self, custom_config):
    unpacked_custom_config = example_gen_pb2.CustomConfig()
    json_format.Parse(custom_config, unpacked_custom_config)

    conn_config = presto_config_pb2.PrestoConnConfig()
    unpacked_custom_config.custom_config.Unpack(conn_config)
    return conn_config
Beispiel #9
0
def load_custom_config(exec_properties):

    seed_config = example_gen_pb2.CustomConfig()
    json_format.Parse(exec_properties['custom_config'], seed_config)

    big_query_seed = bigquery_example_gen_pb2.BigQuerySeed()
    seed_config.custom_config.Unpack(big_query_seed)

    return json_utils.loads(big_query_seed.seed)
Beispiel #10
0
    def __init__(self,
                 query: Optional[Text] = None,
                 elwc_config: Optional[elwc_config_pb2.ElwcConfig] = None,
                 input_config: Optional[example_gen_pb2.Input] = None,
                 output_config: Optional[example_gen_pb2.Output] = None,
                 example_artifacts: Optional[types.Channel] = None,
                 instance_name: Optional[Text] = None):
        """Constructs a BigQueryElwcExampleGen component.

    Args:
      query: BigQuery sql string, query result will be treated as a single
        split, can be overwritten by input_config.
      elwc_config: The elwc config contains a list of context feature fields.
        The fields are used to build context feature. Examples with the same
        context feature will be converted to an ELWC(ExampleListWithContext)
        instance. For example, when there are two examples with the same context
        field, the two examples will be intergrated to a ELWC instance.
      input_config: An example_gen_pb2.Input instance with Split.pattern as
        BigQuery sql string. If set, it overwrites the 'query' arg, and allows
        different queries per split. If any field is provided as a
        RuntimeParameter, input_config should be constructed as a dict with the
        same field names as Input proto message.
      output_config: An example_gen_pb2.Output instance, providing output
        configuration. If unset, default splits will be 'train' and 'eval' with
        size 2:1. If any field is provided as a RuntimeParameter, input_config
          should be constructed as a dict with the same field names as Output
          proto message.
      example_artifacts: Optional channel of 'ExamplesPath' for output train and
        eval examples.
      instance_name: Optional unique instance name. Necessary if multiple
        BigQueryExampleGen components are declared in the same pipeline.

    Raises:
      RuntimeError: Only one of query and input_config should be set and
        elwc_config is required.
    """

        if bool(query) == bool(input_config):
            raise RuntimeError(
                'Exactly one of query and input_config should be set.')
        if not elwc_config:
            raise RuntimeError(
                'elwc_config is required for BigQueryToElwcExampleGen.')
        input_config = input_config or utils.make_default_input_config(query)
        packed_custom_config = example_gen_pb2.CustomConfig()
        packed_custom_config.custom_config.Pack(elwc_config)
        super(BigQueryToElwcExampleGen,
              self).__init__(input_config=input_config,
                             output_config=output_config,
                             output_data_format=example_gen_pb2.FORMAT_PROTO,
                             custom_config=packed_custom_config,
                             example_artifacts=example_artifacts,
                             instance_name=instance_name)
Beispiel #11
0
  def testDo(self):
    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create output dict.
    examples = standard_artifacts.Examples()
    examples.uri = output_data_dir
    output_dict = {'examples': [examples]}

    # Create exe properties.
    exec_properties = {
        'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='bq', pattern='SELECT i, f, s FROM `fake`'),
                ]),
                preserving_proto_field_name=True),
        'custom_config':
            json_format.MessageToJson(example_gen_pb2.CustomConfig()),
        'output_config':
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(
                            name='train', hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(
                            name='eval', hash_buckets=1)
                    ]))),
    }

    # Run executor.
    presto_example_gen = executor.Executor()
    presto_example_gen.Do({}, output_dict, exec_properties)

    self.assertEqual(
        artifact_utils.encode_split_names(['train', 'eval']),
        examples.split_names)

    # Check Presto example gen outputs.
    train_output_file = os.path.join(examples.uri, 'train',
                                     'data_tfrecord-00000-of-00001.gz')
    eval_output_file = os.path.join(examples.uri, 'eval',
                                    'data_tfrecord-00000-of-00001.gz')
    self.assertTrue(tf.io.gfile.exists(train_output_file))
    self.assertTrue(tf.io.gfile.exists(eval_output_file))
    self.assertGreater(
        tf.io.gfile.GFile(train_output_file).size(),
        tf.io.gfile.GFile(eval_output_file).size())
Beispiel #12
0
    def __init__(self,
                 conn_config: presto_config_pb2.PrestoConnConfig,
                 query: Optional[Text] = None,
                 input_config: Optional[example_gen_pb2.Input] = None,
                 output_config: Optional[example_gen_pb2.Output] = None,
                 example_artifacts: Optional[channel.Channel] = None,
                 name: Optional[Text] = None):
        """Constructs a PrestoExampleGen component.

    Args:
      conn_config: Parameters for Presto connection client.
      query: Presto sql string, query result will be treated as a single split,
        can be overwritten by input_config.
      input_config: An example_gen_pb2.Input instance with Split.pattern as
        Presto sql string. If set, it overwrites the 'query' arg, and allows
        different queries per split.
      output_config: An example_gen_pb2.Output instance, providing output
        configuration. If unset, default splits will be 'train' and 'eval' with
        size 2:1.
      example_artifacts: Optional channel of 'ExamplesPath' for output train and
        eval examples.
      name: Optional unique name. Necessary if multiple PrestoExampleGen
        components are declared in the same pipeline.

    Raises:
      RuntimeError: Only one of query and input_config should be set. Or
      required host field in connection_config should be set.
    """
        if bool(query) == bool(input_config):
            raise RuntimeError(
                'Exactly one of query and input_config should be set.')
        if not bool(conn_config.host):
            raise RuntimeError(
                'Required host field in connection config should be set.')

        input_config = input_config or utils.make_default_input_config(query)

        packed_custom_config = example_gen_pb2.CustomConfig()
        packed_custom_config.custom_config.Pack(conn_config)

        output_config = output_config or utils.make_default_output_config(
            input_config)

        super(PrestoExampleGen,
              self).__init__(input_config=input_config,
                             output_config=output_config,
                             custom_config=packed_custom_config,
                             component_name='PrestoExampleGen',
                             example_artifacts=example_artifacts,
                             name=name)
Beispiel #13
0
    def testDo(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        train_examples = types.TfxArtifact(type_name='ExamplesPath',
                                           split='train')
        train_examples.uri = os.path.join(output_data_dir, 'train')
        eval_examples = types.TfxArtifact(type_name='ExamplesPath',
                                          split='eval')
        eval_examples.uri = os.path.join(output_data_dir, 'eval')
        output_dict = {'examples': [train_examples, eval_examples]}

        # Create exe properties.
        exec_properties = {
            'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='bq', pattern='SELECT i, f, s FROM `fake`'),
                ])),
            'custom_config':
            json_format.MessageToJson(example_gen_pb2.CustomConfig()),
            'output_config':
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(name='train',
                                                          hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(name='eval',
                                                          hash_buckets=1)
                    ]))),
        }

        # Run executor.
        presto_example_gen = executor.Executor()
        presto_example_gen.Do({}, output_dict, exec_properties)

        # Check Presto example gen outputs.
        train_output_file = os.path.join(train_examples.uri,
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(eval_examples.uri,
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.gfile.Exists(train_output_file))
        self.assertTrue(tf.gfile.Exists(eval_output_file))
        self.assertGreater(
            tf.gfile.GFile(train_output_file).size(),
            tf.gfile.GFile(eval_output_file).size())
Beispiel #14
0
def _BigQueryToElwc(pipeline: beam.Pipeline, exec_properties: Dict[str, Any],
                    split_pattern: str) -> beam.pvalue.PCollection:
  """Read from BigQuery and transform to ExampleListWithContext.

  When a field has no value in BigQuery, a feature with no value will be
  generated in the tf.train.Features. This behavior is consistent with
  BigQueryExampleGen.

  Args:
    pipeline: beam pipeline.
    exec_properties: A dict of execution properties.
    split_pattern: Split.pattern in Input config, a BigQuery sql string.

  Returns:
    PCollection of ExampleListWithContext.

  Raises:
    RuntimeError: Context features must be included in the queried result.
  """

  custom_config = example_gen_pb2.CustomConfig()
  json_format.Parse(exec_properties['custom_config'], custom_config)
  elwc_config = elwc_config_pb2.ElwcConfig()
  custom_config.custom_config.Unpack(elwc_config)

  client = bigquery.Client()
  # Dummy query to get the type information for each field.
  query_job = client.query('SELECT * FROM ({}) LIMIT 0'.format(split_pattern))
  results = query_job.result()
  type_map = {}
  context_feature_fields = set(elwc_config.context_feature_fields)
  field_names = set()
  for field in results.schema:
    type_map[field.name] = field.field_type
    field_names.add(field.name)
  # Check whether the query contains necessary context fields.
  if not field_names.issuperset(context_feature_fields):
    raise RuntimeError('Context feature fields are missing from the query.')

  return (
      pipeline
      | 'ReadFromBigQuery' >> utils.ReadFromBigQuery(query=split_pattern)
      | 'RowToContextFeatureAndExample' >> beam.ParDo(
          _RowToContextFeatureAndExample(type_map, context_feature_fields))
      |
      'CombineByContext' >> beam.CombinePerKey(beam.combiners.ToListCombineFn())
      | 'ConvertContextAndExamplesToElwc' >>
      beam.Map(_ConvertContextAndExamplesToElwc))
Beispiel #15
0
 def testEnableCache(self):
     input_base = standard_artifacts.ExternalArtifact()
     custom_config = example_gen_pb2.CustomConfig(
         custom_config=any_pb2.Any())
     example_gen_1 = component.FileBasedExampleGen(
         input=channel_utils.as_channel([input_base]),
         custom_config=custom_config,
         custom_executor_spec=executor_spec.ExecutorClassSpec(
             TestExampleGenExecutor))
     self.assertEqual(None, example_gen_1.enable_cache)
     example_gen_2 = component.FileBasedExampleGen(
         input=channel_utils.as_channel([input_base]),
         custom_config=custom_config,
         custom_executor_spec=executor_spec.ExecutorClassSpec(
             TestExampleGenExecutor),
         enable_cache=True)
     self.assertEqual(True, example_gen_2.enable_cache)
def build_query_seed():
    """
    Do the elaborate proto packing necessary to feed
    a QueryExampleGen custom_config.
    """

    seed_runtime = data_types.RuntimeParameter(
        name='seed_pattern',
        default="'%meni%','%avw3%'",
        ptype=str
    )

    bigquery_seed_proto = bigquery_example_gen_pb2.BigQuerySeed()
    bigquery_seed_proto.seed = json_utils.dumps(seed_runtime)

    any_proto = any_pb2.Any()
    any_proto.Pack(bigquery_seed_proto, 'bigqueryseed.dstillery.com')

    return example_gen_pb2.CustomConfig(custom_config=any_proto)
Beispiel #17
0
  def __init__(self,
               conn_config: presto_config_pb2.PrestoConnConfig,
               query: Optional[str] = None,
               input_config: Optional[example_gen_pb2.Input] = None,
               output_config: Optional[example_gen_pb2.Output] = None):
    """Constructs a PrestoExampleGen component.

    Args:
      conn_config: Parameters for Presto connection client.
      query: Presto sql string, query result will be treated as a single split,
        can be overwritten by input_config.
      input_config: An example_gen_pb2.Input instance with Split.pattern as
        Presto sql string. If set, it overwrites the 'query' arg, and allows
        different queries per split.
      output_config: An example_gen_pb2.Output instance, providing output
        configuration. If unset, default splits will be 'train' and 'eval' with
        size 2:1.

    Raises:
      RuntimeError: Only one of query and input_config should be set. Or
      required host field in connection_config should be set.
    """
    if bool(query) == bool(input_config):
      raise RuntimeError('Exactly one of query and input_config should be set.')
    if not bool(conn_config.host):
      raise RuntimeError(
          'Required host field in connection config should be set.')

    input_config = input_config or utils.make_default_input_config(query)

    packed_custom_config = example_gen_pb2.CustomConfig()
    packed_custom_config.custom_config.Pack(conn_config)

    output_config = output_config or utils.make_default_output_config(
        input_config)

    super().__init__(
        input_config=input_config,
        output_config=output_config,
        custom_config=packed_custom_config)
Beispiel #18
0
def _PrestoToExample(  # pylint: disable=invalid-name
    pipeline: beam.Pipeline, exec_properties: Dict[str, Any],
    split_pattern: str) -> beam.pvalue.PCollection:
  """Read from Presto and transform to TF examples.

  Args:
    pipeline: beam pipeline.
    exec_properties: A dict of execution properties.
    split_pattern: Split.pattern in Input config, a Presto sql string.

  Returns:
    PCollection of TF examples.
  """
  conn_config = example_gen_pb2.CustomConfig()
  proto_utils.json_to_proto(exec_properties['custom_config'], conn_config)
  presto_config = presto_config_pb2.PrestoConnConfig()
  conn_config.custom_config.Unpack(presto_config)

  client = _deserialize_conn_config(presto_config)
  return (pipeline
          | 'Query' >> beam.Create([split_pattern])
          | 'QueryTable' >> beam.ParDo(_ReadPrestoDoFn(client))
          | 'ToTFExample' >> beam.Map(_row_to_example))
Beispiel #19
0
    def testBigQueryToElwc(self, mock_client):
        # Mock query result schema for _BigQueryElwcConverter.
        mock_client.return_value.query.return_value.result.return_value.schema = self._schema
        elwc_config = elwc_config_pb2.ElwcConfig(
            context_feature_fields=['context_feature_1', 'context_feature_2'])
        packed_custom_config = example_gen_pb2.CustomConfig()
        packed_custom_config.custom_config.Pack(elwc_config)
        with beam.Pipeline() as pipeline:
            elwc_examples = (pipeline | 'ToElwc' >> executor._BigQueryToElwc(
                exec_properties={
                    '_beam_pipeline_args': [],
                    'custom_config':
                    json_format.MessageToJson(packed_custom_config,
                                              preserving_proto_field_name=True)
                },
                split_pattern='SELECT context_feature_1, context_feature_2, '
                'feature_id_1, feature_id_2, feature_id_3 FROM `fake`'))

            expected_elwc_examples = [
                _ELWC_1, _ELWC_2, _ELWC_3, _ELWC_4, _ELWC_5
            ]
            util.assert_that(elwc_examples,
                             util.equal_to(expected_elwc_examples))