Ejemplo n.º 1
0
def build_datasets(hparams):
    root_dir = get_dataset_root(hparams['dataroot'])
    train_dataset = MedicalImageDataset(root_dir, 'train', transform=segment_transform((256, 256)),
                                        augment=augment if hparams['data_aug'] else None, equalize=hparams['data_equ'])
    val_dataset = MedicalImageDataset(root_dir, 'val', transform=segment_transform((256, 256)), augment=None, equalize=hparams['data_equ'])

    return train_dataset, val_dataset
Ejemplo n.º 2
0
def test_dataloader():
    root_dir = '../dataset/ACDC-2D-All'
    train_dataset = MedicalImageDataset(root_dir,
                                        'train',
                                        transform=segment_transform(
                                            (500, 500)),
                                        augment=augment)
    val_dataset = MedicalImageDataset(root_dir,
                                      'val',
                                      transform=segment_transform((500, 500)),
                                      augment=None)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    for i, (Img, GT, wgt, _) in enumerate(train_loader):
        if GT.sum() <= 0 or wgt.sum() <= 0:
            continue
        show_img_mask_weakmask(Img.numpy(), GT.numpy(), wgt.numpy())

        # # ToPILImage()(Img[0]).show()
        # if i == 5:
        #     train_loader.dataset.set_mode('eval')
        # ToPILImage()(Img[0]).show()
        # if i == 10:
        #     break

    # for i, (img, gt, wgt, _) in enumerate(val_loader):
    #     ToPILImage()(img[0]).show()
    #     if i == 5:
    #         val_loader.dataset.set_mode('eval')
    #     ToPILImage()(img[0]).show()
    #     if i == 10:
    #         break
    assert train_dataset.__len__() == train_dataset.imgs.__len__()
Ejemplo n.º 3
0
def test_prostate_dataloader():
    root_dir = '../dataset/PROSTATE'
    train_dataset = MedicalImageDataset(root_dir,
                                        'train',
                                        transform=segment_transform(
                                            (128, 128)),
                                        augment=augment)
    val_dataset = MedicalImageDataset(root_dir,
                                      'val',
                                      transform=segment_transform((128, 128)),
                                      augment=None)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    # for i, (Img, GT, wgt, _) in enumerate(train_loader):
    #     ToPILImage()(Img[0]).show()
    #     if i == 5:
    #         train_loader.dataset.set_mode('eval')
    #     ToPILImage()(Img[0]).show()
    #     if i == 10:
    #         break
    #
    # for i, (img, gt, wgt, _) in enumerate(val_loader):
    #     ToPILImage()(img[0]).show()
    #     if i == 5:
    #         val_loader.dataset.set_mode('eval')
    #     ToPILImage()(img[0]).show()
    #     if i == 10:
    #         break
    assert train_dataset.__len__() == train_dataset.imgs.__len__()
Ejemplo n.º 4
0
def test_admm_gc_size():
    for name in list(flags.FLAGS):
        delattr(flags.FLAGS, name)
    root_dir = '../dataset/ACDC-2D-All'
    train_dataset = MedicalImageDataset(root_dir, 'train', transform=segment_transform((128, 128)), augment=augment)
    val_dataset = MedicalImageDataset(root_dir, 'val', transform=segment_transform((128, 128)), augment=None)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    AdmmGCSize.setup_arch_flags()
    hparams = flags.FLAGS.flag_values_dict()
    torchnet = get_arch('enet', {'num_classes': 2})
    # torchnet.load_state_dict(torch.load('/Users/jizong/workspace/DGA1033/checkpoints/weakly/enet_fdice_0.8906.pth',
    # map_location=lambda storage, loc: storage))

    weight = torch.Tensor([0, 1])
    criterion = get_loss_fn('cross_entropy', weight=weight)

    test_admm = AdmmGCSize(torchnet, hparams)

    val_score = test_admm.evaluate(val_loader)
    print(val_score)

    for i, (img, gt, wgt, _) in enumerate(train_loader):
        if gt.sum() == 0 or wgt.sum() == 0:
            continue
        test_admm.reset(img)
        for j in range(2):
            test_admm.update((img, gt, wgt), criterion)
        if i >= 4:
            break
