def transform(self):

        if hp.mode == '3d':
            if hp.aug:
                training_transform = Compose([
                    # ToCanonical(),
                    CropOrPad((hp.crop_or_pad_size), padding_mode='reflect'),
                    # RandomMotion(),
                    RandomBiasField(),
                    ZNormalization(),
                    RandomNoise(),
                    RandomFlip(axes=(0, )),
                    OneOf({
                        RandomAffine(): 0.8,
                        RandomElasticDeformation(): 0.2,
                    }),
                ])
            else:
                training_transform = Compose([
                    CropOrPad((hp.crop_or_pad_size, hp.crop_or_pad_size,
                               hp.crop_or_pad_size),
                              padding_mode='reflect'),
                    ZNormalization(),
                ])
        elif hp.mode == '2d':
            if hp.aug:
                training_transform = Compose([
                    CropOrPad((hp.crop_or_pad_size), padding_mode='reflect'),
                    # RandomMotion(),
                    RandomBiasField(),
                    ZNormalization(),
                    RandomNoise(),
                    RandomFlip(axes=(0, )),
                    OneOf({
                        RandomAffine(): 0.8,
                        RandomElasticDeformation(): 0.2,
                    }),
                ])
            else:
                training_transform = Compose([
                    CropOrPad((hp.crop_or_pad_size, hp.crop_or_pad_size,
                               hp.crop_or_pad_size),
                              padding_mode='reflect'),
                    ZNormalization(),
                ])

        else:
            raise Exception('no such kind of mode!')

        return training_transform
def initialize_transforms_simple(p=0.8):
    transforms = [
        RandomFlip(axes=(0, 1, 2), flip_probability=1, p=p),

        #RandomAffine(scales=(0.9, 1.1), degrees=(10), isotropic=False,
        #             default_pad_value='otsu', image_interpolation=Interpolation.LINEAR,
        #             p = p, seed=None),

        # *** SLOWS DOWN DATALOADER ***
        #RandomElasticDeformation(num_control_points = 7, max_displacement = 7.5,
        #                         locked_borders = 2, image_interpolation = Interpolation.LINEAR,
        #                         p = 0.5, seed = None),
        RandomMotion(degrees=10,
                     translation=10,
                     num_transforms=2,
                     image_interpolation='linear',
                     p=p),
        RandomAnisotropy(axes=(0, 1, 2), downsampling=2),
        RandomBiasField(coefficients=0.5, order=3, p=p),
        RandomBlur(std=(0, 2), p=p),
        RandomNoise(mean=0, std=(0, 5), p=p),
        RescaleIntensity((0, 255))
    ]
    transform = tio.Compose(transforms)
    return transform
예제 #3
0
    def test_transforms(self):
        landmarks_dict = dict(
            t1=np.linspace(0, 100, 13),
            t2=np.linspace(0, 100, 13),
        )
        random_transforms = (
            RandomFlip(axes=(0, 1, 2), flip_probability=1),
            RandomNoise(),
            RandomBiasField(),
            RandomElasticDeformation(proportion_to_augment=1),
            RandomAffine(),
            RandomMotion(proportion_to_augment=1),
        )
        intensity_transforms = (
            Rescale(),
            ZNormalization(),
            HistogramStandardization(landmarks_dict=landmarks_dict),
        )
        for transform in random_transforms:
            sample = self.get_sample()
            transformed = transform(sample)

        for transform in intensity_transforms:
            sample = self.get_sample()
            transformed = transform(sample)
예제 #4
0
 def test_transforms(self):
     landmarks_dict = dict(
         t1=np.linspace(0, 100, 13),
         t2=np.linspace(0, 100, 13),
     )
     transforms = (
         CenterCropOrPad((9, 21, 30)),
         ToCanonical(),
         Resample((1, 1.1, 1.25)),
         RandomFlip(axes=(0, 1, 2), flip_probability=1),
         RandomMotion(proportion_to_augment=1),
         RandomGhosting(proportion_to_augment=1, axes=(0, 1, 2)),
         RandomSpike(),
         RandomNoise(),
         RandomBlur(),
         RandomSwap(patch_size=2, num_iterations=5),
         Lambda(lambda x: 1.5 * x, types_to_apply=INTENSITY),
         RandomBiasField(),
         Rescale((0, 1)),
         ZNormalization(masking_method='label'),
         HistogramStandardization(landmarks_dict=landmarks_dict),
         RandomElasticDeformation(proportion_to_augment=1),
         RandomAffine(),
         Pad((1, 2, 3, 0, 5, 6)),
         Crop((3, 2, 8, 0, 1, 4)),
     )
     transformed = self.get_sample()
     for transform in transforms:
         transformed = transform(transformed)
