Пример #1
0
def run_test(batch_size, img_name, seg_name, output_dir, device="cuda:0"):
    ds = NiftiDataset([img_name], [seg_name], transform=AddChannel(), seg_transform=AddChannel(), image_only=False)
    loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available())

    net = UNet(
        dimensions=3, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2
    ).to(device)
    roi_size = (16, 32, 48)
    sw_batch_size = batch_size

    def _sliding_window_processor(_engine, batch):
        net.eval()
        img, seg, meta_data = batch
        with torch.no_grad():
            seg_probs = sliding_window_inference(img.to(device), roi_size, sw_batch_size, net, device=device)
            return predict_segmentation(seg_probs)

    infer_engine = Engine(_sliding_window_processor)

    SegmentationSaver(
        output_dir=output_dir, output_ext=".nii.gz", output_postfix="seg", batch_transform=lambda x: x[2]
    ).attach(infer_engine)

    infer_engine.run(loader)

    basename = os.path.basename(img_name)[: -len(".nii.gz")]
    saved_name = os.path.join(output_dir, basename, f"{basename}_seg.nii.gz")
    return saved_name
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        '/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz',
        '/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz',
        '/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz',
        '/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz',
        '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz',
        '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz',
        '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz',
        '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz',
        '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz',
        '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz'
    ]
    # 2 binary labels for gender classification: man and woman
    labels = np.array([
        0, 0, 1, 0, 1, 0, 1, 0, 1, 0
    ])

    # Define transforms for image
    val_transforms = Compose([
        ScaleIntensity(),
        AddChannel(),
        Resize((96, 96, 96)),
        ToTensor()
    ])

    # Define nifti dataset
    val_ds = NiftiDataset(image_files=images, labels=labels, transform=val_transforms, image_only=False)
    # create a validation data loader
    val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())

    # Create DenseNet121
    device = torch.device('cuda:0')
    model = monai.networks.nets.densenet.densenet121(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
    ).to(device)

    model.load_state_dict(torch.load('best_metric_model.pth'))
    model.eval()
    with torch.no_grad():
        num_correct = 0.
        metric_count = 0
        saver = CSVSaver(output_dir='./output')
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            val_outputs = model(val_images).argmax(dim=1)
            value = torch.eq(val_outputs, val_labels)
            metric_count += len(value)
            num_correct += value.sum().item()
            saver.save_batch(val_outputs, val_data[2])
        metric = num_correct / metric_count
        print('evaluation metric:', metric)
        saver.finalize()
def main():
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print('generating synthetic data to {} (this may take a while)'.format(tempdir))
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

    images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
    segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)
    # sliding window inference for one image at every iteration
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())

    device = torch.device('cuda:0')
    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(torch.load('best_metric_model.pth'))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.
        metric_count = 0
        saver = NiftiSaver(output_dir='./output')
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True,
                                     to_onehot_y=False, add_sigmoid=True)
            metric_count += len(value)
            metric_sum += value.sum().item()
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(val_outputs, val_data[2])
        metric = metric_sum / metric_count
        print('evaluation metric:', metric)
    shutil.rmtree(tempdir)
Пример #4
0
def main(tempdir):
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    val_ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)
    # sliding window inference for one image at every iteration
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(torch.load("best_metric_model_segmentation3d_array.pth"))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = NiftiSaver(output_dir="./output")
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            value = dice_metric(y_pred=val_outputs, y=val_labels)
            metric_count += len(value)
            metric_sum += value.item() * len(value)
            val_outputs = (val_outputs.sigmoid() >= 0.5).float()
            saver.save_batch(val_outputs, val_data[2])
        metric = metric_sum / metric_count
        print("evaluation metric:", metric)
Пример #5
0
train_transforms = Compose([
    ScaleIntensity(),
    AddChannel(),
    Resize((96, 96, 96)),
    RandRotate90(),
    ToTensor()
])
val_transforms = Compose(
    [ScaleIntensity(),
     AddChannel(),
     Resize((96, 96, 96)),
     ToTensor()])

# Define nifti dataset, data loader
check_ds = NiftiDataset(image_files=images,
                        labels=labels,
                        transform=train_transforms)
check_loader = DataLoader(check_ds,
                          batch_size=2,
                          num_workers=2,
                          pin_memory=torch.cuda.is_available())
