示例#1
0
def test_3D(model, device, args, num_fold=0):
    # 导入模型, 选取每一折的最优模型
    if os.path.exists(args.val_result_file):
        with open(args.val_result_file, "r") as f:
            reader = csv.reader(f)
            val_result = list(reader)
        best_epoch = utils.best_model_in_fold(val_result, num_fold)
    else:
        best_epoch = args.num_epochs

    # 导入模型
    model_list = os.listdir(args.checkpoint_dir[num_fold])
    model_dir = [x for x in model_list if str(best_epoch) in x][0]
    model_dir = os.path.join(args.checkpoint_dir[num_fold], model_dir)
    if not os.path.exists(model_dir):
        model_dir = os.path.join(args.checkpoint_dir[num_fold],
                                 f'CP_epoch{best_epoch}.pth')
    model.load_state_dict(torch.load(model_dir, map_location=device))
    print(f'\rtest model loaded: [fold:{num_fold}] [best_epoch:{best_epoch}]')

    data_root = args.val_data_root if hasattr(
        args, "val_data_root") else args.data_root
    target_root = args.val_target_root if hasattr(
        args, "val_target_root") else args.data_root
    val_image_root = os.path.join(data_root, f"f{num_fold}")
    val_mask_root = os.path.join(target_root, f"f{num_fold}")
    volumes = sorted(os.listdir(val_image_root))
    all_dice = []
    all_iou = []
    all_acc = []
    all_sen = []
    all_spe = []
    model.eval()
    with torch.no_grad():
        with tqdm(total=len(volumes), desc=f'VAL', unit='img') as pbar:
            for volume in volumes:
                volume_img_path = os.path.join(val_image_root, volume)
                volume_mask_path = os.path.join(val_mask_root, volume)
                v_dataset = volume_Dataset(volume_img_path, volume_mask_path,
                                           args.crop_size)
                v_dataloader = DataLoader(v_dataset,
                                          batch_size=1,
                                          shuffle=False,
                                          num_workers=args.num_workers,
                                          pin_memory=True,
                                          drop_last=False)
                volume_pred = []
                volume_label = []
                for batch in v_dataloader:
                    image = batch["image"]
                    label = batch["label"]
                    image = image.to(device, dtype=torch.float32)
                    label = label.to(device, dtype=torch.long)

                    outputs = model(image)
                    main_out = outputs["main_out"]
                    main_out = torch.exp(main_out).max(dim=1)[1]  # 阈值处理
                    volume_pred.append(main_out)
                    volume_label.append(label)
                volume_pred = torch.cat(volume_pred, dim=0)
                volume_label = torch.cat(volume_label, dim=0)

                hist = utils.fast_hist(volume_label, volume_pred, args.n_class)
                dice, iou, acc, Sensitivity, Specificity = utils.cal_scores(
                    hist.cpu().numpy(), drop_non=args.drop_non)
                all_dice.append(list(dice))
                all_iou.append(list(iou))
                all_acc.append([acc])
                all_sen.append(list(Sensitivity))
                all_spe.append(list(Specificity))

                # 写入每个测试数据的指标
                test_result = [volume, dice.mean()]+list(dice)+[iou.mean()]+list(iou)+[acc] + \
                    [Sensitivity.mean()]+list(Sensitivity)+[Specificity.mean()]+list(Specificity)
                with open(args.test_result_file, "a") as f:
                    w = csv.writer(f)
                    w.writerow(test_result)

                if args.plot:
                    volume_pred = sitk.GetImageFromArray(
                        np.array(volume_pred.cpu().numpy(), np.int8))
                    sitk.WriteImage(
                        volume_pred,
                        os.path.join(args.plot_save_dir,
                                     volume + f'_{dice.mean():.2f}.nii.gz'))

                pbar.update(image.size()[0])

    print(f"\r---------Fold {num_fold} Test Result---------")
    if args.drop_non:
        print(
            f'mDice: {np.mean(np.array(all_dice).sum(axis=0)/np.sum(np.array(all_dice) > 0, axis=0))}'
        )
        print(
            f'mIoU:  {np.mean(np.array(all_iou).sum(axis=0)/np.sum(np.array(all_iou) > 0, axis=0))}'
        )
        print(f'mAcc:  {np.array(all_acc).mean()}')
        print(
            f'mSens: {np.mean(np.array(all_sen).sum(axis=0)/np.sum(np.array(all_sen) > 0, axis=0))}'
        )
        print(
            f'mSpec: {np.mean(np.array(all_spe).sum(axis=0)/np.sum(np.array(all_spe) > 0, axis=0))}'
        )
    else:
        print(f'mDice: {np.array(all_dice).mean()}')
        print(f'mIoU:  {np.array(all_iou).mean()}')
        print(f'mAcc:  {np.array(all_acc).mean()}')
        print(f'mSens: {np.array(all_sen).mean()}')
        print(f'mSpec: {np.array(all_spe).mean()}')

    if num_fold == 0:
        utils.save_print_score(all_dice,
                               all_iou,
                               all_acc,
                               all_sen,
                               all_spe,
                               args.test_result_file,
                               args.label_names,
                               drop_non=args.drop_non)
        return

    return all_dice, all_iou, all_acc, all_sen, all_spe
