예제 #1
0
def check_livecell_images():
    patch_shape = (512, 512)
    loader = get_livecell_loader(PATH,
                                 patch_shape,
                                 "train",
                                 download=True,
                                 batch_size=1)
    check_loader(loader, 10, instance_labels=True)
예제 #2
0
def check(args, train=True, val=True, n_images=2):
    from torch_em.util.debug import check_loader
    if train:
        print("Check train loader")
        loader = get_loader(args.input, is_train=True, n_samples=100)
        check_loader(loader, n_images)
    if val:
        print("Check val loader")
        loader = get_loader(args.input, is_train=False, n_samples=100)
        check_loader(loader, n_images)
예제 #3
0
def check(args, train=True, val=True, n_images=2):
    from torch_em.util.debug import check_loader
    patch_shape = [32, 320, 320]
    if train:
        print("Check train loader")
        loader = get_loader(args.input, True, patch_shape)
        check_loader(loader, n_images)
    if val:
        print("Check val loader")
        loader = get_loader(args.input, False, patch_shape)
        check_loader(loader, n_images)
예제 #4
0
def check(input_path, samples, train=True, val=True, n_images=5):
    from torch_em.util.debug import check_loader
    patch_shape = [32, 256, 256]
    if train:
        print("Check train loader")
        loader = get_loader(input_path, samples, splits=["train"], patch_shape=patch_shape)
        check_loader(loader, n_images)
    if val:
        print("Check val loader")
        loader = get_loader(input_path, samples, splits=["val"], patch_shape=patch_shape)
        check_loader(loader, n_images)
def check(train=True, val=True, n_images=5):
    from torch_em.util.debug import check_loader
    patch_shape = [1, 512, 512]
    if train:
        print("Check train loader")
        loader = get_loader('train', patch_shape, batch_size=1)
        check_loader(loader, n_images)
    if val:
        print("Check val loader")
        loader = get_loader('val', patch_shape, batch_size=1)
        check_loader(loader, n_images)
예제 #6
0
def check(args, train=True, val=True, n_images=2):
    from torch_em.util.debug import check_loader
    patch_shape = [32, 256, 256]
    samples, _ = normalize_samples(args.samples)
    if train:
        print("Check train loader")
        loader = get_loader(args.input, samples, True, patch_shape)
        check_loader(loader, n_images)
    if val:
        print("Check val loader")
        loader = get_loader(args.input, samples, False, patch_shape)
        check_loader(loader, n_images)
예제 #7
0
def check(datasets, train=True, val=True, n_images=5):
    from torch_em.util.debug import check_loader
    patch_shape = [32, 256, 256]
    if train:
        print("Check train loader")
        dsets = [f'{ds}_train' for ds in datasets]
        loader = get_loader(dsets, patch_shape)
        check_loader(loader, n_images)
    if val:
        print("Check val loader")
        dsets = [f'{ds}_val' for ds in datasets]
        loader = get_loader(dsets, patch_shape)
        check_loader(loader, n_images)
예제 #8
0
def check_loader(args, n=4):
    from torch_em.util.debug import check_loader
    loader = get_isbi_loader(args, "train", "./checkpoints/isbi2d/rfs")
    check_loader(loader, n)
예제 #9
0
    trainer = torch_em.default_segmentation_trainer(name,
                                                    model,
                                                    train_loader,
                                                    val_loader,
                                                    loss=dice_loss,
                                                    metric=dice_loss,
                                                    learning_rate=1.0e-4,
                                                    device=args.device,
                                                    log_image_interval=50)
    trainer.fit(args.n_iterations)


def check_loader(args, n=4):
    from torch_em.util.debug import check_loader
    loader = get_isbi_loader(args, "train", "./checkpoints/isbi2d/rfs")
    check_loader(loader, n)


if __name__ == "__main__":
    parser = torch_em.util.parser_helper()
    parser.add_argument("-p", "--pseudo_label", type=int, default=0)
    parser.add_argument("--n_rfs", type=int, default=500)
    parser.add_argument("--n_threads", type=int, default=32)
    args = parser.parse_args()
    if args.check:
        check_loader(args)
    elif args.pseudo_label:
        train_pseudo_label(args)
    else:
        train_shallow2deep(args)
예제 #10
0
def check():
    from torch_em.util.debug import check_loader
    loader = get_loader('train')
    check_loader(loader, 4)