im, label = monai.utils.misc.first(check_loader)
print(type(im), im.shape, label)

# create a training data loader
train_ds = NiftiDataset(image_files=images[:10],
                        labels=labels[:10],
                        transform=train_transforms)
train_loader = DataLoader(train_ds,
                          batch_size=2,
                          shuffle=True,
def main():
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    tempdir = tempfile.mkdtemp()
    print('generating synthetic data to {} (this may take a while)'.format(tempdir))
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

    images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
    segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)

    device = torch.device('cuda:0')
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )
    net.to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4


    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch[0].to(device), batch[1].to(device)
            seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
            return seg_probs, val_labels


    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, 'Mean_Dice')

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name='evaluator',
        output_transform=lambda x: None  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # for the array data format, assume the 3rd item of batch data is the meta_data
    file_saver = SegmentationSaver(
        output_dir='tempdir', output_ext='.nii.gz', output_postfix='seg', name='evaluator',
        batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0]))
    file_saver.attach(evaluator)

    # the model was trained by "unet_training_array" example
    ckpt_saver = CheckpointLoader(load_path='./runs/net_checkpoint_50.pth', load_dict={'net': net})
    ckpt_saver.attach(evaluator)

    # sliding window inference for one image at every iteration
    loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
    state = evaluator.run(loader)
    shutil.rmtree(tempdir)
Пример #7
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    train_imtrans = Compose(
        [
            ScaleIntensity(),
            AddChannel(),
            RandSpatialCrop((96, 96, 96), random_size=False),
            RandRotate90(prob=0.5, spatial_axes=(0, 2)),
            ToTensor(),
        ]
    )
    train_segtrans = Compose(
        [
            AddChannel(),
            RandSpatialCrop((96, 96, 96), random_size=False),
            RandRotate90(prob=0.5, spatial_axes=(0, 2)),
            ToTensor(),
        ]
    )
    val_imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    val_segtrans = Compose([AddChannel(), ToTensor()])

    # define nifti dataset, data loader
    check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans)
    check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
    im, seg = monai.utils.misc.first(check_loader)
    print(im.shape, seg.shape)

    # create a training data loader
    train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans)
    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda:0")
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(do_sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                    roi_size = (96, 96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                    value = compute_meandice(
                        y_pred=val_outputs, y=val_labels, include_background=True, to_onehot_y=False, add_sigmoid=True
                    )
                    metric_count += len(value)
                    metric_sum += value.sum().item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric, best_metric, best_metric_epoch
                    )
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
    shutil.rmtree(tempdir)
    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()
Пример #8
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz",
    ]
    # 2 binary labels for gender classification: man and woman
    labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0])

    # define transforms
    train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90(), ToTensor()])
    val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), ToTensor()])

    # define nifti dataset, data loader
    check_ds = NiftiDataset(image_files=images, labels=labels, transform=train_transforms)
    check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available())
    im, label = monai.utils.misc.first(check_loader)
    print(type(im), im.shape, label)

    # create DenseNet121, CrossEntropyLoss and Adam optimizer
    net = monai.networks.nets.densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=2,)
    loss = torch.nn.CrossEntropyLoss()
    lr = 1e-5
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.device("cuda:0")

    # ignite trainer expects batch=(img, label) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    trainer = create_supervised_trainer(net, opt, loss, device, False)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint("./runs/", "net", n_saved=10, require_empty=False)
    trainer.add_event_handler(
        event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"net": net, "opt": opt}
    )

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not loss value
    train_stats_handler = StatsHandler(name="trainer")
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler()
    train_tensorboard_stats_handler.attach(trainer)

    # set parameters for validation
    validation_every_n_epochs = 1

    metric_name = "Accuracy"
    # add evaluation metric to the evaluator engine
    val_metrics = {metric_name: Accuracy()}
    # ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net, val_metrics, device, True)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x: None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x: None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

    # create a validation data loader
    val_ds = NiftiDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available())

    @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
    def run_validation(engine):
        evaluator.run(val_loader)

    # create a training data loader
    train_ds = NiftiDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms)
    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())

    train_epochs = 30
    state = trainer.run(train_loader, train_epochs)
