def testWriteValidationResults(self):
        model_dir, baseline_dir = self._getExportDir(), self._getBaselineDir()
        eval_shared_model = self._build_keras_model(model_dir, mul=0)
        baseline_eval_shared_model = self._build_keras_model(baseline_dir,
                                                             mul=1)
        validations_file = os.path.join(self._getTempDir(),
                                        constants.VALIDATIONS_KEY)
        schema = text_format.Parse(
            """
        tensor_representation_group {
          key: ""
          value {
            tensor_representation {
              key: "input"
              value {
                dense_tensor {
                  column_name: "input"
                  shape { dim { size: 1 } }
                }
              }
            }
          }
        }
        feature {
          name: "input"
          type: FLOAT
        }
        feature {
          name: "label"
          type: FLOAT
        }
        feature {
          name: "example_weight"
          type: FLOAT
        }
        feature {
          name: "extra_feature"
          type: BYTES
        }
        """, schema_pb2.Schema())
        tfx_io = test_util.InMemoryTFExampleRecord(
            schema=schema, raw_record_column_name=constants.ARROW_INPUT_COLUMN)
        tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
            arrow_schema=tfx_io.ArrowSchema(),
            tensor_representations=tfx_io.TensorRepresentations())
        examples = [
            self._makeExample(input=0.0,
                              label=1.0,
                              example_weight=1.0,
                              extra_feature='non_model_feature'),
            self._makeExample(input=1.0,
                              label=0.0,
                              example_weight=0.5,
                              extra_feature='non_model_feature'),
        ]

        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(name='candidate',
                                 label_key='label',
                                 example_weight_key='example_weight'),
                config.ModelSpec(name='baseline',
                                 label_key='label',
                                 example_weight_key='example_weight',
                                 is_baseline=True)
            ],
            slicing_specs=[config.SlicingSpec()],
            metrics_specs=[
                config.MetricsSpec(
                    metrics=[
                        config.MetricConfig(
                            class_name='WeightedExampleCount',
                            # 1.5 < 1, NOT OK.
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(
                                    upper_bound={'value': 1}))),
                        config.MetricConfig(
                            class_name='ExampleCount',
                            # 2 > 10, NOT OK.
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(
                                    lower_bound={'value': 10}))),
                        config.MetricConfig(
                            class_name='MeanLabel',
                            # 0 > 0 and 0 > 0%?: NOT OK.
                            threshold=config.MetricThreshold(
                                change_threshold=config.GenericChangeThreshold(
                                    direction=config.MetricDirection.
                                    HIGHER_IS_BETTER,
                                    relative={'value': 0},
                                    absolute={'value': 0}))),
                        config.MetricConfig(
                            # MeanPrediction = (0+0)/(1+0.5) = 0
                            class_name='MeanPrediction',
                            # -.01 < 0 < .01, OK.
                            # Diff% = -.333/.333 = -100% < -99%, OK.
                            # Diff = 0 - .333 = -.333 < 0, OK.
                            threshold=config.MetricThreshold(
                                value_threshold=config.GenericValueThreshold(
                                    upper_bound={'value': .01},
                                    lower_bound={'value': -.01}),
                                change_threshold=config.GenericChangeThreshold(
                                    direction=config.MetricDirection.
                                    LOWER_IS_BETTER,
                                    relative={'value': -.99},
                                    absolute={'value': 0})))
                    ],
                    model_names=['candidate', 'baseline']),
            ],
            options=config.Options(
                disabled_outputs={'values': ['eval_config.json']}),
        )
        slice_spec = [
            slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
        ]
        eval_shared_models = {
            'candidate': eval_shared_model,
            'baseline': baseline_eval_shared_model
        }
        extractors = [
            batched_input_extractor.BatchedInputExtractor(eval_config),
            batched_predict_extractor_v2.BatchedPredictExtractor(
                eval_shared_model=eval_shared_models,
                eval_config=eval_config,
                tensor_adapter_config=tensor_adapter_config),
            unbatch_extractor.UnbatchExtractor(),
            slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
        ]
        evaluators = [
            metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                eval_config=eval_config, eval_shared_model=eval_shared_models)
        ]
        output_paths = {
            constants.VALIDATIONS_KEY: validations_file,
        }
        writers = [
            metrics_plots_and_validations_writer.
            MetricsPlotsAndValidationsWriter(output_paths,
                                             add_metrics_callbacks=[])
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            _ = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples])
                | 'BatchExamples' >> tfx_io.BeamSource()
                |
                'InputsToExtracts' >> model_eval_lib.BatchedInputsToExtracts()
                | 'ExtractEvaluate' >> model_eval_lib.ExtractAndEvaluate(
                    extractors=extractors, evaluators=evaluators)
                |
                'WriteResults' >> model_eval_lib.WriteResults(writers=writers))
            # pylint: enable=no-value-for-parameter

        validation_result = model_eval_lib.load_validation_result(
            os.path.dirname(validations_file))

        expected_validations = [
            text_format.Parse(
                """
            metric_key {
              name: "weighted_example_count"
              model_name: "candidate"
            }
            metric_threshold {
              value_threshold {
                upper_bound {
                  value: 1.0
                }
              }
            }
            metric_value {
              double_value {
                value: 1.5
              }
            }
            """, validation_result_pb2.ValidationFailure()),
            text_format.Parse(
                """
            metric_key {
              name: "example_count"
            }
            metric_threshold {
              value_threshold {
                lower_bound {
                  value: 10.0
                }
              }
            }
            metric_value {
              double_value {
                value: 2.0
              }
            }
            """, validation_result_pb2.ValidationFailure()),
            text_format.Parse(
                """
            metric_key {
              name: "mean_label"
              model_name: "candidate"
              is_diff: true
            }
            metric_threshold {
              change_threshold {
                absolute {
                  value: 0.0
                }
                relative {
                  value: 0.0
                }
                direction: HIGHER_IS_BETTER
              }
            }
            metric_value {
              double_value {
                value: 0.0
              }
            }
            """, validation_result_pb2.ValidationFailure()),
        ]
        self.assertFalse(validation_result.validation_ok)
        self.assertLen(validation_result.metric_validations_per_slice, 1)
        self.assertCountEqual(
            expected_validations,
            validation_result.metric_validations_per_slice[0].failures)
    def testSerializeDeserializeToFile(self):
        metrics_slice_key = _make_slice_key(b'fruit', b'pear', b'animal',
                                            b'duck')
        metrics_for_slice = text_format.Parse(
            """
        slice_key {
          single_slice_keys {
            column: "fruit"
            bytes_value: "pear"
          }
          single_slice_keys {
            column: "animal"
            bytes_value: "duck"
          }
        }
        metrics {
          key: "accuracy"
          value {
            double_value {
              value: 0.8
            }
          }
        }
        metrics {
          key: "example_weight"
          value {
            double_value {
              value: 10.0
            }
          }
        }
        metrics {
          key: "auc"
          value {
            bounded_value {
              lower_bound {
                value: 0.1
              }
              upper_bound {
                value: 0.3
              }
              value {
                value: 0.2
              }
            }
          }
        }
        metrics {
          key: "auprc"
          value {
            bounded_value {
              lower_bound {
                value: 0.05
              }
              upper_bound {
                value: 0.17
              }
              value {
                value: 0.1
              }
            }
          }
        }""", metrics_for_slice_pb2.MetricsForSlice())
        plots_for_slice = text_format.Parse(
            """
        slice_key {
          single_slice_keys {
            column: "fruit"
            bytes_value: "peach"
          }
          single_slice_keys {
            column: "animal"
            bytes_value: "cow"
          }
        }
        plots {
          key: ''
          value {
            calibration_histogram_buckets {
              buckets {
                lower_threshold_inclusive: -inf
                upper_threshold_exclusive: 0.0
                num_weighted_examples { value: 0.0 }
                total_weighted_label { value: 0.0 }
                total_weighted_refined_prediction { value: 0.0 }
              }
              buckets {
                lower_threshold_inclusive: 0.0
                upper_threshold_exclusive: 0.5
                num_weighted_examples { value: 1.0 }
                total_weighted_label { value: 1.0 }
                total_weighted_refined_prediction { value: 0.3 }
              }
              buckets {
                lower_threshold_inclusive: 0.5
                upper_threshold_exclusive: 1.0
                num_weighted_examples { value: 1.0 }
                total_weighted_label { value: 0.0 }
                total_weighted_refined_prediction { value: 0.7 }
              }
              buckets {
                lower_threshold_inclusive: 1.0
                upper_threshold_exclusive: inf
                num_weighted_examples { value: 0.0 }
                total_weighted_label { value: 0.0 }
                total_weighted_refined_prediction { value: 0.0 }
              }
            }
          }
        }""", metrics_for_slice_pb2.PlotsForSlice())
        plots_slice_key = _make_slice_key(b'fruit', b'peach', b'animal',
                                          b'cow')
        eval_config = model_eval_lib.EvalConfig(
            model_location='/path/to/model',
            data_location='/path/to/data',
            slice_spec=[
                slicer.SingleSliceSpec(features=[('age', 5), ('gender', 'f')],
                                       columns=['country']),
                slicer.SingleSliceSpec(features=[('age', 6), ('gender', 'm')],
                                       columns=['interest'])
            ],
            example_weight_metric_key='key')

        output_path = self._getTempDir()
        with beam.Pipeline() as pipeline:
            metrics = (pipeline
                       | 'CreateMetrics' >> beam.Create(
                           [metrics_for_slice.SerializeToString()]))
            plots = (pipeline
                     | 'CreatePlots' >> beam.Create(
                         [plots_for_slice.SerializeToString()]))
            evaluation = {
                constants.METRICS_KEY: metrics,
                constants.PLOTS_KEY: plots
            }
            _ = (evaluation
                 | 'WriteResults' >> model_eval_lib.WriteResults(
                     writers=model_eval_lib.default_writers(
                         output_path=output_path)))
            _ = pipeline | model_eval_lib.WriteEvalConfig(
                eval_config, output_path)

        metrics = metrics_and_plots_evaluator.load_and_deserialize_metrics(
            path=os.path.join(output_path,
                              model_eval_lib._METRICS_OUTPUT_FILE))
        plots = metrics_and_plots_evaluator.load_and_deserialize_plots(
            path=os.path.join(output_path, model_eval_lib._PLOTS_OUTPUT_FILE))
        self.assertSliceMetricsListEqual(
            [(metrics_slice_key, metrics_for_slice.metrics)], metrics)
        self.assertSlicePlotsListEqual(
            [(plots_slice_key, plots_for_slice.plots)], plots)
        got_eval_config = model_eval_lib.load_eval_config(output_path)
        self.assertEqual(eval_config, got_eval_config)