예제 #5
0
    def __init__(self, transform1, m1, p1, transform2, m2, p2):
        ranges = {
            'flip': np.zeros(10),
            'affine': np.linspace(0, 180, 10),
            'noise': np.linspace(0, 0.5, 10),
            'blur': np.arange(10),
            'elasticD': np.zeros(10)
        }

        transforms = {
            'flip': lambda magnitude, p: RandomFlip(p=p),
            'affine':
            lambda magnitude, p: RandomAffine(degrees=(magnitude), p=p),
            'noise': lambda magnitude, p: RandomNoise(std=magnitude, p=p),
            'blur': lambda magnitude, p: RandomBlur(std=magnitude, p=p),
            'elasticD': lambda magnitude, p: RandomElasticDeformation(p=p)
        }

        self.transform1 = transforms[transform1]
        self.t1_input = transform1
        self.m1 = ranges[transform1][m1]
        self.m1_input = m1
        self.p1 = p1

        self.transform2 = transforms[transform2]
        self.t2_input = transform2
        self.m2 = ranges[transform2][m2]
        self.m2_input = m2
        self.p2 = p2

        self.kappa = 0.0
예제 #6
0
 def _get_default_transforms(self):
     io_transforms = Compose([
         RandomMotion(),
         RandomFlip(axes=(1, )),
         RandomAffine(scales=(0.9, 1.2),
                      degrees=(10),
                      isotropic=False,
                      default_pad_value='otsu',
                      image_interpolation='bspline'),
         RescaleIntensity((0, 1))
     ])
     return io_transforms
예제 #7
0
def get_brats(
        data_root='/scratch/weina/dld_data/brats2019/MICCAI_BraTS_2019_Data_Training/',
        fold=1,
        seed=torch.distributed.get_rank()
    if torch.distributed.is_initialized() else 0,
        **kwargs):
    """ data iter for brats
    """
    logging.debug("BratsIter:: fold = {}, seed = {}".format(fold, seed))
    # args for transforms
    d_size, h_size, w_size = 155, 240, 240
    input_size = [7, 223, 223]
    spacing = (d_size / input_size[0], h_size / input_size[1],
               w_size / input_size[2])
    Mean, Std, Max = read_brats_mean(fold, data_root)
    normalize = transforms.Normalize(mean=Mean, std=Std)
    training_transform = Compose([
        # RescaleIntensity((0, 1)),  # so that there are no negative values for RandomMotion
        # RandomMotion(),
        # HistogramStandardization({MRI: landmarks}),
        RandomBiasField(),
        # ZNormalization(masking_method=ZNormalization.mean),
        RandomNoise(),
        ToCanonical(),
        Resample(spacing),
        # CropOrPad((48, 60, 48)),
        RandomFlip(axes=(0, )),
        OneOf({
            RandomAffine(): 0.8,
            RandomElasticDeformation(): 0.2,
        }),
        normalize
    ])
    val_transform = Compose([Resample(spacing), normalize])

    train = BratsIter(csv_file=os.path.join(data_root, 'IDH_label',
                                            'train_fold_{}.csv'.format(fold)),
                      brats_path=os.path.join(data_root, 'all'),
                      brats_transform=training_transform,
                      shuffle=True)

    val = BratsIter(csv_file=os.path.join(data_root, 'IDH_label',
                                          'val_fold_{}.csv'.format(fold)),
                    brats_path=os.path.join(data_root, 'all'),
                    brats_transform=val_transform,
                    shuffle=False)
    return train, val
예제 #8
0
def random_augment(x):
    '''Randomly augment input data.

    Returns: Randomly augmented input
    '''

    # Data augmentations to be used
    transforms_dict = {
        RandomFlip(): 1,
        RandomElasticDeformation(): 1,
        RandomAffine(): 1,
        RandomNoise(): 1,
        RandomBlur(): 1
    }

    # Create random transform, with a p chance to apply augmentation
    transform = OneOf(transforms_dict, p=0.95)
    return augment(x, transform)