Пример #9
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz",
    ]
    # 2 binary labels for gender classification: man and woman
    labels = np.array([0, 0, 1, 0, 1, 0, 1, 0, 1, 0])

    # define transforms for image
    val_transforms = Compose(
        [ScaleIntensity(),
         AddChannel(),
         Resize((96, 96, 96)),
         ToTensor()])
    # define nifti dataset
    val_ds = NiftiDataset(image_files=images,
                          labels=labels,
                          transform=val_transforms,
                          image_only=False)
    # create DenseNet121
    net = monai.networks.nets.densenet.densenet121(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
    )
    device = torch.device("cuda:0")

    metric_name = "Accuracy"
    # add evaluation metric to the evaluator engine
    val_metrics = {metric_name: Accuracy()}

    def prepare_batch(batch, device=None, non_blocking=False):
        return _prepare_batch((batch[0], batch[1]), device, non_blocking)

    # ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net,
                                            val_metrics,
                                            device,
                                            True,
                                            prepare_batch=prepare_batch)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # for the array data format, assume the 3rd item of batch data is the meta_data
    prediction_saver = ClassificationSaver(
        output_dir="tempdir",
        batch_transform=lambda batch: batch[2],
        output_transform=lambda output: output[0].argmax(1),
    )
    prediction_saver.attach(evaluator)

    # the model was trained by "densenet_training_array" example
    CheckpointLoader(load_path="./runs/net_checkpoint_20.pth",
                     load_dict={
                         "net": net
                     }).attach(evaluator)

    # create a validation data loader
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())

    state = evaluator.run(val_loader)
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    data_dir = '/home/marafath/scratch/eu_data'
    labels = np.load('eu_labels.npy')
    train_images = []
    train_labels = []

    val_images = []
    val_labels = []

    n_count = 0
    p_count = 0
    idx = 0
    for case in os.listdir(data_dir):
        if p_count < 13 and labels[idx] == 1:
            val_images.append(
                os.path.join(data_dir, case, 'image_masked.nii.gz'))
            val_labels.append(labels[idx])
            p_count += 1
            idx += 1
        elif n_count < 11 and labels[idx] == 0:
            val_images.append(
                os.path.join(data_dir, case, 'image_masked.nii.gz'))
            val_labels.append(labels[idx])
            n_count += 1
            idx += 1
        else:
            train_images.append(
                os.path.join(data_dir, case, 'image_masked.nii.gz'))
            train_labels.append(labels[idx])
            idx += 1

    # Define transforms
    train_transforms = Compose([
        ScaleIntensity(),
        AddChannel(),
        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
        SpatialPad((256, 256, 92), mode='constant'),
        Resize((256, 256, 92)),
        ToTensor()
    ])

    val_transforms = Compose([
        ScaleIntensity(),
        AddChannel(),
        SpatialPad((256, 256, 92), mode='constant'),
        Resize((256, 256, 92)),
        ToTensor()
    ])

    # create a training data loader
    train_ds = NiftiDataset(image_files=train_images,
                            labels=train_labels,
                            transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=torch.cuda.is_available())

    # create a validation data loader
    val_ds = NiftiDataset(image_files=val_images,
                          labels=val_labels,
                          transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=2,
                            pin_memory=torch.cuda.is_available())

    # Create DenseNet121, CrossEntropyLoss and Adam optimizer
    device = torch.device('cuda:0')
    model = monai.networks.nets.densenet.densenet121(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
    ).to(device)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # finetuning
    #model.load_state_dict(torch.load('best_metric_model_d121.pth'))

    # start a typical PyTorch training
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    epc = 100  # Number of epoch
    for epoch in range(epc):
        print('-' * 10)
        print('epoch {}/{}'.format(epoch + 1, epc))
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(
                device=device, dtype=torch.int64)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len,
                                                     loss.item()))
            writer.add_scalar('train_loss', loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss))

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                num_correct = 0.
                metric_count = 0
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(
                        device), val_data[1].to(device)
                    val_outputs = model(val_images)
                    value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                    metric_count += len(value)
                    num_correct += value.sum().item()
                metric = num_correct / metric_count
                metric_values.append(metric)
                #torch.save(model.state_dict(), 'model_d121_epoch_{}.pth'.format(epoch + 1))
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(
                        model.state_dict(),
                        '/home/marafath/scratch/saved_models/best_metric_model_d121.pth'
                    )
                    print('saved new best metric model')
                print(
                    'current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}'
                    .format(epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar('val_accuracy', metric, epoch + 1)
    print('train completed, best_metric: {:.4f} at epoch: {}'.format(
        best_metric, best_metric_epoch))
    writer.close()
Пример #11
0
    n = nib.Nifti1Image(im, np.eye(4))
    nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

    n = nib.Nifti1Image(seg, np.eye(4))
    nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))

