def validator_epoch_comp_callback(engine): # log ignite metrics # logging_logger.info(engine.state.metrics) # ious = engine.state.metrics['iou'] # msg = 'IoU: ' # for ins_id, iou in enumerate(ious): # msg += '{:d}: {:.3f}, '.format(ins_id + 1, iou) # logging_logger.info(msg) # logging_logger.info('nonzero mean IoU for all data: {:.3f}'.format(ious[ious > 0].mean())) # log monitored epoch metrics epoch_metrics = engine.state.epoch_metrics ######### NOTICE: Two metrics are available but different ########## ### 1. mean metrics for all data calculated by confusion matrix #### ''' compared with using confusion_matrix[1:, 1:] in original code, we use the full confusion matrix and only present non-background result ''' confusion_matrix = epoch_metrics['confusion_matrix'] # [1:, 1:] ious = calculate_iou(confusion_matrix) dices = calculate_dice(confusion_matrix) mean_ious = np.mean(list(ious.values())) mean_dices = np.mean(list(dices.values())) std_ious = np.std(list(ious.values())) std_dices = np.std(list(dices.values())) logging_logger.info('mean IoU: %.3f, std: %.3f, for each class: %s' % (mean_ious, std_ious, ious)) logging_logger.info('mean Dice: %.3f, std: %.3f, for each class: %s' % (mean_dices, std_dices, dices)) ### 2. mean metrics for all data calculated by definition ### iou_data_mean = epoch_metrics['iou'].data_mean() dice_data_mean = epoch_metrics['dice'].data_mean() logging_logger.info('data (%d) mean IoU: %.3f, std: %.3f' % (len(iou_data_mean['items']), iou_data_mean['mean'], iou_data_mean['std'])) logging_logger.info('data (%d) mean Dice: %.3f, std: %.3f' % (len(dice_data_mean['items']), dice_data_mean['mean'], dice_data_mean['std'])) # record metrics in trainer every epoch # trainer.state.metrics_records[trainer.state.epoch] = \ # {'miou': mean_ious, 'std_miou': std_ious, # 'mdice': mean_dices, 'std_mdice': std_dices} trainer.state.metrics_records[trainer.state.epoch] = \ {'miou': iou_data_mean['mean'], 'std_miou': iou_data_mean['std'], 'mdice': dice_data_mean['mean'], 'std_mdice': dice_data_mean['std']}
def tb_log_valid_epoch_vars(engine, logger, event_name): log_tag = 'valid_iter' # log monitored epoch metrics epoch_metrics = engine.state.epoch_metrics confusion_matrix = epoch_metrics['confusion_matrix'] # [1:, 1:] ious = calculate_iou(confusion_matrix) dices = calculate_dice(confusion_matrix) mean_ious = np.mean(list(ious.values())) mean_dices = np.mean(list(dices.values())) logger.writer.add_scalar('mIoU', mean_ious, engine.state.epoch) logger.writer.add_scalar('mIoU', mean_dices, engine.state.epoch)
def valid_step(engine, batch): with torch.no_grad(): model.eval() inputs = batch['input'].cuda(non_blocking=True) targets = batch['target'].cuda(non_blocking=True) # additional arguments add_params = {} # for TAPNet, add attention maps if 'TAPNet' in args.model: add_params['attmap'] = batch['attmap'].cuda(non_blocking=True) # output logits outputs = model(inputs, **add_params) # loss loss = loss_func(outputs, targets) output_softmaxs = torch.softmax(outputs, dim=1) output_argmaxs = output_softmaxs.argmax(dim=1) # output_classes and target_classes: <b, h, w> output_classes = output_argmaxs.cpu().numpy() target_classes = targets.cpu().numpy() # record current batch metrics iou_mRecords = MetricRecord() dice_mRecords = MetricRecord() cm_b = np.zeros((num_classes, num_classes), dtype=np.uint32) for output_class, target_class in zip(output_classes, target_classes): # calculate metrics for each frame # calculate using confusion matrix or dirctly using definition cm = calculate_confusion_matrix_from_arrays( output_class, target_class, num_classes) iou_mRecords.update_record(calculate_iou(cm)) dice_mRecords.update_record(calculate_dice(cm)) cm_b += cm ######## calculate directly using definition ########## # iou_mRecords.update_record(iou_multi_np(target_class, output_class)) # dice_mRecords.update_record(dice_multi_np(target_class, output_class)) # accumulate batch metrics to engine state engine.state.epoch_metrics['confusion_matrix'] += cm_b engine.state.epoch_metrics['iou'].merge(iou_mRecords) engine.state.epoch_metrics['dice'].merge(dice_mRecords) return_dict = { 'loss': loss.item(), 'output': outputs, 'output_argmax': output_argmaxs, 'target': targets, # for monitoring 'iou': iou_mRecords, 'dice': dice_mRecords, } if 'TAPNet' in args.model: # for TAPNet, update attention maps after each iteration valid_loader.dataset.update_attmaps( output_softmaxs.cpu().numpy(), batch['abs_idx'].numpy()) # for TAPNet, return extra internal values return_dict['attmap'] = add_params['attmap'] # TODO: for TAPNet, return internal self-learned attention maps return return_dict