예제 #9
0
def predict_majority(model, x, y):
    '''Augments all samples of the original data, and chooses majority predictions predicted by the model.

    Usage: predict_majority(model, x_original, y_original)
    '''

    # Reshape arrays
    x = np.reshape(x, (len(x), 40, 40, 4, 1))
    y = [x - 1 for x in y]
    y = to_categorical(y, 5)

    # Predict majority
    x_flip = augment(x.copy(), RandomFlip())
    x_ed = augment(x.copy(), RandomElasticDeformation())
    x_affine = augment(x.copy(), RandomAffine())
    x_noise = augment(x.copy(), RandomNoise())
    x_blur = augment(x.copy(), RandomBlur())

    y_true = pred_list(y)
    y_pred = pred_list(model.predict(x.copy()))
    y_flip = pred_list(model.predict(x_flip.copy()))
    y_ed = pred_list(model.predict(x_ed.copy()))
    y_affine = pred_list(model.predict(x_affine.copy()))
    y_noise = pred_list(model.predict(x_noise.copy()))
    y_blur = pred_list(model.predict(x_blur.copy()))

    y_most = []
    correct = 0
    print(
        '\nEntry Number | Prediction (None, Flip, Elastic Deformation, Affine, Noise, Blur) | Actual'
    )
    for i in range(len(y_true)):
        preds = [
            y_pred[i], y_flip[i], y_ed[i], y_affine[i], y_noise[i], y_blur[i]
        ]
        most = max(set(preds), key=preds.count)
        y_most.append(most)
        print('Entry', i, '| Predictions:', preds, '| Most Occuring:', most,
              '| Correct:', y_true[i])
        if most == y_true[i]:
            correct += 1
    print('\nTest Accuracy: ', correct / len(y_true))
    print('Quadratic Weighted Kappa: ',
          cohen_kappa_score(y_true, y_most, weights='quadratic'))
예제 #10
0
파일: data.py 프로젝트: JIiminIT/Torch
def training_network(landmarks, dataset, subjects):
    training_transform = Compose([
        ToCanonical(),
        Resample(4),
        CropOrPad((48, 60, 48), padding_mode='reflect'),
        RandomMotion(),
        HistogramStandardization({'mri': landmarks}),
        RandomBiasField(),
        ZNormalization(masking_method=ZNormalization.mean),
        RandomNoise(),
        RandomFlip(axes=(0, )),
        OneOf({
            RandomAffine(): 0.8,
            RandomElasticDeformation(): 0.2,
        }),
    ])

    validation_transform = Compose([
        ToCanonical(),
        Resample(4),
        CropOrPad((48, 60, 48), padding_mode='reflect'),
        HistogramStandardization({'mri': landmarks}),
        ZNormalization(masking_method=ZNormalization.mean),
    ])

    training_split_ratio = 0.9
    num_subjects = len(dataset)
    num_training_subjects = int(training_split_ratio * num_subjects)

    training_subjects = subjects[:num_training_subjects]
    validation_subjects = subjects[num_training_subjects:]

    training_set = tio.SubjectsDataset(training_subjects,
                                       transform=training_transform)

    validation_set = tio.SubjectsDataset(validation_subjects,
                                         transform=validation_transform)

    print('Training set:', len(training_set), 'subjects')
    print('Validation set:', len(validation_set), 'subjects')
    return training_set, validation_set
예제 #11
0
def flip(axes=0, p=1):
    return RandomFlip(axes=axes, p=p)