# define transforms for image and segmentation
imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
segtrans = Compose([AddChannel(), ToTensor()])
ds = NiftiDataset(images,
                  segs,
                  transform=imtrans,
                  seg_transform=segtrans,
                  image_only=False)

device = torch.device('cuda:0')
net = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
)
net.to(device)

# define sliding window size and batch size for windows inference
Пример #12
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI314-IOP-0889-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI249-Guys-1072-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI609-HH-2600-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI173-HH-1590-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI020-Guys-0700-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI342-Guys-0909-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI134-Guys-0780-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI577-HH-2661-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI066-Guys-0731-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI130-HH-1528-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI607-Guys-1097-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI175-HH-1570-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI385-HH-2078-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI344-Guys-0905-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI409-Guys-0960-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI584-Guys-1129-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI253-HH-1694-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI092-HH-1436-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI574-IOP-1156-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI585-Guys-1130-T1.nii.gz"
        ]),
    ]

    # 2 binary labels for gender classification: man and woman
    labels = np.array(
        [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0],
        dtype=np.int64)

    # Define transforms
    train_transforms = Compose([
        ScaleIntensity(),
        AddChannel(),
        Resize((96, 96, 96)),
        RandRotate90(),
        ToTensor()
    ])
    val_transforms = Compose(
        [ScaleIntensity(),
         AddChannel(),
         Resize((96, 96, 96)),
         ToTensor()])

    # Define nifti dataset, data loader
    check_ds = NiftiDataset(image_files=images,
                            labels=labels,
                            transform=train_transforms)
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=2,
                              pin_memory=torch.cuda.is_available())
    im, label = monai.utils.misc.first(check_loader)
    print(type(im), im.shape, label)

    # create a training data loader
    train_ds = NiftiDataset(image_files=images[:10],
                            labels=labels[:10],
                            transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=torch.cuda.is_available())

    # create a validation data loader
    val_ds = NiftiDataset(image_files=images[-10:],
                          labels=labels[-10:],
                          transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=2,
                            pin_memory=torch.cuda.is_available())

    # Create DenseNet121, CrossEntropyLoss and Adam optimizer
    device = torch.device("cuda:0")
    model = monai.networks.nets.densenet.densenet121(spatial_dims=3,
                                                     in_channels=1,
                                                     out_channels=2).to(device)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 1e-5)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                num_correct = 0.0
                metric_count = 0
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(
                        device), val_data[1].to(device)
                    val_outputs = model(val_images)
                    value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                    metric_count += len(value)
                    num_correct += value.sum().item()
                metric = num_correct / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}"
                    .format(epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar("val_accuracy", metric, epoch + 1)
    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
Пример #13
0
    RandSpatialCrop((96, 96, 96), random_size=False),
    RandRotate90(prob=0.5, spatial_axes=(0, 2)),
    ToTensor()
])
train_segtrans = Compose([
    AddChannel(),
    RandSpatialCrop((96, 96, 96), random_size=False),
    RandRotate90(prob=0.5, spatial_axes=(0, 2)),
    ToTensor()
])
val_imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
val_segtrans = Compose([AddChannel(), ToTensor()])

# define nifti dataset, data loader
check_ds = NiftiDataset(images,
                        segs,
                        transform=train_imtrans,
                        seg_transform=train_segtrans)
check_loader = DataLoader(check_ds,
                          batch_size=10,
                          num_workers=2,
                          pin_memory=torch.cuda.is_available())
im, seg = monai.utils.misc.first(check_loader)
print(im.shape, seg.shape)

# create a training data loader
train_ds = NiftiDataset(images[:20],
                        segs[:20],
                        transform=train_imtrans,
                        seg_transform=train_segtrans)
