Ejemplo n.º 1
0
def main():
    args = parse_args()
    model_path = Path(args.model_path)
    if args.network_conf_path is not None:
        train_conf = ConfigFactory.parse_file(str(Path(
            args.network_conf_path)))
        conf_path = Path(args.network_conf_path)
    else:
        train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH))
        conf_path = TRAIN_CONF_PATH
    data_config = DataConfig.from_conf(conf_path)
    dataset_config = DatasetConfig.from_conf(
        name=data_config.training_datasets[0],
        mount_prefix=data_config.mount_prefix,
        mode=data_config.data_mode)
    input_path = Path(args.input_dir).joinpath(
        dataset_config.image_label_format.image.format(phase=args.phase))
    output_dir = Path(args.output_dir)
    checkpoint = torch.load(str(model_path),
                            map_location=torch.device(args.device))

    get_conf(train_conf, group="network", key="experiment_dir")
    network = UNet(
        in_channels=get_conf(train_conf, group="network", key="in_channels"),
        n_classes=get_conf(train_conf, group="network", key="n_classes"),
        n_filters=get_conf(train_conf, group="network", key="n_filters"),
    )
    network.load_state_dict(checkpoint)
    network.cuda(device=args.device)
    # image = nib.load(str(input_path)).get_data()
    # if image.ndim == 4:
    #     image = np.squeeze(image, axis=-1).astype(np.int16)
    # image = image.astype(np.int16)
    # image = np.transpose(image, (2, 0, 1))
    dataset = Torch2DSegmentationDataset(
        name=dataset_config.name,
        image_paths=[input_path],
        label_paths=[
            input_path.parent.joinpath(
                dataset_config.image_label_format.label.format(
                    phase=args.phase))
        ],
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="n_slices"),
        is_3d=True,
    )

    image = dataset.get_image_tensor_from_index(0)
    image = torch.unsqueeze(image, 0)
    image = prepare_tensors(image, True, args.device)

    label = dataset.get_label_tensor_from_index(0)
    inference(
        image=image,
        label=label,
        image_path=input_path,
        network=network,
        output_dir=output_dir,
    )
