def validation(self, dataloader: TrainingAbnormalDataSet, _model: BaseModel) -> dict:
        from src.modeling.models.retinaNet.retinaNet import RetinaNet

        model = RetinaNet()
        model.load_state_dict(_model.state_dict())
        model.to(config.validation_device)
        model.eval()

        self.log.info("Beginning Validation")

        dataloader.display_metrics(dataloader.get_metrics())

        data = iter(DataLoader(dataloader, batch_size=config.batch_size, num_workers=4))
        total = (len(dataloader) // config.batch_size) + 1

        # idx 0 == correct, idx 1 == incorrect
        stats = {
            'healthy': [0, 0],
            'abnormal': [0, 0]
        }

        labels = ['healthy', 'abnormal']

        for _, i in tqdm(enumerate(range(total)), total=len(range(total)), desc="Validating the model"):
            batch = next(data)

            for ky, val in batch.items():
                # If we can, try to load up the batched data into the device (try to only send what is needed)
                if isinstance(batch[ky], torch.Tensor):
                    batch[ky] = batch[ky].to(config.validation_device)

            y: torch.Tensor = torch.argmax(batch['label'], 1)

            preds = model(batch)
            predictions = torch.argmax(preds['preds'], 1)

            for idx, prediction in enumerate(predictions.tolist()):
                if prediction == y[idx]:
                    stats[labels[y[idx]]][0] += 1
                else:
                    stats[labels[y[idx]]][1] += 1

        table = []
        for stat in stats:
            table.append([stat, stats[stat][0], stats[stat][1]])

        self.log.info(f'\n-- Validation Report --\n{tabulate(table, headers=["Type","Correct","Incorrect"])}')

        model.train()

        return stats
예제 #2
0
    def test_load_records(self):
        data_loader = TrainingAbnormalDataSet()

        records = data_loader.load_records()

        loader = DataLoader(data_loader,
                            batch_size=4,
                            shuffle=True,
                            num_workers=4)

        for batch in loader:
            print(batch)
            pass

        assert len(records) > 0
        loss = self.criterion(predictions, data['label'])
        return {'loss': loss}



if __name__ == "__main__":
    from src.data.abnormal_dataset import TrainingAbnormalDataSet
    from src.training_tasks.tasks.AbnormalClassificationTask import AbnormalClassificationTask
    from src.utils.hooks import StepTimer, PeriodicStepFuncHook, LogTrainingLoss
    from torch import optim

    from src.training_tasks import BackpropAggregators

    model = Res50()

    dataloader = TrainingAbnormalDataSet()
    dataloader.load_records(keep_annotations=False)

    train_dl, val_dl = dataloader.partition_data([0.75, 0.25], TrainingAbnormalDataSet)

    task = AbnormalClassificationTask(model, train_dl, optim.Adam(model.parameters(), lr=0.0001), backward_agg=BackpropAggregators.MeanLosses)
    task.max_iter = 25000

    val_hook = PeriodicStepFuncHook(5000, lambda: task.validation(val_dl, model))
    checkpoint_hook = CheckpointHook(1000, "resnet50_test3")

    task.register_hook(LogTrainingLoss())
    task.register_hook(StepTimer())
    task.register_hook(val_hook)
    task.register_hook(checkpoint_hook)
    def annotation_validation(self, dataloader: TrainingAbnormalDataSet, _model: BaseModel) -> dict:
        from src.modeling.models.retinaNetFPN.retinaNetFPN import RetinaNetFPN

        model = RetinaNetFPN()
        model.load_state_dict(_model.state_dict())
        model.to(config.validation_device)
        model.eval()

        self.log.info("Beginning Validation")

        dataloader.display_metrics(dataloader.get_metrics())

        data = iter(DataLoader(dataloader, batch_size=config.batch_size, num_workers=4, collate_fn=self.collater))
        total = (len(dataloader) // config.batch_size) + 1

        # idx 0 == correct, idx 1 == incorrect
        stats = {
            'healthy': [0, 0],
            'abnormal': [0, 0]
        }

        labels = ['healthy', 'abnormal']

        det = []
        ann = []

        image_id = 0
        image_id = 0

        for _, i in tqdm(enumerate(range(total)), total=len(range(total)), desc="Validating the model"):
            batch = next(data)

            for ky, val in batch.items():
                # If we can, try to load up the batched data into the device (try to only send what is needed)
                if isinstance(batch[ky], torch.Tensor):
                    batch[ky] = batch[ky].to(config.validation_device)

            predictions = model(batch)

            for idx, pred in enumerate(predictions):
                annotation = batch['annotations'][idx]

                for p_idx in range(len(pred['boxes'])):
                    det.append([f'{image_id}', pred['labels'][p_idx].item(), pred['scores'][p_idx].item(), pred['boxes'][p_idx][0].item() / 256.0, pred['boxes'][p_idx][1].item() / 256.0, pred['boxes'][p_idx][2].item() / 256.0, pred['boxes'][p_idx][3].item() / 256.0])

                for a_idx in range(len(batch['annotations'][idx]['boxes'])):
                    ann.append([f'{image_id}', torch.argmax(annotation['labels'][a_idx], 0).item(), annotation['boxes'][a_idx][0].item() / 256.0, annotation['boxes'][a_idx][1].item() / 256.0, annotation['boxes'][a_idx][2].item() / 256.0, annotation['boxes'][a_idx][3].item() / 256.0])

                image_id += 1

        for idx in range(len(ann)):
            ann[idx][1] = 'healthy' if ann[idx][1] == 0 else 'abnormal'
        for idx in range(len(det)):
            det[idx][1] = 'healthy' if det[idx][1] == 0 else 'abnormal'

        mean_ap, average_precisions = mean_average_precision_for_boxes(ann, det)


        # table = []
        # for stat in stats:
        #     table.append([stat, stats[stat][0], stats[stat][1]])
        #
        # self.log.info(f'\n-- Validation Report --\n{tabulate(table, headers=["Type", "Correct", "Incorrect"])}')

        return stats
예제 #5
0
        x = self.FC2(x)

        predictions = self.LSoftmax(x)

        out = {'preds': predictions}

        if self.training:
            out['losses'] = self.loss(out, data)

        return out

    def loss(self, predictions: dict, data: dict) -> dict:
        predictions: torch.Tensor = predictions['preds']
        loss = self.criterion(predictions, data['label'])
        return {'loss': loss}



if __name__ == "__main__":
    from src.data.abnormal_dataset import TrainingAbnormalDataSet
    from src.training_tasks.tasks.AbnormalClassificationTask import AbnormalClassificationTask

    model = SimpleNet()
    dataloader = TrainingAbnormalDataSet()
    dataloader.load_records()

    training_task = AbnormalClassificationTask("abnormal_classification_task", checkpoint_frequency=1, validation_frequency=1)
    training_task.register_training_data(dataloader, train_to_val_split=0.75)
    training_task.begin_or_resume(model)