Exemplo n.º 1
0
 def test_correct_results(self, min_zoom, max_zoom, order, mode, cval,
                          prefilter, use_gpu, keep_size):
     random_zoom = RandZoom(
         prob=1.0,
         min_zoom=min_zoom,
         max_zoom=max_zoom,
         order=order,
         mode=mode,
         cval=cval,
         prefilter=prefilter,
         use_gpu=use_gpu,
         keep_size=keep_size,
     )
     random_zoom.set_random_state(234)
     zoomed = random_zoom(self.imt[0])
     expected = list()
     for channel in self.imt[0]:
         expected.append(
             zoom_scipy(channel,
                        zoom=random_zoom._zoom,
                        mode=mode,
                        order=order,
                        cval=cval,
                        prefilter=prefilter))
     expected = np.stack(expected).astype(np.float32)
     self.assertTrue(np.allclose(expected, zoomed))
Exemplo n.º 2
0
    def test_gpu_zoom(self, min_zoom, max_zoom, order, mode, cval, prefilter):
        if importlib.util.find_spec("cupy"):
            random_zoom = RandZoom(
                prob=1.0,
                min_zoom=min_zoom,
                max_zoom=max_zoom,
                order=order,
                mode=mode,
                cval=cval,
                prefilter=prefilter,
                use_gpu=True,
                keep_size=False,
            )
            random_zoom.set_random_state(234)

            zoomed = random_zoom(self.imt[0])
            expected = list()
            for channel in self.imt[0]:
                expected.append(
                    zoom_scipy(channel,
                               zoom=random_zoom._zoom,
                               mode=mode,
                               order=order,
                               cval=cval,
                               prefilter=prefilter))
            expected = np.stack(expected).astype(np.float32)

            self.assertTrue(np.allclose(expected, zoomed))
Exemplo n.º 3
0
 def test_correct_results(self, min_zoom, max_zoom, order, keep_size):
     random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, interp_order=order, keep_size=keep_size,)
     random_zoom.set_random_state(1234)
     zoomed = random_zoom(self.imt[0])
     expected = list()
     for channel in self.imt[0]:
         expected.append(zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False))
     expected = np.stack(expected).astype(np.float32)
     np.testing.assert_allclose(zoomed, expected, atol=1.0)
Exemplo n.º 4
0
 def test_auto_expand_3d(self):
     random_zoom = RandZoom(
         prob=1.0,
         min_zoom=[0.8, 0.7],
         max_zoom=[1.2, 1.3],
         mode="nearest",
         keep_size=False,
     )
     random_zoom.set_random_state(1234)
     test_data = np.random.randint(0, 2, size=[2, 2, 3, 4])
     zoomed = random_zoom(test_data)
     np.testing.assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2)
     np.testing.assert_allclose(zoomed.shape, (2, 2, 3, 3))
Exemplo n.º 5
0
 def test_auto_expand_3d(self):
     for p in TEST_NDARRAYS_ALL:
         random_zoom = RandZoom(prob=1.0,
                                min_zoom=[0.8, 0.7],
                                max_zoom=[1.2, 1.3],
                                mode="nearest",
                                keep_size=False)
         random_zoom.set_random_state(1234)
         test_data = p(np.random.randint(0, 2, size=[2, 2, 3, 4]))
         zoomed = random_zoom(test_data)
         assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637),
                         atol=1e-2,
                         type_test=False)
         assert_allclose(zoomed.shape, (2, 2, 3, 3), type_test=False)
Exemplo n.º 6
0
 def test_keep_size(self):
     for p in TEST_NDARRAYS_ALL:
         im = p(self.imt[0])
         random_zoom = RandZoom(prob=1.0,
                                min_zoom=0.6,
                                max_zoom=0.7,
                                keep_size=True)
         random_zoom.set_random_state(12)
         zoomed = random_zoom(im)
         test_local_inversion(random_zoom, zoomed, im)
         self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))
         zoomed = random_zoom(im)
         self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))
         zoomed = random_zoom(im)
         self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))