Ejemplo n.º 2
0
def main():
    args = parse_args()
    if args.conf_path is None:
        train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH))
        conf_path = TRAIN_CONF_PATH
    else:
        train_conf = ConfigFactory.parse_file(str(Path(args.conf_path)))
        conf_path = Path(args.conf_path)

    if get_conf(train_conf, group="experiment", key="experiment_dir") is not None:
        experiment_dir = Path(get_conf(train_conf, group="experiment", key="experiment_dir"))
    else:
        experiment_dir = None
    config = ExperimentConfig(
        experiment_dir=experiment_dir,
        batch_size=get_conf(train_conf, group="experiment", key="batch_size"),
        num_epochs=get_conf(train_conf, group="experiment", key="num_epochs"),
        gpu=get_conf(train_conf, group="experiment", key="gpu"),
        device=get_conf(train_conf, group="experiment", key="device"),
        num_workers=get_conf(train_conf, group="experiment", key="num_workers"),
        pin_memory=get_conf(train_conf, group="experiment", key="pin_memory"),
        n_inference=get_conf(train_conf, group="experiment", key="n_inference"),
        seed=get_conf(train_conf, group="experiment", key="seed"),
    )
    augmentation_config = AugmentationConfig.from_conf(conf_path)
    shutil.copy(str(conf_path), str(config.experiment_dir.joinpath("train.conf")))
    network = DefSegNet(
        in_channels=get_conf(train_conf, group="network", key="in_channels"),
        n_classes=get_conf(train_conf, group="network", key="n_classes"),
        n_filters=get_conf(train_conf, group="network", key="n_filters"),
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="n_slices"),
        int_downsize=get_conf(train_conf, group="network", key="integrate_downsize"),
        bidir=True,
    )
    checkpoint = torch.load(
        str(get_conf(train_conf, group="network", key="checkpoint_path")),
        map_location=torch.device(get_conf(train_conf, group="experiment", key="device"))
    )

    network.load_state_dict(checkpoint)
    network.cuda(device=get_conf(train_conf, group="experiment", key="device"))
    network.freeze_vxm()
    training_sets, validation_sets, extra_validation_sets = construct_training_validation_dataset(
        DataConfig.from_conf(conf_path), feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="n_slices"), is_3d=True, seed=config.seed,
        augmentation_config=augmentation_config, output_dir=config.experiment_dir,
        template_path=Path(get_conf(train_conf, group="network", key="template_path")),
        template_image_path=Path(get_conf(train_conf, group="network", key="template_image_path")),
    )
    for train in training_sets:
        train.export(config.experiment_dir.joinpath("training_set_{}.csv".format(train.name)))
    for val in validation_sets:
        val.export(config.experiment_dir.joinpath("validation_set_{}.csv".format(val.name)))
    for val in extra_validation_sets:
        val.export(config.experiment_dir.joinpath("extra_validation_set_{}.csv".format(val.name)))
    if get_conf(train_conf, group="optimizer", key="type") == "SGD":
        optimizer = torch.optim.SGD(
            network.parameters(),
            lr=get_conf(train_conf, group="optimizer", key="learning_rate"),
            momentum=get_conf(train_conf, group="optimizer", key="momentum"),
        )
    else:
        optimizer = torch.optim.Adam(
            network.parameters(), lr=get_conf(train_conf, group="optimizer", key="learning_rate")
        )

    loss = UDADefSegLoss(
        penalty="l2",
        loss_mult=get_conf(train_conf, group="network", key="integrate_downsize"),
        weights=get_conf(train_conf, group="loss", key="weights"),
    )
    experiment = DefSegExperiment(
        config=config,
        network=network,
        training_sets=training_sets,
        validation_sets=validation_sets,
        extra_validation_sets=extra_validation_sets,
        optimizer=optimizer,
        loss=loss,
        other_validation_metrics=[
            DefSegWarpedTemplateDice(), DefSegPredDice(), DefSegWarpedMapsDice()
        ],
        inference_func=inference
    )
    experiment.train()
Ejemplo n.º 3
0
def main():
    args = parse_args()
    if args.conf_path is None:
        train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH))
    else:
        train_conf = ConfigFactory.parse_file(str(Path(args.conf_path)))

    if get_conf(train_conf, group="experiment",
                key="experiment_dir") is not None:
        experiment_dir = Path(
            get_conf(train_conf, group="experiment", key="experiment_dir"))
    else:
        experiment_dir = None
    config = ExperimentConfig(
        experiment_dir=experiment_dir,
        batch_size=get_conf(train_conf, group="experiment", key="batch_size"),
        num_epochs=get_conf(train_conf, group="experiment", key="num_epochs"),
        gpu=get_conf(train_conf, group="experiment", key="gpu"),
        device=get_conf(train_conf, group="experiment", key="device"),
        num_workers=get_conf(train_conf, group="experiment",
                             key="num_workers"),
        pin_memory=get_conf(train_conf, group="experiment", key="pin_memory"),
    )
    shutil.copy(str(TRAIN_CONF_PATH),
                str(config.experiment_dir.joinpath("train.conf")))
    network = FCN2DSegmentationModel(
        in_channels=get_conf(train_conf, group="network", key="in_channels"),
        n_classes=get_conf(train_conf, group="network", key="n_classes"),
        n_filters=get_conf(train_conf, group="network", key="n_filters"),
        up_conv_filter=get_conf(train_conf,
                                group="network",
                                key="up_conv_filter"),
        final_conv_filter=get_conf(train_conf,
                                   group="network",
                                   key="final_conv_filter"),
        feature_size=get_conf(train_conf, group="network", key="feature_size"))
    training_set, validation_set = construct_training_validation_dataset(
        DataConfig.from_conf(TRAIN_CONF_PATH),
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="in_channels"))
    training_set.export(config.experiment_dir.joinpath("training_set.csv"))
    validation_set.export(config.experiment_dir.joinpath("validation_set.csv"))
    if get_conf(train_conf, group="optimizer", key="type") == "SGD":
        optimizer = torch.optim.SGD(
            network.parameters(),
            lr=get_conf(train_conf, group="optimizer", key="learning_rate"),
            momentum=get_conf(train_conf, group="optimizer", key="momentum"),
        )
    else:
        optimizer = torch.optim.Adam(network.parameters(),
                                     lr=get_conf(train_conf,
                                                 group="optimizer",
                                                 key="learning_rate"))
    if get_conf(train_conf, group="loss", key="type") == "FocalLoss":
        loss = FocalLoss(
            alpha=get_conf(train_conf, group="loss", key="alpha"),
            gamma=get_conf(train_conf, group="loss", key="gamma"),
            logits=True,
        )
    else:
        loss = BCELoss()
    experiment = Experiment(
        config=config,
        network=network,
        training_set=training_set,
        validation_set=validation_set,
        optimizer=optimizer,
        loss=loss,
        other_validation_metrics=[DiceCoeffWithLogits()],
    )
    experiment.train()
