def test1(self): y = torch.tensor([0, 1, 0, 1]) out = torch.tensor([0, 1, 0, 1]) num_labels = 2 conf_mat = compute_conf_mat(out, y, num_labels) exp_conf_mat = torch.tensor([[2., 0], [0, 2]]) self.assertTrue(conf_mat.equal(exp_conf_mat))
def validate_step(self, batch, batch_ind): x, y = batch out = self.post_forward(self.model(x)) val_loss = F.cross_entropy(out, y) num_labels = len(self.cfg.data.class_names) out = self.prob_to_pred(out) conf_mat = compute_conf_mat(out, y, num_labels) return {'val_loss': val_loss, 'conf_mat': conf_mat}