def testExampleWeightsExtractorMultiModel(self):
        model_spec1 = config_pb2.ModelSpec(name='model1',
                                           example_weight_key='example_weight')
        model_spec2 = config_pb2.ModelSpec(name='model2',
                                           example_weight_keys={
                                               'output1': 'example_weight1',
                                               'output2': 'example_weight2'
                                           })
        eval_config = config_pb2.EvalConfig(
            model_specs=[model_spec1, model_spec2])
        feature_extractor = features_extractor.FeaturesExtractor(eval_config)
        example_weight_extractor = example_weights_extractor.ExampleWeightsExtractor(
            eval_config)

        schema = text_format.Parse(
            """
        feature {
          name: "example_weight"
          type: FLOAT
        }
        feature {
          name: "example_weight1"
          type: FLOAT
        }
        feature {
          name: "example_weight2"
          type: FLOAT
        }
        feature {
          name: "fixed_int"
          type: INT
        }
        """, schema_pb2.Schema())
        tfx_io = test_util.InMemoryTFExampleRecord(
            schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN)

        examples = [
            self._makeExample(example_weight=0.5,
                              example_weight1=0.5,
                              example_weight2=0.5,
                              fixed_int=1),
            self._makeExample(example_weight=0.0,
                              example_weight1=0.0,
                              example_weight2=1.0,
                              fixed_int=1)
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples], reshuffle=False)
                | 'BatchExamples' >> tfx_io.BeamSource(batch_size=2)
                |
                'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
                | feature_extractor.stage_name >> feature_extractor.ptransform
                | example_weight_extractor.stage_name >>
                example_weight_extractor.ptransform)

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    for model_name in ('model1', 'model2'):
                        self.assertIn(model_name,
                                      got[0][constants.EXAMPLE_WEIGHTS_KEY])
                    self.assertAllClose(
                        got[0][constants.EXAMPLE_WEIGHTS_KEY]['model1'],
                        np.array([0.5, 0.0]))
                    self.assertAllClose(
                        got[0][constants.EXAMPLE_WEIGHTS_KEY]['model2'], {
                            'output1': np.array([0.5, 0.0]),
                            'output2': np.array([0.5, 1.0])
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
    def testExampleWeightsExtractor(self, example_weight):
        model_spec = config_pb2.ModelSpec(example_weight_key=example_weight)
        eval_config = config_pb2.EvalConfig(model_specs=[model_spec])
        feature_extractor = features_extractor.FeaturesExtractor(eval_config)
        example_weight_extractor = (
            example_weights_extractor.ExampleWeightsExtractor(eval_config))

        example_weight_feature = ''
        if example_weight is not None:
            example_weight_feature = """
          feature {
            name: "%s"
            type: FLOAT
          }
          """ % example_weight
        schema = text_format.Parse(
            example_weight_feature + """
        feature {
          name: "fixed_int"
          type: INT
        }
        """, schema_pb2.Schema())
        tfx_io = test_util.InMemoryTFExampleRecord(
            schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN)

        def maybe_add_key(d, key, value):
            if key is not None:
                d[key] = value
            return d

        example_kwargs = [
            maybe_add_key({
                'fixed_int': 1,
            }, example_weight, 0.5),
            maybe_add_key({
                'fixed_int': 1,
            }, example_weight, 0.0),
            maybe_add_key({
                'fixed_int': 2,
            }, example_weight, 1.0),
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create([
                    self._makeExample(**kwargs).SerializeToString()
                    for kwargs in example_kwargs
                ],
                                          reshuffle=False)
                | 'BatchExamples' >> tfx_io.BeamSource(batch_size=3)
                |
                'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
                | feature_extractor.stage_name >> feature_extractor.ptransform
                | example_weight_extractor.stage_name >>
                example_weight_extractor.ptransform)

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    if example_weight:
                        self.assertAllClose(
                            got[0][constants.EXAMPLE_WEIGHTS_KEY],
                            np.array([0.5, 0.0, 1.0]))
                    else:
                        self.assertNotIn(constants.EXAMPLE_WEIGHTS_KEY, got[0])

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
    def testUnbatchExtractor(self):
        model_spec = config_pb2.ModelSpec(label_key='label',
                                          example_weight_key='example_weight')
        eval_config = config_pb2.EvalConfig(model_specs=[model_spec])
        feature_extractor = features_extractor.FeaturesExtractor(eval_config)
        label_extractor = labels_extractor.LabelsExtractor(eval_config)
        example_weight_extractor = (
            example_weights_extractor.ExampleWeightsExtractor(eval_config))
        predict_extractor = predictions_extractor.PredictionsExtractor(
            eval_config)
        unbatch_inputs_extractor = unbatch_extractor.UnbatchExtractor()

        schema = text_format.Parse(
            """
        feature {
          name: "label"
          type: FLOAT
        }
        feature {
          name: "example_weight"
          type: FLOAT
        }
        feature {
          name: "fixed_int"
          type: INT
        }
        feature {
          name: "fixed_float"
          type: FLOAT
        }
        feature {
          name: "fixed_string"
          type: BYTES
        }
        """, schema_pb2.Schema())
        tfx_io = test_util.InMemoryTFExampleRecord(
            schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN)
        examples = [
            self._makeExample(label=1.0,
                              example_weight=0.5,
                              fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string1'),
            self._makeExample(label=0.0,
                              example_weight=0.0,
                              fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string2'),
            self._makeExample(label=0.0,
                              example_weight=1.0,
                              fixed_int=2,
                              fixed_float=0.0,
                              fixed_string='fixed_string3')
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples], reshuffle=False)
                | 'BatchExamples' >> tfx_io.BeamSource(batch_size=3)
                |
                'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
                | feature_extractor.stage_name >> feature_extractor.ptransform
                | label_extractor.stage_name >> label_extractor.ptransform
                | example_weight_extractor.stage_name >>
                example_weight_extractor.ptransform
                | predict_extractor.stage_name >> predict_extractor.ptransform
                | unbatch_inputs_extractor.stage_name >>
                unbatch_inputs_extractor.ptransform)

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 3)
                    self.assertDictElementsAlmostEqual(
                        got[0][constants.FEATURES_KEY], {
                            'fixed_int': np.array([1]),
                            'fixed_float': np.array([1.0]),
                        })
                    self.assertEqual(
                        got[0][constants.FEATURES_KEY]['fixed_string'],
                        np.array([b'fixed_string1']))
                    self.assertAlmostEqual(got[0][constants.LABELS_KEY],
                                           np.array([1.0]))
                    self.assertAlmostEqual(
                        got[0][constants.EXAMPLE_WEIGHTS_KEY], np.array([0.5]))
                    self.assertDictElementsAlmostEqual(
                        got[1][constants.FEATURES_KEY], {
                            'fixed_int': np.array([1]),
                            'fixed_float': np.array([1.0]),
                        })
                    self.assertEqual(
                        got[1][constants.FEATURES_KEY]['fixed_string'],
                        np.array([b'fixed_string2']))
                    self.assertAlmostEqual(got[1][constants.LABELS_KEY],
                                           np.array([0.0]))
                    self.assertAlmostEqual(
                        got[1][constants.EXAMPLE_WEIGHTS_KEY], np.array([0.0]))
                    self.assertDictElementsAlmostEqual(
                        got[2][constants.FEATURES_KEY], {
                            'fixed_int': np.array([2]),
                            'fixed_float': np.array([0.0]),
                        })
                    self.assertEqual(
                        got[2][constants.FEATURES_KEY]['fixed_string'],
                        np.array([b'fixed_string3']))
                    self.assertAlmostEqual(got[2][constants.LABELS_KEY],
                                           np.array([0.0]))
                    self.assertAlmostEqual(
                        got[2][constants.EXAMPLE_WEIGHTS_KEY], np.array([1.0]))

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')