Ejemplo n.º 4
0
def main():
    args = parse_args()
    if args.starting_epoch > 0:
        assert args.exp_dir is not None
        exp_dir = Path(args.exp_dir)
    if args.conf_path is None:
        train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH))
        conf_path = TRAIN_CONF_PATH
    else:
        train_conf = ConfigFactory.parse_file(str(Path(args.conf_path)))
        conf_path = Path(args.conf_path)

    if get_conf(train_conf, group="experiment",
                key="experiment_dir") is not None:
        experiment_dir = Path(
            get_conf(train_conf, group="experiment", key="experiment_dir"))
    else:
        experiment_dir = None
    config = ExperimentConfig(
        experiment_dir=experiment_dir,
        batch_size=get_conf(train_conf, group="experiment", key="batch_size"),
        num_epochs=get_conf(train_conf, group="experiment", key="num_epochs"),
        gpu=get_conf(train_conf, group="experiment", key="gpu"),
        device=get_conf(train_conf, group="experiment", key="device"),
        num_workers=get_conf(train_conf, group="experiment",
                             key="num_workers"),
        pin_memory=get_conf(train_conf, group="experiment", key="pin_memory"),
        n_inference=get_conf(train_conf, group="experiment",
                             key="n_inference"),
        seed=get_conf(train_conf, group="experiment", key="seed"),
    )
    augmentation_config = AugmentationConfig.from_conf(conf_path)
    shutil.copy(str(conf_path),
                str(config.experiment_dir.joinpath("train.conf")))

    if get_conf(train_conf, group="network", key="name") == "encodedecode":
        network = ImgTemplateEncoderNet(
            feature_size=get_conf(train_conf,
                                  group="network",
                                  key="feature_size"),
            n_slices=get_conf(train_conf, group="network", key="n_slices"),
            n_filters=get_conf(train_conf, group="network", key="n_filters"),
            batch_norm=get_conf(train_conf, group="network", key="batch_norm"),
            group_norm=get_conf(train_conf, group="network", key="group_norm"),
            int_downsize=get_conf(train_conf,
                                  group="network",
                                  key="integrate_downsize"),
            bidir=True,
            in_channels=get_conf(train_conf,
                                 group="network",
                                 key="in_channels"),
            n_classes=get_conf(train_conf, group="network", key="n_classes"),
            batch_size=get_conf(train_conf,
                                group="experiment",
                                key="batch_size"),
            gpu=get_conf(train_conf, group="experiment", key="gpu"),
            device=get_conf(train_conf, group="experiment", key="device"),
        )
    elif get_conf(train_conf, group="network", key="name") == "istn":
        network = ISTNNet(
            feature_size=get_conf(train_conf,
                                  group="network",
                                  key="feature_size"),
            n_slices=get_conf(train_conf, group="network", key="n_slices"),
            n_filters=get_conf(train_conf, group="network", key="n_filters"),
            batch_norm=get_conf(train_conf, group="network", key="batch_norm"),
            group_norm=get_conf(train_conf, group="network", key="group_norm"),
            int_downsize=get_conf(train_conf,
                                  group="network",
                                  key="integrate_downsize"),
            bidir=True,
            in_channels=get_conf(train_conf,
                                 group="network",
                                 key="in_channels"),
            n_classes=get_conf(train_conf, group="network", key="n_classes"),
            batch_size=get_conf(train_conf,
                                group="experiment",
                                key="batch_size"),
            gpu=get_conf(train_conf, group="experiment", key="gpu"),
            device=get_conf(train_conf, group="experiment", key="device"),
        )
    else:
        raise ValueError("network name not supported. istn or encodedecode")

    if args.starting_epoch > 0:
        print(
            exp_dir.joinpath("checkpoints",
                             "CP_{}.pth".format(args.starting_epoch - 1)))
        checkpoint = torch.load(
            str(
                exp_dir.joinpath("checkpoints",
                                 "CP_{}.pth".format(args.starting_epoch - 1))),
            map_location=torch.device(
                get_conf(train_conf, group="experiment", key="device")),
        )
        network.load_state_dict(checkpoint)
        network.cuda(
            device=get_conf(train_conf, group="experiment", key="device"))
        image = DefSegDataset.read_image(
            image_path=exp_dir.joinpath(
                "atlas", "epoch_{}".format(args.starting_epoch - 1),
                "image.nii.gz"),
            feature_size=None,
            n_slices=None,
        )
        label = DefSegDataset.read_label(
            label_path=exp_dir.joinpath(
                "atlas", "epoch_{}".format(args.starting_epoch - 1),
                "label.nii.gz"),
            feature_size=None,
            n_slices=None,
        )
        atlas = Atlas(image=image, label=label)
    else:
        atlas = None

    training_sets, validation_sets, extra_validation_sets = construct_training_validation_dataset(
        DataConfig.from_conf(conf_path),
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="n_slices"),
        is_3d=True,
        seed=config.seed,
        augmentation_config=augmentation_config,
        output_dir=config.experiment_dir,
    )
    for train in training_sets:
        train.export(
            config.experiment_dir.joinpath("training_set_{}.csv".format(
                train.name)))
    for val in validation_sets:
        val.export(
            config.experiment_dir.joinpath("validation_set_{}.csv".format(
                val.name)))
    for val in extra_validation_sets:
        val.export(
            config.experiment_dir.joinpath(
                "extra_validation_set_{}.csv".format(val.name)))
    if get_conf(train_conf, group="optimizer", key="type") == "SGD":
        optimizer = torch.optim.SGD(
            network.parameters(),
            lr=get_conf(train_conf, group="optimizer", key="learning_rate"),
            momentum=get_conf(train_conf, group="optimizer", key="momentum"),
        )
    else:
        optimizer = torch.optim.Adam(network.parameters(),
                                     lr=get_conf(train_conf,
                                                 group="optimizer",
                                                 key="learning_rate"))

    loss = DefLoss(
        penalty="l2",
        loss_mult=get_conf(train_conf,
                           group="network",
                           key="integrate_downsize"),
        weights=get_conf(train_conf, group="loss", key="weights"),
    )

    experiment = DefSegExperiment(
        config=config,
        network=network,
        training_sets=training_sets,
        validation_sets=validation_sets,
        extra_validation_sets=extra_validation_sets,
        optimizer=optimizer,
        loss=loss,
        other_validation_metrics=[
            DefAffineWarpedTemplateDice(),
            DefWarpedTemplateDice(),
            DefPredDice(),
        ],
        inference_func=inference,
    )
    experiment.train(
        starting_epoch=args.starting_epoch,
        atlas=atlas,
        atlas_eta=get_conf(train_conf, group="network", key="atlas_eta"),
    )
