コード例 #1
0
def inference(args):
    # load hyper parameters
    task_id = args.task_id
    checkpoint = args.checkpoint
    val_output_dir = "./runs_{}_fold{}_{}/".format(args.task_id, args.fold,
                                                   args.expr_name)
    sw_batch_size = args.sw_batch_size
    infer_output_dir = os.path.join(val_output_dir, task_name[task_id])
    window_mode = args.window_mode
    eval_overlap = args.eval_overlap
    amp = args.amp
    tta_val = args.tta_val
    multi_gpu_flag = args.multi_gpu
    local_rank = args.local_rank

    if not os.path.exists(infer_output_dir):
        os.makedirs(infer_output_dir)

    if multi_gpu_flag:
        dist.init_process_group(backend="nccl", init_method="env://")
        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cuda")

    properties, test_loader = get_data(args, mode="test")

    net = get_network(properties, task_id, val_output_dir, checkpoint)
    net = net.to(device)

    if multi_gpu_flag:
        net = DistributedDataParallel(module=net,
                                      device_ids=[device],
                                      find_unused_parameters=True)

    net.eval()

    inferrer = DynUNetInferrer(
        device=device,
        val_data_loader=test_loader,
        network=net,
        output_dir=infer_output_dir,
        n_classes=len(properties["labels"]),
        inferer=SlidingWindowInferer(
            roi_size=patch_size[task_id],
            sw_batch_size=sw_batch_size,
            overlap=eval_overlap,
            mode=window_mode,
        ),
        amp=amp,
        tta_val=tta_val,
    )

    inferrer.run()
コード例 #2
0
ファイル: train.py プロジェクト: nicolizamacorrea/tutorials-1
def train(args):
    # load hyper parameters
    task_id = args.task_id
    fold = args.fold
    val_output_dir = "./runs_{}_fold{}_{}/".format(task_id, fold,
                                                   args.expr_name)
    log_filename = "nnunet_task{}_fold{}.log".format(task_id, fold)
    log_filename = os.path.join(val_output_dir, log_filename)
    interval = args.interval
    learning_rate = args.learning_rate
    max_epochs = args.max_epochs
    multi_gpu_flag = args.multi_gpu
    amp_flag = args.amp
    lr_decay_flag = args.lr_decay
    sw_batch_size = args.sw_batch_size
    tta_val = args.tta_val
    batch_dice = args.batch_dice
    window_mode = args.window_mode
    eval_overlap = args.eval_overlap
    local_rank = args.local_rank
    determinism_flag = args.determinism_flag
    determinism_seed = args.determinism_seed
    if determinism_flag:
        set_determinism(seed=determinism_seed)
        if local_rank == 0:
            print("Using deterministic training.")

    # transforms
    train_batch_size = data_loader_params[task_id]["batch_size"]
    if multi_gpu_flag:
        dist.init_process_group(backend="nccl", init_method="env://")

        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cuda")

    properties, val_loader = get_data(args, mode="validation")
    _, train_loader = get_data(args, batch_size=train_batch_size, mode="train")

    # produce the network
    checkpoint = args.checkpoint
    net = get_network(properties, task_id, val_output_dir, checkpoint)
    net = net.to(device)

    if multi_gpu_flag:
        net = DistributedDataParallel(module=net,
                                      device_ids=[device],
                                      find_unused_parameters=True)

    optimizer = torch.optim.SGD(
        net.parameters(),
        lr=learning_rate,
        momentum=0.99,
        weight_decay=3e-5,
        nesterov=True,
    )

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=lambda epoch: (1 - epoch / max_epochs)**0.9)
    # produce evaluator
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointSaver(save_dir=val_output_dir,
                        save_dict={"net": net},
                        save_key_metric=True),
    ]

    evaluator = DynUNetEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        n_classes=len(properties["labels"]),
        inferer=SlidingWindowInferer(
            roi_size=patch_size[task_id],
            sw_batch_size=sw_batch_size,
            overlap=eval_overlap,
            mode=window_mode,
        ),
        post_transform=None,
        key_val_metric={
            "val_mean_dice":
            MeanDice(
                include_background=False,
                output_transform=lambda x: (x["pred"], x["label"]),
            )
        },
        val_handlers=val_handlers,
        amp=amp_flag,
        tta_val=tta_val,
    )
    # produce trainer
    loss = DiceCELoss(to_onehot_y=True, softmax=True, batch=batch_dice)
    train_handlers = []
    if lr_decay_flag:
        train_handlers += [
            LrScheduleHandler(lr_scheduler=scheduler, print_lr=True)
        ]

    train_handlers += [
        ValidationHandler(validator=evaluator,
                          interval=interval,
                          epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
    ]

    trainer = DynUNetTrainer(
        device=device,
        max_epochs=max_epochs,
        train_data_loader=train_loader,
        network=net,
        optimizer=optimizer,
        loss_function=loss,
        inferer=SimpleInferer(),
        post_transform=None,
        key_train_metric=None,
        train_handlers=train_handlers,
        amp=amp_flag,
    )

    # run
    logger = logging.getLogger()

    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s")

    # Setup file handler
    fhandler = logging.FileHandler(log_filename)
    fhandler.setLevel(logging.INFO)
    fhandler.setFormatter(formatter)

    # Configure stream handler for the cells
    chandler = logging.StreamHandler()
    chandler.setLevel(logging.INFO)
    chandler.setFormatter(formatter)

    # Add both handlers
    if local_rank == 0:
        logger.addHandler(fhandler)
        logger.addHandler(chandler)
        logger.setLevel(logging.INFO)

    trainer.run()
コード例 #3
0
ファイル: train.py プロジェクト: nicolizamacorrea/tutorials-1
def validation(args):
    # load hyper parameters
    task_id = args.task_id
    sw_batch_size = args.sw_batch_size
    tta_val = args.tta_val
    window_mode = args.window_mode
    eval_overlap = args.eval_overlap
    multi_gpu_flag = args.multi_gpu
    local_rank = args.local_rank
    amp = args.amp

    # produce the network
    checkpoint = args.checkpoint
    val_output_dir = "./runs_{}_fold{}_{}/".format(task_id, args.fold,
                                                   args.expr_name)

    if multi_gpu_flag:
        dist.init_process_group(backend="nccl", init_method="env://")
        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cuda")

    properties, val_loader = get_data(args, mode="validation")
    net = get_network(properties, task_id, val_output_dir, checkpoint)
    net = net.to(device)

    if multi_gpu_flag:
        net = DistributedDataParallel(module=net,
                                      device_ids=[device],
                                      find_unused_parameters=True)

    n_classes = len(properties["labels"])

    net.eval()

    evaluator = DynUNetEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        n_classes=n_classes,
        inferer=SlidingWindowInferer(
            roi_size=patch_size[task_id],
            sw_batch_size=sw_batch_size,
            overlap=eval_overlap,
            mode=window_mode,
        ),
        post_transform=None,
        key_val_metric={
            "val_mean_dice":
            MeanDice(
                include_background=False,
                output_transform=lambda x: (x["pred"], x["label"]),
            )
        },
        additional_metrics=None,
        amp=amp,
        tta_val=tta_val,
    )

    evaluator.run()
    if local_rank == 0:
        print(evaluator.state.metrics)
        results = evaluator.state.metric_details["val_mean_dice"]
        if n_classes > 2:
            for i in range(n_classes - 1):
                print("mean dice for label {} is {}".format(
                    i + 1, results[:, i].mean()))