def test_passing_kwarg(self): with self.subTest("illegal kwargs"): with self.assertRaises(ValueError): # `average` is illegal karg trace = Precision(true_key="label", pred_key="pred", output_name=self.p_key, average="binary") with self.subTest("check if kwargs pass to precision_score"): with unittest.mock.patch( "fastestimator.trace.metric.precision.precision_score" ) as fake: kwargs = {"e1": "extra1", "e2": "extra2"} trace = Precision(true_key="label", pred_key="pred", output_name=self.p_key, **kwargs) batch = {"label": tf.constant([0, 1, 0, 1])} pred = { "pred": tf.constant([[0.2], [0.6], [0.8], [0.1]]) } # [[0], [1], [1], [0]] run = TraceRun(trace=trace, batch=batch, prediction=pred) run.run_trace() fake_kwargs = fake.call_args[1] for key, val in kwargs.items(): self.assertTrue(key in fake_kwargs) self.assertEqual(val, fake_kwargs[key])
def test_torch_binary_class(self): with self.subTest("ordinal label"): trace = Precision(true_key="label", pred_key="pred", output_name=self.p_key) # tp, tn, fp, fn = [1, 1, 1, 1] batch = {"label": torch.tensor([0, 1, 0, 1])} pred = { "pred": torch.tensor([[0.2], [0.6], [0.8], [0.1]]) } # [[0], [1], [1], [0]] run = TraceRun(trace=trace, batch=batch, prediction=pred) run.run_trace() self.assertEqual(run.data_on_epoch_end[self.p_key], 0.5) # precision = tp / (tp + fp) = 0.5 with self.subTest("one-hot label"): trace = Precision(true_key="label", pred_key="pred", output_name=self.p_key) # tp, tn, fp, fn = [2, 1, 0, 1] batch = { "label": torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]]) } # [0, 1, 1, 1] pred = { "pred": torch.tensor([[0.2], [0.6], [0.8], [0.1]]) } # [[0], [1], [1], [0]] run = TraceRun(trace=trace, batch=batch, prediction=pred) run.run_trace() self.assertEqual(run.data_on_epoch_end[self.p_key], 1.0) # precision = tp / (tp + fp) = 1
def test_torch_multi_class(self): with self.subTest("ordinal label"): trace = Precision(true_key="label", pred_key="pred", output_name=self.p_key) batch = {"label": torch.tensor([0, 0, 0, 1, 1, 2])} pred = { "pred": torch.tensor([[0.2, 0.1, -0.6], [0.6, 2.0, 0.1], [0.1, 0.1, 0.8], [0.4, 0.1, -0.3], [0.2, 0.7, 0.1], [0.3, 0.6, 1.5]]) # [[0], [1], [2], [0], [1], [2]] } run = TraceRun(trace=trace, batch=batch, prediction=pred) run.run_trace() self.assertEqual( run.data_on_epoch_end[self.p_key][0], 0.5) # for 0, [tp, tn, fp, fn] = [1, 2, 1, 2], precision = 0.5 self.assertEqual( run.data_on_epoch_end[self.p_key][1], 0.5) # for 1, [tp, tn, fp, fn] = [1, 3, 1, 1], precision = 0.5 self.assertEqual( run.data_on_epoch_end[self.p_key][2], 0.5) # for 2, [tp, tn, fp, fn] = [1, 4, 1, 0], precision = 0.5 with self.subTest("one-hot label"): trace = Precision(true_key="label", pred_key="pred", output_name=self.p_key) batch = { "label": torch.tensor([[1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 0, 1]]) } # [0, 0, 0, 1, 1, 2] pred = { "pred": torch.tensor([[0.2, 0.1, -0.6], [0.6, 2.0, 0.1], [0.1, 0.1, 0.8], [0.4, 0.1, -0.3], [0.2, 0.7, 0.1], [0.3, 0.6, 1.5]]) # [[0], [1], [2], [0], [1], [2]] } run = TraceRun(trace=trace, batch=batch, prediction=pred) run.run_trace() self.assertEqual( run.data_on_epoch_end[self.p_key][0], 0.5) # for 0, [tp, tn, fp, fn] = [1, 2, 1, 2], precision = 0.5 self.assertEqual( run.data_on_epoch_end[self.p_key][1], 0.5) # for 1, [tp, tn, fp, fn] = [1, 3, 1, 1], precision = 0.5 self.assertEqual( run.data_on_epoch_end[self.p_key][2], 0.5) # for 2, [tp, tn, fp, fn] = [1, 4, 1, 0], precision = 0.5
def setUpClass(cls): x = np.array([[1, 2], [3, 4]]) x_pred = np.array([[1, 5, 3], [2, 1, 0]]) x_binary = np.array([1]) x_pred_binary = np.array([0.9]) cls.data = Data({'x': x, 'x_pred': x_pred}) cls.data_binary = Data({'x': x_binary, 'x_pred': x_pred_binary}) cls.precision = Precision(true_key='x', pred_key='x_pred')