示例#1
0
    def testBuildAnalysisTable(self):
        model_location = self._exportEvalSavedModel(
            linear_classifier.simple_linear_classifier)
        eval_shared_model = types.EvalSharedModel(model_path=model_location)

        example1 = self._makeExample(age=3.0,
                                     language='english',
                                     label=1.0,
                                     slice_key='first_slice')

        with beam.Pipeline() as pipeline:
            result = (
                pipeline
                | 'CreateInput' >> beam.Create([example1.SerializeToString()])
                | 'BuildTable' >> contrib.BuildAnalysisTable(
                    eval_shared_model=eval_shared_model))

            def check_result(got):
                self.assertEqual(1, len(got), 'got: %s' % got)
                extracts = got[0]

                # Values of type MaterializedColumn are emitted to signal to
                # downstream sink components to output the data to file.
                materialized_dict = dict(
                    (k, v) for k, v in extracts.items()
                    if isinstance(v, types.MaterializedColumn))
                self._assertMaterializedColumns(
                    materialized_dict,
                    {
                        # Slice key
                        'features__slice_key':
                        types.MaterializedColumn(name='features__slice_key',
                                                 value=[b'first_slice']),

                        # Features
                        'features__language':
                        types.MaterializedColumn(name='features__language',
                                                 value=[b'english']),
                        'features__age':
                        types.MaterializedColumn(name='features__age',
                                                 value=np.array(
                                                     [3.], dtype=np.float32)),

                        # Label
                        'features__label':
                        types.MaterializedColumn(name='features__label',
                                                 value=np.array(
                                                     [1.], dtype=np.float32)),
                        'labels':
                        types.MaterializedColumn(name='labels',
                                                 value=np.array(
                                                     [1.], dtype=np.float32)),
                    })
                self._assertMaterializedColumnsExist(materialized_dict, [
                    'predictions__logits', 'predictions__probabilities',
                    'predictions__classes', 'predictions__logistic',
                    'predictions__class_ids', constants.SLICE_KEYS_KEY
                ])

            util.assert_that(result[constants.ANALYSIS_KEY], check_result)
    def testBuildAnalysisTableWithSlices(self):
        model_location = self._exportEvalSavedModel(
            linear_classifier.simple_linear_classifier)
        eval_shared_model = model_eval_lib.default_eval_shared_model(
            eval_saved_model_path=model_location)

        example1 = self._makeExample(age=3.0,
                                     language='english',
                                     label=1.0,
                                     slice_key='first_slice')
        slice_spec = [
            slicer.SingleSliceSpec(columns=['age']),
            slicer.SingleSliceSpec(features=[('age', 3)]),
            slicer.SingleSliceSpec(columns=['age'],
                                   features=[('language', 'english')])
        ]

        with beam.Pipeline() as pipeline:
            result = (
                pipeline
                | 'CreateInput' >> beam.Create([example1.SerializeToString()])
                | 'BuildTable' >> contrib.BuildAnalysisTable(
                    eval_shared_model, slice_spec))

            def check_result(got):
                self.assertEqual(1, len(got), 'got: %s' % got)
                extracts = got[0]

                # Values of type MaterializedColumn are emitted to signal to
                # downstream sink components to output the data to file.
                materialized_dict = dict(
                    (k, v) for k, v in extracts.items()
                    if isinstance(v, types.MaterializedColumn))
                self._assertMaterializedColumns(
                    materialized_dict, {
                        constants.SLICE_KEYS_KEY:
                        types.MaterializedColumn(
                            name=constants.SLICE_KEYS_KEY,
                            value=[
                                b'age:3.0', b'age:3',
                                b'age_X_language:3.0_X_english'
                            ])
                    })
                self._assertMaterializedColumnsExist(materialized_dict, [
                    'predictions__logits', 'predictions__probabilities',
                    'predictions__classes', 'predictions__logistic',
                    'predictions__class_ids'
                ])

            util.assert_that(result[constants.ANALYSIS_KEY], check_result)