def testBuildDiagnosticsTable(self):
        model_location = self._exportEvalSavedModel(
            linear_classifier.simple_linear_classifier)
        eval_shared_model = types.EvalSharedModel(model_path=model_location)

        example1 = self._makeExample(age=3.0,
                                     language='english',
                                     label=1.0,
                                     slice_key='first_slice')

        with beam.Pipeline() as pipeline:
            result = (
                pipeline
                | 'CreateInput' >> beam.Create([example1.SerializeToString()])
                | 'BuildTable' >> contrib.BuildDiagnosticTable(
                    eval_shared_model=eval_shared_model))

            def check_result(got):
                self.assertEqual(1, len(got), 'got: %s' % got)
                _, extracts = got[0]

                # Values of type MaterializedColumn are emitted to signal to
                # downstream sink components to output the data to file.
                materialized_dict = dict(
                    (k, v) for k, v in extracts.items()
                    if isinstance(v, types.MaterializedColumn))
                self._assertMaterializedColumns(
                    materialized_dict,
                    {
                        # Slice key
                        'slice_key':
                        types.MaterializedColumn(name='slice_key',
                                                 value=[b'first_slice']),

                        # Features
                        'language':
                        types.MaterializedColumn(name='language',
                                                 value=[b'english']),
                        'age':
                        types.MaterializedColumn(name='age',
                                                 value=np.array(
                                                     [3.], dtype=np.float32)),

                        # Label
                        'label':
                        types.MaterializedColumn(name='label',
                                                 value=np.array(
                                                     [1.], dtype=np.float32)),
                        '__labels':
                        types.MaterializedColumn(name='__labels',
                                                 value=np.array(
                                                     [1.], dtype=np.float32)),
                    })
                self._assertMaterializedColumnsExist(materialized_dict, [
                    'logits', 'probabilities', 'classes', 'logistic',
                    'class_ids', 'materialized_slice_keys'
                ])

            util.assert_that(result, check_result)
    def testBuildDiagnosticsTableWithSlices(self):
        model_location = self._exportEvalSavedModel(
            linear_classifier.simple_linear_classifier)
        eval_shared_model = types.EvalSharedModel(model_path=model_location)

        example1 = self._makeExample(age=3.0,
                                     language='english',
                                     label=1.0,
                                     slice_key='first_slice')
        slice_spec = [
            slicer.SingleSliceSpec(columns=['age']),
            slicer.SingleSliceSpec(features=[('age', 3)]),
            slicer.SingleSliceSpec(columns=['age'],
                                   features=[('language', 'english')])
        ]

        with beam.Pipeline() as pipeline:
            result = (
                pipeline
                | 'CreateInput' >> beam.Create([example1.SerializeToString()])
                | 'BuildTable' >> contrib.BuildDiagnosticTable(
                    eval_shared_model, slice_spec))

            def check_result(got):
                self.assertEqual(1, len(got), 'got: %s' % got)
                _, extracts = got[0]

                # Values of type MaterializedColumn are emitted to signal to
                # downstream sink components to output the data to file.
                materialized_dict = dict(
                    (k, v) for k, v in extracts.items()
                    if isinstance(v, types.MaterializedColumn))
                self._assertMaterializedColumns(
                    materialized_dict, {
                        'materialized_slice_keys':
                        types.MaterializedColumn(
                            name='materialized_slice_keys',
                            value=[
                                b'age:3.0', b'age:3',
                                b'age_X_language:3.0_X_english'
                            ])
                    })
                self._assertMaterializedColumnsExist(materialized_dict, [
                    'logits', 'probabilities', 'classes', 'logistic',
                    'class_ids'
                ])

            util.assert_that(result, check_result)