Exemplo n.º 7
0
 def test_keep_size(self):
     random_zoom = RandZoom(prob=1.0,
                            min_zoom=0.6,
                            max_zoom=0.7,
                            keep_size=True)
     zoomed = random_zoom(self.imt[0])
     self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))
Exemplo n.º 8
0
 def test_invalid_inputs(self, _, min_zoom, max_zoom, mode, raises):
     with self.assertRaises(raises):
         random_zoom = RandZoom(prob=1.0,
                                min_zoom=min_zoom,
                                max_zoom=max_zoom,
                                mode=mode)
         random_zoom(self.imt[0])
Exemplo n.º 9
0
 def test_invalid_inputs(self, _, min_zoom, max_zoom, order, raises):
     with self.assertRaises(raises):
         random_zoom = RandZoom(prob=1.0,
                                min_zoom=min_zoom,
                                max_zoom=max_zoom,
                                order=order)
         zoomed = random_zoom(self.imt[0])
Exemplo n.º 10
0
 def test_invalid_inputs(self, _, min_zoom, max_zoom, mode, raises):
     for p in TEST_NDARRAYS_ALL:
         with self.assertRaises(raises):
             random_zoom = RandZoom(prob=1.0,
                                    min_zoom=min_zoom,
                                    max_zoom=max_zoom,
                                    mode=mode)
             random_zoom(p(self.imt[0]))
Exemplo n.º 11
0
    def test_correct_results(self, min_zoom, max_zoom, mode, keep_size):
        for p in TEST_NDARRAYS:
            random_zoom = RandZoom(prob=1.0,
                                   min_zoom=min_zoom,
                                   max_zoom=max_zoom,
                                   mode=mode,
                                   keep_size=keep_size)
            random_zoom.set_random_state(1234)
            zoomed = random_zoom(p(self.imt[0]))
            expected = [
                zoom_scipy(channel,
                           zoom=random_zoom._zoom,
                           mode="nearest",
                           order=0,
                           prefilter=False) for channel in self.imt[0]
            ]

            expected = np.stack(expected).astype(np.float32)
            assert_allclose(zoomed, p(expected), atol=1.0)
Exemplo n.º 12
0
 def test_keep_size(self):
     for p in TEST_NDARRAYS:
         im = p(self.imt[0])
         random_zoom = RandZoom(prob=1.0,
                                min_zoom=0.6,
                                max_zoom=0.7,
                                keep_size=True)
         zoomed = random_zoom(im)
         self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))
         zoomed = random_zoom(im)
         self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))
         zoomed = random_zoom(im)
         self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))