예제 #12
0
def get_data_loader(cfg: DictConfig, _) -> dict:
    log = logging.getLogger(__name__)

    transform = Compose([
        RandomMotion(),
        RandomBiasField(),
        RandomNoise(),
        RandomFlip(axes=(0, )),
    ])

    log.info(f"Data loader selected: {cfg['dataset']}")
    try:
        log.info("Attempting to use defined data loader")
        dataset = getattr(datasets, cfg["dataset"])(cfg, transform)
    except ImportError:
        log.info(
            "Not a defined data loader... Attempting to use torchio loader")
        dataset = getattr(torchio.datasets,
                          cfg["dataset"])(root=cfg["base_path"],
                                          transform=transform,
                                          download=True)

    for subject in random.sample(dataset._subjects, cfg["plot_number"]):
        plot_subject(
            subject,
            os.path.join(os.environ["OUTPUT_PATH"], cfg["save_plot_dir"],
                         subject["subject_id"]),
        )

    sampler = GridSampler(patch_size=cfg["patch_size"])
    samples_per_volume = len(sampler._compute_locations(
        dataset[0]))  # type: ignore

    with open_dict(cfg):
        cfg["size"] = dataset[0].spatial_shape

    val_size = max(1, int(0.2 * len(dataset)))
    test_set, train_set, val_set = split_dataset(
        dataset, [21, len(dataset) - val_size - 21, val_size])

    train_loader = __create_data_loader(
        train_set,
        queue_max_length=samples_per_volume * cfg["queue_length"],
        queue_samples_per_volume=samples_per_volume,
        sampler=sampler,
        verbose=log.level > 0,
        batch_size=cfg["batch"],
    )

    val_loader = __create_data_loader(
        val_set,
        queue_max_length=samples_per_volume * cfg["queue_length"],
        queue_samples_per_volume=samples_per_volume,
        sampler=sampler,
        verbose=log.level > 0,
        batch_size=cfg["batch"],
    )

    test_loader = __create_data_loader(
        test_set,
        queue_max_length=samples_per_volume * cfg["queue_length"],
        queue_samples_per_volume=samples_per_volume,
        sampler=sampler,
        verbose=log.level > 0,
        batch_size=cfg["batch"],
    )

    return {
        "data_loader_train": train_loader,
        "data_loader_val": val_loader,
        "data_loader_test": test_loader,
    }
예제 #13
0
def compose_transforms() -> Compose:
    print(f"{ctime()}:  Setting up transformations...")
    """
    # Our Preprocessing Options available in TorchIO are:

    * Intensity
        - NormalizationTransform
        - RescaleIntensity
        - ZNormalization
        - HistogramStandardization
    * Spatial
        - CropOrPad
        - Crop
        - Pad
        - Resample
        - ToCanonical

    We should read and experiment with these, but for now will just use a bunch with
    the default values.

    """

    preprocessors = [
        ToCanonical(p=1),
        ZNormalization(masking_method=None,
                       p=1),  # alternately, use RescaleIntensity
    ]
    """
    # Our Augmentation Options available in TorchIO are:

    * Spatial
        - RandomFlip
        - RandomAffine
        - RandomElasticDeformation

    * Intensity
        - RandomMotion
        - RandomGhosting
        - RandomSpike
        - RandomBiasField
        - RandomBlur
        - RandomNoise
        - RandomSwap



    We should read and experiment with these, but for now will just use a bunch with
    the default values.

    """
    augments = [
        RandomFlip(axes=(0, 1, 2), flip_probability=0.5),
        RandomAffine(image_interpolation="linear",
                     p=0.8),  # default, compromise on speed + quality
        # this will be most processing intensive, leave out for now, see results
        # RandomElasticDeformation(p=1),
        RandomMotion(),
        RandomSpike(),
        RandomBiasField(),
        RandomBlur(),
        RandomNoise(),
    ]
    transform = Compose(preprocessors + augments)
    print(f"{ctime()}:  Transformations registered.")
    return transform
예제 #14
0
    verbose = True
    patch_size = 128

    batch_size = 4

    scales = (0.75, 0.75)
    degrees = (-5, -5)
    axes = (0, )

    transforms = (
        RandomAffine(scales=scales,
                     degrees=degrees,
                     image_interpolation=Interpolation.BSPLINE,
                     isotropic=False,
                     verbose=verbose),
        RandomFlip(axes, verbose=verbose),
    )
    transform = Compose(transforms)
    subjects_dataset = ImagesDataset(subjects_paths,
                                     transform=transform,
                                     verbose=verbose)
    sample = subjects_dataset[0]

    sampler = LabelSampler(sample, patch_size)
    loader = DataLoader(sampler, batch_size=batch_size)

    # TODO: check that this works as expected, use dummy data
    for batch in islice(loader, 1):
        save_batch(batch, '/tmp/batch')
예제 #15
0
# Mock PyTorch model
model = lambda x: x

# Define training and patches sampling parameters
num_epochs = 4
patch_size = 128
queue_length = 100
samples_per_volume = 1
batch_size = 2

