Exemplo n.º 1
0
    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self.forward(x)
        loss = self.m_loss_function(pred, y)
        y_hat = torch.argmax(torch.softmax(pred, dim=1), dim=1)
        acc = FM.accuracy(y_hat, y)
        print(FM.confusion_matrix(y_hat, y, 5, normalize=None))
        print(FM.confusion_matrix(y_hat, y, 5, normalize='true'))

        return {'loss': loss, 'acc': acc}
    def metrics(self, pred, target, num_classes=3, remove_bg=False, is_swnet=False):
        if is_swnet:
            pred[pred == 2] = 1
            target[target == 2] = 1
            num_classes = 2

        confusion_m = mf.confusion_matrix(pred, target, num_classes=num_classes)

        # acc
        accuracy = confusion_m.diag().sum() / len(pred)

        # kappa
        p0 = accuracy
        pc = 0
        for i in range(confusion_m.shape[0]):
            pc = pc + confusion_m[i].sum() * confusion_m[:, i].sum()
        pc = pc / len(pred)**2
        kc = (p0 - pc) / (1 - pc)

        # iou
        if remove_bg:
            iou = mf.iou(pred, target, num_classes=num_classes, ignore_index=0)
        else:
            iou = mf.iou(pred, target, num_classes=num_classes)

        f1 = mf.f1_score(pred, target, num_classes=num_classes, class_reduction='none')
        precision = mf.precision(pred, target, num_classes=num_classes, class_reduction='none')
        recall = mf.recall(pred, target, num_classes=num_classes, class_reduction='none')

        return accuracy, kc, iou, f1, precision, recall
Exemplo n.º 3
0
 def validation_epoch_end(self, outputs):
     preds = torch.cat([tmp['predictions'] for tmp in outputs])
     targets = torch.cat([tmp['labels'] for tmp in outputs])
     confusion = confusion_matrix(preds, targets, num_classes=10)
     confusion_table = wandb.Table(data=confusion.tolist(), columns=['plane', 'car', 'bird', 'cat', 'deer',
                     'dog', 'frog', 'horse', 'ship', 'truck'])
     self.logger.experiment.log({'confusion': confusion_table})
Exemplo n.º 4
0
    def _calculate_step_metrics(self, logits, y):
        # prepare the metrics
        loss = self._loss_function(logits[1], y)
        # loss = F.cross_entropy(logits[1], y)
        preds = torch.argmax(logits[1], dim=1)
        num_correct = torch.eq(preds.view(-1), y.view(-1)).sum()
        acc = accuracy(preds, y)
        f1_score = f1(preds, y, num_classes=2, average='weighted')
        fb05_score = fbeta(preds,
                           y,
                           num_classes=2,
                           average='weighted',
                           beta=0.5)
        fb2_score = fbeta(preds, y, num_classes=2, average='weighted', beta=2)
        cm = confusion_matrix(preds, y, num_classes=2)
        prec = precision(preds, y, num_classes=2, class_reduction='weighted')
        rec = recall(preds, y, num_classes=2, class_reduction='weighted')
        # au_roc = auroc(preds, y, pos_label=1)

        return {
            'loss': loss,
            'acc': acc,
            'f1_score': f1_score,
            'f05_score': fb05_score,
            'f2_score': fb2_score,
            'precision': prec,
            'recall': rec,
            # 'auroc': au_roc,
            'confusion_matrix': cm,
            'num_correct': num_correct
        }
