def get_longitudinal_preprocess(is_label: bool) -> List[Transform]:
    # only without cropping, somehow, there is not much left to crop in this dataset...
    if not is_label:
        return [
            NormalizeIntensity(nonzero=True),
            Unsqueeze(),
            SpatialPad(spatial_size=[215, 215, 215],
                       method="symmetric",
                       mode="constant"),
            Resize((IMAGESIZE, IMAGESIZE, IMAGESIZE)),
        ]
    else:
        return [
            NormalizeIntensity(nonzero=True),
            Unsqueeze(),
            SpatialPad(spatial_size=[215, 215, 215],
                       method="symmetric",
                       mode="constant"),
            Resize((IMAGESIZE, IMAGESIZE, IMAGESIZE)),
        ]
Ejemplo n.º 2
0
 def test_pad_kwargs(self):
     for p in TEST_NDARRAYS:
         input_data = p(np.zeros((3, 8, 4)))
         if isinstance(input_data, torch.Tensor):
             result = (SpatialPad(spatial_size=[15, 8],
                                  method="end",
                                  mode="constant",
                                  value=2)(img=input_data).cpu().numpy())
         else:
             result = SpatialPad(spatial_size=[15, 8],
                                 method="end",
                                 mode="constant",
                                 constant_values=((0, 0), (1, 1),
                                                  (2, 2)))(img=input_data)
             torch.testing.assert_allclose(result[:, 8:, :4],
                                           np.ones((3, 7, 4)),
                                           rtol=1e-7,
                                           atol=0)
         torch.testing.assert_allclose(result[:, :, 4:],
                                       np.ones((3, 15, 4)) + 1,
                                       rtol=1e-7,
                                       atol=0)
def get_preprocess(is_label: bool) -> List[Transform]:
    if not is_label:
        return [
            Crop(),
            NormalizeIntensity(nonzero=True),
            # Channel
            Unsqueeze(),
            SpatialPad(spatial_size=[193, 193, 193],
                       method="symmetric",
                       mode="constant"),
            Resize((IMAGESIZE, IMAGESIZE, IMAGESIZE)),
        ]
    else:
        return [
            Crop(),
            NormalizeIntensity(nonzero=True),
            Unsqueeze(),
            SpatialPad(spatial_size=[193, 193, 193],
                       method="symmetric",
                       mode="constant"),
            Resize((IMAGESIZE, IMAGESIZE, IMAGESIZE)),
        ]
Ejemplo n.º 4
0
    def __call__(self, data):
        d = dict(data)

        centroid = d[
            self.
            centroid_key]  # create mask based on centroid (select nuclei based on centroid)
        roi_size = (self.patch_size, self.patch_size)

        for key in self.keys:
            img = d[key]
            x_start, x_end, y_start, y_end = self.bbox(self.patch_size,
                                                       centroid,
                                                       img.shape[-2:])
            cropped = img[:, x_start:x_end, y_start:y_end]
            d[key] = SpatialPad(spatial_size=roi_size, **self.kwargs)(cropped)
        return d
Ejemplo n.º 5
0
 def test_pad_shape(self, input_param, input_shape, expected_shape):
     results_1 = []
     results_2 = []
     input_data = self.get_arr(input_shape)
     # check result is the same regardless of input type
     for p in TEST_NDARRAYS:
         padder = SpatialPad(**input_param)
         r1 = padder(p(input_data))
         r2 = padder(p(input_data), mode=input_param["mode"])
         results_1.append(r1.cpu() if isinstance(r1, torch.Tensor) else r1)
         results_2.append(r2.cpu() if isinstance(r2, torch.Tensor) else r2)
         for results in (results_1, results_2):
             np.testing.assert_allclose(results[-1].shape, expected_shape)
             if input_param["mode"] not in ("empty", NumpyPadMode.EMPTY):
                 torch.testing.assert_allclose(results[0],
                                               results[-1],
                                               atol=0,
                                               rtol=1e-5)
Ejemplo n.º 6
0
from monai.transforms.spatial.dictionary import RandAffined, RandRotate90d
from monai.utils import optional_import, set_determinism
from monai.utils.enums import InverseKeys
from tests.utils import make_nifti_image

_, has_nib = optional_import("nibabel")

KEYS = ["image"]

TESTS_DICT: List[Tuple] = []
TESTS_DICT.append((SpatialPadd(KEYS, 150), RandFlipd(KEYS, prob=1.0, spatial_axis=1)))
TESTS_DICT.append((RandRotate90d(KEYS, prob=0.0, max_k=1),))
TESTS_DICT.append((RandAffined(KEYS, prob=0.0, translate_range=10),))