Exemplo n.º 13
0
    def test_invert(self):
        set_determinism(seed=0)
        im_fname = make_nifti_image(create_test_image_3d(101, 100, 107, noise_max=100)[1])  # label image, discrete
        data = [im_fname for _ in range(12)]
        transform = Compose(
            [
                LoadImage(image_only=True),
                EnsureChannelFirst(),
                Orientation("RPS"),
                Spacing(pixdim=(1.2, 1.01, 0.9), mode="bilinear", dtype=np.float32),
                RandFlip(prob=0.5, spatial_axis=[1, 2]),
                RandAxisFlip(prob=0.5),
                RandRotate90(prob=0, spatial_axes=(1, 2)),
                RandZoom(prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True),
                RandRotate(prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64),
                RandAffine(prob=0.5, rotate_range=np.pi, mode="nearest"),
                ResizeWithPadOrCrop(100),
                CastToType(dtype=torch.uint8),
            ]
        )

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2
        dataset = Dataset(data, transform=transform)
        self.assertIsInstance(transform.inverse(dataset[0]), MetaTensor)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=1)
        inverter = Invert(transform=transform, nearest_interp=True, device="cpu")

        for d in loader:
            d = decollate_batch(d)
            for item in d:
                orig = deepcopy(item)
                i = inverter(item)
                self.assertTupleEqual(orig.shape[1:], (100, 100, 100))
                # check the nearest interpolation mode
                torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float))
                self.assertTupleEqual(i.shape[1:], (100, 101, 107))
        # check labels match
        reverted = i.detach().cpu().numpy().astype(np.int32)
        original = LoadImage(image_only=True)(data[-1])
        n_good = np.sum(np.isclose(reverted, original.numpy(), atol=1e-3))
        reverted_name = i.meta["filename_or_obj"]
        original_name = original.meta["filename_or_obj"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        self.assertTrue((reverted.size - n_good) < 300000, f"diff. {reverted.size - n_good}")
        set_determinism(seed=None)
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    data_dir = '/home/marafath/scratch/eu_data'
    labels = np.load('eu_labels.npy')
    train_images = []
    train_labels = []

    val_images = []
    val_labels = []

    n_count = 0
    p_count = 0
    idx = 0
    for case in os.listdir(data_dir):
        if p_count < 13 and labels[idx] == 1:
            val_images.append(
                os.path.join(data_dir, case, 'image_masked.nii.gz'))
            val_labels.append(labels[idx])
            p_count += 1
            idx += 1
        elif n_count < 11 and labels[idx] == 0:
            val_images.append(
                os.path.join(data_dir, case, 'image_masked.nii.gz'))
            val_labels.append(labels[idx])
            n_count += 1
            idx += 1
        else:
            train_images.append(
                os.path.join(data_dir, case, 'image_masked.nii.gz'))
            train_labels.append(labels[idx])
            idx += 1

    # Define transforms
    train_transforms = Compose([
        ScaleIntensity(),
        AddChannel(),
        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
        SpatialPad((256, 256, 92), mode='constant'),
        Resize((256, 256, 92)),
        ToTensor()
    ])

    val_transforms = Compose([
        ScaleIntensity(),
        AddChannel(),
        SpatialPad((256, 256, 92), mode='constant'),
        Resize((256, 256, 92)),
        ToTensor()
    ])

    # create a training data loader
    train_ds = NiftiDataset(image_files=train_images,
                            labels=train_labels,
                            transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=torch.cuda.is_available())

    # create a validation data loader
    val_ds = NiftiDataset(image_files=val_images,
                          labels=val_labels,
                          transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=2,
                            pin_memory=torch.cuda.is_available())

    # Create DenseNet121, CrossEntropyLoss and Adam optimizer
    device = torch.device('cuda:0')
    model = monai.networks.nets.densenet.densenet121(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
    ).to(device)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # finetuning
    #model.load_state_dict(torch.load('best_metric_model_d121.pth'))

    # start a typical PyTorch training
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    epc = 100  # Number of epoch
    for epoch in range(epc):
        print('-' * 10)
        print('epoch {}/{}'.format(epoch + 1, epc))
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(
                device=device, dtype=torch.int64)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len,
                                                     loss.item()))
            writer.add_scalar('train_loss', loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss))

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                num_correct = 0.
                metric_count = 0
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(
                        device), val_data[1].to(device)
                    val_outputs = model(val_images)
                    value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                    metric_count += len(value)
                    num_correct += value.sum().item()
                metric = num_correct / metric_count
                metric_values.append(metric)
                #torch.save(model.state_dict(), 'model_d121_epoch_{}.pth'.format(epoch + 1))
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(
                        model.state_dict(),
                        '/home/marafath/scratch/saved_models/best_metric_model_d121.pth'
                    )
                    print('saved new best metric model')
                print(
                    'current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}'
                    .format(epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar('val_accuracy', metric, epoch + 1)
    print('train completed, best_metric: {:.4f} at epoch: {}'.format(
        best_metric, best_metric_epoch))
    writer.close()
Exemplo n.º 15
0
    TESTS.append((dict, pad_collate,
                  RandZoomd("image",
                            prob=1,
                            min_zoom=1.1,
                            max_zoom=2.0,
                            keep_size=False)))
    TESTS.append((dict, pad_collate, RandRotate90d("image", prob=1, max_k=2)))

    TESTS.append(
        (list, pad_collate, RandSpatialCrop(roi_size=[8, 7],
                                            random_size=True)))
    TESTS.append(
        (list, pad_collate, RandRotate(prob=1, range_x=np.pi,
                                       keep_size=False)))
    TESTS.append((list, pad_collate,
                  RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0,
                           keep_size=False)))
    TESTS.append((list, pad_collate, RandRotate90(prob=1, max_k=2)))


class _Dataset(torch.utils.data.Dataset):
    def __init__(self, images, labels, transforms):
        self.images = images
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        return self.transforms(self.images[index]), self.labels[index]
Exemplo n.º 16
0
    train_x = [image_files_list[i] for i in train_indices]
    train_y = [image_class[i] for i in train_indices]
    val_x = [image_files_list[i] for i in val_indices]
    val_y = [image_class[i] for i in val_indices]
    test_x = [image_files_list[i] for i in test_indices]
    test_y = [image_class[i] for i in test_indices]

    # MONAI transforms, Dataset and Dataloader for preprocessing
    train_transforms = Compose([
        LoadImage(image_only=True),
        AddChannel(),
        ScaleIntensity(),
        RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
        RandFlip(spatial_axis=0, prob=0.5),
        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
        ToTensor(),
    ])

    val_transforms = Compose([
        LoadImage(image_only=True),
        AddChannel(),
        ScaleIntensity(),
        ToTensor()
    ])

    act = Activations(softmax=True)
    to_onehot = AsDiscrete(to_onehot=True, n_classes=num_class)

    class MedNISTDataset(torch.utils.data.Dataset):
        def __init__(self, image_files, labels, transforms):
Exemplo n.º 17
0
class Loader():
    """Loader for different image datasets with built in split function and download if needed.
    
    Functions:
        load_IXIT1: Loads the IXIT1 3D brain MRI dataset.
        load_MedNIST: Loads the MedNIST 2D image dataset.
    """
    
    ixi_train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90()])
    ixi_test_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96))])
    
    mednist_train_transforms = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(),
                                        RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True), 
                                        RandFlip(spatial_axis=0, prob=0.5), 
                                        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5)])
    mednist_test_transforms = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity()])
    
    
    @staticmethod
    def load_IXIT1(download: bool = False, train_transforms: object = ixi_train_transforms, 
                   test_transforms: object = ixi_test_transforms, test_size: float = 0.2, 
                   val_size: float = 0.0, sample_size: float = 0.01, shuffle: bool = True):
        """Loads the IXIT1 3D Brain MRI dataset.
        
        Consists of ~566 images of 3D Brain MRI scans and labels (0) for male and (1) for female.
        
        Args:
            download (bool): If true, then data is downloaded before loading it as dataset.
            train_transforms (Compose): Specify the transformations to be applied to the training dataset.
            test_transforms (Compose): Specify the transformations to be applied to the test dataset.
            sample_size (float): Percentage of available images to be used.
            test_size (float): Precantage of sample to be used as test data.
            val_size (float): Percentage of sample to be used as validation data.
            shuffle (bool): Whether or not the data should be shuffled after loading.
        """
        # Download data if needed
        if download:
            data_url = 'http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI-T1.tar'
            compressed_file = os.sep.join(['Data', 'IXI-T1.tar'])
            data_dir = os.sep.join(['Data', 'IXI-T1'])

            # Data download
            monai.apps.download_and_extract(data_url, compressed_file, './Data/IXI-T1')

            # Labels document download
            labels_url = 'http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI.xls'
            monai.apps.download_url(labels_url, './Data/IXI.xls')
            
        # Get all the images and corresponding Labels
        images = [impath for impath in os.listdir('./Data/IXI-T1')]

        df = pd.read_excel('./Data/IXI.xls')

        data = []
        labels = []
        for i in images:
            ixi_id = int(i[3:6])
            row = df.loc[df['IXI_ID'] == ixi_id]
            if not row.empty:
                data.append(os.sep.join(['Data', 'IXI-T1', i]))
                labels.append(int(row.iat[0, 1] - 1)) # Sex labels are 1/2 but need to be 0/1

        data, labels = data[:int(len(data) * sample_size)], labels[:int(len(data) * sample_size)]
        
        # Make train test validation split
        train_data, train_labels, test_data, test_labels, val_data, val_labels = _split(data, labels, 
                                                                                        test_size, val_size)
        
        # Construct and return Datasets
        train_ds = IXIT1Dataset(train_data, train_labels, train_transforms, shuffle)
        test_ds = IXIT1Dataset(test_data, test_labels, test_transforms, shuffle)
        
        if val_size == 0:
            return train_ds, test_ds
        else:
            val_ds = IXIT1Dataset(val_data, val_labels, test_transforms, shuffle)
            return train_ds, test_ds, val_ds
        
    
    @staticmethod
    def load_MedNIST(download: bool = False, train_transforms: object = mednist_train_transforms, 
                   test_transforms: object = mednist_test_transforms, test_size: float = 0.2, 
                   val_size: float = 0.0, sample_size: float = 0.01, shuffle: bool = True):
        """Loads the MedNIST 2D image dataset.
        
        Consists of ~60.000 2D images from 6 classes: AbdomenCT, BreastMRI, ChestCT, CXR, Hand, HeadCT.
        
        Args:
            download (bool): If true, then data is downloaded before loading it as dataset.
            train_transforms (Compose): Specify the transformations to be applied to the training dataset.
            test_transforms (Compose): Specify the transformations to be applied to the test dataset.
            sample_size (float): Percentage of available images to be used.
            test_size (float): Precantage of sample to be used as test data.
            val_size (float): Percentage of sample to be used as validation data.
            shuffle (bool): Whether or not the data should be shuffled after loading.
        """
        
        root_dir = './Data'
        resource = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1"
        md5 = "0bc7306e7427e00ad1c5526a6677552d"

        compressed_file = os.path.join(root_dir, "MedNIST.tar.gz")
        data_dir = os.path.join(root_dir, "MedNIST")
            
        if download:
            monai.apps.download_and_extract(resource, compressed_file, root_dir, md5)

        # Reading image filenames from dataset folders and assigning labels
        class_names = sorted(x for x in os.listdir(data_dir)
                             if os.path.isdir(os.path.join(data_dir, x)))
        num_class = len(class_names)

        image_files = [
            [
                os.path.join(data_dir, class_names[i], x)
                for x in os.listdir(os.path.join(data_dir, class_names[i]))
            ]
            for i in range(num_class)
        ]
        
        image_files = [images[:int(len(images) * sample_size)] for images in image_files]
        
        # Constructing data and labels
        num_each = [len(image_files[i]) for i in range(num_class)]
        data = []
        labels = []

        for i in range(num_class):
            data.extend(image_files[i])
            labels.extend([int(i)] * num_each[i])
            
        if shuffle:
            np.random.seed(42)
            indicies = np.arange(len(data))
            np.random.shuffle(indicies)
            
            data = [data[i] for i in indicies]
            labels = [labels[i] for i in indicies]
        
        # Make train test validation split
        train_data, train_labels, test_data, test_labels, val_data, val_labels = _split(data, labels, 
                                                                                        test_size, val_size)
        
        # Construct and return datasets
        train_ds = MedNISTDataset(train_data, train_labels, train_transforms, shuffle)
        test_ds = MedNISTDataset(test_data, test_labels, test_transforms, shuffle)
        
        if val_size == 0:
            return train_ds, test_ds
        else:
            val_ds = MedNISTDataset(val_data, val_labels, test_transforms, shuffle)
            return train_ds, test_ds, val_ds
