Example #1
0
 def test1(self):
     y = torch.tensor([0, 1, 0, 1])
     out = torch.tensor([0, 1, 0, 1])
     num_labels = 2
     conf_mat = compute_conf_mat(out, y, num_labels)
     exp_conf_mat = torch.tensor([[2., 0], [0, 2]])
     self.assertTrue(conf_mat.equal(exp_conf_mat))
Example #2
0
    def validate_step(self, batch, batch_ind):
        x, y = batch
        out = self.post_forward(self.model(x))
        val_loss = F.cross_entropy(out, y)

        num_labels = len(self.cfg.data.class_names)
        out = self.prob_to_pred(out)
        conf_mat = compute_conf_mat(out, y, num_labels)

        return {'val_loss': val_loss, 'conf_mat': conf_mat}