TESTS_LIST: List[Tuple] = []
TESTS_LIST.append((SpatialPad(150), RandFlip(prob=1.0, spatial_axis=1)))
TESTS_LIST.append((RandRotate90(prob=0.0, max_k=1),))
TESTS_LIST.append((RandAffine(prob=0.0, translate_range=10),))


TEST_BASIC = [
    [("channel", "channel"), ["channel", "channel"]],
    [torch.Tensor([1, 2, 3]), [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)]],
    [
        [[torch.Tensor((1.0, 2.0, 3.0)), torch.Tensor((2.0, 3.0, 1.0))]],
        [
            [[torch.tensor(1.0), torch.tensor(2.0)]],
            [[torch.tensor(2.0), torch.tensor(3.0)]],
            [[torch.tensor(3.0), torch.tensor(1.0)]],
        ],
    ],
Ejemplo n.º 7
0
 def test_array_transform(self):
     for t in [SpatialPad(10), Compose([SpatialPad(10)])]:
         with self.assertRaises(TypeError):
             with allow_missing_keys_mode(t):
                 pass
Ejemplo n.º 8
0
 def test_pad_shape(self, input_param, input_data, expected_val):
     padder = SpatialPad(**input_param)
     result = padder(input_data)
     np.testing.assert_allclose(result.shape, expected_val.shape)
     result = padder(input_data, mode=input_param["mode"])
     np.testing.assert_allclose(result.shape, expected_val.shape)
Ejemplo n.º 9
0
 def test_pad_shape(self, input_param, input_data, expected_val):
     padder = SpatialPad(**input_param)
     result = padder(input_data)
     self.assertAlmostEqual(result.shape, expected_val.shape)
     result = padder(input_data, mode=input_param["mode"])
     self.assertAlmostEqual(result.shape, expected_val.shape)
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    data_dir = '/home/marafath/scratch/eu_data'
    labels = np.load('eu_labels.npy')
    train_images = []
    train_labels = []

    val_images = []
    val_labels = []

    n_count = 0
    p_count = 0
    idx = 0
    for case in os.listdir(data_dir):
        if p_count < 13 and labels[idx] == 1:
            val_images.append(
                os.path.join(data_dir, case, 'image_masked.nii.gz'))
            val_labels.append(labels[idx])
            p_count += 1
            idx += 1
        elif n_count < 11 and labels[idx] == 0:
            val_images.append(
                os.path.join(data_dir, case, 'image_masked.nii.gz'))
            val_labels.append(labels[idx])
            n_count += 1
            idx += 1
        else:
            train_images.append(
                os.path.join(data_dir, case, 'image_masked.nii.gz'))
            train_labels.append(labels[idx])
            idx += 1

    # Define transforms
    train_transforms = Compose([
        ScaleIntensity(),
        AddChannel(),
        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
        SpatialPad((256, 256, 92), mode='constant'),
        Resize((256, 256, 92)),
        ToTensor()
    ])

    val_transforms = Compose([
        ScaleIntensity(),
        AddChannel(),
        SpatialPad((256, 256, 92), mode='constant'),
        Resize((256, 256, 92)),
        ToTensor()
    ])

    # create a training data loader
    train_ds = NiftiDataset(image_files=train_images,
                            labels=train_labels,
                            transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=torch.cuda.is_available())

    # create a validation data loader
    val_ds = NiftiDataset(image_files=val_images,
                          labels=val_labels,
                          transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=2,
                            pin_memory=torch.cuda.is_available())

    # Create DenseNet121, CrossEntropyLoss and Adam optimizer
    device = torch.device('cuda:0')
    model = monai.networks.nets.densenet.densenet121(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
    ).to(device)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # finetuning
    #model.load_state_dict(torch.load('best_metric_model_d121.pth'))

    # start a typical PyTorch training
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    epc = 100  # Number of epoch
    for epoch in range(epc):
        print('-' * 10)
        print('epoch {}/{}'.format(epoch + 1, epc))
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(
                device=device, dtype=torch.int64)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len,
                                                     loss.item()))
            writer.add_scalar('train_loss', loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss))

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                num_correct = 0.
                metric_count = 0
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(
                        device), val_data[1].to(device)
                    val_outputs = model(val_images)
                    value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                    metric_count += len(value)
                    num_correct += value.sum().item()
                metric = num_correct / metric_count
                metric_values.append(metric)
                #torch.save(model.state_dict(), 'model_d121_epoch_{}.pth'.format(epoch + 1))
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(
                        model.state_dict(),
                        '/home/marafath/scratch/saved_models/best_metric_model_d121.pth'
                    )
                    print('saved new best metric model')
                print(
                    'current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}'
                    .format(epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar('val_accuracy', metric, epoch + 1)
    print('train completed, best_metric: {:.4f} at epoch: {}'.format(
        best_metric, best_metric_epoch))
    writer.close()