예제 #1
0
    def test_shape(self, img_transform, label_transform, indexes, expected_shape):
        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=(128, 128, 128)), np.eye(4))
        tempdir = tempfile.mkdtemp()
        test_image1 = os.path.join(tempdir, "test_image1.nii.gz")
        test_seg1 = os.path.join(tempdir, "test_seg1.nii.gz")
        test_image2 = os.path.join(tempdir, "test_image2.nii.gz")
        test_seg2 = os.path.join(tempdir, "test_seg2.nii.gz")
        nib.save(test_image, test_image1)
        nib.save(test_image, test_seg1)
        nib.save(test_image, test_image2)
        nib.save(test_image, test_seg2)
        test_images = [test_image1, test_image2]
        test_segs = [test_seg1, test_seg2]
        test_labels = [1, 1]
        dataset = ArrayDataset(test_images, img_transform, test_segs, label_transform, test_labels, None)
        self.assertEqual(len(dataset), 2)
        dataset.set_random_state(1234)
        data1 = dataset[0]
        data2 = dataset[1]

        self.assertTupleEqual(data1[indexes[0]].shape, expected_shape)
        self.assertTupleEqual(data1[indexes[1]].shape, expected_shape)
        np.testing.assert_allclose(data1[indexes[0]], data1[indexes[1]])
        self.assertTupleEqual(data2[indexes[0]].shape, expected_shape)
        self.assertTupleEqual(data2[indexes[1]].shape, expected_shape)
        np.testing.assert_allclose(data2[indexes[0]], data2[indexes[0]])

        dataset = ArrayDataset(test_images, img_transform, test_segs, label_transform, test_labels, None)
        dataset.set_random_state(1234)
        _ = dataset[0]
        data2_new = dataset[1]
        np.testing.assert_allclose(data2[indexes[0]], data2_new[indexes[0]], atol=1e-3)
        shutil.rmtree(tempdir)
예제 #2
0
    def test_dataloading_img_label(self, img_transform, expected_shape):
        test_image = nib.Nifti1Image(
            np.random.randint(0, 2, size=(128, 128, 128)), np.eye(4))
        with tempfile.TemporaryDirectory() as tempdir:
            test_image1 = os.path.join(tempdir, "test_image1.nii.gz")
            test_image2 = os.path.join(tempdir, "test_image2.nii.gz")
            test_label1 = os.path.join(tempdir, "test_label1.nii.gz")
            test_label2 = os.path.join(tempdir, "test_label2.nii.gz")
            nib.save(test_image, test_image1)
            nib.save(test_image, test_image2)
            nib.save(test_image, test_label1)
            nib.save(test_image, test_label2)
            test_images = [test_image1, test_image2]
            test_labels = [test_label1, test_label2]
            dataset = ArrayDataset(test_images, img_transform, test_labels,
                                   img_transform)
            self.assertEqual(len(dataset), 2)
            dataset.set_random_state(1234)
            loader = DataLoader(dataset, batch_size=10, num_workers=1)
            data = next(iter(loader))  # test batching
            np.testing.assert_allclose(data[0].shape,
                                       [2] + list(expected_shape))

            dataset.set_random_state(1234)
            new_data = next(iter(loader))  # test batching
            np.testing.assert_allclose(data[0], new_data[0], atol=1e-3)
예제 #3
0
def main(tempdir):
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray((im * 255).astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray((seg * 255).astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))

    # define transforms for image and segmentation
    imtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])
    segtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])
    val_ds = ArrayDataset(images, imtrans, segs, segtrans)
    # sliding window inference for one image at every iteration
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    saver = SaveImage(output_dir="./output", output_ext=".png", output_postfix="seg")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(torch.load("best_metric_model_segmentation2d_array.pth"))
    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
            val_labels = decollate_batch(val_labels)
            # compute metric for current iteration
            dice_metric(y_pred=val_outputs, y=val_labels)
            for val_output in val_outputs:
                saver(val_output)
        # aggregate the final mean dice result
        print("evaluation metric:", dice_metric.aggregate().item())
        # reset the status
        dice_metric.reset()
