Beispiel #1
0
    def test_image_pipeline_and_pin_memory(self):
        '''
        This just should not crash
        :return:
        '''
        try:
            import torch
        except ImportError:
            '''dont test if torch is not installed'''
            return

        tr_transforms = []
        tr_transforms.append(MirrorTransform())
        tr_transforms.append(
            TransposeAxesTransform(transpose_any_of_these=(0, 1),
                                   p_per_sample=0.5))
        tr_transforms.append(NumpyToTensor(keys='data', cast_to='float'))

        composed = Compose(tr_transforms)

        dl = self.dl_images
        mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, True)

        for _ in range(50):
            res = mt.next()

        assert isinstance(res['data'], torch.Tensor)
        assert res['data'].is_pinned()

        # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
        # the success of the test but it does not look pretty)
        sleep(2)
Beispiel #2
0
    def test_no_crash(self):
        """
        This one should just not crash, that's all
        :return:
        """
        dl = self.dl_images
        mt_dl = MultiThreadedAugmenter(dl, None, self.num_threads, 1, None,
                                       False)

        for _ in range(20):
            _ = mt_dl.next()
Beispiel #3
0
def train(train_loader, model, optimizer, criterion_ADC, criterion_T2,
          final_transform, workers, seed, training_batches):

    train_losses = AverageMeter()
    np.random.seed(seed)
    seeds = np.random.choice(seed, workers, False, None)
    model.train()
    multithreaded_generator = MultiThreadedAugmenter(train_loader,
                                                     final_transform,
                                                     workers,
                                                     2,
                                                     seeds=seeds)
    torch.cuda.empty_cache()

    for i in range(training_batches):
        print('Batch: [{0}/{1}]'.format(i + 1, training_batches))
        batch = multithreaded_generator.next()
        TensorBatch = ToTensor(batch)
        target = TensorBatch['seg'].cuda()
        target_T2 = TensorBatch['seg_T2'].cuda()
        input_var = torch.autograd.Variable(
            TensorBatch['data'], requires_grad=True).cuda(async=True)
        input_var = input_var.float()
        target_var = torch.autograd.Variable(target)
        target_var = target_var.long()
        target_var_T2 = torch.autograd.Variable(target_T2)
        target_var_T2 = target_var_T2.long()
        optimizer.zero_grad()
        output = model(input_var)
        loss_ADC = criterion_ADC(output, target_var)
        loss_T2 = criterion_T2(output, target_var_T2)
        loss = (loss_ADC + loss_T2) / 2.
        loss.backward()
        optimizer.step()
        train_losses.update(loss.item())
        print 'train_loss', loss.item()

    torch.cuda.empty_cache()

    return train_losses.avg
Beispiel #4
0
    def test_image_pipeline(self):
        '''
        This just should not crash
        :return:
        '''

        tr_transforms = []
        tr_transforms.append(MirrorTransform())
        tr_transforms.append(
            TransposeAxesTransform(transpose_any_of_these=(0, 1),
                                   p_per_sample=0.5))

        composed = Compose(tr_transforms)

        dl = self.dl_images
        mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, False)

        for _ in range(50):
            res = mt.next()

        # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
        # the success of the test but it does not look pretty)
        sleep(2)