Exemplo n.º 5
0
    def get_stat(self, preds, targets, mode):
        if mode == 'train':
            _, pred_labels = preds.topk(1, dim=1, largest=True, sorted=True)
            b = pred_labels.shape[0]
            if b == 0:
                return

            pred_labels = pred_labels.squeeze(1).detach().reshape(b, -1)
            target_labels = targets.data.detach().reshape(b, -1)

        elif (mode == 'valid') or (mode == 'test'):
            # Old
            _, pred_labels = preds.topk(1, dim=1, largest=True, sorted=True)

            # Current
            # preds = F.softmax(preds, dim=1)
            # true_probs = preds[:, 1, :].unsqueeze(1)
            # pred_labels = torch.where(true_probs > 0.05,
            #                           torch.ones(true_probs.shape).to(0),
            #                           torch.zeros(true_probs.shape).to(0))
            b, _, num_stn = pred_labels.shape
            assert (b, num_stn) == targets.shape

        pred_labels = pred_labels.squeeze(1).detach()
        target_labels = targets.data.detach()

        correct = [0] * b
        hit = [0] * b
        miss = [0] * b
        fa = [0] * b
        cn = [0] * b

        for i in range(b):
            pred, target = pred_labels[i], target_labels[i]
            """
            [Confusion matrix]
            - tp: Hit
            - fn: Miss
            - fp: False Alarm
            - tn: Correct Negative
            """
            if -1 in target:
                print('invalid target:', target.shape)
                print('target:\n', target)
            conf_mat = confusion_matrix(pred,
                                        target,
                                        num_classes=self.num_classes)
            _hit, _miss, _fa, _cn = conf_mat[1, 1], conf_mat[1, 0], conf_mat[
                0, 1], conf_mat[0, 0]
            _hit, _miss, _fa, _cn = int(_hit), int(_miss), int(_fa), int(_cn)
            _correct = _hit + _cn

            correct[i] = _correct
            hit[i] = _hit
            miss[i] = _miss
            fa[i] = _fa
            cn[i] = _cn

        return correct, hit, miss, fa, cn
Exemplo n.º 6
0
    def validation_step(self, batch: PathContextBatch, batch_idx: int) -> Dict:  # type: ignore
        # [batch size; num_classes]
        logits = self(batch.contexts, batch.contexts_per_label)
        loss = F.cross_entropy(logits, batch.labels.squeeze(0))
        with torch.no_grad():
            conf_matrix = confusion_matrix(logits.argmax(-1), batch.labels.squeeze(0))

        return {"loss": loss, "confusion_matrix": conf_matrix}
Exemplo n.º 7
0
 def get_confusion_matrix(self, display=True):
     cm = confusion_matrix(self.preds.argmax(dim=1),
                           self.targets,
                           num_classes=2)
     classes = find_classes(dir='./images/')
     if display:
         self.plot_confusion_matrix(cm.int(), classes)
     return cm.int()
Exemplo n.º 8
0
def test_v1_5_metric_classif_mix():
    ConfusionMatrix.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        ConfusionMatrix(num_classes=1)

    FBeta.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        FBeta(num_classes=1)

    F1.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        F1(num_classes=1)

    HammingDistance.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        HammingDistance()

    StatScores.__init__._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        StatScores()

    target = torch.tensor([1, 1, 0, 0])
    preds = torch.tensor([0, 1, 0, 0])
    confusion_matrix._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.equal(
            confusion_matrix(preds, target, num_classes=2).float(),
            torch.tensor([[2.0, 0.0], [1.0, 1.0]]))

    target = torch.tensor([0, 1, 2, 0, 1, 2])
    preds = torch.tensor([0, 2, 1, 0, 0, 1])
    fbeta._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.allclose(fbeta(preds, target, num_classes=3, beta=0.5),
                              torch.tensor(0.3333),
                              atol=1e-4)

    f1._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.allclose(f1(preds, target, num_classes=3),
                              torch.tensor(0.3333),
                              atol=1e-4)

    target = torch.tensor([[0, 1], [1, 1]])
    preds = torch.tensor([[0, 1], [0, 1]])
    hamming_distance._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert hamming_distance(preds, target) == torch.tensor(0.25)

    preds = torch.tensor([1, 0, 2, 1])
    target = torch.tensor([1, 1, 2, 0])
    stat_scores._warned = False
    with pytest.deprecated_call(match="It will be removed in v1.5.0"):
        assert torch.equal(stat_scores(preds, target, reduce="micro"),
                           torch.tensor([2, 2, 6, 2, 4]))
Exemplo n.º 9
0
    def training_step(self, batch: PathContextBatch, batch_idx: int) -> Dict:  # type: ignore
        # [batch size; num_classes]
        logits = self(batch.contexts, batch.contexts_per_label)
        loss = F.cross_entropy(logits, batch.labels.squeeze(0))
        log = {"train/loss": loss}
        with torch.no_grad():
            conf_matrix = confusion_matrix(logits.argmax(-1), batch.labels.squeeze(0))
            log["train/accuracy"] = conf_matrix.trace() / conf_matrix.sum()
        self.log_dict(log)

        return {"loss": loss, "confusion_matrix": conf_matrix}