예제 #4
0
def main(tempdir):
    config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(5):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray(im.astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray(seg.astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))

    # define transforms for image and segmentation
    imtrans = Compose([LoadImage(image_only=True), ScaleIntensity(), AddChannel(), ToTensor()])
    segtrans = Compose([LoadImage(image_only=True), AddChannel(), ToTensor()])
    val_ds = ArrayDataset(images, imtrans, segs, segtrans)
    # sliding window inference for one image at every iteration
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    model.load_state_dict(torch.load("best_metric_model_segmentation2d_array.pth"))
    model.eval()
    with torch.no_grad():
        metric_sum = 0.0
        metric_count = 0
        saver = PNGSaver(output_dir="./output")
        for val_data in val_loader:
            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
            value = dice_metric(y_pred=val_outputs, y=val_labels)
            metric_count += len(value)
            metric_sum += value.item() * len(value)
            val_outputs = val_outputs.sigmoid() >= 0.5
            saver.save_batch(val_outputs)
        metric = metric_sum / metric_count
        print("evaluation metric:", metric)
예제 #5
0
    def test_default_none(self, img_transform, expected_shape):
        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=(128, 128, 128)), np.eye(4))
        with tempfile.TemporaryDirectory() as tempdir:
            test_image1 = os.path.join(tempdir, "test_image1.nii.gz")
            test_image2 = os.path.join(tempdir, "test_image2.nii.gz")
            nib.save(test_image, test_image1)
            nib.save(test_image, test_image2)
            test_images = [test_image1, test_image2]
            dataset = ArrayDataset(test_images, img_transform)
            self.assertEqual(len(dataset), 2)
            dataset.set_random_state(1234)
            data1 = dataset[0]
            data2 = dataset[1]
            self.assertTupleEqual(data1.shape, expected_shape)
            self.assertTupleEqual(data2.shape, expected_shape)

            dataset = ArrayDataset(test_images, img_transform)
            dataset.set_random_state(1234)
            _ = dataset[0]
            data2_new = dataset[1]
            np.testing.assert_allclose(data2, data2_new, atol=1e-3)
예제 #6
0
    def test_dataloading(self, img_transform, expected_shape):
        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=(128, 128, 128)), np.eye(4))
        tempdir = tempfile.mkdtemp()
        test_image1 = os.path.join(tempdir, "test_image1.nii.gz")
        test_image2 = os.path.join(tempdir, "test_image2.nii.gz")
        nib.save(test_image, test_image1)
        nib.save(test_image, test_image2)
        test_images = [test_image1, test_image2]
        dataset = ArrayDataset(test_images, img_transform)
        self.assertEqual(len(dataset), 2)
        dataset.set_random_state(1234)
        loader = DataLoader(dataset, batch_size=10, num_workers=1)
        imgs = next(iter(loader))  # test batching
        np.testing.assert_allclose(imgs.shape, [2] + list(expected_shape))

        dataset.set_random_state(1234)
        new_imgs = next(iter(loader))  # test batching
        np.testing.assert_allclose(imgs, new_imgs, atol=1e-3)
예제 #7
0
    def test_dataloading_img(self, img_transform, expected_shape):
        test_image = nib.Nifti1Image(
            np.random.randint(0, 2, size=(128, 128, 128)).astype(float),
            np.eye(4))
        with tempfile.TemporaryDirectory() as tempdir:
            test_image1 = os.path.join(tempdir, "test_image1.nii.gz")
            test_image2 = os.path.join(tempdir, "test_image2.nii.gz")
            nib.save(test_image, test_image1)
            nib.save(test_image, test_image2)
            test_images = [test_image1, test_image2]
            dataset = ArrayDataset(test_images, img_transform)
            self.assertEqual(len(dataset), 2)
            dataset.set_random_state(1234)
            n_workers = 0 if sys.platform == "win32" else 2
            loader = DataLoader(dataset, batch_size=10, num_workers=n_workers)
            imgs = next(iter(loader))  # test batching
            np.testing.assert_allclose(imgs.shape, [2] + list(expected_shape))

            dataset.set_random_state(1234)
            new_imgs = next(iter(loader))  # test batching
            np.testing.assert_allclose(imgs, new_imgs, atol=1e-3)