train_loader = DataLoader(train_ds,
                          batch_size=4,
Пример #14
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print('generating synthetic data to {} (this may take a while)'.format(tempdir))
    for i in range(40):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

    images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
    segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))

    # define transforms for image and segmentation
    train_imtrans = Compose([
        ScaleIntensity(),
        AddChannel(),
        RandSpatialCrop((96, 96, 96), random_size=False),
        ToTensor()
    ])
    train_segtrans = Compose([
        AddChannel(),
        RandSpatialCrop((96, 96, 96), random_size=False),
        ToTensor()
    ])
    val_imtrans = Compose([
        ScaleIntensity(),
        AddChannel(),
        Resize((96, 96, 96)),
        ToTensor()
    ])
    val_segtrans = Compose([
        AddChannel(),
        Resize((96, 96, 96)),
        ToTensor()
    ])

    # define nifti dataset, data loader
    check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans)
    check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
    im, seg = monai.utils.misc.first(check_loader)
    print(im.shape, seg.shape)

    # create a training data loader
    train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans)
    train_loader = DataLoader(train_ds, batch_size=5, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans)
    val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available())

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )
    loss = monai.losses.DiceLoss(do_sigmoid=True)
    lr = 1e-3
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.device('cuda:0')

    # ignite trainer expects batch=(img, seg) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    trainer = create_supervised_trainer(net, opt, loss, device, False)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={'net': net, 'opt': opt})

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not a loss value
    train_stats_handler = StatsHandler(name='trainer')
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler()
    train_tensorboard_stats_handler.attach(trainer)

    validation_every_n_epochs = 1
    # Set parameters for validation
    metric_name = 'Mean_Dice'
    # add evaluation metric to the evaluator engine
    val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)}

    # ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net, val_metrics, device, True)


    @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
    def run_validation(engine):
        evaluator.run(val_loader)


    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(patience=4,
                                  score_function=stopping_fn_from_metric(metric_name),
                                  trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name='evaluator',
        output_transform=lambda x: None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch)  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every validation epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x: None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch)  # fetch global epoch number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add handler to draw the first image and the corresponding label and model output in the last batch
    # here we draw the 3D output as GIF format along Depth axis, at every validation epoch
    val_tensorboard_image_handler = TensorBoardImageHandler(
        batch_transform=lambda batch: (batch[0], batch[1]),
        output_transform=lambda output: predict_segmentation(output[0]),
        global_iter_transform=lambda x: trainer.state.epoch
    )
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler)

    train_epochs = 30
    state = trainer.run(train_loader, train_epochs)
    shutil.rmtree(tempdir)
Пример #15
0
    '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz',
    '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz',
    '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz'
]
# 2 binary labels for gender classification: man and woman
labels = np.array([0, 0, 1, 0, 1, 0, 1, 0, 1, 0])

# define transforms for image
val_transforms = Compose(
    [ScaleIntensity(),
     AddChannel(),
     Resize((96, 96, 96)),
     ToTensor()])
# define nifti dataset
val_ds = NiftiDataset(image_files=images,
                      labels=labels,
                      transform=val_transforms,
                      image_only=False)
# create DenseNet121
net = monai.networks.nets.densenet.densenet121(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
)
device = torch.device('cuda:0')

metric_name = 'Accuracy'
# add evaluation metric to the evaluator engine
val_metrics = {metric_name: Accuracy()}