Ejemplo n.º 5
0
def main():
    args = parse_args()
    if args.conf_path is None:
        train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH))
        conf_path = TRAIN_CONF_PATH
    else:
        train_conf = ConfigFactory.parse_file(str(Path(args.conf_path)))
        conf_path = Path(args.conf_path)

    if get_conf(train_conf, group="experiment",
                key="experiment_dir") is not None:
        experiment_dir = Path(
            get_conf(train_conf, group="experiment", key="experiment_dir"))
    else:
        experiment_dir = None
    config = ExperimentConfig(
        experiment_dir=experiment_dir,
        batch_size=get_conf(train_conf, group="experiment", key="batch_size"),
        num_epochs=get_conf(train_conf, group="experiment", key="num_epochs"),
        gpu=get_conf(train_conf, group="experiment", key="gpu"),
        device=get_conf(train_conf, group="experiment", key="device"),
        num_workers=get_conf(train_conf, group="experiment",
                             key="num_workers"),
        pin_memory=get_conf(train_conf, group="experiment", key="pin_memory"),
        n_inference=get_conf(train_conf, group="experiment",
                             key="n_inference"),
        seed=get_conf(train_conf, group="experiment", key="seed"),
    )
    augmentation_config = AugmentationConfig.from_conf(conf_path)
    shutil.copy(str(conf_path),
                str(config.experiment_dir.joinpath("train.conf")))
    network = UNet(
        in_channels=get_conf(train_conf, group="network", key="in_channels"),
        n_classes=get_conf(train_conf, group="network", key="n_classes"),
        n_filters=get_conf(train_conf, group="network", key="n_filters"),
    )
    training_sets, validation_sets, extra_validation_sets = construct_training_validation_dataset(
        DataConfig.from_conf(conf_path),
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="n_slices"),
        is_3d=True,
        seed=config.seed,
        augmentation_config=augmentation_config,
        output_dir=config.experiment_dir,
    )
    for train in training_sets:
        train.export(
            config.experiment_dir.joinpath("training_set_{}.csv".format(
                train.name)))
    for val in validation_sets:
        val.export(
            config.experiment_dir.joinpath("validation_set_{}.csv".format(
                val.name)))
    for val in extra_validation_sets:
        val.export(
            config.experiment_dir.joinpath(
                "extra_validation_set_{}.csv".format(val.name)))
    if get_conf(train_conf, group="optimizer", key="type") == "SGD":
        optimizer = torch.optim.SGD(
            network.parameters(),
            lr=get_conf(train_conf, group="optimizer", key="learning_rate"),
            momentum=get_conf(train_conf, group="optimizer", key="momentum"),
        )
    else:
        optimizer = torch.optim.Adam(network.parameters(),
                                     lr=get_conf(train_conf,
                                                 group="optimizer",
                                                 key="learning_rate"))
    if get_conf(train_conf, group="loss", key="type") == "FocalLoss":
        loss = FocalLoss(
            alpha=get_conf(train_conf, group="loss", key="alpha"),
            gamma=get_conf(train_conf, group="loss", key="gamma"),
            logits=True,
        )
    else:
        loss = BCELoss()
    experiment = FCN3DExperiment(
        config=config,
        network=network,
        training_sets=training_sets,
        validation_sets=validation_sets,
        extra_validation_sets=extra_validation_sets,
        optimizer=optimizer,
        loss=loss,
        other_validation_metrics=[DiceCoeffWithLogits()],
        inference_func=inference)
    experiment.train()