예제 #8
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray(im.astype("uint8")).save(
            os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray(seg.astype("uint8")).save(
            os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
    train_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_imtrans = Compose([
        LoadImage(image_only=True),
        ScaleIntensity(),
        AddChannel(),
        RandSpatialCrop((96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 1)),
        ToTensor(),
    ])
    train_segtrans = Compose([
        LoadImage(image_only=True),
        AddChannel(),
        RandSpatialCrop((96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 1)),
        ToTensor(),
    ])
    val_imtrans = Compose([
        LoadImage(image_only=True),
        ScaleIntensity(),
        AddChannel(),
        ToTensor()
    ])
    val_segtrans = Compose(
        [LoadImage(image_only=True),
         AddChannel(), ToTensor()])

    # define array dataset, data loader
    check_ds = ArrayDataset(images, train_imtrans, segs, train_segtrans)
    check_loader = DataLoader(check_ds,
                              batch_size=10,
                              num_workers=2,
                              pin_memory=torch.cuda.is_available())
    im, seg = monai.utils.misc.first(check_loader)
    print(im.shape, seg.shape)

    # create a training data loader
    train_ds = ArrayDataset(images[:20], train_imtrans, segs[:20],
                            train_segtrans)
    train_loader = DataLoader(train_ds,
                              batch_size=4,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = ArrayDataset(images[-20:], val_imtrans, segs[-20:], val_segtrans)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    post_trans = Compose(
        [Activations(sigmoid=True),
         AsDiscrete(threshold_values=True)])
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(10):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{10}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(
                        device), val_data[1].to(device)
                    roi_size = (96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(
                        val_images, roi_size, sw_batch_size, model)
                    val_outputs = post_trans(val_outputs)
                    value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
                    metric_count += len(value)
                    metric_sum += value.item() * len(value)
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(),
                               "best_metric_model_segmentation2d_array.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}"
                    .format(epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="image")
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="label")
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="output")

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
예제 #9
0
    def _define_training_data_loaders(self) -> bool:
        """Initialize training datasets and data loaders.

        @Note: in Windows, it is essential to set `persistent_workers=True` in the data loaders!

        @return True if datasets and data loaders could be instantiated, False otherwise.
        """

        # Optimize arguments
        if sys.platform == 'win32':
            persistent_workers = True
            pin_memory = False
        else:
            persistent_workers = False
            pin_memory = torch.cuda.is_available()

        if len(self._train_image_names) == 0 or \
                len(self._train_mask_names) == 0 or \
                len(self._validation_image_names) == 0 or \
                len(self._validation_mask_names) == 0 or \
                len(self._test_image_names) == 0 or \
                len(self._test_mask_names) == 0:

            self._train_dataset = None
            self._train_dataloader = None
            self._validation_dataset = None
            self._validation_dataloader = None
            self._test_dataset = None
            self._test_dataloader = None

            return False

        # Training
        self._train_dataset = ArrayDataset(
            self._train_image_names,
            self._train_image_transforms,
            self._train_mask_names,
            self._train_mask_transforms
        )
        self._train_dataloader = DataLoader(
            self._train_dataset,
            batch_size=self._training_batch_size,
            shuffle=False,
            num_workers=self._training_num_workers,
            persistent_workers=persistent_workers,
            pin_memory=pin_memory
        )

        # Validation
        self._validation_dataset = ArrayDataset(
            self._validation_image_names,
            self._validation_image_transforms,
            self._validation_mask_names,
            self._validation_mask_transforms
        )
        self._validation_dataloader = DataLoader(
            self._validation_dataset,
            batch_size=self._validation_batch_size,
            shuffle=False,
            num_workers=self._validation_num_workers,
            persistent_workers=persistent_workers,
            pin_memory=pin_memory
        )

        # Test
        self._test_dataset = ArrayDataset(
            self._test_image_names,
            self._test_image_transforms,
            self._test_mask_names,
            self._test_mask_transforms
        )
        self._test_dataloader = DataLoader(
            self._test_dataset,
            batch_size=self._test_batch_size,
            shuffle=False,
            num_workers=self._test_num_workers,
            persistent_workers=persistent_workers,
            pin_memory=pin_memory
        )

        return True
### Define array dataset, data loader
# check_ds = ArrayDataset(img=image_files, img_transform=train_imtrans, seg=None, seg_transform=None)
check_ds = monai.data.NiftiDataset(image_files=image_files,
                                   transform=train_imtrans)
check_loader = DataLoader(check_ds,
                          batch_size=10,
                          num_workers=2,
                          pin_memory=torch.cuda.is_available())
im = monai.utils.misc.first(check_loader)
print(im.shape)

### Create a training data loader
train_ds = ArrayDataset(image_files[:-20],
                        train_imtrans,
                        seg=None,
                        seg_transform=None,
                        label=None,
                        label_transform=None)
train_loader = DataLoader(train_ds,
                          batch_size=4,
                          shuffle=True,
                          num_workers=8,
                          pin_memory=torch.cuda.is_available())

p("Start Training")
for idx, batch_data in enumerate(train_loader):
    p("len(batch_data): ", len(batch_data))
    # inputs, labels = batch_data[0].cuda(), batch_data[1].cuda()
    inputs = batch_data
    import ipdb
    ipdb.set_trace()
예제 #11
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    # --------------- Dataset  ---------------
    
    datadir1 = "/home1/quanquan/datasets/lsw/benign_65/fpAML_55/slices/"
    image_files = np.array([x.path for x in os.scandir(datadir1+"image") if x.name.endswith(".npy")])
    label_files = np.array([x.path for x in os.scandir(datadir1+"label") if x.name.endswith(".npy")])
    image_files.sort()
    label_files.sort()
    # --- ??? what's up ???
    train_files = [{"img":img, "seg":seg} for img, seg in zip(image_files[:-20], label_files[:-20])]
    val_files   = [{"img":img, "seg":seg} for img, seg in zip(image_files[-20:], label_files[-20:])]
    # print("files", train_files[:20])
    # print(val_files)
   
    val_imtrans = Compose([LoadNumpy(data_only=True), ScaleIntensity(), AddChannel(), ToTensor()])
    val_segtrans = Compose([LoadNumpy(data_only=True), AddChannel(), ToTensor()])
    
    # define array dataset, data loader
    check_ds = ArrayDataset(image_files, train_imtrans, label_files, train_segtrans)
    check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
    im, seg = monai.utils.misc.first(check_loader)
    print(im.shape, seg.shape)
    
    # create a training data loader
    train_ds = ArrayDataset(image_files[:-20], train_imtrans, label_files[:-20], train_segtrans)
    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = ArrayDataset(image_files[-20:], val_imtrans, label_files[-20:], val_segtrans)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
    
    # ---------------  model  ---------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    
    # ---------------  loss function  ---------------
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    # writer = SummaryWriter(logdir=tdir(output_dir, "sumamry"))
    
    # -------------------  Training ----------------------
    for epoch in range(max_epoch):
        # print("-" * 10)
        # print(f"epoch {epoch + 1}/{10}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"\r\t Training batch: {step}/{epoch_len}, \ttrain_loss: {loss.item():.4f}\t", end="")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"\n\tepoch {epoch + 1} \taverage loss: {epoch_loss:.4f}")
        
        # -------------------  Save Model  ----------------------
        if epoch % 5 == 0:
            def get_lr(optimizer):
                for param_group in optimizer.param_groups:
                    return float(param_group['lr'])
                  
            state = {'epoch': epoch + 1,
                     'lr': get_lr(optimizer),
                     'model_state': model.state_dict(),
                     'optimizer_state': optimizer.state_dict()
                     }
            torch.save(state, tfilename(output_dir, "model", "{}_{}.pkl".format("lsw_monai_simple", epoch)))
    
        # -------------------  Evaluation  -----------------------
        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                    roi_size = (96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                    val_outputs = post_trans(val_outputs)
                    # value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
                    value = dice_metric(y_pred=val_outputs, y=val_labels)
                    metric_count += len(value)
                    metric_sum += value.item() * len(value)
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric, best_metric, best_metric_epoch
                    )
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")

    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()
예제 #12
0
ds = CacheDataset(data, trans)
loader = DataLoader(
    dataset=ds,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=torch.cuda.is_available(),
)

# for simplicity we'll keep the existing pipeline for the validation data since it doesn't have any augmentations

val_image_trans = Compose([ScaleIntensity(), AddChannel(), ToTensor(),])

val_seg_trans = Compose([AddChannel(), ToTensor()])


val_ds = ArrayDataset(test_images, val_image_trans, test_segs, val_seg_trans)
val_loader = DataLoader(
    dataset=val_ds,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=torch.cuda.is_available(),
)

# %timeit first(loader)

batch = first(loader)
im = batch["img"]
seg = batch["seg"]
print(im.shape, im.min(), im.max(), seg.shape)
plt.imshow(im[0, 0].numpy() + seg[0, 0].numpy(), cmap="gray")