Exemplo n.º 10
0
    def test_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict:
        """ 
        Runs one training step. This usually consists in the forward function followed
            by the loss function.
        
        :param batch: The output of your dataloader. 
        :param batch_nb: Integer displaying which batch this is

        Returns:
            - dictionary containing the loss and the metrics to be added to the lightning logger.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)
            
        self.log('test_loss',loss_val)


        y_hat=model_out['logits']
        labels_hat = torch.argmax(y_hat, dim=1)
        y=targets['labels']


        f1 = metrics.f1_score(labels_hat,y,    class_reduction='weighted')
        prec =metrics.precision(labels_hat,y,  class_reduction='weighted')
        recall = metrics.recall(labels_hat,y,  class_reduction='weighted')
        acc = metrics.accuracy(labels_hat,y,   class_reduction='weighted')
        # auroc = metrics.multiclass_auroc(labels_hat, y)

        self.log('test_batch_prec',prec)
        self.log('test_batch_f1',f1)
        self.log('test_batch_recall',recall)
        self.log('test_batch_weighted_acc', acc)
        # self.log('test_batch_auc_roc', auroc)

        from pytorch_lightning.metrics.functional import confusion_matrix
        # TODO CHANGE THIS
        # return (labels_hat, y)
        cm = confusion_matrix(preds = labels_hat,target=y,normalize=None, num_classes=50)
        # cm = confusion_matrix(preds = labels_hat,target=y,normalize=False, num_classes=len(y.unique()))
        self.test_conf_matrices.append(cm)
Exemplo n.º 11
0
    def __metrics_per_epoch(outputs):
        num_correct = torch.stack(
            [output['num_correct'] for output in outputs]).sum()
        loss_mean = torch.stack([output['loss'] for output in outputs]).mean()
        acc_mean = torch.stack([output['acc'] for output in outputs]).mean()
        prec_mean = torch.stack([output['precision']
                                 for output in outputs]).mean()
        rec_mean = torch.stack([output['recall'] for output in outputs]).mean()
        f1_mean = torch.stack([output['f1_score']
                               for output in outputs]).mean()
        y_true = torch.cat([output['y_true'] for output in outputs], dim=-1)
        y_hat = torch.cat([output['y_hat'] for output in outputs], dim=-1)
        acc1_mean = torch.stack([output['acc1'] for output in outputs]).mean()
        acc5_mean = torch.stack([output['acc5'] for output in outputs]).mean()

        confusion_matrix = plm.confusion_matrix(y_hat, y_true)

        return (y_true, y_hat, loss_mean, num_correct, acc_mean, prec_mean,
                rec_mean, f1_mean, confusion_matrix, acc1_mean, acc5_mean)
Exemplo n.º 12
0
    def metrics(self, pred, target, num_classes=3, is_swnet=False):
        if is_swnet:
            pred[pred == 2] = 1
            target[target == 2] = 1
            num_classes = 2

        confusion_m = mf.confusion_matrix(pred, target, num_classes=num_classes)

        # acc
        # try:
        accuracy = confusion_m.diag().sum() / len(pred)
        # except:
        #     print("pred:")
        #     print(pred)
        #     print("target:")
        #     print(target)
        #     print("confusion_m:")
        #     print("confusion_m")
        #     print("---")
        #     accuracy = 0

        # kappa
        # try:
        p0 = accuracy
        pc = 0
        for i in range(confusion_m.shape[0]):
            pc = pc + confusion_m[i].sum() * confusion_m[:, i].sum()
        pc = pc / len(pred)**2
        kc = (p0 - pc) / (1 - pc)
        # if pc != 1:
        #     kc = (p0 - pc) / (1 - pc)
        # else:
        #     kc = torch.tensor(1.0)
        #     kc.to(p0.device)
        # except:
        #     kc = 0

        f1 = mf.f1_score(pred, target, num_classes=num_classes, class_reduction='none')
        precision = mf.precision(pred, target, num_classes=num_classes, class_reduction='none')
        recall = mf.recall(pred, target, num_classes=num_classes, class_reduction='none')

        return accuracy, kc, f1, precision, recall
Exemplo n.º 13
0
    def __metrics_per_batch(self, batch):
        # 1. Forward pass:
        x, y_true = batch
        logits = self.forward(x)

        # 2. Compute loss & performance metrics:
        # class prediction: if binary (num_outputs == 1) then class label is 0 if logit < 0 else it's 1
        # if multiclass then simply run argmax to find the index of the most confident class
        y_hat = torch.argmax(logits, dim=1) if self.n_outputs > 1 else (
            logits > 0.0).squeeze(1).long()
        loss = self.loss(
            logits, y_true if self.n_outputs > 1 else y_true.view(
                (-1, 1)).type_as(x))
        num_correct = torch.eq(y_hat, y_true.view(-1)).sum()
        acc = num_correct.float() / self.batch_size
        prec = plm.precision(y_hat, y_true, num_classes=self.n_classes)
        rec = plm.recall(y_hat, y_true, num_classes=self.n_classes)
        f1 = plm.f1_score(y_hat, y_true, num_classes=self.n_classes)
        conf_matrix = plm.confusion_matrix(y_hat.long(), y_true.long())
        acc1, acc5 = self.__accuracy(logits, y_true, topk=(1, 5))

        return (y_true, y_hat, logits, loss, num_correct, acc, prec, rec, f1,
                conf_matrix, acc1, acc5)
Exemplo n.º 14
0
def main(*, module_type, checkpointPath, showCount, realDataPath,
         trainDataPath, testDataPath):
    # Ensure reproducibility
    seed_everything(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    # Parse model
    if module_type == 'mme':
        model = MMETrainingModule.load_from_checkpoint(
            checkpoint_path=checkpointPath, num_cls=4)
    elif module_type in ['baseline', 'sandt', 'hm', 'CycleGAN']:
        model = SimpleTrainModule.load_from_checkpoint(
            checkpoint_path=checkpointPath, num_cls=4)
    else:
        raise RuntimeError(f"Cannot recognize module type {module_type}")

    model.eval()
    print(f"Loaded {model.__class__} instance.")
    print(f"Model has the following hyperparameters: {model.hparams}")

    # Get transform function
    testTransform = MyTransform(augment=False)

    if trainDataPath is not None and realDataPath is not None:
        # Randomly sample showCount number of images from training and real folders
        train_img_paths = glob.glob(os.path.join(trainDataPath, '*.png'))
        train_img_paths = random.sample(train_img_paths, showCount)
        real_img_paths = glob.glob(os.path.join(realDataPath, '*.png'))
        real_img_paths = random.sample(real_img_paths, showCount)

        # Create samples from training and real image predictions
        finalResult = np.empty([0, 4 * model.width, 3], dtype=np.uint8)
        for train_img_path, real_img_path in zip(train_img_paths,
                                                 real_img_paths):
            train_img = cv2.imread(train_img_path, cv2.IMREAD_COLOR)
            train_img = cv2.resize(train_img, (model.width, model.height),
                                   cv2.INTER_LANCZOS4)
            real_img = cv2.imread(real_img_path, cv2.IMREAD_COLOR)
            real_img = cv2.resize(real_img, (model.width, model.height),
                                  cv2.INTER_LANCZOS4)

            img_batch = [train_img, real_img]
            img_batch = torch.stack(
                [testTransform(img_)[0] for img_ in img_batch])

            _, pred = torch.max(model.forward(img_batch), 1)
            pred = pred.byte()
            pred = [pred_.squeeze().numpy() for pred_ in pred]

            train_img2 = train_img.copy()
            train_img2[pred[0] == 1] = (0, 255, 0)  # Right lane
            train_img2[pred[0] == 2] = (255, 0, 0)  # Left lane
            train_img2[pred[0] == 3] = (0, 0, 255)  # Obstacles
            real_img2 = real_img.copy()
            real_img2[pred[1] == 1] = (0, 255, 0)  # Right lane
            real_img2[pred[1] == 2] = (255, 0, 0)  # Left lane
            real_img2[pred[1] == 3] = (0, 0, 255)  # Obstacles

            result = np.concatenate(
                (train_img, train_img2, real_img, real_img2), axis=1)
            finalResult = np.concatenate((finalResult, result), axis=0)

        cv2.imwrite('results/samplePredictions.png', finalResult)

    if testDataPath is not None:
        # Perform qualitative evaluation
        testDataset = RightLaneDataset(testDataPath,
                                       transform=testTransform,
                                       haveLabels=True)
        testDataLoader = DataLoader(testDataset,
                                    batch_size=32,
                                    shuffle=False,
                                    num_workers=8)

        if torch.cuda.is_available():
            model = model.cuda()

        test_acc, test_dice, test_iou = 0.0, 0.0, 0.0
        test_conf_matrix = torch.zeros(4, 4, device=model.device)
        totalWeight = 0
        for batch in tqdm(testDataLoader):
            img, label = batch
            if torch.cuda.is_available():
                img, label = img.cuda(), label.cuda()

            probas = model.forward(img)
            _, label_hat = torch.max(probas, 1)

            weight = img.shape[0]
            test_acc += accuracy(label_hat, label) * weight
            test_dice += dice_score(probas, label) * weight
            test_iou += iou(label_hat, label) * weight
            test_conf_matrix += confusion_matrix(label_hat,
                                                 label,
                                                 num_classes=4)
            totalWeight += weight

        assert totalWeight == len(testDataset)

        if len(testDataset) > 0:
            test_acc /= len(testDataset)
            test_dice /= len(testDataset)
            test_iou /= len(testDataset)

        print(f"Accuracy on test set: {test_acc * 100.0:.4f}%")
        print(f"Dice score on test set: {test_dice:.4f}")
        print(f"IoU on test set: {test_iou * 100.0:.4f}")
        print(f"Confusion matrix (column: prediction, row: label):")
        print(test_conf_matrix)
        print(f"Total: {torch.sum(test_conf_matrix)}")
Exemplo n.º 15
0
model.load_state_dict(state_dict, strict=True)

model = model.to(device)

model.eval()
torch.set_grad_enabled(False)

metrics = []

for x, y in tqdm(val_loader):
    x = x.to(device)
    y = y.to(device)

    y_pred = torch.argmax(model(x), 1)

    miou = confusion_matrix(y_pred[y != 255], y[y != 255], num_classes=19)
    metrics.append(miou)


cm = sum(metrics)

iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-15)

metrics = {
    # 'accuracy': state.metrics['accuracy'],
    'miou': iou.mean(),
    'iou': iou,
    # 'iou': {name: state.metrics['iou'][id].item() for id, name in enumerate(classes)},
}

pprint(metrics)
Exemplo n.º 16
0
    def on_validation_epoch_end(self, trainer, pl_module):

        for param in tqdm(self.info, leave=False, colour="blue"):
            if param not in ["alpha", "drift_norm", "model"]:
                continue
            info = torch.cat(self.info[param], dim=0)
            pred = torch.cat(self.preds[param], dim=0)

            if param == "model":
                info = info[:, 0]
                pred = torch.argmax(pred, dim=1)
                n_models = len(pl_module.hparams["RW_types"])
                CM = (
                    confusion_matrix(pred, info, n_models, normalize="true")
                    .detach()
                    .cpu()
                )

                fig = plt.figure()
                ax = fig.add_subplot(111)
                ax.imshow(CM, cmap="Blues", vmin=0.0, vmax=1.0)
                ax.set_xticks(np.arange(n_models))
                ax.set_xticklabels(pl_module.hparams["RW_types"])
                ax.set_yticks(np.arange(n_models))
                ax.set_yticklabels(pl_module.hparams["RW_types"])

                self.tb.add_figure("model_val", fig, global_step=self.round, close=True)

            elif param == "drift_norm":
                cond = is_concerned(info).detach().cpu().numpy()
                info = info[:, 0].detach().cpu().numpy()
                pred = pred[:, 0].detach().cpu().numpy()
                info[~cond] = 0.0

                plt.figure()
                plt.hist(pred[~cond])

                fig = plt.figure()
                ax = fig.add_subplot(211)
                ax.scatter(info[cond], pred[cond], label="BM & OU with drift", s=3)
                ax.plot([0.0, 0.5], [0.0, 0.5], ls=":", c="red")
                ax.set_xlim((-0.05, 0.55))
                ax.set_ylim((-0.05, 0.55))
                ax.set_xlabel("True drift")
                ax.set_ylabel("Inferred drift")
                ax.legend()

                ax = fig.add_subplot(212)
                ax.hist(
                    pred[~cond],
                    bins=30,
                    density=True,
                    label="Anomalous diffusion",
                    range=(-0.05, 1.0),
                    alpha=0.5,
                )
                ax.hist(
                    pred[cond & (info <= 0.25)],
                    bins=30,
                    density=True,
                    label="BM & OU w/ low drift",
                    range=(-0.05, 1.0),
                    alpha=0.5,
                )
                ax.hist(
                    pred[cond & (info > 0.25)],
                    bins=30,
                    density=True,
                    label="BM & OU w/ high drift",
                    range=(-0.05, 1.0),
                    alpha=0.5,
                )
                ax.set_xlabel("Inferred drift")
                ax.legend()

                plt.tight_layout()

                self.tb.add_figure("DNorm_val", fig, global_step=self.round, close=True)

        self.round += 1