def test_admm_size(): for name in list(flags.FLAGS): delattr(flags.FLAGS, name) root_dir = '../dataset/ACDC-2D-All' train_dataset = MedicalImageDataset(root_dir, 'train', transform=segment_transform((200, 200)), augment=augment) val_dataset = MedicalImageDataset(root_dir, 'val', transform=segment_transform((200, 200)), augment=None) train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) AdmmSize.setup_arch_flags() hparams = flags.FLAGS.flag_values_dict() arch_hparam = extract_from_big_dict(hparams, AdmmSize.arch_hparam_keys) torchnet = get_arch('enet', arch_hparam) weight = torch.Tensor([0.1, 1]) criterion = get_loss_fn('cross_entropy', weight=weight) test_admm = AdmmSize(torchnet, hparams) val_score = test_admm.evaluate(val_loader) print(val_score) for i, (img, gt, wgt, _) in enumerate(train_loader): if gt.sum() == 0 or wgt.sum() == 0: continue test_admm.reset(img) for j in range(3): test_admm.update((img, gt, wgt), criterion) if i >= 3: break
def test_admm_gc_size(): for name in list(flags.FLAGS): delattr(flags.FLAGS, name) root_dir = '../dataset/ACDC-2D-All' train_dataset = MedicalImageDataset(root_dir, 'train', transform=segment_transform((128, 128)), augment=augment) val_dataset = MedicalImageDataset(root_dir, 'val', transform=segment_transform((128, 128)), augment=None) train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) AdmmGCSize.setup_arch_flags() hparams = flags.FLAGS.flag_values_dict() torchnet = get_arch('enet', {'num_classes': 2}) # torchnet.load_state_dict(torch.load('/Users/jizong/workspace/DGA1033/checkpoints/weakly/enet_fdice_0.8906.pth', # map_location=lambda storage, loc: storage)) weight = torch.Tensor([0, 1]) criterion = get_loss_fn('cross_entropy', weight=weight) test_admm = AdmmGCSize(torchnet, hparams) val_score = test_admm.evaluate(val_loader) print(val_score) for i, (img, gt, wgt, _) in enumerate(train_loader): if gt.sum() == 0 or wgt.sum() == 0: continue test_admm.reset(img) for j in range(2): test_admm.update((img, gt, wgt), criterion) if i >= 4: break
def test_prostate_dataloader(): root_dir = '../dataset/PROSTATE' train_dataset = MedicalImageDataset(root_dir, 'train', transform=segment_transform( (128, 128)), augment=augment) val_dataset = MedicalImageDataset(root_dir, 'val', transform=segment_transform((128, 128)), augment=None) train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) # for i, (Img, GT, wgt, _) in enumerate(train_loader): # ToPILImage()(Img[0]).show() # if i == 5: # train_loader.dataset.set_mode('eval') # ToPILImage()(Img[0]).show() # if i == 10: # break # # for i, (img, gt, wgt, _) in enumerate(val_loader): # ToPILImage()(img[0]).show() # if i == 5: # val_loader.dataset.set_mode('eval') # ToPILImage()(img[0]).show() # if i == 10: # break assert train_dataset.__len__() == train_dataset.imgs.__len__()
def test_dataloader(): root_dir = '../dataset/ACDC-2D-All' train_dataset = MedicalImageDataset(root_dir, 'train', transform=segment_transform( (500, 500)), augment=augment) val_dataset = MedicalImageDataset(root_dir, 'val', transform=segment_transform((500, 500)), augment=None) train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) for i, (Img, GT, wgt, _) in enumerate(train_loader): if GT.sum() <= 0 or wgt.sum() <= 0: continue show_img_mask_weakmask(Img.numpy(), GT.numpy(), wgt.numpy()) # # ToPILImage()(Img[0]).show() # if i == 5: # train_loader.dataset.set_mode('eval') # ToPILImage()(Img[0]).show() # if i == 10: # break # for i, (img, gt, wgt, _) in enumerate(val_loader): # ToPILImage()(img[0]).show() # if i == 5: # val_loader.dataset.set_mode('eval') # ToPILImage()(img[0]).show() # if i == 10: # break assert train_dataset.__len__() == train_dataset.imgs.__len__()
def test_visualization(): tensorbord_dir = 'runs_test' root_dir = '../dataset/ACDC-2D-All' train_dataset = MedicalImageDataset(root_dir, 'train', transform=segment_transform((200, 200)), augment=augment) val_dataset = MedicalImageDataset(root_dir, 'val', transform=segment_transform((200, 200)), augment=None) train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) torchnet = get_arch('enet', {'num_classes': 2}) writer = Writter_tf(tensorbord_dir, torchnet, 40) for i in range(2): writer.add_images(train_loader, i) ## clean up shutil.rmtree(tensorbord_dir)
def test_iter(): dataroot: str = '../dataset/ACDC-all' train_set = MedicalImageDataset( dataroot, 'train', subfolders=['img', 'gt'], transform=segment_transform((256, 256)), augment=PILaugment, equalize=None, pin_memory=False, metainfo="[classSizeCalulator, {'C':4,'foldernames':['gt','gt2']}]") train_loader_1 = DataLoader(train_set, batch_size=10, num_workers=8) train_loader_2 = DataLoader(train_set, batch_size=10, num_workers=8) n_batches_1 = len(train_loader_1) n_batches_2 = len(train_loader_2) assert n_batches_1 == n_batches_2 train_loader_1_ = iterator_(train_loader_1) train_loader_2_ = iterator_(train_loader_2) output_list1 = [] output_list2 = [] for i in range(n_batches_2 + 1): data1 = train_loader_1_.__next__() data2 = train_loader_2_.__next__() output_list1.extend(data1[2]) output_list2.extend(data2[2]) assert set(output_list1) == set( [Path(x).stem for x in train_loader_1.dataset.filenames['img']]) assert set(output_list2) == set( [Path(x).stem for x in train_loader_2.dataset.filenames['img']])
def test_dataset(): dataroot: str = '../dataset/ACDC-all' train_set = MedicalImageDataset( dataroot, 'train', subfolders=['img', 'gt'], transform=segment_transform((256, 256)), augment=PILaugment, equalize=None, pin_memory=False, metainfo="[classSizeCalulator, {'C':4,'foldernames':['gt','gt2']}]") train_loader = DataLoader(train_set, batch_size=10, num_workers=8) n_batches = len(train_loader) for i, (imgs, metainfo, filenames) in enumerate(train_loader): print(imgs) print(metainfo) print(filenames) time.sleep(2)