Beispiel #5
0
def validate(val_loader,
             model,
             epoch,
             criterion_ADC,
             criterion_T2,
             split_ixs,
             Center_Crop,
             workers,
             seed,
             folder_name,
             test=False):

    val_losses = AverageMeter()
    seeds = np.random.choice(seed, workers, False, None)
    torch.cuda.empty_cache()
    model.eval()
    multithreaded_generator = MultiThreadedAugmenter(val_loader,
                                                     Center_Crop,
                                                     workers,
                                                     2,
                                                     seeds=seeds)

    for i in range(len(split_ixs)):
        patient = split_ixs[i]
        print 'patient', patient
        batch = multithreaded_generator.next()
        TensorBatch = ToTensor(batch)
        target = TensorBatch['seg'].cuda()
        target_T2 = TensorBatch['seg_T2'].cuda()
        input_var = torch.autograd.Variable(TensorBatch['data'],
                                            volatile=True).cuda(async=True)
        input_var = input_var.float()
        target_var = torch.autograd.Variable(target, volatile=True)
        target_var = target_var.long()
        target_var_T2 = torch.autograd.Variable(target_T2, volatile=True)
        target_var_T2 = target_var_T2.long()
        output = model(input_var)
        probs = F.softmax(output)
        loss_ADC = criterion_ADC(output, target_var)
        loss_T2 = criterion_T2(output, target_var_T2)
        loss = (loss_ADC + loss_T2) / 2.
        val_losses.update(loss.item())
        if test == False:
            print 'val_loss', loss.item()
        else:
            print 'test_loss', loss.item()

        image = (input_var.data).cpu().numpy()
        Mprobs = (probs.data).cpu().numpy()
        fprobs = (probs.data).cpu().numpy()
        segmentation = (target).cpu().numpy()
        segmentation_T2 = (target_T2).cpu().numpy()
        label = np.where(segmentation == 2, 1, 0)
        label_T2 = np.where(segmentation_T2 == 2, 1, 0)
        label = np.uint8(label)
        label_T2 = np.uint8(label_T2)
        PRO = np.where(segmentation == 1, 1, 0)
        PRO = np.uint8(PRO)
        PRO_T2 = np.where(segmentation_T2 == 1, 1, 0)
        PRO_T2 = np.uint8(PRO_T2)

        fprobs[:, 0, :, :] = fprobs[:, 0, :, :] == np.amax(fprobs, axis=1)
        fprobs[:, 1, :, :] = fprobs[:, 1, :, :] == np.amax(fprobs, axis=1)
        fprobs[:, 2, :, :] = fprobs[:, 2, :, :] == np.amax(fprobs, axis=1)

        ProstateOut = fprobs[:, 1, :, :]
        TumorOut = fprobs[:, 2, :, :]

        probability_map_back = Mprobs[:, 0, :, :]
        probability_map_pro = Mprobs[:, 1, :, :]
        probability_map_tu = Mprobs[:, 2, :, :]

        if test == False:
            try:
                os.mkdir(folder_name + '/Val_Images')
            except OSError:
                pass

            try:
                os.mkdir(folder_name + '/Val_Images/Epoch_{}'.format(epoch))
            except OSError:
                pass

            save_images_to = folder_name + '/Val_Images/Epoch_{}/Patient_{}'.format(
                epoch, patient)
            try:
                os.mkdir(save_images_to)
            except OSError:
                pass

        else:
            try:
                os.mkdir(folder_name + '/Test_Images')
            except OSError:
                pass

            save_images_to = folder_name + '/Test_Images/Patient_{}'.format(
                patient)
            try:
                os.mkdir(save_images_to)
            except OSError:
                pass

        ADCimage = image[:, 0, :, :]
        BVALimage = image[:, 1, :, :]
        T2image = image[:, 2, :, :]

        TumorOut = sitk.GetImageFromArray(np.uint8(TumorOut))
        ProstateOut = sitk.GetImageFromArray(np.uint8(ProstateOut))
        ADCimg = sitk.GetImageFromArray(ADCimage)
        BVALimg = sitk.GetImageFromArray(BVALimage)
        T2img = sitk.GetImageFromArray(T2image)
        seg = sitk.GetImageFromArray(label)
        seg_T2 = sitk.GetImageFromArray(label_T2)
        pro = sitk.GetImageFromArray(PRO)
        pro_T2 = sitk.GetImageFromArray(PRO_T2)
        probsBack = sitk.GetImageFromArray(probability_map_back)
        probsPRO = sitk.GetImageFromArray(probability_map_pro)
        probsTU = sitk.GetImageFromArray(probability_map_tu)

        save(TumorOut, save_images_to + '/Tumor_Output.nrrd', Mask=True)
        save(ProstateOut, save_images_to + '/Prostate_Output.nrrd', Mask=True)
        save(ADCimg, save_images_to + '/ADCImage.nrrd')
        save(BVALimg, save_images_to + '/BVALImage.nrrd')
        save(T2img, save_images_to + '/T2Image.nrrd')

        save(seg, save_images_to + '/Label.nrrd', Mask=True)
        save(seg_T2, save_images_to + '/Label_T2.nrrd', Mask=True)
        save(pro, save_images_to + '/Pro_Label.nrrd', Mask=True)
        save(pro_T2, save_images_to + '/Pro_Label_T2.nrrd', Mask=True)
        save(probsBack, save_images_to + '/ProbabilityMapBack.nrrd')
        save(probsTU, save_images_to + '/ProbabilityMapTU.nrrd')
        save(probsPRO, save_images_to + '/ProbabilityMapPRO.nrrd')
        torch.cuda.empty_cache()

    return val_losses.avg
    do_rotation=True,
    angle_z=(0, 2 * np.pi),  # 旋转
    do_scale=True,
    scale=(0.3, 3.),  # 缩放
    border_mode_data='constant',
    border_cval_data=0,
    order_data=1,
    random_crop=False)
my_transforms.append(spatial_transform)
GaussianNoise = GaussianNoiseTransform()  # 高斯噪声
my_transforms.append(GaussianNoise)
GaussianBlur = GaussianBlurTransform()  # 高斯模糊
my_transforms.append(GaussianBlur)
Brightness = BrightnessTransform(0, 0.2)  # 亮度
my_transforms.append(Brightness)
brightness_transform = ContrastAugmentationTransform(
    (0.3, 3.), preserve_range=True)  # 对比度
my_transforms.append(brightness_transform)
SimulateLowResolution = SimulateLowResolutionTransform()  # 低分辨率
my_transforms.append(SimulateLowResolution)
Gamma = GammaTransform()  # 伽马增强
my_transforms.append(Gamma)
mirror_transform = MirrorTransform(axes=(0, 1))  # 镜像
my_transforms.append(mirror_transform)
all_transforms = Compose(my_transforms)
multithreaded_generator = MultiThreadedAugmenter(batchgen, all_transforms, 1,
                                                 2)

t = multithreaded_generator.next()
plot_batch(t)