def prepare_batch(batch, device=None, non_blocking=False):
Пример #16
0
    def test_dataset(self):
        tempdir = tempfile.mkdtemp()
        full_names, ref_data = [], []
        for filename in FILENAMES:
            test_image = np.random.randint(0, 2, size=(4, 4, 4))
            ref_data.append(test_image)
            save_path = os.path.join(tempdir, filename)
            full_names.append(save_path)
            nib.save(nib.Nifti1Image(test_image, np.eye(4)), save_path)

        # default loading no meta
        dataset = NiftiDataset(full_names)
        for d, ref in zip(dataset, ref_data):
            np.testing.assert_allclose(d, ref, atol=1e-3)

        # loading no meta, int
        dataset = NiftiDataset(full_names, dtype=np.float16)
        for d, _ in zip(dataset, ref_data):
            self.assertEqual(d.dtype, np.float16)

        # loading with meta, no transform
        dataset = NiftiDataset(full_names, image_only=False)
        for d_tuple, ref in zip(dataset, ref_data):
            d, meta = d_tuple
            np.testing.assert_allclose(d, ref, atol=1e-3)
            np.testing.assert_allclose(meta["original_affine"], np.eye(4))

        # loading image/label, no meta
        dataset = NiftiDataset(full_names, seg_files=full_names, image_only=True)
        for d_tuple, ref in zip(dataset, ref_data):
            img, seg = d_tuple
            np.testing.assert_allclose(img, ref, atol=1e-3)
            np.testing.assert_allclose(seg, ref, atol=1e-3)

        # loading image/label, no meta
        dataset = NiftiDataset(full_names, transform=lambda x: x + 1, image_only=True)
        for d, ref in zip(dataset, ref_data):
            np.testing.assert_allclose(d, ref + 1, atol=1e-3)

        # set seg transform, but no seg_files
        with self.assertRaises(TypeError):
            dataset = NiftiDataset(full_names, seg_transform=lambda x: x + 1, image_only=True)
            _ = dataset[0]

        # set seg transform, but no seg_files
        with self.assertRaises(TypeError):
            dataset = NiftiDataset(full_names, seg_transform=lambda x: x + 1, image_only=True)
            _ = dataset[0]

        # loading image/label, with meta
        dataset = NiftiDataset(
            full_names, transform=lambda x: x + 1, seg_files=full_names, seg_transform=lambda x: x + 2, image_only=False
        )
        for d_tuple, ref in zip(dataset, ref_data):
            img, seg, meta = d_tuple
            np.testing.assert_allclose(img, ref + 1, atol=1e-3)
            np.testing.assert_allclose(seg, ref + 2, atol=1e-3)
            np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3)

        # loading image/label, with meta
        dataset = NiftiDataset(
            full_names, transform=lambda x: x + 1, seg_files=full_names, labels=[1, 2, 3], image_only=False
        )
        for idx, (d_tuple, ref) in enumerate(zip(dataset, ref_data)):
            img, seg, label, meta = d_tuple
            np.testing.assert_allclose(img, ref + 1, atol=1e-3)
            np.testing.assert_allclose(seg, ref, atol=1e-3)
            np.testing.assert_allclose(idx + 1, label)
            np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3)

        # loading image/label, with sync. transform
        dataset = NiftiDataset(
            full_names, transform=RandTest(), seg_files=full_names, seg_transform=RandTest(), image_only=False
        )
        for d_tuple, ref in zip(dataset, ref_data):
            img, seg, meta = d_tuple
            np.testing.assert_allclose(img, seg, atol=1e-3)
            self.assertTrue(not np.allclose(img, ref))
            np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3)
        shutil.rmtree(tempdir)
Пример #17
0
def main(tempdir):
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([AddChannel(), ToTensor()])
    ds = NiftiDataset(images,
                      segs,
                      transform=imtrans,
                      seg_transform=segtrans,
                      image_only=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    # define sliding window size and batch size for windows inference
    roi_size = (96, 96, 96)
    sw_batch_size = 4

    post_trans = Compose(
        [Activations(sigmoid=True),
         AsDiscrete(threshold_values=True)])

    def _sliding_window_processor(engine, batch):
        net.eval()
        with torch.no_grad():
            val_images, val_labels = batch[0].to(device), batch[1].to(device)
            seg_probs = sliding_window_inference(val_images, roi_size,
                                                 sw_batch_size, net)
            seg_probs = post_trans(seg_probs)
            return seg_probs, val_labels

    evaluator = Engine(_sliding_window_processor)

    # add evaluation metric to the evaluator engine
    MeanDice().attach(evaluator, "Mean_Dice")

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
    )
    val_stats_handler.attach(evaluator)

    # for the array data format, assume the 3rd item of batch data is the meta_data
    file_saver = SegmentationSaver(
        output_dir="tempdir",
        output_ext=".nii.gz",
        output_postfix="seg",
        name="evaluator",
        batch_transform=lambda x: x[2],
        output_transform=lambda output: output[0],
    )
    file_saver.attach(evaluator)

    # the model was trained by "unet_training_array" example
    ckpt_saver = CheckpointLoader(
        load_path="./runs_array/net_checkpoint_100.pt", load_dict={"net": net})
    ckpt_saver.attach(evaluator)

    # sliding window inference for one image at every iteration
    loader = DataLoader(ds,
                        batch_size=1,
                        num_workers=1,
                        pin_memory=torch.cuda.is_available())
    state = evaluator.run(loader)
    print(state)