def testRaisesErrorWhenExampleWeightsDiffer(self):
    with self.assertRaises(ValueError):
      metric = min_label_position.MinLabelPosition().computations(
          query_key='query')[0]

      query1_example1 = {
          'labels': np.array([0.0]),
          'predictions': np.array([0.2]),
          'example_weights': np.array([1.0]),
          'features': {
              'query': np.array(['query1'])
          }
      }
      query1_example2 = {
          'labels': np.array([1.0]),
          'predictions': np.array([0.8]),
          'example_weights': np.array([0.5]),
          'features': {
              'query': np.array(['query1'])
          }
      }

      with beam.Pipeline() as pipeline:
        # pylint: disable=no-value-for-parameter
        _ = (
            pipeline
            | 'Create' >> beam.Create(
                [tfma_util.merge_extracts([query1_example1, query1_example2])])
            | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs, True)
            | 'AddSlice' >> beam.Map(lambda x: ((), x))
            | 'Combine' >> beam.CombinePerKey(metric.combiner))
  def testMinLabelPositionWithNoWeightedExamples(self):
    metric = min_label_position.MinLabelPosition().computations(
        query_key='query')[0]

    query1_example1 = {
        'labels': np.array([1.0]),
        'predictions': np.array([0.2]),
        'example_weights': np.array([0.0]),
        'features': {
            'query': np.array(['query1'])
        }
    }

    def to_standard_metric_inputs_list(list_of_extracts):
      return [
          metric_util.to_standard_metric_inputs(e, True)
          for e in list_of_extracts
      ]

    with beam.Pipeline() as pipeline:
      # pylint: disable=no-value-for-parameter
      result = (
          pipeline
          | 'Create' >> beam.Create([[query1_example1]])
          | 'Process' >> beam.Map(to_standard_metric_inputs_list)
          | 'AddSlice' >> beam.Map(lambda x: ((), x))
          | 'Combine' >> beam.CombinePerKey(metric.combiner))

      # pylint: enable=no-value-for-parameter

      def check_result(got):
        try:
          self.assertLen(got, 1)
          got_slice_key, got_metrics = got[0]
          self.assertEqual(got_slice_key, ())
          key = metric_types.MetricKey(name='min_label_position')
          self.assertIn(key, got_metrics)
          self.assertTrue(math.isnan(got_metrics[key]))

        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(result, check_result, label='result')
  def testMinLabelPosition(self, label_key):
    metric = min_label_position.MinLabelPosition(
        label_key=label_key).computations(query_key='query')[0]

    query1_example1 = {
        'labels': np.array([1.0]),
        'predictions': np.array([0.2]),
        'example_weights': np.array([1.0]),
        'features': {
            'custom_label': np.array([0.0]),
            'query': np.array(['query1'])
        }
    }
    query1_example2 = {
        'labels': np.array([0.0]),
        'predictions': np.array([0.8]),
        'example_weights': np.array([1.0]),
        'features': {
            'custom_label': np.array([1.0]),
            'query': np.array(['query1'])
        }
    }
    query2_example1 = {
        'labels': np.array([1.0]),
        'predictions': np.array([0.9]),
        'example_weights': np.array([2.0]),
        'features': {
            'custom_label': np.array([0.0]),
            'query': np.array(['query2'])
        }
    }
    query2_example2 = {
        'labels': np.array([0.0]),
        'predictions': np.array([0.1]),
        'example_weights': np.array([2.0]),
        'features': {
            'custom_label': np.array([1.0]),
            'query': np.array(['query2'])
        }
    }
    query2_example3 = {
        'labels': np.array([0.0]),
        'predictions': np.array([0.5]),
        'example_weights': np.array([2.0]),
        'features': {
            'custom_label': np.array([0.0]),
            'query': np.array(['query2'])
        }
    }
    query3_example1 = {
        'labels': np.array([1.0]),
        'predictions': np.array([0.9]),
        'example_weights': np.array([3.0]),
        'features': {
            'custom_label': np.array([0.0]),
            'query': np.array(['query3'])
        }
    }
    examples = [
        tfma_util.merge_extracts([query1_example1, query1_example2]),
        tfma_util.merge_extracts(
            [query2_example1, query2_example2, query2_example3]),
        tfma_util.merge_extracts([query3_example1])
    ]

    if label_key:
      self.assertIsNotNone(metric.preprocessor)

    with beam.Pipeline() as pipeline:
      # pylint: disable=no-value-for-parameter
      result = (
          pipeline
          | 'Create' >> beam.Create(examples)
          | 'Process' >> beam.Map(
              metric_util.to_standard_metric_inputs, include_features=True)
          | 'AddSlice' >> beam.Map(lambda x: ((), x))
          | 'Combine' >> beam.CombinePerKey(metric.combiner))

      # pylint: enable=no-value-for-parameter

      def check_result(got):
        try:
          self.assertLen(got, 1)
          got_slice_key, got_metrics = got[0]
          self.assertEqual(got_slice_key, ())
          key = metric_types.MetricKey(name='min_label_position')
          self.assertIn(key, got_metrics)
          if label_key == 'custom_label':
            # (1*1.0 + 3*2.0) / (1.0 + 2.0) = 2.333333
            self.assertAllClose(got_metrics[key], 2.333333)
          else:
            # (2*1.0 + 1*2.0 + 1*3.0) / (1.0 + 2.0 + 3.0) = 1.166666
            self.assertAllClose(got_metrics[key], 1.166666)

        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(result, check_result, label='result')
 def testRaisesErrorIfNoQueryKey(self):
   with self.assertRaises(ValueError):
     min_label_position.MinLabelPosition().computations()
  def testMinLabelPosition(self):
    metric = min_label_position.MinLabelPosition().computations(
        query_key='query')[0]

    query1_example1 = {
        'labels': np.array([1.0]),
        'predictions': np.array([0.2]),
        'example_weights': np.array([1.0]),
        'features': {
            'query': np.array(['query1'])
        }
    }
    query1_example2 = {
        'labels': np.array([0.0]),
        'predictions': np.array([0.8]),
        'example_weights': np.array([1.0]),
        'features': {
            'query': np.array(['query1'])
        }
    }
    query2_example1 = {
        'labels': np.array([0.0]),
        'predictions': np.array([0.5]),
        'example_weights': np.array([2.0]),
        'features': {
            'query': np.array(['query2'])
        }
    }
    query2_example2 = {
        'labels': np.array([1.0]),
        'predictions': np.array([0.9]),
        'example_weights': np.array([2.0]),
        'features': {
            'query': np.array(['query2'])
        }
    }
    query2_example3 = {
        'labels': np.array([0.0]),
        'predictions': np.array([0.1]),
        'example_weights': np.array([2.0]),
        'features': {
            'query': np.array(['query2'])
        }
    }
    query3_example1 = {
        'labels': np.array([1.0]),
        'predictions': np.array([0.9]),
        'example_weights': np.array([3.0]),
        'features': {
            'query': np.array(['query3'])
        }
    }
    examples = [[query1_example1, query1_example2],
                [query2_example1, query2_example2, query2_example3],
                [query3_example1]]

    def to_standard_metric_inputs_list(list_of_extracts):
      return [
          metric_util.to_standard_metric_inputs(e, True)
          for e in list_of_extracts
      ]

    with beam.Pipeline() as pipeline:
      # pylint: disable=no-value-for-parameter
      result = (
          pipeline
          | 'Create' >> beam.Create(examples)
          | 'Process' >> beam.Map(to_standard_metric_inputs_list)
          | 'AddSlice' >> beam.Map(lambda x: ((), x))
          | 'Combine' >> beam.CombinePerKey(metric.combiner))

      # pylint: enable=no-value-for-parameter

      def check_result(got):
        try:
          self.assertLen(got, 1)
          got_slice_key, got_metrics = got[0]
          self.assertEqual(got_slice_key, ())
          key = metric_types.MetricKey(name='min_label_position')
          self.assertDictElementsAlmostEqual(
              got_metrics, {
                  key: 0.66667,
              }, places=5)

        except AssertionError as err:
          raise util.BeamAssertException(err)

      util.assert_that(result, check_result, label='result')