示例#2
0
            print(
                "###################### Train & Test Start ######################"
            )

        if mode == "train" or mode == "train_test":
            args.mode = "train"
            print("###################### Train Start ######################")
            for i in range(args.start_fold, args.end_fold):
                torch.cuda.empty_cache()
                main(args, num_fold=i + 1)

        if mode == "test" or mode == "train_test":
            args.mode = "test"
            print("###################### Test Start ######################")
            all_dice = []
            all_iou = []
            all_acc = []
            all_sen = []
            all_spe = []
            for i in range(args.start_fold, args.end_fold):
                Dice, IoU, Acc, Sensitivity, Specificity = main(args,
                                                                num_fold=i + 1)
                all_dice += Dice
                all_iou += IoU
                all_acc += Acc
                all_sen += Sensitivity
                all_spe += Specificity
            utils.save_print_score(all_dice, all_iou, all_acc, all_sen,
                                   all_spe, args.test_result_file,
                                   args.label_names)
示例#3
0
def test(model, device, args, num_fold=0):
    # 导入模型, 选取每一折的最优模型
    if os.path.exists(args.val_result_file):
        with open(args.val_result_file, "r") as f:
            reader = csv.reader(f)
            val_result = list(reader)
        best_epoch = utils.best_model_in_fold(val_result, num_fold)
    else:
        best_epoch = args.num_epochs

    # 导入模型
    model_list = os.listdir(args.checkpoint_dir[num_fold])
    model_dir = [x for x in model_list if str(best_epoch) in x][0]
    model_dir = os.path.join(args.checkpoint_dir[num_fold], model_dir)
    if not os.path.exists(model_dir):
        model_dir = os.path.join(args.checkpoint_dir[num_fold],
                                 f'CP_epoch{best_epoch}.pth')
    model.load_state_dict(torch.load(model_dir, map_location=device))
    print(f'\rtest model loaded: [fold:{num_fold}] [best_epoch:{best_epoch}]')

    dataset_test = myDataset(args.data_root,
                             args.target_root,
                             args.crop_size,
                             "test",
                             k_fold=args.k_fold,
                             imagefile_csv=args.dataset_file_list,
                             num_fold=num_fold)
    dataloader = DataLoader(dataset_test,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.num_workers,
                            pin_memory=True)

    all_dice = []
    all_iou = []
    all_acc = []
    all_sen = []
    all_spe = []
    model.eval()
    with torch.no_grad():
        with tqdm(total=len(dataset_test),
                  desc=f'TEST fold {num_fold}/{args.k_fold}',
                  unit='img') as pbar:
            for batch in dataloader:
                image = batch["image"]
                label = batch["label"]
                file = batch["file"]
                assert len(image.size()) == 4
                assert len(label.size()) == 3
                image = image.to(device, dtype=torch.float32)
                label = label.to(device, dtype=torch.long)

                outputs = model(image)
                pred = outputs["main_out"]
                pred = torch.exp(pred).max(dim=1)[1]  # 阈值处理

                for b in range(image.size()[0]):
                    hist = utils.fast_hist(label[b, :, :], pred[b, :, :],
                                           args.n_class)
                    dice, iou, acc, Sensitivity, Specificity = utils.cal_scores(
                        hist.cpu().numpy(), smooth=0.01)

                    # 写入每个测试数据的指标
                    test_result = [file[b], dice.mean()]+list(dice)+[iou.mean()]+list(iou)+[acc] + \
                        [Sensitivity.mean()]+list(Sensitivity)+[Specificity.mean()]+list(Specificity)
                    with open(args.test_result_file, "a") as f:
                        w = csv.writer(f)
                        w.writerow(test_result)

                    all_dice.append(list(dice))
                    all_iou.append(list(iou))
                    all_acc.append([acc])
                    all_sen.append(list(Sensitivity))
                    all_spe.append(list(Specificity))
                    if args.plot:
                        file_name, _ = os.path.splitext(file[b])
                        # save_image(pred[b,:,:].cpu().float().unsqueeze(0), os.path.join(args.plot_save_dir, file_name + f"_pred_{dice.mean():.2f}.png"), normalize=True)
                        # save_image(image[b,:,:].cpu(), os.path.join(args.plot_save_dir, file[b]))
                        # save_image(label[b,:,:].cpu().float().unsqueeze(0), os.path.join(args.plot_save_dir, file_name + f"_label.png"), normalize=True)
                        if "A4" in outputs.keys():
                            for i in range(0, 25, 5):
                                proj_map = F.interpolate(
                                    outputs["A1"][b, ...].unsqueeze(0),
                                    size=image.size()[-2:],
                                    mode="bilinear",
                                    align_corners=True).squeeze(0)
                                save_image(proj_map[i, :, :],
                                           os.path.join(
                                               args.plot_save_dir,
                                               file_name + f"_A1_{i}.png"),
                                           normalize=True)
                            for i in range(0, 25, 5):
                                proj_map = F.interpolate(
                                    outputs["A2"][b, ...].unsqueeze(0),
                                    size=image.size()[-2:],
                                    mode="bilinear",
                                    align_corners=True).squeeze(0)
                                save_image(proj_map[i, :, :],
                                           os.path.join(
                                               args.plot_save_dir,
                                               file_name + f"_A2_{i}.png"),
                                           normalize=True)
                            for i in range(0, 25, 5):
                                proj_map = F.interpolate(
                                    outputs["A3"][b, ...].unsqueeze(0),
                                    size=image.size()[-2:],
                                    mode="bilinear",
                                    align_corners=True).squeeze(0)
                                save_image(proj_map[i, :, :],
                                           os.path.join(
                                               args.plot_save_dir,
                                               file_name + f"_A3_{i}.png"),
                                           normalize=True)
                            for i in range(0, 25, 5):
                                proj_map = F.interpolate(
                                    outputs["A4"][b, ...].unsqueeze(0),
                                    size=image.size()[-2:],
                                    mode="bilinear",
                                    align_corners=True).squeeze(0)
                                save_image(proj_map[i, :, :],
                                           os.path.join(
                                               args.plot_save_dir,
                                               file_name + f"_A4_{i}.png"),
                                           normalize=True)
                        # if "x_proj_1" in outputs.keys():
                        #     for i in range(7):
                        #         proj_map = F.interpolate(outputs["x_proj_1"][b, ...].unsqueeze(0), size=image.size()[-2:],
                        #                               mode="bilinear", align_corners=True).squeeze(0)
                        #         save_image(proj_map[i, :, :], os.path.join(args.plot_save_dir, file_name + f"_A1_{i}.png"),
                        #                    normalize=True)
                        #     for i in range(7):
                        #         proj_map = F.interpolate(outputs["x_proj_2"][b, ...].unsqueeze(0), size=image.size()[-2:],
                        #                               mode="bilinear", align_corners=True).squeeze(0)
                        #         save_image(proj_map[i, :, :], os.path.join(args.plot_save_dir, file_name + f"_A2_{i}.png"),
                        #                    normalize=True)
                        save_image(image[b, :, :].cpu(),
                                   os.path.join(args.plot_save_dir, file[b]))
                        pred_image = pred[b, :, :].unsqueeze(0)
                        true_mask = label[b, :, :].unsqueeze(0)
                        result_image = torch.cat(
                            (pred_image, true_mask,
                             torch.zeros_like(pred_image)),
                            dim=0).permute(1, 2, 0).cpu().numpy()
                        result_image = Image.fromarray(
                            np.uint8(result_image) * 255)
                        result_image.save(
                            args.plot_save_dir +
                            f"/{file_name}({dice.mean():.2f}).png")
                pbar.update(image.size()[0])

    print(f"\r---------Fold {num_fold} Test Result---------")
    print(f'mDice: {np.array(all_dice).mean()}')
    print(f'mIoU:  {np.array(all_iou).mean()}')
    print(f'mAcc:  {np.array(all_acc).mean()}')
    print(f'mSens: {np.array(all_sen).mean()}')
    print(f'mSpec: {np.array(all_spe).mean()}')

    if num_fold == 0:
        utils.save_print_score(all_dice, all_iou, all_acc, all_sen, all_spe,
                               args.test_result_file, args.label_names)
        return

    return all_dice, all_iou, all_acc, all_sen, all_spe
