コード例 #1
0
  def testSliceOneSlice(self):
    with beam.Pipeline() as pipeline:
      fpls = create_fpls()
      metrics = (
          pipeline
          | 'CreateTestInput' >> beam.Create(fpls)
          | 'WrapFpls' >> beam.Map(wrap_fpl)
          | 'ExtractSlices' >> slice_key_extractor._ExtractSliceKeys([
              slicer.SingleSliceSpec(),
              slicer.SingleSliceSpec(columns=['gender'])
          ])
          | 'FanoutSlices' >> slicer.FanoutSlices())

      def check_result(got):
        try:
          self.assertEqual(4, len(got), 'got: %s' % got)
          expected_result = [
              ((), wrap_fpl(fpls[0])),
              ((), wrap_fpl(fpls[1])),
              ((('gender', 'f'),), wrap_fpl(fpls[0])),
              ((('gender', 'm'),), wrap_fpl(fpls[1])),
          ]
          self.assertEqual(
              sorted(got, key=lambda x: x[0]),
              sorted(expected_result, key=lambda x: x[0]))
        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(metrics, check_result)
コード例 #2
0
  def testSliceDefaultSlice(self):
    with beam.Pipeline() as pipeline:
      fpls = create_fpls()

      metrics = (
          pipeline
          | 'CreateTestInput' >> beam.Create(fpls)
          | 'WrapFpls' >> beam.Map(wrap_fpl)
          | 'ExtractSlices' >> slice_key_extractor._ExtractSliceKeys(
              [slicer.SingleSliceSpec()])
          | 'FanoutSlices' >> slicer.FanoutSlices())

      def check_result(got):
        try:
          self.assertEqual(2, len(got), 'got: %s' % got)
          expected_result = [
              ((), wrap_fpl(fpls[0])),
              ((), wrap_fpl(fpls[1])),
          ]
          self.assertEqual(len(got), len(expected_result))
          self.assertTrue(
              got[0] == expected_result[0] and got[1] == expected_result[1] or
              got[1] == expected_result[0] and got[0] == expected_result[1])
        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(metrics, check_result)
コード例 #3
0
def AutoSliceKeyExtractor(
    statistics: statistics_pb2.DatasetFeatureStatisticsList,
    materialize: Optional[bool] = True
) -> extractor.Extractor:
  """Creates an extractor for automatically extracting slice keys.

  The incoming Extracts must contain a FeaturesPredictionsLabels extract keyed
  by tfma.FEATURES_PREDICTIONS_LABELS_KEY. Typically this will be obtained by
  calling the PredictExtractor.

  The extractor's PTransform yields a copy of the Extracts input with an
  additional extract pointing at the list of SliceKeyType values keyed by
  tfma.SLICE_KEY_TYPES_KEY. If materialize is True then a materialized version
  of the slice keys will be added under the key tfma.MATERIALZED_SLICE_KEYS_KEY.

  Args:
    statistics: Data statistics.
    materialize: True to add MaterializedColumn entries for the slice keys.

  Returns:
    Extractor for slice keys.
  """
  slice_spec = slice_spec_from_stats(statistics)
  return extractor.Extractor(
      stage_name=SLICE_KEY_EXTRACTOR_STAGE_NAME,
      ptransform=slice_key_extractor._ExtractSliceKeys(slice_spec, materialize))  # pylint: disable=protected-access
コード例 #4
0
    def testSliceOnMetaFeature(self):
        # We want to make sure that slicing on the newly added feature works, so
        # pulling in slice here.
        with beam.Pipeline() as pipeline:
            fpls = create_fpls()
            metrics = (
                pipeline
                | 'CreateTestInput' >> beam.Create(fpls)
                | 'WrapFpls' >> beam.Map(wrap_fpl)
                | 'ExtractInterestsNum' >>
                meta_feature_extractor.ExtractMetaFeature(get_num_interests)
                | 'ExtractSlices' >> slice_key_extractor._ExtractSliceKeys([
                    slicer.SingleSliceSpec(),
                    slicer.SingleSliceSpec(columns=['num_interests'])
                ])
                | 'FanoutSlices' >> slicer.FanoutSlices())

            def check_result(got):
                try:
                    self.assertEqual(4, len(got), 'got: %s' % got)
                    expected_slice_keys = [
                        (),
                        (),
                        (('num_interests', 1), ),
                        (('num_interests', 2), ),
                    ]
                    self.assertEqual(sorted(slice_key for slice_key, _ in got),
                                     sorted(expected_slice_keys))
                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics, check_result)
コード例 #5
0
  def testSliceKeys(self):
    with beam.Pipeline() as pipeline:
      fpls = create_fpls()
      slice_keys_extracts = (
          pipeline
          | 'CreateTestInput' >> beam.Create(fpls)
          | 'WrapFpls' >> beam.Map(wrap_fpl)
          | 'ExtractSlices' >> slice_key_extractor._ExtractSliceKeys([
              slicer.SingleSliceSpec(),
              slicer.SingleSliceSpec(columns=['gender'])
          ]))

      def check_result(got):
        try:
          self.assertEqual(2, len(got), 'got: %s' % got)
          expected_results = sorted([[(), (('gender', 'f'),)],
                                     [(), (('gender', 'm'),)]])
          got_results = []
          for item in got:
            self.assertTrue(constants.SLICE_KEY_TYPES_KEY in item)
            got_results.append(sorted(item[constants.SLICE_KEY_TYPES_KEY]))
          self.assertEqual(sorted(got_results), sorted(expected_results))
        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(slice_keys_extracts, check_result)
コード例 #6
0
  def testMaterializedSliceKeys(self):
    with beam.Pipeline() as pipeline:
      fpls = create_fpls()
      slice_keys_extracts = (
          pipeline
          | 'CreateTestInput' >> beam.Create(fpls)
          | 'WrapFpls' >> beam.Map(wrap_fpl)
          | 'ExtractSlices' >> slice_key_extractor._ExtractSliceKeys(
              [
                  slicer.SingleSliceSpec(),
                  slicer.SingleSliceSpec(columns=['gender'])
              ],
              materialize=True))

      def check_result(got):
        try:
          self.assertEqual(2, len(got), 'got: %s' % got)
          expected_results = sorted([
              types.MaterializedColumn(
                  name=constants.SLICE_KEYS_KEY,
                  value=[b'Overall', b'gender:f']),
              types.MaterializedColumn(
                  name=constants.SLICE_KEYS_KEY,
                  value=[b'Overall', b'gender:m'])
          ])
          got_results = []
          for item in got:
            self.assertTrue(constants.SLICE_KEYS_KEY in item)
            got_results.append(item[constants.SLICE_KEYS_KEY])
          self.assertEqual(sorted(got_results), sorted(expected_results))
        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(slice_keys_extracts, check_result)