コード例 #1
0
  def testSqlSliceKeyExtractorWithMultipleSchema(self):
    eval_config = config_pb2.EvalConfig(slicing_specs=[
        config_pb2.SlicingSpec(slice_keys_sql="""
        SELECT
          STRUCT(fixed_string)
        FROM
          example.fixed_string,
          example.fixed_int
        WHERE fixed_int = 1
        """)
    ])
    slice_key_extractor = sql_slice_key_extractor.SqlSliceKeyExtractor(
        eval_config)

    record_batch_1 = pa.RecordBatch.from_arrays([
        pa.array([[1], [1], [2]], type=pa.list_(pa.int64())),
        pa.array([[1.0], [1.0], [2.0]], type=pa.list_(pa.float64())),
        pa.array([['fixed_string1'], ['fixed_string2'], ['fixed_string3']],
                 type=pa.list_(pa.string())),
    ], ['fixed_int', 'fixed_float', 'fixed_string'])
    record_batch_2 = pa.RecordBatch.from_arrays([
        pa.array([[1], [1], [2]], type=pa.list_(pa.int64())),
        pa.array([[1.0], [1.0], [2.0]], type=pa.list_(pa.float64())),
        pa.array([['fixed_string1'], ['fixed_string2'], ['fixed_string3']],
                 type=pa.list_(pa.string())),
        pa.array([['extra_field1'], ['extra_field2'], ['extra_field3']],
                 type=pa.list_(pa.string())),
    ], ['fixed_int', 'fixed_float', 'fixed_string', 'extra_field'])

    with beam.Pipeline() as pipeline:
      # pylint: disable=no-value-for-parameter
      result = (
          pipeline
          | 'Create' >> beam.Create([record_batch_1, record_batch_2],
                                    reshuffle=False)
          | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
          | slice_key_extractor.stage_name >> slice_key_extractor.ptransform)

      # pylint: enable=no-value-for-parameter

      def check_result(got):
        try:
          self.assertLen(got, 2)
          self.assertEqual(got[0][constants.SLICE_KEY_TYPES_KEY],
                           [[(('fixed_string', 'fixed_string1'),)],
                            [(('fixed_string', 'fixed_string2'),)], []])
          self.assertEqual(got[1][constants.SLICE_KEY_TYPES_KEY],
                           [[(('fixed_string', 'fixed_string1'),)],
                            [(('fixed_string', 'fixed_string2'),)], []])

        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(result, check_result)
コード例 #2
0
    def testSqlSliceKeyExtractorWithEmptySqlConfig(self):
        eval_config = config_pb2.EvalConfig()
        feature_extractor = features_extractor.FeaturesExtractor(
            eval_config=eval_config)
        slice_key_extractor = sql_slice_key_extractor.SqlSliceKeyExtractor(
            eval_config)

        tfx_io = tf_example_record.TFExampleBeamRecord(
            physical_format='inmem',
            telemetry_descriptors=['test', 'component'],
            schema=_SCHEMA,
            raw_record_column_name=constants.ARROW_INPUT_COLUMN)
        examples = [
            self._makeExample(fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string1'),
            self._makeExample(fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string2'),
            self._makeExample(fixed_int=2,
                              fixed_float=0.0,
                              fixed_string='fixed_string3')
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples], reshuffle=False)
                | 'BatchExamples' >> tfx_io.BeamSource(batch_size=3)
                |
                'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
                | feature_extractor.stage_name >> feature_extractor.ptransform
                | slice_key_extractor.stage_name >>
                slice_key_extractor.ptransform)

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    np.testing.assert_equal(
                        got[0][constants.SLICE_KEY_TYPES_KEY],
                        types.VarLenTensorValue.from_dense_rows(
                            [np.array([]),
                             np.array([]),
                             np.array([])]))

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result)
コード例 #3
0
  def testSqlSliceKeyExtractor(self):
    eval_config = config_pb2.EvalConfig(slicing_specs=[
        config_pb2.SlicingSpec(slice_keys_sql="""
        SELECT
          STRUCT(fixed_string)
        FROM
          example.fixed_string,
          example.fixed_int
        WHERE fixed_int = 1
        """)
    ])
    slice_key_extractor = sql_slice_key_extractor.SqlSliceKeyExtractor(
        eval_config)

    tfx_io = tf_example_record.TFExampleBeamRecord(
        physical_format='inmem',
        telemetry_descriptors=['test', 'component'],
        schema=_SCHEMA,
        raw_record_column_name=constants.ARROW_INPUT_COLUMN)
    examples = [
        self._makeExample(
            fixed_int=1, fixed_float=1.0, fixed_string='fixed_string1'),
        self._makeExample(
            fixed_int=1, fixed_float=1.0, fixed_string='fixed_string2'),
        self._makeExample(
            fixed_int=2, fixed_float=0.0, fixed_string='fixed_string3')
    ]

    with beam.Pipeline() as pipeline:
      # pylint: disable=no-value-for-parameter
      result = (
          pipeline
          | 'Create' >> beam.Create([e.SerializeToString() for e in examples],
                                    reshuffle=False)
          | 'BatchExamples' >> tfx_io.BeamSource(batch_size=3)
          | 'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
          | slice_key_extractor.stage_name >> slice_key_extractor.ptransform)

      # pylint: enable=no-value-for-parameter

      def check_result(got):
        try:
          self.assertLen(got, 1)
          self.assertEqual(got[0][constants.SLICE_KEY_TYPES_KEY],
                           [[(('fixed_string', 'fixed_string1'),)],
                            [(('fixed_string', 'fixed_string2'),)], []])

        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(result, check_result)
コード例 #4
0
    def testSqlSliceKeyExtractorWithTransformedFeatures(self):
        eval_config = config_pb2.EvalConfig(
            model_specs=[
                config_pb2.ModelSpec(name='model1'),
                config_pb2.ModelSpec(name='model2')
            ],
            slicing_specs=[
                config_pb2.SlicingSpec(slice_keys_sql="""
            SELECT
              STRUCT(fixed_string)
            FROM
              example.fixed_string,
              example.fixed_int
            WHERE fixed_int = 1
            """)
            ])
        slice_key_extractor = sql_slice_key_extractor.SqlSliceKeyExtractor(
            eval_config)

        extracts = {
            constants.FEATURES_KEY: {
                'fixed_int': np.array([1, 1, 2]),
            },
            constants.TRANSFORMED_FEATURES_KEY: {
                'model1': {
                    'fixed_int':
                    np.array([1, 1, 2]),
                    'fixed_float':
                    np.array([1.0, 1.0, 0.0]),
                    'fixed_string':
                    np.array(
                        ['fixed_string1', 'fixed_string2', 'fixed_string3'])
                },
                'model2': {
                    'fixed_int':
                    np.array([1, 1, 2]),
                    'fixed_string':
                    np.array(
                        ['fixed_string1', 'fixed_string2', 'fixed_string3'])
                },
            }
        }

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (pipeline
                      | 'CreateTestInput' >> beam.Create([extracts])
                      | slice_key_extractor.stage_name >>
                      slice_key_extractor.ptransform)

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    np.testing.assert_equal(
                        got[0][constants.SLICE_KEY_TYPES_KEY],
                        types.VarLenTensorValue.from_dense_rows([
                            slicer_lib.slice_keys_to_numpy_array([
                                (('fixed_string', 'fixed_string1'), )
                            ]),
                            slicer_lib.slice_keys_to_numpy_array([
                                (('fixed_string', 'fixed_string2'), )
                            ]),
                            np.array([])
                        ]))

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result)