def test_passing_kwarg(self):
        with self.subTest("illegal kwargs"):
            with self.assertRaises(ValueError):
                # `average` is illegal karg
                trace = Precision(true_key="label",
                                  pred_key="pred",
                                  output_name=self.p_key,
                                  average="binary")

        with self.subTest("check if kwargs pass to precision_score"):
            with unittest.mock.patch(
                    "fastestimator.trace.metric.precision.precision_score"
            ) as fake:
                kwargs = {"e1": "extra1", "e2": "extra2"}
                trace = Precision(true_key="label",
                                  pred_key="pred",
                                  output_name=self.p_key,
                                  **kwargs)
                batch = {"label": tf.constant([0, 1, 0, 1])}
                pred = {
                    "pred": tf.constant([[0.2], [0.6], [0.8], [0.1]])
                }  # [[0], [1], [1], [0]]
                run = TraceRun(trace=trace, batch=batch, prediction=pred)
                run.run_trace()

            fake_kwargs = fake.call_args[1]
            for key, val in kwargs.items():
                self.assertTrue(key in fake_kwargs)
                self.assertEqual(val, fake_kwargs[key])
    def test_torch_binary_class(self):
        with self.subTest("ordinal label"):
            trace = Precision(true_key="label",
                              pred_key="pred",
                              output_name=self.p_key)
            # tp, tn, fp, fn = [1, 1, 1, 1]
            batch = {"label": torch.tensor([0, 1, 0, 1])}
            pred = {
                "pred": torch.tensor([[0.2], [0.6], [0.8], [0.1]])
            }  # [[0], [1], [1], [0]]
            run = TraceRun(trace=trace, batch=batch, prediction=pred)
            run.run_trace()
            self.assertEqual(run.data_on_epoch_end[self.p_key],
                             0.5)  # precision = tp / (tp + fp) = 0.5

        with self.subTest("one-hot label"):
            trace = Precision(true_key="label",
                              pred_key="pred",
                              output_name=self.p_key)
            # tp, tn, fp, fn = [2, 1, 0, 1]
            batch = {
                "label": torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
            }  #  [0, 1, 1, 1]
            pred = {
                "pred": torch.tensor([[0.2], [0.6], [0.8], [0.1]])
            }  #  [[0], [1], [1], [0]]
            run = TraceRun(trace=trace, batch=batch, prediction=pred)
            run.run_trace()
            self.assertEqual(run.data_on_epoch_end[self.p_key],
                             1.0)  # precision = tp / (tp + fp) = 1
    def test_torch_multi_class(self):
        with self.subTest("ordinal label"):
            trace = Precision(true_key="label",
                              pred_key="pred",
                              output_name=self.p_key)
            batch = {"label": torch.tensor([0, 0, 0, 1, 1, 2])}
            pred = {
                "pred":
                torch.tensor([[0.2, 0.1, -0.6], [0.6, 2.0,
                                                 0.1], [0.1, 0.1, 0.8],
                              [0.4, 0.1, -0.3], [0.2, 0.7, 0.1],
                              [0.3, 0.6,
                               1.5]])  # [[0], [1], [2], [0], [1], [2]]
            }
            run = TraceRun(trace=trace, batch=batch, prediction=pred)
            run.run_trace()
            self.assertEqual(
                run.data_on_epoch_end[self.p_key][0],
                0.5)  # for 0, [tp, tn, fp, fn] = [1, 2, 1, 2], precision = 0.5
            self.assertEqual(
                run.data_on_epoch_end[self.p_key][1],
                0.5)  # for 1, [tp, tn, fp, fn] = [1, 3, 1, 1], precision = 0.5
            self.assertEqual(
                run.data_on_epoch_end[self.p_key][2],
                0.5)  # for 2, [tp, tn, fp, fn] = [1, 4, 1, 0], precision = 0.5

        with self.subTest("one-hot label"):
            trace = Precision(true_key="label",
                              pred_key="pred",
                              output_name=self.p_key)
            batch = {
                "label":
                torch.tensor([[1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0],
                              [0, 1, 0], [0, 0, 1]])
            }  # [0, 0, 0, 1, 1, 2]
            pred = {
                "pred":
                torch.tensor([[0.2, 0.1, -0.6], [0.6, 2.0,
                                                 0.1], [0.1, 0.1, 0.8],
                              [0.4, 0.1, -0.3], [0.2, 0.7, 0.1],
                              [0.3, 0.6,
                               1.5]])  # [[0], [1], [2], [0], [1], [2]]
            }
            run = TraceRun(trace=trace, batch=batch, prediction=pred)
            run.run_trace()
            self.assertEqual(
                run.data_on_epoch_end[self.p_key][0],
                0.5)  # for 0, [tp, tn, fp, fn] = [1, 2, 1, 2], precision = 0.5
            self.assertEqual(
                run.data_on_epoch_end[self.p_key][1],
                0.5)  # for 1, [tp, tn, fp, fn] = [1, 3, 1, 1], precision = 0.5
            self.assertEqual(
                run.data_on_epoch_end[self.p_key][2],
                0.5)  # for 2, [tp, tn, fp, fn] = [1, 4, 1, 0], precision = 0.5
 def setUpClass(cls):
     x = np.array([[1, 2], [3, 4]])
     x_pred = np.array([[1, 5, 3], [2, 1, 0]])
     x_binary = np.array([1])
     x_pred_binary = np.array([0.9])
     cls.data = Data({'x': x, 'x_pred': x_pred})
     cls.data_binary = Data({'x': x_binary, 'x_pred': x_pred_binary})
     cls.precision = Precision(true_key='x', pred_key='x_pred')