def check_livecell_images(): patch_shape = (512, 512) loader = get_livecell_loader(PATH, patch_shape, "train", download=True, batch_size=1) check_loader(loader, 10, instance_labels=True)
def check(args, train=True, val=True, n_images=2): from torch_em.util.debug import check_loader if train: print("Check train loader") loader = get_loader(args.input, is_train=True, n_samples=100) check_loader(loader, n_images) if val: print("Check val loader") loader = get_loader(args.input, is_train=False, n_samples=100) check_loader(loader, n_images)
def check(args, train=True, val=True, n_images=2): from torch_em.util.debug import check_loader patch_shape = [32, 320, 320] if train: print("Check train loader") loader = get_loader(args.input, True, patch_shape) check_loader(loader, n_images) if val: print("Check val loader") loader = get_loader(args.input, False, patch_shape) check_loader(loader, n_images)
def check(input_path, samples, train=True, val=True, n_images=5): from torch_em.util.debug import check_loader patch_shape = [32, 256, 256] if train: print("Check train loader") loader = get_loader(input_path, samples, splits=["train"], patch_shape=patch_shape) check_loader(loader, n_images) if val: print("Check val loader") loader = get_loader(input_path, samples, splits=["val"], patch_shape=patch_shape) check_loader(loader, n_images)
def check(train=True, val=True, n_images=5): from torch_em.util.debug import check_loader patch_shape = [1, 512, 512] if train: print("Check train loader") loader = get_loader('train', patch_shape, batch_size=1) check_loader(loader, n_images) if val: print("Check val loader") loader = get_loader('val', patch_shape, batch_size=1) check_loader(loader, n_images)
def check(args, train=True, val=True, n_images=2): from torch_em.util.debug import check_loader patch_shape = [32, 256, 256] samples, _ = normalize_samples(args.samples) if train: print("Check train loader") loader = get_loader(args.input, samples, True, patch_shape) check_loader(loader, n_images) if val: print("Check val loader") loader = get_loader(args.input, samples, False, patch_shape) check_loader(loader, n_images)
def check(datasets, train=True, val=True, n_images=5): from torch_em.util.debug import check_loader patch_shape = [32, 256, 256] if train: print("Check train loader") dsets = [f'{ds}_train' for ds in datasets] loader = get_loader(dsets, patch_shape) check_loader(loader, n_images) if val: print("Check val loader") dsets = [f'{ds}_val' for ds in datasets] loader = get_loader(dsets, patch_shape) check_loader(loader, n_images)
def check_loader(args, n=4): from torch_em.util.debug import check_loader loader = get_isbi_loader(args, "train", "./checkpoints/isbi2d/rfs") check_loader(loader, n)
trainer = torch_em.default_segmentation_trainer(name, model, train_loader, val_loader, loss=dice_loss, metric=dice_loss, learning_rate=1.0e-4, device=args.device, log_image_interval=50) trainer.fit(args.n_iterations) def check_loader(args, n=4): from torch_em.util.debug import check_loader loader = get_isbi_loader(args, "train", "./checkpoints/isbi2d/rfs") check_loader(loader, n) if __name__ == "__main__": parser = torch_em.util.parser_helper() parser.add_argument("-p", "--pseudo_label", type=int, default=0) parser.add_argument("--n_rfs", type=int, default=500) parser.add_argument("--n_threads", type=int, default=32) args = parser.parse_args() if args.check: check_loader(args) elif args.pseudo_label: train_pseudo_label(args) else: train_shallow2deep(args)
def check(): from torch_em.util.debug import check_loader loader = get_loader('train') check_loader(loader, 4)