# Define transforms for data normalization and augmentation
transforms = (
    ZNormalization(),
    RandomAffine(scales=(0.9, 1.1), degrees=10),
    RandomNoise(std_range=(0, 0.25)),
    RandomFlip(axes=(0, )),
)
transform = Compose(transforms)

# Populate a list with dictionaries of paths
one_subject_dict = {
    'T1':
    dict(path='../BRATS2018_crop_renamed/LGG75_T1.nii.gz',
         type=torchio.INTENSITY),
    'T2':
    dict(path='../BRATS2018_crop_renamed/LGG75_T2.nii.gz',
         type=torchio.INTENSITY),
    'label':
    dict(path='../BRATS2018_crop_renamed/LGG75_Label.nii.gz',
         type=torchio.LABEL),
}
예제 #16
0
fig, ax = plt.subplots(dpi=100)
plot_histogram(ax, znormed.mri.data, label='Z-normed', alpha=1)
ax.set_title('Intensity values of one sample after z-normalization')
ax.set_xlabel('Intensity')
ax.grid()

training_transform = Compose([
    ToCanonical(),
    #  Resample(4),
    CropOrPad((112, 112, 48), padding_mode=0),  #reflect , original 112,112,48
    RandomMotion(num_transforms=6, image_interpolation='nearest', p=0.2),
    HistogramStandardization({'mri': landmarks}),
    RandomBiasField(p=0.2),
    RandomBlur(p=0.2),
    ZNormalization(masking_method=ZNormalization.mean),
    RandomFlip(axes=['inferior-superior'], flip_probability=0.2),
    #  RandomNoise(std=0.5, p=0.2),
    RandomGhosting(intensity=1.8, p=0.2),
    #  RandomNoise(),
    #  RandomFlip(axes=(0,)),
    #  OneOf({
    #      RandomAffine(): 0.8,
    #      RandomElasticDeformation(): 0.2,
    #  }),
])