示例#4
0
def test(model, device, args, num_fold=0):
    # 导入模型, 选取每一折的最优模型
    if os.path.exists(args.val_result_file):
        with open(args.val_result_file, "r") as f:
            reader = csv.reader(f)
            val_result = list(reader)
        best_epoch = utils.best_model_in_fold(val_result, num_fold)
    else:
        best_epoch = args.num_epochs
    # 导入模型
    model_dir = os.path.join(args.checkpoint_dir[num_fold], f'CP_epoch{best_epoch}.pth')
    model.load_state_dict(torch.load(model_dir, map_location=device))
    print(f'\rtest model loaded: [fold:{num_fold}] [best_epoch:{best_epoch}]')

    dataset_test = myDataset(args.data_root, args.target_root, args.crop_size, "test",
                                k_fold=args.k_fold, imagefile_csv=args.dataset_file_list, num_fold=num_fold)
    dataloader = DataLoader(dataset_test, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)

    all_dice = []
    all_iou = []
    all_acc = []
    all_sen = []
    all_spe = []
    model.eval()
    with torch.no_grad():
        with tqdm(total=len(dataset_test), desc=f'TEST fold {num_fold}/{args.k_fold}', unit='img') as pbar:
            for batch in dataloader:
                image = batch["image"]
                label = batch["label"]
                file = batch["file"]
                assert len(image.size()) == 4
                assert len(label.size()) == 3
                image = image.to(device, dtype=torch.float32)
                label = label.to(device, dtype=torch.long)

                outputs = model(image)
                pred = outputs["main_out"]
                pred = torch.exp(pred).max(dim=1)[1]  # 阈值处理

                for b in range(image.size()[0]):
                    hist = utils.fast_hist(label[b,:,:], pred[b,:,:], args.n_class)
                    dice, iou, acc, Sensitivity, Specificity = utils.cal_scores(hist.cpu().numpy(), smooth=0.01)

                    # 写入每个测试数据的指标
                    test_result = [file[b], dice.mean()]+list(dice)+[iou.mean()]+list(iou)+[acc] + \
                        [Sensitivity.mean()]+list(Sensitivity)+[Specificity.mean()]+list(Specificity)
                    with open(args.test_result_file, "a") as f:
                        w = csv.writer(f)
                        w.writerow(test_result)

                    all_dice.append(list(dice))
                    all_iou.append(list(iou))
                    all_acc.append([acc])
                    all_sen.append(list(Sensitivity))
                    all_spe.append(list(Specificity))
                    if args.plot:
                        file_name, _ = os.path.splitext(file[b])
                        save_image(pred[b,:,:].cpu().float().unsqueeze(0), os.path.join(args.plot_save_dir, file_name + f"_pred_{dice.mean():.2f}.png"), normalize=True)
                        save_image(image[b,:,:].cpu(), os.path.join(args.plot_save_dir, file[b]))
                        save_image(label[b,:,:].cpu().float().unsqueeze(0), os.path.join(args.plot_save_dir, file_name + f"_label.png"), normalize=True)
                pbar.update(image.size()[0])

    print(f"\r---------Fold {num_fold} Test Result---------")
    print(f'mDice: {np.array(all_dice).mean()}')
    print(f'mIoU:  {np.array(all_iou).mean()}')
    print(f'mAcc:  {np.array(all_acc).mean()}')
    print(f'mSens: {np.array(all_sen).mean()}')
    print(f'mSpec: {np.array(all_spe).mean()}')

    if num_fold == 0:
        utils.save_print_score(all_dice, all_iou, all_acc, all_sen, all_spe, args.test_result_file, args.label_names)
        return

    return all_dice, all_iou, all_acc, all_sen, all_spe