def validation(self, dataloader: TrainingAbnormalDataSet, _model: BaseModel) -> dict: from src.modeling.models.retinaNet.retinaNet import RetinaNet model = RetinaNet() model.load_state_dict(_model.state_dict()) model.to(config.validation_device) model.eval() self.log.info("Beginning Validation") dataloader.display_metrics(dataloader.get_metrics()) data = iter(DataLoader(dataloader, batch_size=config.batch_size, num_workers=4)) total = (len(dataloader) // config.batch_size) + 1 # idx 0 == correct, idx 1 == incorrect stats = { 'healthy': [0, 0], 'abnormal': [0, 0] } labels = ['healthy', 'abnormal'] for _, i in tqdm(enumerate(range(total)), total=len(range(total)), desc="Validating the model"): batch = next(data) for ky, val in batch.items(): # If we can, try to load up the batched data into the device (try to only send what is needed) if isinstance(batch[ky], torch.Tensor): batch[ky] = batch[ky].to(config.validation_device) y: torch.Tensor = torch.argmax(batch['label'], 1) preds = model(batch) predictions = torch.argmax(preds['preds'], 1) for idx, prediction in enumerate(predictions.tolist()): if prediction == y[idx]: stats[labels[y[idx]]][0] += 1 else: stats[labels[y[idx]]][1] += 1 table = [] for stat in stats: table.append([stat, stats[stat][0], stats[stat][1]]) self.log.info(f'\n-- Validation Report --\n{tabulate(table, headers=["Type","Correct","Incorrect"])}') model.train() return stats
def test_load_records(self): data_loader = TrainingAbnormalDataSet() records = data_loader.load_records() loader = DataLoader(data_loader, batch_size=4, shuffle=True, num_workers=4) for batch in loader: print(batch) pass assert len(records) > 0
loss = self.criterion(predictions, data['label']) return {'loss': loss} if __name__ == "__main__": from src.data.abnormal_dataset import TrainingAbnormalDataSet from src.training_tasks.tasks.AbnormalClassificationTask import AbnormalClassificationTask from src.utils.hooks import StepTimer, PeriodicStepFuncHook, LogTrainingLoss from torch import optim from src.training_tasks import BackpropAggregators model = Res50() dataloader = TrainingAbnormalDataSet() dataloader.load_records(keep_annotations=False) train_dl, val_dl = dataloader.partition_data([0.75, 0.25], TrainingAbnormalDataSet) task = AbnormalClassificationTask(model, train_dl, optim.Adam(model.parameters(), lr=0.0001), backward_agg=BackpropAggregators.MeanLosses) task.max_iter = 25000 val_hook = PeriodicStepFuncHook(5000, lambda: task.validation(val_dl, model)) checkpoint_hook = CheckpointHook(1000, "resnet50_test3") task.register_hook(LogTrainingLoss()) task.register_hook(StepTimer()) task.register_hook(val_hook) task.register_hook(checkpoint_hook)
def annotation_validation(self, dataloader: TrainingAbnormalDataSet, _model: BaseModel) -> dict: from src.modeling.models.retinaNetFPN.retinaNetFPN import RetinaNetFPN model = RetinaNetFPN() model.load_state_dict(_model.state_dict()) model.to(config.validation_device) model.eval() self.log.info("Beginning Validation") dataloader.display_metrics(dataloader.get_metrics()) data = iter(DataLoader(dataloader, batch_size=config.batch_size, num_workers=4, collate_fn=self.collater)) total = (len(dataloader) // config.batch_size) + 1 # idx 0 == correct, idx 1 == incorrect stats = { 'healthy': [0, 0], 'abnormal': [0, 0] } labels = ['healthy', 'abnormal'] det = [] ann = [] image_id = 0 image_id = 0 for _, i in tqdm(enumerate(range(total)), total=len(range(total)), desc="Validating the model"): batch = next(data) for ky, val in batch.items(): # If we can, try to load up the batched data into the device (try to only send what is needed) if isinstance(batch[ky], torch.Tensor): batch[ky] = batch[ky].to(config.validation_device) predictions = model(batch) for idx, pred in enumerate(predictions): annotation = batch['annotations'][idx] for p_idx in range(len(pred['boxes'])): det.append([f'{image_id}', pred['labels'][p_idx].item(), pred['scores'][p_idx].item(), pred['boxes'][p_idx][0].item() / 256.0, pred['boxes'][p_idx][1].item() / 256.0, pred['boxes'][p_idx][2].item() / 256.0, pred['boxes'][p_idx][3].item() / 256.0]) for a_idx in range(len(batch['annotations'][idx]['boxes'])): ann.append([f'{image_id}', torch.argmax(annotation['labels'][a_idx], 0).item(), annotation['boxes'][a_idx][0].item() / 256.0, annotation['boxes'][a_idx][1].item() / 256.0, annotation['boxes'][a_idx][2].item() / 256.0, annotation['boxes'][a_idx][3].item() / 256.0]) image_id += 1 for idx in range(len(ann)): ann[idx][1] = 'healthy' if ann[idx][1] == 0 else 'abnormal' for idx in range(len(det)): det[idx][1] = 'healthy' if det[idx][1] == 0 else 'abnormal' mean_ap, average_precisions = mean_average_precision_for_boxes(ann, det) # table = [] # for stat in stats: # table.append([stat, stats[stat][0], stats[stat][1]]) # # self.log.info(f'\n-- Validation Report --\n{tabulate(table, headers=["Type", "Correct", "Incorrect"])}') return stats
x = self.FC2(x) predictions = self.LSoftmax(x) out = {'preds': predictions} if self.training: out['losses'] = self.loss(out, data) return out def loss(self, predictions: dict, data: dict) -> dict: predictions: torch.Tensor = predictions['preds'] loss = self.criterion(predictions, data['label']) return {'loss': loss} if __name__ == "__main__": from src.data.abnormal_dataset import TrainingAbnormalDataSet from src.training_tasks.tasks.AbnormalClassificationTask import AbnormalClassificationTask model = SimpleNet() dataloader = TrainingAbnormalDataSet() dataloader.load_records() training_task = AbnormalClassificationTask("abnormal_classification_task", checkpoint_frequency=1, validation_frequency=1) training_task.register_training_data(dataloader, train_to_val_split=0.75) training_task.begin_or_resume(model)