validation_transform = Compose([
    ToCanonical(),
    #  Resample(4),
    CropOrPad((112, 112, 48), padding_mode=0),  #original 112,112,48
    #  RandomMotion(num_transforms=6, image_interpolation='nearest', p = 0.2),
예제 #17
0
def main():
    opt = parsing_data()

    print("[INFO]Reading data")
    # Dictionary with data parameters for NiftyNet Reader
    if torch.cuda.is_available():
        print('[INFO] GPU available.')
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        raise Exception(
            "[INFO] No GPU found or Wrong gpu id, please run without --cuda")

    # FOLDERS
    fold_dir = opt.model_dir
    fold_dir_model = os.path.join(fold_dir, 'models')
    if not os.path.exists(fold_dir_model):
        os.makedirs(fold_dir_model)
    save_path = os.path.join(fold_dir_model, './CP_{}.pth')

    output_path = os.path.join(fold_dir, 'output')
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    output_path = os.path.join(output_path, 'output_{}.nii.gz')

    # LOGGING
    orig_stdout = sys.stdout
    if os.path.exists(os.path.join(fold_dir, 'out.txt')):
        compt = 0
        while os.path.exists(
                os.path.join(fold_dir, 'out_' + str(compt) + '.txt')):
            compt += 1
        f = open(os.path.join(fold_dir, 'out_' + str(compt) + '.txt'), 'w')
    else:
        f = open(os.path.join(fold_dir, 'out.txt'), 'w')
    sys.stdout = f

    # SPLITS
    split_path_source = opt.dataset_split_source
    assert os.path.isfile(split_path_source), 'source file not found'

    split_path_target = opt.dataset_split_target
    assert os.path.isfile(split_path_target), 'target file not found'

    split_path = dict()
    split_path['source'] = split_path_source
    split_path['target'] = split_path_target

    path_file = dict()
    path_file['source'] = opt.path_source
    path_file['target'] = opt.path_target

    list_split = [
        'training',
        'validation',
    ]
    paths_dict = dict()

    for domain in ['source', 'target']:
        df_split = pd.read_csv(split_path[domain], header=None)
        list_file = dict()
        for split in list_split:
            list_file[split] = df_split[df_split[1].isin([split])][0].tolist()

        paths_dict_domain = {split: [] for split in list_split}
        for split in list_split:
            for subject in list_file[split]:
                subject_data = []
                for modality in MODALITIES[domain]:
                    subject_data.append(
                        Image(
                            modality,
                            path_file[domain] + subject + modality + '.nii.gz',
                            torchio.INTENSITY))
                if split in ['training', 'validation']:
                    subject_data.append(
                        Image('label',
                              path_file[domain] + subject + 'Label.nii.gz',
                              torchio.LABEL))

                    #subject_data[] =
                paths_dict_domain[split].append(Subject(*subject_data))
            print(domain, split, len(paths_dict_domain[split]))
        paths_dict[domain] = paths_dict_domain

    # PREPROCESSING
    transform_training = dict()
    transform_validation = dict()
    for domain in ['source', 'target']:
        transform_training[domain] = (
            ToCanonical(),
            ZNormalization(),
            CenterCropOrPad((144, 192, 48)),
            RandomAffine(scales=(0.9, 1.1), degrees=10),
            RandomNoise(std_range=(0, 0.10)),
            RandomFlip(axes=(0, )),
        )

        transform_training[domain] = Compose(transform_training[domain])

        transform_validation[domain] = (
            ToCanonical(),
            ZNormalization(),
            CenterCropOrPad((144, 192, 48)),
        )
        transform_validation[domain] = Compose(transform_validation[domain])

    transform = {
        'training': transform_training,
        'validation': transform_validation
    }

    # MODEL
    norm_op_kwargs = {'eps': 1e-5, 'affine': True}
    dropout_op_kwargs = {'p': 0, 'inplace': True}
    net_nonlin = nn.LeakyReLU
    net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}

    print("[INFO] Building model")
    model = Generic_UNet(input_modalities=MODALITIES_TARGET,
                         base_num_features=32,
                         num_classes=nb_classes,
                         num_pool=4,
                         num_conv_per_stage=2,
                         feat_map_mul_on_downscale=2,
                         conv_op=torch.nn.Conv3d,
                         norm_op=torch.nn.InstanceNorm3d,
                         norm_op_kwargs=norm_op_kwargs,
                         nonlin=net_nonlin,
                         nonlin_kwargs=net_nonlin_kwargs,
                         convolutional_pooling=False,
                         convolutional_upsampling=False,
                         final_nonlin=torch.nn.Softmax(1))

    print("[INFO] Training")
    train(paths_dict, model, transform, device, save_path, opt)

    sys.stdout = orig_stdout
    f.close()
예제 #18
0
train_subject = torchio.Subject(
    data=torchio.Image(tensor=torch.from_numpy(train_data),
                       label=torchio.INTENSITY),
    label=torchio.Image(tensor=torch.from_numpy(train_seg),
                        label=torchio.LABEL),
)
valid_subject = torchio.Subject(
    data=torchio.Image(tensor=torch.from_numpy(valid_data),
                       label=torchio.INTENSITY),
    label=torchio.Image(tensor=torch.from_numpy(valid_seg),
                        label=torchio.LABEL),
)
# Define the transforms for the set of training patches
training_transform = Compose([
    RandomNoise(p=0.2),
    RandomFlip(axes=(0, 1, 2)),
    RandomBlur(p=0.2),
    OneOf({
        RandomAffine(): 0.8,
        RandomElasticDeformation(): 0.2,
    }, p=0.5),  # Changed from p=0.75 24/6/20
])
# Create the datasets
training_dataset = torchio.ImagesDataset(
    [train_subject], transform=training_transform)

validation_dataset = torchio.ImagesDataset(
    [valid_subject])
# Define the queue of sampled patches for training and validation
sampler = torchio.data.UniformSampler(PATCH_SIZE)
patches_training_set = torchio.Queue(
 def get_torchio_transformer(mask=False):
     return RandomFlip(axes=axes,
                       flip_probability=flip_probability,
                       p=p,
                       seed=seed)
예제 #20
0
def flip(parameters):
    return RandomFlip(axes=parameters["axis"], p=parameters["probability"])
 def get_torchio_transformer(mask=False):
     return RandomFlip(axes, flip_probability, p, seed, is_tensor)
예제 #22
0
def main():
    opt = parsing_data()

    print("[INFO] Reading data")
    # Dictionary with data parameters for NiftyNet Reader
    if torch.cuda.is_available():
        print('[INFO] GPU available.')
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        raise Exception(
            "[INFO] No GPU found or Wrong gpu id, please run without --cuda")

    # FOLDERS
    fold_dir = opt.model_dir
    fold_dir_model = os.path.join(fold_dir, 'models')
    if not os.path.exists(fold_dir_model):
        os.makedirs(fold_dir_model)
    save_path = os.path.join(fold_dir_model, './CP_{}.pth')

    output_path = os.path.join(fold_dir, 'output')
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    output_path = os.path.join(output_path, 'output_{}.nii.gz')

    # LOGGING
    orig_stdout = sys.stdout
    if os.path.exists(os.path.join(fold_dir, 'out.txt')):
        compt = 0
        while os.path.exists(
                os.path.join(fold_dir, 'out_' + str(compt) + '.txt')):
            compt += 1
        f = open(os.path.join(fold_dir, 'out_' + str(compt) + '.txt'), 'w')
    else:
        f = open(os.path.join(fold_dir, 'out.txt'), 'w')
    #sys.stdout = f

    print("[INFO] Hyperparameters")
    print('Alpha: {}'.format(opt.alpha))
    print('Beta: {}'.format(opt.beta))
    print('Beta_DA: {}'.format(opt.beta_da))
    print('Weight Reg: {}'.format(opt.weight_crf))

    # SPLITS
    split_path_source = opt.dataset_split_source
    assert os.path.isfile(split_path_source), 'source file not found'

    split_path_target = opt.dataset_split_target
    assert os.path.isfile(split_path_target), 'target file not found'

    split_path = dict()
    split_path['source'] = split_path_source
    split_path['target'] = split_path_target

    path_file = dict()
    path_file['source'] = opt.path_source
    path_file['target'] = opt.path_target

    list_split = ['training', 'validation', 'inference']
    paths_dict = dict()

    for domain in ['source', 'target']:
        df_split = pd.read_csv(split_path[domain], header=None)
        list_file = dict()
        for split in list_split:
            list_file[split] = df_split[df_split[1].isin([split])][0].tolist()

        list_file['inference'] += list_file['validation']

        paths_dict_domain = {split: [] for split in list_split}
        for split in list_split:
            for subject in list_file[split]:
                subject_data = []
                for modality in MODALITIES[domain]:
                    subject_data.append(
                        Image(
                            modality,
                            path_file[domain] + subject + modality + '.nii.gz',
                            torchio.INTENSITY))
                if split in ['training', 'validation']:
                    if domain == 'source':
                        subject_data.append(
                            Image(
                                'label',
                                path_file[domain] + subject + 't1_seg.nii.gz',
                                torchio.LABEL))
                    else:
                        subject_data.append(
                            Image(
                                'scribble', path_file[domain] + subject +
                                't2scribble_cor.nii.gz', torchio.LABEL))
                    #subject_data[] =
                paths_dict_domain[split].append(Subject(*subject_data))
            print(domain, split, len(paths_dict_domain[split]))
        paths_dict[domain] = paths_dict_domain

    # PREPROCESSING
    transform_training = dict()
    transform_validation = dict()

    for domain in ['source', 'target']:
        transformations = (
            ToCanonical(),
            ZNormalization(),
            CenterCropOrPad((288, 128, 48)),
            RandomAffine(scales=(0.9, 1.1), degrees=10),
            RandomNoise(std_range=(0, 0.10)),
            RandomFlip(axes=(0, )),
        )
        transform_training[domain] = Compose(transformations)

    for domain in ['source', 'target']:
        transformations = (ToCanonical(), ZNormalization(),
                           CenterCropOrPad((288, 128, 48)))
        transform_validation[domain] = Compose(transformations)

    transform = {
        'training': transform_training,
        'validation': transform_validation
    }

    # MODEL
    norm_op_kwargs = {'eps': 1e-5, 'affine': True}
    net_nonlin = nn.LeakyReLU
    net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}

    print("[INFO] Building model")
    model = UNet2D5(input_channels=1,
                    base_num_features=16,
                    num_classes=NB_CLASSES,
                    num_pool=4,
                    conv_op=nn.Conv3d,
                    norm_op=nn.InstanceNorm3d,
                    norm_op_kwargs=norm_op_kwargs,
                    nonlin=net_nonlin,
                    nonlin_kwargs=net_nonlin_kwargs)

    print("[INFO] Training")
    #criterion = DC_and_CE_loss({}, {})
    criterion = DC_CE(NB_CLASSES)

    train(paths_dict, model, transform, criterion, device, save_path, opt)