def test_admm_gc_size(): for name in list(flags.FLAGS): delattr(flags.FLAGS, name) root_dir = '../dataset/ACDC-2D-All' train_dataset = MedicalImageDataset(root_dir, 'train', transform=segment_transform((128, 128)), augment=augment) val_dataset = MedicalImageDataset(root_dir, 'val', transform=segment_transform((128, 128)), augment=None) train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) AdmmGCSize.setup_arch_flags() hparams = flags.FLAGS.flag_values_dict() torchnet = get_arch('enet', {'num_classes': 2}) # torchnet.load_state_dict(torch.load('/Users/jizong/workspace/DGA1033/checkpoints/weakly/enet_fdice_0.8906.pth', # map_location=lambda storage, loc: storage)) weight = torch.Tensor([0, 1]) criterion = get_loss_fn('cross_entropy', weight=weight) test_admm = AdmmGCSize(torchnet, hparams) val_score = test_admm.evaluate(val_loader) print(val_score) for i, (img, gt, wgt, _) in enumerate(train_loader): if gt.sum() == 0 or wgt.sum() == 0: continue test_admm.reset(img) for j in range(2): test_admm.update((img, gt, wgt), criterion) if i >= 4: break
def inference(args: argparse.Namespace) -> None: ## load model assert args.dataset in ('cardiac', 'prostate') checkpoint_path = Path(args.checkpoint) assert checkpoint_path.exists(), f'Checkpoint given {args.checkpoint} does not exisit.' device: torch.device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') checkpoint = torch.load(checkpoint_path, map_location=device) ## report checkpoint: print( f">>>checkpoint {checkpoint_path} loaded. \n" f"Best epoch: {checkpoint['epoch']}, best val-2D dice: {round(checkpoint['dice'], 4)}") ## load model net: torch.nn.Module = get_arch(args.arch, {'num_classes': args.num_classes}) net.load_state_dict(checkpoint['model']) net.to(device) # net.train() net.eval() ## build dataloader root_dir = get_dataset_root(args.dataset) val_dataset = MedicalImageDataset(root_dir, 'val', transform=segment_transform((256, 256)), augment=None) val_loader = DataLoader(val_dataset, batch_size=1) val_loader = tqdm_(val_loader) dice_meter = AverageMeter() for i, (imgs, gts, wgts, paths) in enumerate(val_loader): imgs, gts, wgts = imgs.to(device), gts.to(device), wgts.to(device) pred_masks = net(imgs).max(1)[1] dice_meter.update(dice_loss(pred_masks, gts)[1], gts.shape[0]) save_images(imgs, pred_masks, gts, paths, args) val_loader.set_postfix({'val 2d-dice': dice_meter.avg}) print(f'\nrecalculated dice: {round(dice_meter.avg, 4)}')
def test_admm_size(): for name in list(flags.FLAGS): delattr(flags.FLAGS, name) root_dir = '../dataset/ACDC-2D-All' train_dataset = MedicalImageDataset(root_dir, 'train', transform=segment_transform((200, 200)), augment=augment) val_dataset = MedicalImageDataset(root_dir, 'val', transform=segment_transform((200, 200)), augment=None) train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) AdmmSize.setup_arch_flags() hparams = flags.FLAGS.flag_values_dict() arch_hparam = extract_from_big_dict(hparams, AdmmSize.arch_hparam_keys) torchnet = get_arch('enet', arch_hparam) weight = torch.Tensor([0.1, 1]) criterion = get_loss_fn('cross_entropy', weight=weight) test_admm = AdmmSize(torchnet, hparams) val_score = test_admm.evaluate(val_loader) print(val_score) for i, (img, gt, wgt, _) in enumerate(train_loader): if gt.sum() == 0 or wgt.sum() == 0: continue test_admm.reset(img) for j in range(3): test_admm.update((img, gt, wgt), criterion) if i >= 3: break
def run(argv): del argv hparams = flags.FLAGS.flag_values_dict() check_consistance(hparams) train_dataset, val_dataset = build_datasets(hparams) arch_hparams = extract_from_big_dict(hparams, AdmmGCSize.arch_hparam_keys) torchnet = get_arch(arch_hparams['arch'], arch_hparams) admm = get_method(hparams['method'], torchnet, **hparams) criterion = get_loss_fn(hparams['loss']) trainer = ADMM_Trainer(admm, [train_dataset, val_dataset], criterion, hparams) trainer.start_training()
def test_visualization(): tensorbord_dir = 'runs_test' root_dir = '../dataset/ACDC-2D-All' train_dataset = MedicalImageDataset(root_dir, 'train', transform=segment_transform( (200, 200)), augment=augment) val_dataset = MedicalImageDataset(root_dir, 'val', transform=segment_transform((200, 200)), augment=None) train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) torchnet = get_arch('enet', {'num_classes': 2}) writer = Writter_tf(tensorbord_dir, torchnet, 40) for i in range(2): writer.add_images(train_loader, i) ## clean up shutil.rmtree(tensorbord_dir)