def validate_model(model, data_val, criterion, epoch, make_prediction=True, save_folder_name='prediction'): """ Validation run """ # calculating validation loss total_val_loss = 0 total_val_acc = 0 for batch, (images_v, masks_v, original_msk) in enumerate(data_val): stacked_img = torch.Tensor([]).cuda() for index in range(images_v.size()[1]): with torch.no_grad(): image_v = Variable(images_v[:, index, :, :].unsqueeze(0).cuda()) mask_v = Variable(masks_v[:, index, :, :].squeeze(1).cuda()) # print(image_v.shape, mask_v.shape) output_v = model(image_v) total_val_loss = total_val_loss + criterion( output_v, mask_v).cpu().item() # print('out', output_v.shape) output_v = torch.argmax(output_v, dim=1).float() stacked_img = torch.cat((stacked_img, output_v)) if make_prediction: im_name = batch # TODO: Change this to real image name so we know pred_msk = save_prediction_image(stacked_img, im_name, epoch, save_folder_name) acc_val = accuracy_check(original_msk, pred_msk) total_val_acc = total_val_acc + acc_val return total_val_acc / (batch + 1), total_val_loss / ((batch + 1) * 4)
def test_model( model_path, data_test, epoch, save_folder_name='prediction', save_dir="./history/RMS/result_images_test", save_file_name="./history/RMS/result_images_test/history_RMS3.csv"): """ Test run """ model = torch.load(model_path) model = torch.nn.DataParallel( model, device_ids=list(range(torch.cuda.device_count()))).cuda() model.eval() total_val_acc = 0 total_val_jac = 0 total_val_dice = 0 for batch, (images_v, masks_v, original_msk) in enumerate(data_test): pdsz = 20 ori_shape = original_msk.shape original_msk = original_msk[..., pdsz:-pdsz, pdsz:-pdsz, pdsz:-pdsz] stacked_img = torch.Tensor([]).cuda() stacked_reg = torch.Tensor([]).cuda() for index in range(images_v.size()[1]): with torch.no_grad(): image_v = Variable(images_v[:, index, :, :].unsqueeze(0).cuda()) output_v, output_r_v = model(image_v) output_v = torch.argmax(output_v, dim=1).float() stacked_img = torch.cat((stacked_img, output_v)) output_r_v = torch.squeeze(output_r_v, dim=0) stacked_reg = torch.cat((stacked_reg, output_r_v)) im_name = batch # TODO: Change this to real image name so we know pred_msk = save_prediction_image(stacked_img, ori_shape[-3:], im_name, epoch, 0, save_folder_name) acc_val = accuracy_check(original_msk, pred_msk) avg_dice, jac = dice_coeff(pred_msk, original_msk) total_val_jac += jac total_val_dice += avg_dice total_val_acc = total_val_acc + acc_val reconstruct_image(stacked_reg, ori_shape[-3:], epoch, save_folder_name) print("total_val_acc is:%f. total_val_jac is:%f . total_val_dice is:%f " "Finish Prediction!" % (total_val_acc / (batch + 1), total_val_jac / (batch + 1), total_val_dice / (batch + 1))) header = ['epoch', 'total_val_jac', 'total_val_dice', 'total_val_acc'] values = [ epoch, total_val_jac / (batch + 1), total_val_dice / (batch + 1), total_val_acc / (batch + 1) ] export_history(header, values, save_dir, save_file_name) return total_val_acc / (batch + 1)
def validate_hard_model(model, data_val, criterion, epoch, make_prediction=True, save_folder_name='prediction'): """ Validation run """ # calculating validation loss model.eval() total_val_loss = 0 total_val_acc = 0 total_val_dice = 0 total_val_jac = 0 count_val = 0 for batch, (images_v, masks_v, original_msk) in enumerate(data_val): pdsz = 20 ori_shape = original_msk.shape original_msk = original_msk[..., pdsz:-pdsz, pdsz:-pdsz, pdsz:-pdsz] stacked_img = torch.Tensor([]).cuda() for index in range(images_v.size()[1]): with torch.no_grad(): image_v = Variable(images_v[:, index, :, :].unsqueeze(0).cuda()) mask_v = Variable(masks_v[:, index, :, :].squeeze(1).cuda()) # print(image_v.shape, mask_v.shape) output_v, output_r_v = model(image_v) total_val_loss = total_val_loss + criterion( output_v, mask_v).cpu().item() # print('out', output_v.shape) output_v = torch.argmax(output_v, dim=1).float() stacked_img = torch.cat((stacked_img, output_v)) if make_prediction: im_name = batch # TODO: Change this to real image name so we know pred_msk = save_prediction_image(stacked_img, ori_shape[-3:], im_name, epoch, 1, save_folder_name) acc_val = accuracy_check(original_msk, pred_msk) dice, jac = dice_coeff(pred_msk, original_msk) total_val_acc = total_val_acc + acc_val total_val_dice = total_val_dice + dice total_val_jac = total_val_jac + jac count_val += 1 return total_val_acc/(count_val + 1), total_val_loss/((count_val + 1)*4),\ total_val_dice/(count_val + 1), total_val_jac/(count_val + 1)