Exemple #1
0
def main(args):
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    net = UNet(n_channels=3, n_classes=1)

    checkpoint = flow.load(args.pretrained_path)
    net.load_state_dict(checkpoint)

    net.to("cuda")

    x_test_dir, y_test_dir = get_datadir_path(args, split="test")

    test_dataset = Dataset(
        x_test_dir, y_test_dir, augmentation=get_test_augmentation(),
    )

    print("Begin Testing...")
    for i, (image, mask) in enumerate(tqdm(test_dataset)):
        show_image = image
        with flow.no_grad():
            image = image / 255.0
            image = image.astype(np.float32)
            image = flow.tensor(image, dtype=flow.float32)
            image = image.permute(2, 0, 1)
            image = image.to("cuda")

            pred = net(image.unsqueeze(0).to("cuda"))
            pred = pred.numpy()
            pred = pred > 0.5
        save_picture_name = os.path.join(args.save_path, "test_image_" + str(i))
        visualize(
            save_picture_name, image=show_image, GT=mask[0, :, :], Pred=pred[0, 0, :, :]
        )
Exemple #2
0
def run(criterion, net_name, use_dataset_index, use_mask, save_img_pred: bool):
    path_data = config.path_data[use_dataset_index]

    # 文件路径
    path_data_root = path_data["dataset"]
    path_checkpoints = path_data["checkpoints"]
    path_result = os.path.join(path_data["result"], net_name)
    os.makedirs(path_result, exist_ok=True)

    net = UNet().to(config.device)
    for param in net.parameters():
        param.requires_grad = False

    path_pert = "/home/pengzx/deepLearning/result/Glaucoma/UNet/pert/"

    iou_total = []
    for index in config.checkpoints_indexes:
        path_pert_save = os.path.join(path_pert, str(index))

        attacked_dataset = MyAttackedDataset(path_data_root=path_data_root,
                                             phase="train",
                                             path_pert=path_pert_save,
                                             transform_list=transform_compose)
        attacked_data_loader = DataLoader(attacked_dataset,
                                          batch_size=1,
                                          shuffle=False)

        for index_2 in config.checkpoints_indexes:
            net.load_state_dict(
                torch.load(os.path.join(path_checkpoints,
                                        "{}_{}.pth").format(net_name, index_2),
                           map_location=config.device))

            iou_list = []
            for i, (img, label, mask, pert,
                    name) in enumerate(attacked_data_loader):
                img, pert = img.to(config.device), pert.to(config.device)

                img_pert = img.clone() + pert

                pred = net(img_pert)
                pred[pred > 0] = 1.
                pred[pred < 1.] = 0.

                label[label > 0.5] = 1.
                label[label < 1.] = 0.
                iou = get_iou(pred[0].data.cpu().numpy(),
                              label.data.cpu().numpy())
                iou_list.append(iou)
            iou_list = np.array(iou_list)
            print("模型[{}]产生的扰动,模型[{}]的预测结果iou={}", index, index_2,
                  iou_list.mean())
            iou_total.append(iou_list)
    iou_total = np.array(iou_total)
    np.save("./iou_total.npy", iou_total)
Exemple #3
0
def main():
    # 参数
    args = get_args()
    if not osp.exists(args.result_dir):
        os.makedirs(args.result_dir)
    print("Evaluating configuration:")
    for arg in vars(args):
        print("{}:\t{}".format(arg, getattr(args, arg)))
    with open('eval-config.json', 'w') as f:
        json.dump(args.__dict__, f, indent=4)
    # 数据
    if args.test:
        dataset = SpineDataset(root=args.root,
                               split='test',
                               transform=test_transform)
    else:
        dataset = SpineDataset(root=args.root,
                               split='val',
                               transform=val_transform)

    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True)

    # 模型
    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    if args.network == 'DeepLab':
        model = gcv.models.DeepLabV3(nclass=args.num_classes,
                                     backbone=args.backbone)
    elif args.network == 'FCN':
        model = gcv.models.FCN(nclass=args.num_classes, backbone=args.backbone)
    elif args.network == 'PSPNet':
        model = gcv.models.PSP(nclass=args.num_classes, backbone=args.backbone)
    elif args.network == 'UNet':
        model = UNet(n_class=args.num_classes, backbone=args.backbone)
    print('load model from {} ...'.format(args.model))
    model.load_state_dict(
        torch.load(args.model, map_location='cpu')['state_dict'])
    model = model.to(device)
    print('Done!')

    # 测试
    def eval():
        with torch.no_grad():
            model.eval()
            result = []
            tq = tqdm.tqdm(total=len(dataloader))
            if args.test:
                tq.set_description('test')
                for i, (data, img_file) in enumerate(dataloader):
                    tq.update(1)
                    data = data.to(device)
                    predict = np.zeros(
                        (data.size()[1], data.size()[3], data.size()[4]),
                        dtype=np.uint16)
                    for idx in range(data.size()[1]):
                        if args.network in ['DeepLab', 'FCN', 'PSPNet']:
                            final_out = model(data[:, idx])[0]
                        elif args.network == 'UNet':
                            final_out = model(data[:, idx])
                        predict[idx] = final_out.argmax(
                            dim=1).cpu().squeeze().numpy().astype(np.uint16)
                    pred_img = sitk.GetImageFromArray(predict)
                    test_img = sitk.ReadImage(
                        osp.join(args.root, 'test', 'image', img_file[0]))
                    pred_img.CopyInformation(test_img)
                    result_file = 'mask_' + img_file[0].lower()
                    sitk.WriteImage(pred_img,
                                    osp.join(args.result_dir, result_file))
            else:
                tq.set_description('val')
                for i, (data, mask, mask_file) in enumerate(dataloader):
                    tq.update(1)
                    gt_img = sitk.ReadImage(
                        osp.join(args.root, 'val', 'groundtruth',
                                 mask_file[0]))
                    data = data.to(device)
                    predict = np.zeros(
                        (data.size()[1], data.size()[3], data.size()[4]),
                        dtype=np.uint16)
                    for idx in range(data.size()[1]):
                        if args.network in ['DeepLab', 'FCN', 'PSPNet']:
                            final_out = model(data[:, idx])[0]
                        elif args.network == 'UNet':
                            final_out = model(data[:, idx])
                        predict[idx] = final_out.argmax(
                            dim=1).cpu().squeeze().numpy().astype(np.uint16)
                    pred_img = sitk.GetImageFromArray(predict)
                    pred_img.CopyInformation(gt_img)
                    sitk.WriteImage(pred_img,
                                    osp.join(args.result_dir, mask_file[0]))
                    ppv, sensitivity, dice, _ = metrics.precision_recall_fscore_support(
                        mask.numpy().flatten(),
                        predict.flatten(),
                        average='binary')
                    result.append([dice, ppv, sensitivity])
                result = np.array(result)
                result_mean = result.mean(axis=0)
                result_std = result.std(axis=0)
                print(result_mean, result_std)
                np.savetxt(osp.join(args.result_dir, 'result.txt'),
                           result_mean,
                           fmt='%.3f',
                           header='Dice, Sensitivity, PPV')

            tq.close()

    eval()