Ejemplo n.º 5
0
def test_admm_size():
    for name in list(flags.FLAGS):
        delattr(flags.FLAGS, name)
    root_dir = '../dataset/ACDC-2D-All'
    train_dataset = MedicalImageDataset(root_dir, 'train', transform=segment_transform((200, 200)), augment=augment)
    val_dataset = MedicalImageDataset(root_dir, 'val', transform=segment_transform((200, 200)), augment=None)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    AdmmSize.setup_arch_flags()
    hparams = flags.FLAGS.flag_values_dict()
    arch_hparam = extract_from_big_dict(hparams, AdmmSize.arch_hparam_keys)
    torchnet = get_arch('enet', arch_hparam)
    weight = torch.Tensor([0.1, 1])
    criterion = get_loss_fn('cross_entropy', weight=weight)

    test_admm = AdmmSize(torchnet, hparams)

    val_score = test_admm.evaluate(val_loader)
    print(val_score)

    for i, (img, gt, wgt, _) in enumerate(train_loader):
        if gt.sum() == 0 or wgt.sum() == 0:
            continue
        test_admm.reset(img)
        for j in range(3):
            test_admm.update((img, gt, wgt), criterion)
        if i >= 3:
            break
Ejemplo n.º 6
0
def inference(args: argparse.Namespace) -> None:
    ## load model
    assert args.dataset in ('cardiac', 'prostate')
    checkpoint_path = Path(args.checkpoint)
    assert checkpoint_path.exists(), f'Checkpoint given {args.checkpoint} does not exisit.'
    device: torch.device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')
    checkpoint = torch.load(checkpoint_path, map_location=device)
    ## report checkpoint:
    print(
        f">>>checkpoint {checkpoint_path} loaded. \n"
        f"Best epoch: {checkpoint['epoch']}, best val-2D dice: {round(checkpoint['dice'], 4)}")

    ## load model
    net: torch.nn.Module = get_arch(args.arch, {'num_classes': args.num_classes})
    net.load_state_dict(checkpoint['model'])
    net.to(device)
    # net.train()
    net.eval()

    ## build dataloader
    root_dir = get_dataset_root(args.dataset)
    val_dataset = MedicalImageDataset(root_dir, 'val', transform=segment_transform((256, 256)), augment=None)
    val_loader = DataLoader(val_dataset, batch_size=1)

    val_loader = tqdm_(val_loader)
    dice_meter = AverageMeter()
    for i, (imgs, gts, wgts, paths) in enumerate(val_loader):
        imgs, gts, wgts = imgs.to(device), gts.to(device), wgts.to(device)
        pred_masks = net(imgs).max(1)[1]
        dice_meter.update(dice_loss(pred_masks, gts)[1], gts.shape[0])
        save_images(imgs, pred_masks, gts, paths, args)

        val_loader.set_postfix({'val 2d-dice': dice_meter.avg})
    print(f'\nrecalculated dice: {round(dice_meter.avg, 4)}')
Ejemplo n.º 7
0
def build_datasets(dataset_name, foldername, equalize):
    root = get_dataset_root(dataset_name)
    trainset = MedicalImageDataset(root,
                                   'train',
                                   transform=segment_transform((256, 256)),
                                   augment=None,
                                   foldername=foldername,
                                   equalize=equalize)
    trainLoader = DataLoader(trainset, batch_size=1)
    return trainLoader, None
Ejemplo n.º 8
0
def test_visualization():
    tensorbord_dir = 'runs_test'
    root_dir = '../dataset/ACDC-2D-All'
    train_dataset = MedicalImageDataset(root_dir,
                                        'train',
                                        transform=segment_transform(
                                            (200, 200)),
                                        augment=augment)
    val_dataset = MedicalImageDataset(root_dir,
                                      'val',
                                      transform=segment_transform((200, 200)),
                                      augment=None)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    torchnet = get_arch('enet', {'num_classes': 2})
    writer = Writter_tf(tensorbord_dir, torchnet, 40)

    for i in range(2):
        writer.add_images(train_loader, i)

    ## clean up
    shutil.rmtree(tensorbord_dir)