Exemplo n.º 18
0
def run_training_test(root_dir,
                      train_x,
                      train_y,
                      val_x,
                      val_y,
                      device="cuda:0",
                      num_workers=10):

    monai.config.print_config()
    # define transforms for image and classification
    train_transforms = Compose([
        LoadPNG(image_only=True),
        AddChannel(),
        ScaleIntensity(),
        RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
        RandFlip(spatial_axis=0, prob=0.5),
        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
        ToTensor(),
    ])
    train_transforms.set_random_state(1234)
    val_transforms = Compose(
        [LoadPNG(image_only=True),
         AddChannel(),
         ScaleIntensity(),
         ToTensor()])

    # create train, val data loaders
    train_ds = MedNISTDataset(train_x, train_y, train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=300,
                              shuffle=True,
                              num_workers=num_workers)

    val_ds = MedNISTDataset(val_x, val_y, val_transforms)
    val_loader = DataLoader(val_ds, batch_size=300, num_workers=num_workers)

    model = densenet121(spatial_dims=2,
                        in_channels=1,
                        out_channels=len(np.unique(train_y))).to(device)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 1e-5)
    epoch_num = 4
    val_interval = 1

    # start training validation
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    for epoch in range(epoch_num):
        print("-" * 10)
        print(f"Epoch {epoch + 1}/{epoch_num}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss:{epoch_loss:0.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                y_pred = torch.tensor([], dtype=torch.float32, device=device)
                y = torch.tensor([], dtype=torch.long, device=device)
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(
                        device), val_data[1].to(device)
                    y_pred = torch.cat([y_pred, model(val_images)], dim=0)
                    y = torch.cat([y, val_labels], dim=0)
                auc_metric = compute_roc_auc(y_pred,
                                             y,
                                             to_onehot_y=True,
                                             softmax=True)
                metric_values.append(auc_metric)
                acc_value = torch.eq(y_pred.argmax(dim=1), y)
                acc_metric = acc_value.sum().item() / len(acc_value)
                if auc_metric > best_metric:
                    best_metric = auc_metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), model_filename)
                    print("saved new best metric model")
                print(
                    f"current epoch {epoch +1} current AUC: {auc_metric:0.4f} "
                    f"current accuracy: {acc_metric:0.4f} best AUC: {best_metric:0.4f} at epoch {best_metric_epoch}"
                )
    print(
        f"train completed, best_metric: {best_metric:0.4f}  at epoch: {best_metric_epoch}"
    )
    return epoch_loss_values, best_metric, best_metric_epoch
Exemplo n.º 19
0
        params['data_dir'], params['test_val_split'])

    size_n = 3
    sample_data = train_data.sample(size_n**2)
    #show_sample_dataframe(sample_data, size_n=size_n, title="Training samples")

    train_transforms = Compose([
        LoadPNG(image_only=True),
        AddChannel(),
        ScaleIntensity(),
        RandRotate(range_x=params['rotate_range_x'],
                   prob=params['rotate_prob'],
                   keep_size=True),
        # RandFlip(spatial_axis=0, prob=0.5),
        RandZoom(min_zoom=params['min_zoom'],
                 max_zoom=params['max_zoom'],
                 prob=params['zoom_prob'],
                 keep_size=True),
        ToTensor()
    ])

    val_transforms = Compose(
        [LoadPNG(image_only=True),
         AddChannel(),
         ScaleIntensity(),
         ToTensor()])
    train_ds = LabeledImageDataset(train_data, train_transforms)
    train_loader = DataLoader(
        train_ds,
        batch_size=params['batch_size'],
        # num_workers=params['num_workers'], collate_fn=collate_fn)
        num_workers=params['num_workers'])