def ComputeDiceWNet(result_path, t2_folder, data_type, th=0.3): label_arr = np.load( os.path.join(result_path, '{}_label1.npy'.format(data_type))) preds_arr1 = np.load( os.path.join(result_path, '{}_preds1.npy'.format(data_type))) preds_arr2 = np.load( os.path.join(result_path, '{}_preds2.npy'.format(data_type))) dice1_list, dice2_list, dice3_list, dice4_list = [], [], [], [] case_name = pd.read_csv( r'/home/zhangyihong/Documents/ProstateX_Seg_ZYH/OneSlice/{}_name.csv'. format(data_type)) case_list = sorted(case_name.loc[0].tolist()) for index in range(preds_arr1.shape[0]): # t2 = np.squeeze(np.load(os.path.join(t2_folder, case_list[index] + '.npy'))) # t2_crop, _ = ExtractPatch(t2, (200, 200)) pred1 = ROIOneHot(np.argmax(preds_arr1[index], axis=0)) pred2 = ROIOneHot(np.argmax(preds_arr2[index], axis=0)) label = label_arr[index] dice1_list.append( Dice((preds_arr1[index] > th).astype(int), label_arr1[index])) dice2_list.append( Dice((preds_arr2[index] > th).astype(int), label_arr2[index])) dice3_list.append( Dice((preds_arr3[index] > th).astype(int), label_arr3[index])) dice4_list.append( Dice((preds_arr4[index] > th).astype(int), label_arr4[index])) return dice1_list, dice2_list, dice3_list, dice4_list
def case_test(): seg = ProstateXSeg((200, 200)) data_path = r'/home/zhangyihong/Documents/ProstateX_Seg_ZYH/OriginalData' model_folder = r'/home/zhangyihong/Documents/ProstateX_Seg_ZYH/Model' model = UNet25D(n_channels=1, n_classes=5, bilinear=True, factor=2) dice1_list, dice2_list, dice3_list, dice4_list, dice5_list = [], [], [], [], [] # df = pd.read_csv(os.path.join(r'/home/zhangyihong/Documents/ProstateX_Seg_ZYH/OneSlice', 'train_case_name.csv')) # df = pd.read_csv(os.path.join(r'/home/zhangyihong/Documents/ProstateX_Seg_ZYH/OneSlice', 'val_case_name.csv')) df = pd.read_csv( os.path.join( r'/home/zhangyihong/Documents/ProstateX_Seg_ZYH/OneSlice', 'test_case_name.csv')) case_list = df.values.tolist()[0] for case in case_list: t2_arr, roi_arr = seg.Nii2NPY(case, os.path.join(data_path, case), r'', slice_num=3) # UNet # 13-2.679143.pt # 26-1.161448.pt UNet_0330 preds = seg.run(case, model, model_path=os.path.join(model_folder, 'UNet_0330_weightedloss'), weights_path='13-2.679143.pt', inputs=t2_arr, outputs=roi_arr, is_save=False) if isinstance(preds, tuple): preds = preds[-1] preds = torch.softmax(preds, dim=1) preds = torch.argmax(preds, dim=1).cpu().data.numpy() preds = ROIOneHot(preds) # roi_arr = np.argmax(roi_arr, axis=1) # roi_arr = np.clip(roi_arr, a_min=0, a_max=2) # roi_arr = ROIOneHot(roi_arr, roi_class=[0, 1, 2]) dice1_list.append(Dice(preds[:, 0, ...], roi_arr[:, 0, ...])) dice2_list.append(Dice(preds[:, 1, ...], roi_arr[:, 1, ...])) dice3_list.append(Dice(preds[:, 2, ...], roi_arr[:, 2, ...])) dice4_list.append(Dice(preds[:, 3, ...], roi_arr[:, 3, ...])) dice5_list.append(Dice(preds[:, 4, ...], roi_arr[:, 4, ...])) print('{:.3f}, {:.3f}, {:.3f}, {:.3f}, {:.3f}'.format( sum(dice1_list) / len(dice1_list), sum(dice2_list) / len(dice2_list), sum(dice3_list) / len(dice3_list), sum(dice4_list) / len(dice4_list), sum(dice5_list) / len(dice5_list)))
def ComputeDice(): dice0_list = [] dice1_list = [] dice2_list = [] dice3_list = [] dice4_list = [] for case in os.listdir(case_folder): case_label = case[:case.index('_0000')] + '.nii.gz' case_path = os.path.join(case_folder, case) label_path = os.path.join(label_folder, case_label) predict_path = os.path.join(predict_folder, case_label) _, t2, _ = LoadImage(case_path) _, label, _ = LoadImage(label_path) label = ROIOneHot(label) _, prediction, _ = LoadImage(predict_path) prediction = ROIOneHot(prediction) bg_label = label[:, 0, :, :] pz_label = label[:, 1, :, :] cg_label = label[:, 2, :, :] u_label = label[:, 3, :, :] as_label = label[:, 4, :, :] bg_pred = prediction[:, 0, :, :] pz_pred = prediction[:, 1, :, :] cg_pred = prediction[:, 2, :, :] u_pred = prediction[:, 3, :, :] as_pred = prediction[:, 4, :, :] dice0_list.append(Dice(bg_pred, bg_label)) dice1_list.append(Dice(pz_pred, pz_label)) dice2_list.append(Dice(cg_pred, cg_label)) dice3_list.append(Dice(u_pred, u_label)) dice4_list.append(Dice(as_pred, as_label)) # Imshow3DArray(Normalize01(t2), roi=[Normalize01(label), Normalize01(prediction)]) print('{:.3f}, {:.3f}, {:.3f}, {:.3f}, {:.3f}'.format( sum(dice1_list) / len(dice1_list), sum(dice2_list) / len(dice2_list), sum(dice3_list) / len(dice3_list), sum(dice4_list) / len(dice4_list), sum(dice0_list) / len(dice0_list)))
def Train(): torch.autograd.set_detect_anomaly(True) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') input_shape = (200, 200) total_epoch = 10000 batch_size = 24 model_folder = MakeFolder(model_root + '/UNet_0330_weigthedloss') ClearGraphPath(model_folder) param_config = { RotateTransform.name: {'theta': ['uniform', -10, 10]}, ShiftTransform.name: {'horizontal_shift': ['uniform', -0.05, 0.05], 'vertical_shift': ['uniform', -0.05, 0.05]}, ZoomTransform.name: {'horizontal_zoom': ['uniform', 0.95, 1.05], 'vertical_zoom': ['uniform', 0.95, 1.05]}, FlipTransform.name: {'horizontal_flip': ['choice', True, False]}, BiasTransform.name: {'center': ['uniform', -1., 1., 2], 'drop_ratio': ['uniform', 0., 1.]}, NoiseTransform.name: {'noise_sigma': ['uniform', 0., 0.03]}, ContrastTransform.name: {'factor': ['uniform', 0.8, 1.2]}, GammaTransform.name: {'gamma': ['uniform', 0.8, 1.2]}, ElasticTransform.name: ['elastic', 1, 0.1, 256] } train_df = pd.read_csv(os.path.join(data_root, 'train_name.csv')) train_list = train_df.values.tolist()[0] val_df = pd.read_csv(os.path.join(data_root, 'val_name.csv')) val_list = val_df.values.tolist()[0] train_loader, train_batches = _GetLoader(train_list, param_config, input_shape, batch_size, True) val_loader, val_batches = _GetLoader(val_list, param_config, input_shape, batch_size, True) model = UNet25D(n_channels=1, n_classes=5, bilinear=True, factor=2).to(device) model.apply(HeWeightInit) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) weight = torch.from_numpy(np.array([0.1, 0.8, 0.8, 1., 1.])).float() criterion1 = torch.nn.CrossEntropyLoss(weight=weight.to(device)) criterion2 = DiceLoss() scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=10, factor=0.5, verbose=True) early_stopping = EarlyStopping(store_path=str(model_folder / '{}-{:.6f}.pt'), patience=50, verbose=True) writer = SummaryWriter(log_dir=str(model_folder / 'log'), comment='Net') for epoch in range(total_epoch): train_dice, val_dice = [], [] train_dice_pz, val_dice_pz = [], [] train_dice_cg, val_dice_cg = [], [] train_dice_U, val_dice_U = [], [] train_dice_AFMS, val_dice_AFMS = [], [] train_loss, val_loss = 0., 0. train_loss1, val_loss1 = 0., 0. train_loss2, val_loss2 = 0., 0. model.train() for ind, (inputs, outputs) in enumerate(train_loader): # outputs_nocoding = torch.argmax(outputs, dim=1) inputs = MoveTensorsToDevice(inputs, device) outputs_nocoding = MoveTensorsToDevice(outputs_nocoding, device) outputs = MoveTensorsToDevice(outputs.int(), device) preds = model(inputs) # Crossentropy Loss: preds: logits(没有做softmax) # labels: 没有做编码 # Dice Loss: preds: 要做softmax # labels: 要做编码,格式和preds相同 softmax_preds = F.softmax(preds, dim=1) train_dice.append(Dice(softmax_preds.cpu().data.numpy(), outputs.cpu().data.numpy())) train_dice_pz.append(Dice(softmax_preds.cpu().data.numpy()[:, 1], outputs.cpu().data.numpy()[:, 1])) train_dice_cg.append(Dice(softmax_preds.cpu().data.numpy()[:, 2], outputs.cpu().data.numpy()[:, 2])) train_dice_U.append(Dice(softmax_preds.cpu().data.numpy()[:, 3], outputs.cpu().data.numpy()[:, 3])) train_dice_AFMS.append(Dice(softmax_preds.cpu().data.numpy()[:, 4], outputs.cpu().data.numpy()[:, 4])) loss1 = criterion1(preds, outputs_nocoding) loss2 = criterion2(softmax_preds, outputs) loss = loss1 + loss2 optimizer.zero_grad() loss.backward() optimizer.step() train_loss1 += loss1.item() train_loss2 += loss2.item() train_loss += loss.item() model.eval() with torch.no_grad(): for ind, (inputs, outputs) in enumerate(val_loader): outputs_nocoding = torch.argmax(outputs, dim=1) inputs = MoveTensorsToDevice(inputs, device) outputs_nocoding = MoveTensorsToDevice(outputs_nocoding, device) outputs = MoveTensorsToDevice(outputs.int(), device) preds = model(inputs) softmax_preds = F.softmax(preds, dim=1) val_dice.append(Dice(softmax_preds.cpu().data.numpy(), outputs.cpu().data.numpy())) val_dice_pz.append(Dice(softmax_preds.cpu().data.numpy()[:, 1], outputs.cpu().data.numpy()[:, 1])) val_dice_cg.append(Dice(softmax_preds.cpu().data.numpy()[:, 2], outputs.cpu().data.numpy()[:, 2])) val_dice_U.append(Dice(softmax_preds.cpu().data.numpy()[:, 3], outputs.cpu().data.numpy()[:, 3])) val_dice_AFMS.append(Dice(softmax_preds.cpu().data.numpy()[:, 4], outputs.cpu().data.numpy()[:, 4])) loss1 = criterion1(preds, outputs_nocoding) loss2 = criterion2(softmax_preds, outputs) loss = loss1 + loss2 val_loss += loss.item() val_loss1 += loss1.item() val_loss2 += loss2.item() # Save Tensor Board for index, (name, param) in enumerate(model.named_parameters()): if 'bn' not in name: writer.add_histogram(name + '_data', param.cpu().data.numpy(), epoch + 1) writer.add_scalars('Loss', {'train_loss': train_loss / train_batches, 'val_loss': val_loss / val_batches}, epoch + 1) writer.add_scalars('Crossentropy Loss', {'train_loss': train_loss1 / train_batches, 'val_loss': val_loss1 / val_batches}, epoch + 1) writer.add_scalars('Dice Loss', {'train_loss': train_loss2 / train_batches, 'val_loss': val_loss2 / val_batches}, epoch + 1) writer.add_scalars('Dice', {'train_loss': np.sum(train_dice) / len(train_dice), 'val_loss': np.sum(val_dice) / len(val_dice)}, epoch + 1) print('*************************************** Epoch {} | (◕ᴗ◕✿) ***************************************'.format(epoch + 1)) print(' dice pz: {:.3f}, dice cg: {:.3f}, dice U: {:.3f}, dice AFMS: {:.3f}'. format(np.sum(train_dice_pz) / len(train_dice_pz), np.sum(train_dice_cg) / len(train_dice_cg), np.sum(train_dice_U) / len(train_dice_U), np.sum(train_dice_AFMS) / len(train_dice_AFMS))) print('val-dice pz: {:.3f}, val-dice cg: {:.3f}, val-dice U: {:.3f}, val-dice AFMS: {:.3f}'. format(np.sum(val_dice_pz) / len(val_dice_pz), np.sum(val_dice_cg) / len(val_dice_cg), np.sum(val_dice_U) / len(val_dice_U), np.sum(val_dice_AFMS) / len(val_dice_AFMS))) print() print('loss: {:.3f}, val-loss: {:.3f}'.format(train_loss / train_batches, val_loss / val_batches)) scheduler.step(val_loss) early_stopping(val_loss, model, (epoch + 1, val_loss)) if early_stopping.early_stop: print("Early stopping") break writer.flush() writer.close()
def ComputeDice(result_path, t2_folder, data_type): if len(os.listdir(result_path)) == 9: preds_arr = np.load( os.path.join(result_path, '{}_preds2.npy'.format(data_type))) label_arr = np.load( os.path.join(result_path, '{}_label.npy'.format(data_type))) else: preds_arr = np.load( os.path.join(result_path, '{}_preds.npy'.format(data_type))) label_arr = np.load( os.path.join(result_path, '{}_label.npy'.format(data_type))) dice1_list, dice2_list, dice3_list, dice4_list, dice5_list = [], [], [], [], [] # case_name = pd.read_csv(r'/home/zhangyihong/Documents/ProstateX_Seg_ZYH/OneSlice/{}_name.csv'.format(data_type)) # case_list = sorted(case_name.loc[0].tolist()) for index in range(preds_arr.shape[0]): pred = ROIOneHot(np.argmax(preds_arr[index], axis=0)) label = label_arr[index] # if os.path.exists(os.path.join(t2_folder, case_list[index] + '.npy')): # t2 = np.squeeze(np.load(os.path.join(t2_folder, case_list[index] + '.npy'))) # t2_crop, _ = ExtractPatch(t2, (200, 200)) dice1_list.append(Dice(pred[0], label[0])) dice2_list.append(Dice(pred[1], label[1])) dice3_list.append(Dice(pred[2], label[2])) dice4_list.append(Dice(pred[3], label[3])) dice5_list.append(Dice(pred[4], label[4])) ########################################show result############################################################# # plt.subplot(231) # plt.axis('off') # plt.imshow(pred[0], cmap='gray', vmin=0., vmax=1.) # plt.contour(label[0], colors='r') # plt.subplot(232) # plt.axis('off') # plt.imshow(pred[1], cmap='gray', vmin=0., vmax=1.) # plt.contour(label[1], colors='r') # plt.subplot(234) # plt.axis('off') # plt.imshow(pred[2], cmap='gray', vmin=0., vmax=1.) # plt.contour(label[2], colors='r') # plt.subplot(235) # plt.axis('off') # plt.imshow(pred[3], cmap='gray', vmin=0., vmax=1.) # plt.contour(label[3], colors='r') # plt.subplot(233) # plt.axis('off') # plt.imshow(t2_crop, cmap='gray') # plt.contour(label[0], colors='r') # plt.contour(label[1], colors='y') # plt.contour(label[2], colors='b') # plt.contour(label[3], colors='g') # plt.show() plt.subplot(221) plt.title('aver: {:.3f}'.format(sum(dice2_list) / len(dice2_list))) plt.hist(dice2_list, bins=20) plt.subplot(222) plt.title('aver: {:.3f}'.format(sum(dice3_list) / len(dice3_list))) plt.hist(dice3_list, bins=20) plt.subplot(223) plt.title('aver: {:.3f}'.format(sum(dice4_list) / len(dice4_list))) plt.hist(dice4_list, bins=20) plt.subplot(224) plt.title('aver: {:.3f}'.format(sum(dice5_list) / len(dice5_list))) plt.hist(dice5_list, bins=20) plt.show() return dice1_list, dice2_list, dice3_list, dice4_list, dice5_list
def ShoweResult(model_folder, data_type='train', num_pred=1, save_path=r''): if save_path and not os.path.exists(save_path): os.mkdir(save_path) result_folder = os.path.join(model_folder, 'Result') label_path = os.path.join(result_folder, '{}_label.npy'.format(data_type)) if num_pred == 1: pred_path = os.path.join(result_folder, '{}_preds.npy'.format(data_type)) label = np.load(label_path) pred = np.load(pred_path) for index in range(label.shape[0]): pred_index = ROIOneHot(np.argmax(pred[index], axis=0)) plt.figure(figsize=(12, 4)) plt.subplot(251) plt.axis('off') plt.imshow(label[index][0, ...], cmap='gray') plt.subplot(252) plt.axis('off') plt.imshow(label[index][1, ...], cmap='gray') plt.subplot(253) plt.axis('off') plt.imshow(label[index][2, ...], cmap='gray') plt.subplot(254) plt.axis('off') plt.imshow(label[index][3, ...], cmap='gray') plt.subplot(255) plt.axis('off') plt.imshow(label[index][4, ...], cmap='gray') plt.subplot(256) plt.title('{:.3f}'.format(Dice(pred_index[0], label[index][0]))) plt.axis('off') plt.imshow(pred[index][0, ...], cmap='gray') plt.subplot(257) plt.title('{:.3f}'.format(Dice(pred_index[1], label[index][1]))) plt.axis('off') plt.imshow(pred[index][1, ...], cmap='gray') plt.subplot(258) plt.axis('off') plt.title('{:.3f}'.format(Dice(pred_index[2], label[index][2]))) plt.imshow(pred[index][2, ...], cmap='gray') plt.subplot(259) plt.axis('off') plt.title('{:.3f}'.format(Dice(pred_index[3], label[index][3]))) plt.imshow(pred[index][3, ...], cmap='gray') plt.subplot(2, 5, 10) plt.axis('off') plt.title('{:.3f}'.format(Dice(pred_index[4], label[index][4]))) plt.imshow(pred[index][4, ...], cmap='gray') if save_path: plt.savefig(os.path.join(save_path, 'test_{}.jpg'.format(index))) plt.close() else: plt.show() elif num_pred == 2: pred_path1 = os.path.join(result_folder, '{}_preds1.npy'.format(data_type)) pred_path2 = os.path.join(result_folder, '{}_preds2.npy'.format(data_type)) label = np.load(label_path) pred1 = np.load(pred_path1) pred2 = np.load(pred_path2) for index in range(label.shape[0]): pred_index = ROIOneHot(np.argmax(pred2[index], axis=0)) plt.figure(figsize=(12, 4)) ############################################################################################################ plt.subplot(351) plt.axis('off') plt.imshow(label[index][0, ...], cmap='gray') plt.subplot(352) plt.axis('off') plt.imshow(label[index][1, ...], cmap='gray') plt.subplot(353) plt.axis('off') plt.imshow(label[index][2, ...], cmap='gray') plt.subplot(354) plt.axis('off') plt.imshow(label[index][3, ...], cmap='gray') plt.subplot(355) plt.axis('off') plt.imshow(label[index][4, ...], cmap='gray') ############################################################################################################ plt.subplot(356) plt.axis('off') plt.imshow(pred1[index][0, ...], cmap='gray') plt.subplot(357) plt.axis('off') plt.imshow(pred1[index][1, ...], cmap='gray') plt.subplot(358) plt.axis('off') plt.imshow(pred1[index][2, ...], cmap='gray') ############################################################################################################ plt.subplot(3, 5, 11) plt.title('{:.3f}'.format(Dice(pred_index[0], label[index][0]))) plt.axis('off') plt.imshow(pred2[index][0, ...], cmap='gray') plt.subplot(3, 5, 12) plt.title('{:.3f}'.format(Dice(pred_index[1], label[index][1]))) plt.axis('off') plt.imshow(pred2[index][1, ...], cmap='gray') plt.subplot(3, 5, 13) plt.title('{:.3f}'.format(Dice(pred_index[2], label[index][2]))) plt.axis('off') plt.imshow(pred2[index][2, ...], cmap='gray') plt.subplot(3, 5, 14) plt.title('{:.3f}'.format(Dice(pred_index[3], label[index][3]))) plt.axis('off') plt.imshow(pred2[index][3, ...], cmap='gray') plt.subplot(3, 5, 15) plt.title('{:.3f}'.format(Dice(pred_index[4], label[index][4]))) plt.axis('off') plt.imshow(pred2[index][4, ...], cmap='gray') # plt.savefig() plt.show()
def ShowResult(save_path): for case in os.listdir(case_folder): case_label = case[:case.index('_0000')] + '.nii.gz' case_path = os.path.join(case_folder, case) label_path = os.path.join(label_folder, case_label) predict_path = os.path.join(predict_folder, case_label) _, t2, _ = LoadImage(case_path) _, label, _ = LoadImage(label_path) label = ROIOneHot(label) _, prediction, _ = LoadImage(predict_path) prediction = ROIOneHot(prediction) bg_label = label[:, 0, :, :] pz_label = label[:, 1, :, :] cg_label = label[:, 2, :, :] u_label = label[:, 3, :, :] as_label = label[:, 4, :, :] bg_pred = prediction[:, 0, :, :] pz_pred = prediction[:, 1, :, :] cg_pred = prediction[:, 2, :, :] u_pred = prediction[:, 3, :, :] as_pred = prediction[:, 4, :, :] for index in range(label.shape[0]): plt.figure(figsize=(12, 4)) plt.subplot(251) plt.axis('off') plt.imshow(bg_label[index], cmap='gray') plt.subplot(252) plt.axis('off') plt.imshow(pz_label[index], cmap='gray') plt.subplot(253) plt.axis('off') plt.imshow(cg_label[index], cmap='gray') plt.subplot(254) plt.axis('off') plt.imshow(u_label[index], cmap='gray') plt.subplot(255) plt.axis('off') plt.imshow(as_label[index], cmap='gray') plt.subplot(256) plt.title('{:.3f}'.format(Dice(bg_pred[index], bg_label[index]))) plt.axis('off') plt.imshow(bg_pred[index], cmap='gray') plt.subplot(257) plt.title('{:.3f}'.format(Dice(pz_pred[index], pz_label[index]))) plt.axis('off') plt.imshow(pz_pred[index], cmap='gray') plt.subplot(258) plt.axis('off') plt.title('{:.3f}'.format(Dice(cg_pred[index], cg_label[index]))) plt.imshow(cg_pred[index], cmap='gray') plt.subplot(259) plt.axis('off') plt.title('{:.3f}'.format(Dice(u_pred[index], u_label[index]))) plt.imshow(u_pred[index], cmap='gray') plt.subplot(2, 5, 10) plt.axis('off') plt.title('{:.3f}'.format(Dice(as_pred[index], as_label[index]))) plt.imshow(as_pred[index], cmap='gray') if save_path: plt.savefig( os.path.join(save_path, '{}_{}.jpg'.format(case, index))) plt.close() else: plt.show()