Ejemplo n.º 6
0
def main():
    args = parse_args()
    model_path = Path(args.model_path)
    if args.network_conf_path is not None:
        train_conf = ConfigFactory.parse_file(str(Path(
            args.network_conf_path)))
        conf_path = Path(args.network_conf_path)
    else:
        train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH))
        conf_path = TRAIN_CONF_PATH
    data_config = DataConfig.from_conf(conf_path)
    dataset_config = DatasetConfig.from_conf(
        name=data_config.dataset_names[0],
        mount_prefix=data_config.mount_prefix,
        mode=data_config.data_mode)
    input_path = Path(args.input_dir).joinpath(
        dataset_config.image_label_format.image.format(phase=args.phase))
    output_dir = Path(args.output_dir)
    checkpoint = torch.load(str(model_path),
                            map_location=torch.device(args.device))

    get_conf(train_conf, group="network", key="experiment_dir")
    network = FCN2DSegmentationModel(
        in_channels=get_conf(train_conf, group="network", key="in_channels"),
        n_classes=get_conf(train_conf, group="network", key="n_classes"),
        n_filters=get_conf(train_conf, group="network", key="n_filters"),
        up_conv_filter=get_conf(train_conf,
                                group="network",
                                key="up_conv_filter"),
        final_conv_filter=get_conf(train_conf,
                                   group="network",
                                   key="final_conv_filter"),
        feature_size=get_conf(train_conf, group="network", key="feature_size"))
    network.load_state_dict(checkpoint)
    network.cuda(device=args.device)
    # image = nib.load(str(input_path)).get_data()
    # if image.ndim == 4:
    #     image = np.squeeze(image, axis=-1).astype(np.int16)
    # image = image.astype(np.int16)
    # image = np.transpose(image, (2, 0, 1))
    image = Torch2DSegmentationDataset.read_image(
        input_path,
        get_conf(train_conf, group="network", key="feature_size"),
        get_conf(train_conf, group="network", key="in_channels"),
        crop=args.crop_image,
    )
    image = np.expand_dims(image, 0)
    image = torch.from_numpy(image).float()
    image = prepare_tensors(image, gpu=True, device=args.device)
    predicted = network(image)
    predicted = torch.sigmoid(predicted)
    print("sigmoid", torch.mean(predicted).item(), torch.max(predicted).item())
    predicted = (predicted > 0.5).float()
    print("0.5", torch.mean(predicted).item(), torch.max(predicted).item())
    predicted = predicted.cpu().detach().numpy()

    nim = nib.load(str(input_path))
    # Transpose and crop the segmentation to recover the original size
    predicted = np.squeeze(predicted, axis=0)
    print(predicted.shape)

    # map back to original size
    final_predicted = np.zeros(
        (image.shape[1], image.shape[2], image.shape[3]))
    print(predicted.shape, final_predicted.shape)

    for i in range(predicted.shape[0]):
        a = predicted[i, :, :, :] > 0.5
        print(a.shape)
        final_predicted[predicted[i, :, :, :] > 0.5] = i + 1
    # image = nim.get_data()
    final_predicted = np.transpose(final_predicted, [1, 2, 0])
    print(predicted.shape, final_predicted.shape)
    # final_predicted = np.resize(final_predicted, (image.shape[0], image.shape[1], image.shape[2]))

    print(predicted.shape, final_predicted.shape, np.max(final_predicted),
          np.mean(final_predicted), np.min(final_predicted))
    # if Z < 64:
    #     pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, z1_ - z1:z1_ - z1 + Z]
    # else:
    #     pred_segt = pred_segt[x_pre:x_pre + X, y_pre:y_pre + Y, :]
    #     pred_segt = np.pad(pred_segt, ((0, 0), (0, 0), (z_pre, z_post)), 'constant')

    nim2 = nib.Nifti1Image(final_predicted, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    output_dir.mkdir(parents=True, exist_ok=True)
    nib.save(nim2, '{0}/seg.nii.gz'.format(str(output_dir)))

    final_image = image.cpu().detach().numpy()
    final_image = np.squeeze(final_image, 0)
    final_image = np.transpose(final_image, [1, 2, 0])
    print(final_image.shape)
    nim2 = nib.Nifti1Image(final_image, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    nib.save(nim2, '{0}/image.nii.gz'.format(str(output_dir)))
    # shutil.copy(str(input_path), str(output_dir.joinpath("image.nii.gz")))

    label = Torch2DSegmentationDataset.read_label(
        input_path.parent.joinpath(
            dataset_config.image_label_format.label.format(phase=args.phase)),
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="in_channels"),
        crop=args.crop_image,
    )
    final_label = np.zeros((image.shape[1], image.shape[2], image.shape[3]))
    for i in range(label.shape[0]):
        final_label[label[i, :, :, :] == 1.0] = i + 1

    final_label = np.transpose(final_label, [1, 2, 0])
    print(final_label.shape)
    nim2 = nib.Nifti1Image(final_label, nim.affine)
    nim2.header['pixdim'] = nim.header['pixdim']
    output_dir.mkdir(parents=True, exist_ok=True)
    nib.save(nim2, '{0}/label.nii.gz'.format(str(output_dir)))
Ejemplo n.º 7
0
def main():
    args = parse_args()
    if args.conf_path is None:
        train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH))
        conf_path = TRAIN_CONF_PATH
    else:
        train_conf = ConfigFactory.parse_file(str(Path(args.conf_path)))
        conf_path = Path(args.conf_path)

    if get_conf(train_conf, group="experiment",
                key="experiment_dir") is not None:
        experiment_dir = Path(
            get_conf(train_conf, group="experiment", key="experiment_dir"))
    else:
        experiment_dir = None
    config = ExperimentConfig(
        experiment_dir=experiment_dir,
        batch_size=get_conf(train_conf, group="experiment", key="batch_size"),
        num_epochs=get_conf(train_conf, group="experiment", key="num_epochs"),
        gpu=get_conf(train_conf, group="experiment", key="gpu"),
        device=get_conf(train_conf, group="experiment", key="device"),
        num_workers=get_conf(train_conf, group="experiment",
                             key="num_workers"),
        pin_memory=get_conf(train_conf, group="experiment", key="pin_memory"),
        n_inference=get_conf(train_conf, group="experiment",
                             key="n_inference"),
        seed=get_conf(train_conf, group="experiment", key="seed"),
    )
    augmentation_config = AugmentationConfig.from_conf(conf_path)
    shutil.copy(str(conf_path),
                str(config.experiment_dir.joinpath("train.conf")))
    atlas_path = "D:\surui\\rbh\cardiac\DL_segmentation\RBH_3D_atlases\IHD586_10RZ04876_RBH_IHD_201410241040_MRC25598\seg_lvsa_SR_ED.nii.gz"
    atlas = Torch2DSegmentationDataset.read_label(
        label_path=Path(atlas_path),
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="n_slices"),
    )
    network = DefSegNet(
        in_channels=get_conf(train_conf, group="network", key="in_channels"),
        n_classes=get_conf(train_conf, group="network", key="n_classes"),
        n_filters=get_conf(train_conf, group="network", key="n_filters"),
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="n_slices"),
        template=atlas,
        int_downsize=2,
    )

    training_sets, validation_sets, extra_validation_sets = construct_training_validation_dataset(
        DataConfig.from_conf(conf_path),
        feature_size=get_conf(train_conf, group="network", key="feature_size"),
        n_slices=get_conf(train_conf, group="network", key="n_slices"),
        is_3d=True,
        seed=config.seed,
        augmentation_config=augmentation_config,
        output_dir=config.experiment_dir,
    )
    for train in training_sets:
        train.export(
            config.experiment_dir.joinpath("training_set_{}.csv".format(
                train.name)))
    for val in validation_sets:
        val.export(
            config.experiment_dir.joinpath("validation_set_{}.csv".format(
                val.name)))
    for val in extra_validation_sets:
        val.export(
            config.experiment_dir.joinpath(
                "extra_validation_set_{}.csv".format(val.name)))
    if get_conf(train_conf, group="optimizer", key="type") == "SGD":
        optimizer = torch.optim.SGD(
            network.parameters(),
            lr=get_conf(train_conf, group="optimizer", key="learning_rate"),
            momentum=get_conf(train_conf, group="optimizer", key="momentum"),
        )
    else:
        optimizer = torch.optim.Adam(network.parameters(),
                                     lr=get_conf(train_conf,
                                                 group="optimizer",
                                                 key="learning_rate"))
    if get_conf(train_conf, group="loss", key="type") == "FocalLoss":
        loss = FocalLoss(
            alpha=get_conf(train_conf, group="loss", key="alpha"),
            gamma=get_conf(train_conf, group="loss", key="gamma"),
            logits=True,
        )
    else:
        loss = DefSegLoss(penalty="l2", loss_mult=2, template=atlas)
    experiment = Experiment(config=config,
                            network=network,
                            training_sets=training_sets,
                            validation_sets=validation_sets,
                            extra_validation_sets=extra_validation_sets,
                            optimizer=optimizer,
                            loss=loss,
                            other_validation_metrics=[
                                DefSegWarpedDice(template=atlas),
                                DefSegPredDice()
                            ],
                            inference_func=inference)
    experiment.train()