Esempio n. 1
0
def train_net(net, net_name, nlabels):
    options = parse_inputs()
    c = color_codes()
    # Data stuff
    train_data, train_labels = get_names_from_path(options)
    # Prepare the net architecture parameters
    dfactor = options['dfactor']
    # Prepare the net hyperparameters
    epochs = options['epochs']
    patch_width = options['patch_width']
    patch_size = (patch_width, patch_width, patch_width)
    batch_size = options['batch_size']
    conv_blocks = options['conv_blocks']
    conv_width = options['conv_width']
    kernel_size_list = conv_width if isinstance(
        conv_width, list) else [conv_width] * conv_blocks
    balanced = options['balanced']
    val_rate = options['val_rate']
    preload = options['preload']
    fc_width = patch_width - sum(kernel_size_list) + conv_blocks
    fc_shape = (fc_width, ) * 3

    try:
        net = load_model(net_name + '.md')
    except IOError:
        centers = np.random.permutation(
            get_cnn_centers(train_data[:, 0], train_labels, balanced=balanced))
        print(' '.join([''] * 15) + c['g'] + 'Total number of centers = ' +
              c['b'] + '(%d centers)' % (len(centers)) + c['nc'])
        for i in range(dfactor):
            print(' '.join([''] * 16) + c['g'] + 'Round ' + c['b'] + '%d' %
                  (i + 1) + c['nc'] + c['g'] + '/%d' % dfactor + c['nc'])
            batch_centers = centers[i::dfactor]
            print(' '.join([''] * 16) + c['g'] + 'Loading data ' + c['b'] +
                  '(%d centers)' % (len(batch_centers)) + c['nc'])
            x, y = load_patches_train(
                image_names=train_data,
                label_names=train_labels,
                batch_centers=batch_centers,
                size=patch_size,
                fc_shape=fc_shape,
                nlabels=nlabels,
                preload=preload,
            )

            print(' '.join([''] * 16) + c['g'] + 'Training the model for ' +
                  c['b'] +
                  '(%d parameters)' % net.count_trainable_parameters() +
                  c['nc'])

            net.fit(x,
                    y,
                    batch_size=batch_size,
                    validation_split=val_rate,
                    epochs=epochs)

    net.save(net_name + '.mod')
Esempio n. 2
0
def main():
    options = parse_inputs()
    c = color_codes()

    # Prepare the net architecture parameters
    sequential = options['sequential']
    dfactor = options['dfactor']
    # Prepare the net hyperparameters
    num_classes = 5
    epochs = options['epochs']
    padding = options['padding']
    patch_width = options['patch_width']
    patch_size = (patch_width, patch_width, patch_width)
    batch_size = options['batch_size']
    dense_size = options['dense_size']
    conv_blocks = options['conv_blocks']
    n_filters = options['n_filters']
    filters_list = n_filters if len(n_filters) > 1 else n_filters * conv_blocks
    conv_width = options['conv_width']
    kernel_size_list = conv_width if isinstance(
        conv_width, list) else [conv_width] * conv_blocks
    balanced = options['balanced']
    # Data loading parameters
    preload = options['preload']
    queue = options['queue']

    # Prepare the sufix that will be added to the results for the net and images
    path = options['dir_name']
    filters_s = 'n'.join(['%d' % nf for nf in filters_list])
    conv_s = 'c'.join(['%d' % cs for cs in kernel_size_list])
    s_s = '.s' if sequential else '.f'
    ub_s = '.ub' if not balanced else ''
    params_s = (ub_s, dfactor, s_s, patch_width, conv_s, filters_s, dense_size,
                epochs, padding)
    sufix = '%s.D%d%s.p%d.c%s.n%s.d%d.e%d.pad_%s.' % params_s
    n_channels = np.count_nonzero([
        options['use_flair'], options['use_t2'], options['use_t1'],
        options['use_t1ce']
    ])

    print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' +
          'Starting cross-validation' + c['nc'])
    # N-fold cross validation main loop (we'll do 2 training iterations with testing for each patient)
    data_names, label_names = get_names_from_path(options)
    folds = options['folds']
    fold_generator = izip(
        nfold_cross_validation(data_names, label_names, n=folds,
                               val_data=0.25), xrange(folds))
    dsc_results = list()
    for (train_data, train_labels, val_data, val_labels, test_data,
         test_labels), i in fold_generator:
        print(
            c['c'] + '[' + strftime("%H:%M:%S") + ']  ' + c['nc'] +
            'Fold %d/%d: ' % (i + 1, folds) + c['g'] +
            'Number of training/validation/testing images (%d=%d/%d=%d/%d)' %
            (len(train_data), len(train_labels), len(val_data),
             len(val_labels), len(test_data)) + c['nc'])
        # Prepare the data relevant to the leave-one-out (subtract the patient from the dataset and set the path)
        # Also, prepare the network
        net_name = os.path.join(
            path, 'baseline-brats2017.fold%d' % i + sufix + 'mdl')

        # First we check that we did not train for that patient, in order to save time
        try:
            # net_name_before =  os.path.join(path,'baseline-brats2017.fold0.D500.f.p13.c3c3c3c3c3.n32n32n32n32n32.d256.e1.pad_valid.mdl')
            net = keras.models.load_model(net_name)
        except IOError:
            print '==============================================================='
            # NET definition using Keras
            train_centers = get_cnn_centers(train_data[:, 0],
                                            train_labels,
                                            balanced=balanced)
            val_centers = get_cnn_centers(val_data[:, 0],
                                          val_labels,
                                          balanced=balanced)
            train_samples = len(train_centers) / dfactor
            val_samples = len(val_centers) / dfactor
            print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
                  'Creating and compiling the model ' + c['b'] +
                  '(%d samples)' % train_samples + c['nc'])
            train_steps_per_epoch = -(-train_samples / batch_size)
            val_steps_per_epoch = -(-val_samples / batch_size)
            input_shape = (n_channels, ) + patch_size

            # This architecture is based on the functional Keras API to introduce 3 output paths:
            # - Whole tumor segmentation
            # - Core segmentation (including whole tumor)
            # - Whole segmentation (tumor, core and enhancing parts)
            # The idea is to let the network work on the three parts to improve the multiclass segmentation.
            # merged_inputs = Input(shape=(4,) + patch_size, name='merged_inputs')
            # flair = merged_inputs

            model = Sequential()
            model.add(
                Conv3D(64, (3, 3, 3),
                       strides=1,
                       padding='same',
                       activation='relu',
                       data_format='channels_first',
                       input_shape=(4, options['patch_width'],
                                    options['patch_width'],
                                    options['patch_width'])))
            model.add(
                Conv3D(64, (3, 3, 3),
                       strides=1,
                       padding='same',
                       activation='relu',
                       data_format='channels_first'))
            model.add(
                MaxPooling3D(pool_size=(3, 3, 3),
                             strides=2,
                             data_format='channels_first'))

            model.add(
                Conv3D(128, (3, 3, 3),
                       strides=1,
                       padding='same',
                       activation='relu',
                       data_format='channels_first'))
            model.add(
                Conv3D(128, (3, 3, 3),
                       strides=1,
                       padding='same',
                       activation='relu',
                       data_format='channels_first'))
            model.add(
                MaxPooling3D(pool_size=(3, 3, 3),
                             strides=2,
                             data_format='channels_first'))

            model.add(Flatten())

            model.add(Dense(256, activation='relu'))

            model.add(Dropout(0.5))

            model.add(Dense(num_classes, activation='softmax'))

            net = model

            # net_name_before =  os.path.join(path,'baseline-brats2017.fold0.D500.f.p13.c3c3c3c3c3.n32n32n32n32n32.d256.e1.pad_valid.mdl')
            # net = keras.models.load_model(net_name_before)
            net.compile(optimizer='sgd',
                        loss='categorical_crossentropy',
                        metrics=['accuracy'])

            print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
                  'Training the model with a generator for ' + c['b'] +
                  '(%d parameters)' % net.count_params() + c['nc'])
            print(net.summary())

            net.fit_generator(
                generator=load_patch_batch_train(
                    image_names=train_data,
                    label_names=train_labels,
                    centers=train_centers,
                    batch_size=batch_size,
                    size=patch_size,
                    # fc_shape = patch_size,
                    nlabels=num_classes,
                    dfactor=dfactor,
                    preload=preload,
                    split=not sequential,
                    datatype=np.float32),
                validation_data=load_patch_batch_train(
                    image_names=val_data,
                    label_names=val_labels,
                    centers=val_centers,
                    batch_size=batch_size,
                    size=patch_size,
                    # fc_shape = patch_size,
                    nlabels=num_classes,
                    dfactor=dfactor,
                    preload=preload,
                    split=not sequential,
                    datatype=np.float32),
                # workers=queue,
                steps_per_epoch=train_steps_per_epoch,
                validation_steps=val_steps_per_epoch,
                max_q_size=queue,
                epochs=epochs)
            net.save(net_name)

        # Then we test the net.
        for p, gt_name in zip(test_data, test_labels):
            p_name = p[0].rsplit('/')[-2]
            patient_path = '/'.join(p[0].rsplit('/')[:-1])
            outputname = os.path.join(patient_path,
                                      'deep-brats17' + sufix + 'test.nii.gz')
            gt_nii = load_nii(gt_name)
            gt = np.copy(gt_nii.get_data()).astype(dtype=np.uint8)
            try:
                load_nii(outputname)
            except IOError:
                roi_nii = load_nii(p[0])
                roi = roi_nii.get_data().astype(dtype=np.bool)
                centers = get_mask_voxels(roi)
                test_samples = np.count_nonzero(roi)
                image = np.zeros_like(roi).astype(dtype=np.uint8)
                print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
                      '<Creating the probability map ' + c['b'] + p_name +
                      c['nc'] + c['g'] + ' (%d samples)>' % test_samples +
                      c['nc'])
                test_steps_per_epoch = -(-test_samples / batch_size)
                y_pr_pred = net.predict_generator(
                    generator=load_patch_batch_generator_test(
                        image_names=p,
                        centers=centers,
                        batch_size=batch_size,
                        size=patch_size,
                        preload=preload,
                    ),
                    steps=test_steps_per_epoch,
                    max_q_size=queue)
                [x, y, z] = np.stack(centers, axis=1)

                if not sequential:
                    tumor = np.argmax(y_pr_pred[0], axis=1)
                    y_pr_pred = y_pr_pred[-1]
                    roi = np.zeros_like(roi).astype(dtype=np.uint8)
                    roi[x, y, z] = tumor
                    roi_nii.get_data()[:] = roi
                    roiname = os.path.join(
                        patient_path,
                        'deep-brats17' + sufix + 'test.roi.nii.gz')
                    roi_nii.to_filename(roiname)

                y_pred = np.argmax(y_pr_pred, axis=1)

                image[x, y, z] = y_pred
                # Post-processing (Basically keep the biggest connected region)
                image = get_biggest_region(image)
                labels = np.unique(gt.flatten())
                results = (p_name, ) + tuple(
                    [dsc_seg(gt == l, image == l) for l in labels[1:]])
                text = 'Subject %s DSC: ' + '/'.join(
                    ['%f' for _ in labels[1:]])
                print(text % results)
                dsc_results.append(results)

                print(c['g'] + '                   -- Saving image ' + c['b'] +
                      outputname + c['nc'])
                roi_nii.get_data()[:] = image
                roi_nii.to_filename(outputname)
Esempio n. 3
0
def main():
    options = parse_inputs()
    c = color_codes()

    # Prepare the net hyperparameters
    epochs = options['epochs']
    patch_width = options['patch_width']
    patch_size = (patch_width, patch_width, patch_width)
    dense_size = options['dense_size']
    conv_blocks = options['conv_blocks']
    n_filters = options['n_filters']
    filters_list = n_filters if len(n_filters) > 1 else n_filters * conv_blocks
    conv_width = options['conv_width']
    kernel_size_list = conv_width if isinstance(conv_width, list) else [conv_width] * conv_blocks
    balanced = options['balanced']
    # Data loading parameters
    downsample = options['downsample']
    preload = options['preload']
    shuffle = options['shuffle']

    # Prepare the sufix that will be added to the results for the net and images
    filters_s = 'n'.join(['%d' % nf for nf in filters_list])
    conv_s = 'c'.join(['%d' % cs for cs in kernel_size_list])
    unbalanced_s = '.ub' if not balanced else ''
    shuffle_s = '.s' if shuffle else ''
    params_s = (unbalanced_s, shuffle_s, patch_width, conv_s, filters_s, dense_size, downsample)
    sufix = '%s%s.p%d.c%s.n%s.d%d.D%d' % params_s
    preload_s = ' (with %spreloading%s%s)' % (c['b'], c['nc'], c['c']) if preload else ''

    print('%s[%s] Starting training%s%s' % (c['c'], strftime("%H:%M:%S"), preload_s, c['nc']))
    train_data, _ = get_names_from_path(options)
    test_data, test_labels = get_names_from_path(options, False)

    input_shape = (train_data.shape[1],) + patch_size

    dsc_results = list()
    dsc_results_pr = list()

    train_data, train_labels = get_names_from_path(options)
    centers_s = np.random.permutation(
        get_cnn_centers(train_data[:, 0], train_labels, balanced=balanced)
    )[::downsample]
    x_seg, y_seg = load_patches_ganseg_by_batches(
        image_names=train_data,
        label_names=train_labels,
        source_centers=centers_s,
        size=patch_size,
        nlabels=2,
        preload=preload,
    )

    for i, (p, gt_name) in enumerate(zip(test_data, test_labels)):
        p_name = p[0].rsplit('/')[-3]
        patient_path = '/'.join(p[0].rsplit('/')[:-1])
        print('%s[%s] %sCase %s%s%s%s%s (%d/%d):%s' % (
            c['c'], strftime("%H:%M:%S"), c['nc'],
            c['c'], c['b'], p_name, c['nc'],
            c['c'], i + 1, len(test_data), c['nc']
        ))

        # NO DSC objective
        image_cnn_name = os.path.join(patient_path, p_name + '.cnn.test%s.e%d' % (shuffle_s, epochs))
        image_gan_name = os.path.join(patient_path, p_name + '.gan.test%s.e%d' % (shuffle_s, epochs))
        # DSC objective
        image_cnn_dsc_name = os.path.join(patient_path, p_name + '.dsc-cnn.test%s.e%d' % (shuffle_s, epochs))
        image_gan_dsc_name = os.path.join(patient_path, p_name + '.dsc-gan.test%s.e%d' % (shuffle_s, epochs))
        try:
            # NO DSC objective
            image_cnn = load_nii(image_cnn_name + '.nii.gz').get_data()
            image_cnn_pr = load_nii(image_cnn_name + '.pr.nii.gz').get_data()
            image_gan = load_nii(image_gan_name + '.nii.gz').get_data()
            image_gan_pr = load_nii(image_gan_name + '.pr.nii.gz').get_data()
            # DSC objective
            image_cnn_dsc = load_nii(image_cnn_dsc_name + '.nii.gz').get_data()
            image_cnn_dsc_pr = load_nii(image_cnn_dsc_name + '.pr.nii.gz').get_data()
            image_gan_dsc = load_nii(image_gan_dsc_name + '.nii.gz').get_data()
            image_gan_dsc_pr = load_nii(image_gan_dsc_name + '.pr.nii.gz').get_data()
        except IOError:
            # Lesion segmentation
            adversarial_w = K.variable(0)
            # NO DSC objective
            cnn, gan, gan_test = get_wmh_nets(
                input_shape=input_shape,
                filters_list=filters_list,
                kernel_size_list=kernel_size_list,
                dense_size=dense_size,
                lambda_var=adversarial_w
            )
            # DSC objective
            cnn_dsc, gan_dsc, gan_dsc_test = get_wmh_nets(
                input_shape=input_shape,
                filters_list=filters_list,
                kernel_size_list=kernel_size_list,
                dense_size=dense_size,
                lambda_var=adversarial_w,
                dsc_obj=True
            )
            train_nets(
                gan=gan,
                gan_dsc=gan_dsc,
                cnn=cnn,
                cnn_dsc=cnn_dsc,
                p=p,
                x=x_seg,
                y=y_seg,
                name='wmh2017' + sufix,
                adversarial_w=adversarial_w
            )
            # NO DSC objective
            image_cnn = test_net(cnn, p, image_cnn_name)
            image_cnn_pr = load_nii(image_cnn_name + '.pr.nii.gz').get_data()
            image_gan = test_net(gan_test, p, image_gan_name)
            image_gan_pr = load_nii(image_gan_name + '.pr.nii.gz').get_data()
            # DSC objective
            image_cnn_dsc = test_net(cnn_dsc, p, image_cnn_dsc_name)
            image_cnn_dsc_pr = load_nii(image_cnn_dsc_name + '.pr.nii.gz').get_data()
            image_gan_dsc = test_net(gan_dsc_test, p, image_gan_dsc_name)
            image_gan_dsc_pr = load_nii(image_gan_dsc_name + '.pr.nii.gz').get_data()
        # NO DSC objective
        seg_cnn = image_cnn.astype(np.bool)
        seg_gan = image_gan.astype(np.bool)
        # DSC objective
        seg_cnn_dsc = image_cnn_dsc.astype(np.bool)
        seg_gan_dsc = image_gan_dsc.astype(np.bool)

        seg_gt = load_nii(gt_name).get_data()
        not_roi = np.logical_not(seg_gt == 2)

        results_cnn_dsc = dsc_seg(seg_gt == 1, np.logical_and(seg_cnn_dsc, not_roi))
        results_cnn_dsc_pr = probabilistic_dsc_seg(seg_gt == 1, image_cnn_dsc_pr * not_roi)
        results_cnn = dsc_seg(seg_gt == 1, np.logical_and(seg_cnn, not_roi))
        results_cnn_pr = probabilistic_dsc_seg(seg_gt == 1, image_cnn_pr * not_roi)

        results_gan_dsc = dsc_seg(seg_gt == 1, np.logical_and(seg_gan_dsc, not_roi))
        results_gan_dsc_pr = probabilistic_dsc_seg(seg_gt == 1, image_gan_dsc_pr * not_roi)
        results_gan = dsc_seg(seg_gt == 1, np.logical_and(seg_gan, not_roi))
        results_gan_pr = probabilistic_dsc_seg(seg_gt == 1, image_gan_pr * not_roi)

        whites = ''.join([' '] * 14)
        print('%sCase %s%s%s%s %sCNN%s vs %sGAN%s DSC: %s%f%s (%s%f%s) vs %s%f%s (%s%f%s)' % (
            whites, c['c'], c['b'], p_name, c['nc'],
            c['lgy'], c['nc'],
            c['y'], c['nc'],
            c['lgy'], results_cnn_dsc, c['nc'],
            c['lgy'], results_cnn, c['nc'],
            c['y'], results_gan_dsc, c['nc'],
            c['y'], results_gan, c['nc']
        ))
        print('%sCase %s%s%s%s %sCNN%s vs %sGAN%s DSC Pr: %s%f%s (%s%f%s) vs %s%f%s (%s%f%s)' % (
            whites, c['c'], c['b'], p_name, c['nc'],
            c['lgy'], c['nc'],
            c['y'], c['nc'],
            c['lgy'], results_cnn_dsc_pr, c['nc'],
            c['lgy'], results_cnn_pr, c['nc'],
            c['y'], results_gan_dsc_pr, c['nc'],
            c['y'], results_gan_pr, c['nc']
        ))

        dsc_results.append((results_cnn_dsc, results_cnn, results_gan_dsc, results_gan))
        dsc_results_pr.append((results_cnn_dsc_pr, results_cnn_pr, results_gan_dsc_pr, results_gan_pr))

    final_dsc = tuple(np.mean(dsc_results, axis=0))
    final_dsc_pr = tuple(np.mean(dsc_results_pr, axis=0))
    print('Final results DSC: %s%f%s (%s%f%s) vs %s%f%s (%s%f%s)' % (
        c['lgy'], final_dsc[0], c['nc'],
        c['lgy'], final_dsc[1], c['nc'],
        c['y'], final_dsc[2], c['nc'],
        c['y'], final_dsc[3], c['nc']
    ))
    print('Final results DSC Pr: %s%f%s (%s%f%s) vs %s%f%s (%s%f%s)' % (
        c['lgy'], final_dsc_pr[0], c['nc'],
        c['lgy'], final_dsc_pr[1], c['nc'],
        c['y'], final_dsc_pr[2], c['nc'],
        c['y'], final_dsc_pr[3], c['nc']
    ))
def main():
    options = parse_inputs()
    c = color_codes()

    # Prepare the net architecture parameters
    sequential = options['sequential']
    dfactor = options['dfactor']
    # Prepare the net hyperparameters
    num_classes = 5
    epochs = options['epochs']
    padding = options['padding']
    patch_width = options['patch_width']
    patch_size = (patch_width, patch_width, patch_width)
    batch_size = options['batch_size']
    dense_size = options['dense_size']
    conv_blocks = options['conv_blocks']
    n_filters = options['n_filters']
    filters_list = n_filters if len(n_filters) > 1 else n_filters * conv_blocks
    conv_width = options['conv_width']
    kernel_size_list = conv_width if isinstance(
        conv_width, list) else [conv_width] * conv_blocks
    balanced = options['balanced']
    recurrent = options['recurrent']
    # Data loading parameters
    preload = options['preload']
    queue = options['queue']

    # Prepare the sufix that will be added to the results for the net and images
    path = options['dir_name']
    filters_s = 'n'.join(['%d' % nf for nf in filters_list])
    conv_s = 'c'.join(['%d' % cs for cs in kernel_size_list])
    s_s = '.s' if sequential else '.f'
    ub_s = '.ub' if not balanced else ''
    params_s = (ub_s, dfactor, s_s, patch_width, conv_s, filters_s, dense_size,
                epochs, padding)
    sufix = '%s.D%d%s.p%d.c%s.n%s.d%d.e%d.pad_%s.' % params_s
    n_channels = np.count_nonzero([
        options['use_flair'], options['use_t2'], options['use_t1'],
        options['use_t1ce']
    ])

    print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' +
          'Starting cross-validation' + c['nc'])
    # N-fold cross validation main loop (we'll do 2 training iterations with testing for each patient)
    data_names, label_names = get_names_from_path(options)
    folds = options['folds']
    fold_generator = izip(
        nfold_cross_validation(data_names, label_names, n=folds,
                               val_data=0.25), xrange(folds))
    dsc_results = list()
    for (train_data, train_labels, val_data, val_labels, test_data,
         test_labels), i in fold_generator:
        print(c['c'] + '[' + strftime("%H:%M:%S") + ']  ' + c['nc'] +
              'Fold %d/%d: ' % (i + 1, folds) + c['g'] +
              'Number of training/validation/testing images (%d=%d/%d=%d/%d)' %
              (len(train_data), len(train_labels), len(val_data),
               len(val_labels), len(test_data)) + c['nc'])
        # Prepare the data relevant to the leave-one-out (subtract the patient from the dataset and set the path)
        # Also, prepare the network
        net_name = os.path.join(
            path, 'baseline-brats2017.fold%d' % i + sufix + 'mdl')

        # First we check that we did not train for that patient, in order to save time
        try:
            net = keras.models.load_model(net_name)
        except IOError:
            # NET definition using Keras
            train_centers = get_cnn_centers(train_data[:, 0],
                                            train_labels,
                                            balanced=balanced)
            val_centers = get_cnn_centers(val_data[:, 0],
                                          val_labels,
                                          balanced=balanced)
            train_samples = len(train_centers) / dfactor
            val_samples = len(val_centers) / dfactor
            print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
                  'Creating and compiling the model ' + c['b'] +
                  '(%d samples)' % train_samples + c['nc'])
            train_steps_per_epoch = -(-train_samples / batch_size)
            val_steps_per_epoch = -(-val_samples / batch_size)
            input_shape = (n_channels, ) + patch_size
            if sequential:
                # Sequential model that merges all 4 images. This architecture is just a set of convolutional blocks
                # that end in a dense layer. This is supposed to be an original baseline.
                net = Sequential()
                net.add(
                    Conv3D(filters_list[0],
                           kernel_size=kernel_size_list[0],
                           input_shape=input_shape,
                           activation='relu',
                           data_format='channels_first'))
                for filters, kernel_size in zip(filters_list[1:],
                                                kernel_size_list[1:]):
                    net.add(Dropout(0.5))
                    net.add(
                        Conv3D(filters,
                               kernel_size=kernel_size,
                               activation='relu',
                               data_format='channels_first'))
                net.add(Dropout(0.5))
                net.add(Flatten())
                net.add(Dense(dense_size, activation='relu'))
                net.add(Dropout(0.5))
                net.add(Dense(num_classes, activation='softmax'))
            else:
                # This architecture is based on the functional Keras API to introduce 3 output paths:
                # - Whole tumor segmentation
                # - Core segmentation (including whole tumor)
                # - Whole segmentation (tumor, core and enhancing parts)
                # The idea is to let the network work on the three parts to improve the multiclass segmentation.
                merged_inputs = Input(shape=(4, ) + patch_size,
                                      name='merged_inputs')
                flair = Reshape((1, ) + patch_size)(Lambda(
                    lambda l: l[:, 0, :, :, :],
                    output_shape=(1, ) + patch_size)(merged_inputs), )
                t2 = Reshape((1, ) + patch_size)(Lambda(
                    lambda l: l[:, 1, :, :, :],
                    output_shape=(1, ) + patch_size)(merged_inputs))
                t1 = Lambda(lambda l: l[:, 2:, :, :, :],
                            output_shape=(2, ) + patch_size)(merged_inputs)
                for filters, kernel_size in zip(filters_list,
                                                kernel_size_list):
                    flair = Conv3D(filters,
                                   kernel_size=kernel_size,
                                   activation='relu',
                                   data_format='channels_first')(flair)
                    t2 = Conv3D(filters,
                                kernel_size=kernel_size,
                                activation='relu',
                                data_format='channels_first')(t2)
                    t1 = Conv3D(filters,
                                kernel_size=kernel_size,
                                activation='relu',
                                data_format='channels_first')(t1)
                    flair = Dropout(0.5)(flair)
                    t2 = Dropout(0.5)(t2)
                    t1 = Dropout(0.5)(t1)

                # We only apply the RCNN to the multioutput approach (we keep the simple one, simple)
                if recurrent:
                    flair = Conv3D(dense_size,
                                   kernel_size=(1, 1, 1),
                                   activation='relu',
                                   data_format='channels_first',
                                   name='fcn_flair')(flair)
                    flair = Dropout(0.5)(flair)
                    t2 = concatenate([flair, t2], axis=1)
                    t2 = Conv3D(dense_size,
                                kernel_size=(1, 1, 1),
                                activation='relu',
                                data_format='channels_first',
                                name='fcn_t2')(t2)
                    t2 = Dropout(0.5)(t2)
                    t1 = concatenate([t2, t1], axis=1)
                    t1 = Conv3D(dense_size,
                                kernel_size=(1, 1, 1),
                                activation='relu',
                                data_format='channels_first',
                                name='fcn_t1')(t1)
                    t1 = Dropout(0.5)(t1)
                    flair = Dropout(0.5)(flair)
                    t2 = Dropout(0.5)(t2)
                    t1 = Dropout(0.5)(t1)
                    lstm_instance = LSTM(dense_size,
                                         implementation=1,
                                         name='rf_layer')
                    flair = lstm_instance(
                        Permute((2, 1))(Reshape((dense_size, -1))(flair)))
                    t2 = lstm_instance(
                        Permute((2, 1))(Reshape((dense_size, -1))(t2)))
                    t1 = lstm_instance(
                        Permute((2, 1))(Reshape((dense_size, -1))(t1)))

                else:
                    flair = Flatten()(flair)
                    t2 = Flatten()(t2)
                    t1 = Flatten()(t1)
                    flair = Dense(dense_size, activation='relu')(flair)
                    flair = Dropout(0.5)(flair)
                    t2 = concatenate([flair, t2])
                    t2 = Dense(dense_size, activation='relu')(t2)
                    t2 = Dropout(0.5)(t2)
                    t1 = concatenate([t2, t1])
                    t1 = Dense(dense_size, activation='relu')(t1)
                    t1 = Dropout(0.5)(t1)

                tumor = Dense(2, activation='softmax', name='tumor')(flair)
                core = Dense(3, activation='softmax', name='core')(t2)
                enhancing = Dense(num_classes,
                                  activation='softmax',
                                  name='enhancing')(t1)

                net = Model(inputs=merged_inputs,
                            outputs=[tumor, core, enhancing])

            net.compile(optimizer='adadelta',
                        loss='categorical_crossentropy',
                        metrics=['accuracy'])

            print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
                  'Training the model with a generator for ' + c['b'] +
                  '(%d parameters)' % net.count_params() + c['nc'])
            print(net.summary())
            net.fit_generator(
                generator=load_patch_batch_train(image_names=train_data,
                                                 label_names=train_labels,
                                                 centers=train_centers,
                                                 batch_size=batch_size,
                                                 size=patch_size,
                                                 nlabels=num_classes,
                                                 dfactor=dfactor,
                                                 preload=preload,
                                                 split=not sequential,
                                                 datatype=np.float32),
                validation_data=load_patch_batch_train(image_names=val_data,
                                                       label_names=val_labels,
                                                       centers=val_centers,
                                                       batch_size=batch_size,
                                                       size=patch_size,
                                                       nlabels=num_classes,
                                                       dfactor=dfactor,
                                                       preload=preload,
                                                       split=not sequential,
                                                       datatype=np.float32),
                steps_per_epoch=train_steps_per_epoch,
                validation_steps=val_steps_per_epoch,
                max_q_size=queue,
                epochs=epochs)
            net.save(net_name)

        # Then we test the net.
        use_gt = options['use_gt']
        for p, gt_name in zip(test_data, test_labels):
            p_name = p[0].rsplit('/')[-2]
            patient_path = '/'.join(p[0].rsplit('/')[:-1])
            outputname = os.path.join(patient_path,
                                      'deep-brats17' + sufix + 'test.nii.gz')
            try:
                load_nii(outputname)
            except IOError:
                roi_nii = load_nii(p[0])
                roi = roi_nii.get_data().astype(dtype=np.bool)
                centers = get_mask_voxels(roi)
                test_samples = np.count_nonzero(roi)
                image = np.zeros_like(roi).astype(dtype=np.uint8)
                print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
                      '<Creating the probability map ' + c['b'] + p_name +
                      c['nc'] + c['g'] + ' (%d samples)>' % test_samples +
                      c['nc'])
                test_steps_per_epoch = -(-test_samples / batch_size)
                y_pr_pred = net.predict_generator(
                    generator=load_patch_batch_generator_test(
                        image_names=p,
                        centers=centers,
                        batch_size=batch_size,
                        size=patch_size,
                        preload=preload,
                    ),
                    steps=test_steps_per_epoch,
                    max_q_size=queue)
                [x, y, z] = np.stack(centers, axis=1)

                if not sequential:
                    tumor = np.argmax(y_pr_pred[0], axis=1)
                    y_pr_pred = y_pr_pred[-1]
                    roi = np.zeros_like(roi).astype(dtype=np.uint8)
                    roi[x, y, z] = tumor
                    roi_nii.get_data()[:] = roi
                    roiname = os.path.join(
                        patient_path,
                        'deep-brats17' + sufix + 'test.roi.nii.gz')
                    roi_nii.to_filename(roiname)

                y_pred = np.argmax(y_pr_pred, axis=1)

                image[x, y, z] = y_pred
                # Post-processing (Basically keep the biggest connected region)
                image = get_biggest_region(image)
                if use_gt:
                    gt_nii = load_nii(gt_name)
                    gt = np.copy(gt_nii.get_data()).astype(dtype=np.uint8)
                    labels = np.unique(gt.flatten())
                    results = (p_name, ) + tuple(
                        [dsc_seg(gt == l, image == l) for l in labels[1:]])
                    text = 'Subject %s DSC: ' + '/'.join(
                        ['%f' for _ in labels[1:]])
                    print(text % results)
                    dsc_results.append(results)

                print(c['g'] + '                   -- Saving image ' + c['b'] +
                      outputname + c['nc'])
                roi_nii.get_data()[:] = image
                roi_nii.to_filename(outputname)
Esempio n. 5
0
def main():
    options = parse_inputs()
    c = color_codes()

    # Prepare the net hyperparameters
    epochs = options['epochs']
    patch_width = options['patch_width']
    patch_size = (patch_width, patch_width, patch_width)
    dense_size = options['dense_size']
    conv_blocks = options['conv_blocks']
    n_filters = options['n_filters']
    filters_list = n_filters if len(n_filters) > 1 else n_filters * conv_blocks
    conv_width = options['conv_width']
    kernel_size_list = conv_width if isinstance(conv_width, list) else [conv_width] * conv_blocks
    balanced = options['balanced']
    # Data loading parameters
    preload = options['preload']

    # Prepare the sufix that will be added to the results for the net and images
    filters_s = 'n'.join(['%d' % nf for nf in filters_list])
    conv_s = 'c'.join(['%d' % cs for cs in kernel_size_list])
    ub_s = '.ub' if not balanced else ''
    params_s = (ub_s, patch_width, conv_s, filters_s, dense_size, epochs)
    sufix = '%s.p%d.c%s.n%s.d%d.e%d' % params_s
    preload_s = ' (with ' + c['b'] + 'preloading' + c['nc'] + c['c'] + ')' if preload else ''

    print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + 'Starting training' + preload_s + c['nc'])
    train_data, _ = get_names_from_path(options)
    test_data, test_labels = get_names_from_path(options, False)

    input_shape = (train_data.shape[1],) + patch_size

    dsc_results_gan = list()
    dsc_results_cnn = list()
    dsc_results_caps = list()

    train_data, train_labels = get_names_from_path(options)
    centers_s = np.random.permutation(
        get_cnn_centers(train_data[:, 0], train_labels, balanced=balanced)
    )[::options['down_sampling']]
    x_seg, y_seg = load_patches_ganseg_by_batches(
        image_names=train_data,
        label_names=train_labels,
        source_centers=centers_s,
        size=patch_size,
        nlabels=5,
        preload=preload,
        batch_size=51200
    )

    y_seg_roi = np.empty((len(y_seg), 2), dtype=np.bool)
    y_seg_roi[:, 0] = y_seg[:, 0]
    y_seg_roi[:, 1] = np.sum(y_seg[:, 1:], axis=1)

    for i, (p, gt_name) in enumerate(zip(test_data, test_labels)):
        p_name = p[0].rsplit('/')[-2]
        patient_path = '/'.join(p[0].rsplit('/')[:-1])
        print('%s[%s] %sCase %s%s%s%s%s (%d/%d):%s' % (
            c['c'],
            strftime("%H:%M:%S"),
            c['nc'],
            c['c'],
            c['b'],
            p_name,
            c['nc'],
            c['c'],
            i + 1,
            len(test_data),
            c['nc']
        ))

        # ROI segmentation
        adversarial_w = K.variable(0)
        roi_cnn = get_brats_fc(input_shape, filters_list, kernel_size_list, dense_size, 2)
        roi_caps = get_brats_caps(input_shape, filters_list, kernel_size_list, 8, 2)
        roi_gan, _ = get_brats_gan_fc(
            input_shape,
            filters_list,
            kernel_size_list,
            dense_size,
            2,
            lambda_var=adversarial_w
        )
        train_nets(
            x=x_seg,
            y=y_seg_roi,
            gan=roi_gan,
            cnn=roi_cnn,
            caps=roi_caps,
            p=p,
            name='brats2017-roi' + sufix,
            adversarial_w=adversarial_w
        )

        # Tumor substructures net
        adversarial_w = K.variable(0)
        seg_cnn = get_brats_fc(input_shape, filters_list, kernel_size_list, dense_size, 5)
        seg_caps = get_brats_caps(input_shape, filters_list, kernel_size_list, 8, 5)
        seg_gan_tr, seg_gan_tst = get_brats_gan_fc(
            input_shape,
            filters_list,
            kernel_size_list,
            dense_size,
            5,
            lambda_var=adversarial_w
        )
        roi_net_conv_layers = [l for l in roi_gan.layers if 'conv' in l.name]
        seg_net_conv_layers = [l for l in seg_gan_tr.layers if 'conv' in l.name]
        for lr, ls in zip(roi_net_conv_layers[:conv_blocks], seg_net_conv_layers[:conv_blocks]):
            ls.set_weights(lr.get_weights())
        train_nets(
            x=x_seg,
            y=y_seg,
            gan=seg_gan_tr,
            cnn=seg_cnn,
            caps=seg_caps,
            p=p,
            name='brats2017-full' + sufix,
            adversarial_w=adversarial_w
        )

        image_cnn_name = os.path.join(patient_path, p_name + '.cnn.test')
        try:
            image_cnn = load_nii(image_cnn_name + '.nii.gz').get_data()
        except IOError:
            image_cnn = test_net(seg_cnn, p, image_cnn_name)

        image_caps_name = os.path.join(patient_path, p_name + '.caps.test')
        try:
            image_caps = load_nii(image_caps_name + '.nii.gz').get_data()
        except IOError:
            image_caps = test_net(seg_caps, p, image_caps_name)

        image_gan_name = os.path.join(patient_path, p_name + '.gan.test')
        try:
            image_gan = load_nii(image_gan_name + '.nii.gz').get_data()
        except IOError:
            image_gan = test_net(seg_gan_tst, p, image_gan_name)

        results_cnn = check_dsc(gt_name, image_cnn)
        dsc_string = c['g'] + '/'.join(['%f'] * len(results_cnn)) + c['nc']
        print(''.join([' '] * 14) + c['c'] + c['b'] + p_name + c['nc'] + ' CNN DSC: ' +
              dsc_string % tuple(results_cnn))

        results_caps = check_dsc(gt_name, image_caps)
        dsc_string = c['g'] + '/'.join(['%f'] * len(results_caps)) + c['nc']
        print(''.join([' '] * 14) + c['c'] + c['b'] + p_name + c['nc'] + ' CAPS DSC: ' +
              dsc_string % tuple(results_caps))

        results_gan = check_dsc(gt_name, image_gan)
        dsc_string = c['g'] + '/'.join(['%f'] * len(results_gan)) + c['nc']
        print(''.join([' '] * 14) + c['c'] + c['b'] + p_name + c['nc'] + ' GAN DSC: ' +
              dsc_string % tuple(results_gan))

        dsc_results_cnn.append(results_cnn)
        dsc_results_caps.append(results_caps)
        dsc_results_gan.append(results_gan)

    f_dsc = tuple(
        [np.array([dsc[i] for dsc in dsc_results_cnn if len(dsc) > i]).mean() for i in range(3)]
    ) + tuple(
        [np.array([dsc[i] for dsc in dsc_results_caps if len(dsc) > i]).mean() for i in range(3)]
    ) + tuple(
        [np.array([dsc[i] for dsc in dsc_results_gan if len(dsc) > i]).mean() for i in range(3)]
    )
    print('Final results DSC: (%f/%f/%f) vs (%f/%f/%f) vs (%f/%f/%f)' % f_dsc)
def main():
    options = parse_inputs()
    c = color_codes()

    # Prepare the net architecture parameters
    dfactor = options['dfactor']
    # Prepare the net hyperparameters
    num_classes = 4
    epochs = options['epochs']
    patch_width = options['patch_width']
    patch_size = (patch_width, patch_width, patch_width)
    batch_size = options['batch_size']
    dense_size = options['dense_size']
    conv_blocks = options['conv_blocks']
    n_filters = options['n_filters']
    filters_list = n_filters if len(n_filters) > 1 else n_filters * conv_blocks
    conv_width = options['conv_width']
    kernel_size_list = conv_width if isinstance(
        conv_width, list) else [conv_width] * conv_blocks
    balanced = options['balanced']
    val_rate = options['val_rate']
    # Data loading parameters
    preload = options['preload']
    queue = options['queue']

    # Prepare the sufix that will be added to the results for the net and images
    path = options['dir_name']
    filters_s = 'n'.join(['%d' % nf for nf in filters_list])
    conv_s = 'c'.join(['%d' % cs for cs in kernel_size_list])
    ub_s = '.ub' if not balanced else ''
    params_s = (ub_s, dfactor, patch_width, conv_s, filters_s, dense_size,
                epochs)
    sufix = '%s.D%d.p%d.c%s.n%s.d%d.e%d.' % params_s
    n_channels = 4
    preload_s = ' (with ' + c['b'] + 'preloading' + c['nc'] + c[
        'c'] + ')' if preload else ''

    print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + 'Starting training' +
          preload_s + c['nc'])
    # N-fold cross validation main loop (we'll do 2 training iterations with testing for each patient)
    train_data, train_labels = get_names_from_path(options)

    print(c['c'] + '[' + strftime("%H:%M:%S") + ']  ' + c['nc'] + c['g'] +
          'Number of training images (%d=%d)' %
          (len(train_data), len(train_labels)) + c['nc'])
    #  Also, prepare the network
    net_name = os.path.join(path, 'CBICA-brats2017' + sufix)

    print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
          'Creating and compiling the model ' + c['nc'])
    input_shape = (train_data.shape[1], ) + patch_size

    # Sequential model that merges all 4 images. This architecture is just a set of convolutional blocks
    #  that end in a dense layer. This is supposed to be an original baseline.
    inputs = Input(shape=input_shape, name='merged_inputs')
    conv = inputs
    for filters, kernel_size in zip(filters_list, kernel_size_list):
        conv = Conv3D(filters,
                      kernel_size=kernel_size,
                      activation='relu',
                      data_format='channels_first')(conv)
        conv = Dropout(0.5)(conv)

    full = Conv3D(dense_size,
                  kernel_size=(1, 1, 1),
                  data_format='channels_first')(conv)
    full = PReLU()(full)
    full = Conv3D(2, kernel_size=(1, 1, 1), data_format='channels_first')(full)

    rf = concatenate([conv, full], axis=1)

    while np.product(K.int_shape(rf)[2:]) > 1:
        rf = Conv3D(dense_size,
                    kernel_size=(3, 3, 3),
                    data_format='channels_first')(rf)
        rf = Dropout(0.5)(rf)

    full = Reshape((2, -1))(full)
    full = Permute((2, 1))(full)
    full_out = Activation('softmax', name='fc_out')(full)

    tumor = Dense(2, activation='softmax', name='tumor')(rf)

    outputs = [tumor, full_out]

    net = Model(inputs=inputs, outputs=outputs)

    net.compile(optimizer='adadelta',
                loss='categorical_crossentropy',
                loss_weights=[0.8, 1.0],
                metrics=['accuracy'])

    fc_width = patch_width - sum(kernel_size_list) + conv_blocks
    fc_shape = (fc_width, ) * 3

    checkpoint = net_name + '{epoch:02d}.{val_tumor_acc:.2f}.hdf5'
    callbacks = [
        EarlyStopping(monitor='val_tumor_loss', patience=options['patience']),
        ModelCheckpoint(os.path.join(path, checkpoint),
                        monitor='val_tumor_loss',
                        save_best_only=True)
    ]

    for i in range(options['r_epochs']):
        try:
            net = load_model(net_name + ('e%d.' % i) + 'mdl')
        except IOError:
            train_centers = get_cnn_centers(train_data[:, 0],
                                            train_labels,
                                            balanced=balanced)
            train_samples = len(train_centers) / dfactor
            print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
                  'Loading data ' + c['b'] + '(%d centers)' %
                  (len(train_centers) / dfactor) + c['nc'])
            x, y = load_patches_train(image_names=train_data,
                                      label_names=train_labels,
                                      centers=train_centers,
                                      size=patch_size,
                                      fc_shape=fc_shape,
                                      nlabels=2,
                                      dfactor=dfactor,
                                      preload=preload,
                                      split=True,
                                      iseg=False,
                                      experimental=1,
                                      datatype=np.float32)

            print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
                  'Training the model for ' + c['b'] +
                  '(%d parameters)' % net.count_params() + c['nc'])
            print(net.summary())

            net.fit(x,
                    y,
                    batch_size=batch_size,
                    validation_split=val_rate,
                    epochs=epochs,
                    callbacks=callbacks)
            net.save(net_name + ('e%d.' % i) + 'mdl')
def train_net(fold_n, train_data, train_labels, options):
    # Prepare the net architecture parameters
    dfactor = options['dfactor']
    # Prepare the net hyperparameters
    epochs = options['epochs']
    patch_width = options['patch_width']
    patch_size = (patch_width, ) * 3
    batch_size = options['batch_size']
    dense_size = options['dense_size']
    conv_blocks = options['conv_blocks']
    nfilters = options['n_filters']
    filters_list = nfilters if len(nfilters) > 1 else nfilters * conv_blocks
    conv_width = options['conv_width']
    kernel_size_list = conv_width if isinstance(
        conv_width, list) else [conv_width] * conv_blocks
    experimental = options['experimental']

    fc_width = patch_width - sum(kernel_size_list) + conv_blocks
    fc_shape = (fc_width, ) * 3
    # Data loading parameters
    preload = options['preload']

    # Prepare the sufix that will be added to the results for the net and images
    path = options['dir_name']
    sufix = get_sufix(options)

    net_name = os.path.join(path, 'iseg2017.fold%d' % fold_n + sufix + 'mdl')
    checkpoint = 'iseg2017.fold%d' % fold_n + sufix + '{epoch:02d}.{val_brain_acc:.2f}.hdf5'

    c = color_codes()
    try:
        net = load_model(net_name)
        net.load_weights(os.path.join(path, checkpoint))
    except IOError:
        # Data loading
        train_centers = get_cnn_centers(train_data[:, 0], train_labels)
        train_samples = len(train_centers) / dfactor
        print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
              'Loading data ' + c['b'] + '(%d centers)' % len(train_centers) +
              c['nc'])
        x, y = load_patches_train(image_names=train_data,
                                  label_names=train_labels,
                                  centers=train_centers,
                                  size=patch_size,
                                  fc_shape=fc_shape,
                                  nlabels=4,
                                  dfactor=dfactor,
                                  preload=preload,
                                  split=True,
                                  iseg=True,
                                  experimental=experimental,
                                  datatype=np.float32)
        # NET definition using Keras
        print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
              'Creating and compiling the model ' + c['b'] +
              '(%d samples)' % train_samples + c['nc'])
        input_shape = (2, ) + patch_size
        # This architecture is based on the functional Keras API to introduce 3 output paths:
        # - Whole tumor segmentation
        # - Core segmentation (including whole tumor)
        # - Whole segmentation (tumor, core and enhancing parts)
        # The idea is to let the network work on the three parts to improve the multiclass segmentation.
        network_func = [
            get_iseg_baseline, get_iseg_experimental1, get_iseg_experimental2,
            get_iseg_experimental3, get_iseg_experimental4
        ]
        net = network_func[experimental](input_shape, filters_list,
                                         kernel_size_list, dense_size)

        print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
              'Training the model ' + c['b'] +
              '(%d parameters)' % net.count_params() + c['nc'])
        print(net.summary())
        callbacks = [
            EarlyStopping(monitor='val_brain_loss',
                          patience=options['patience']),
            ModelCheckpoint(os.path.join(path, checkpoint),
                            monitor='val_brain_loss',
                            save_best_only=True)
        ]
        net.save(net_name)
        net.fit(x,
                y,
                batch_size=batch_size,
                validation_split=0.25,
                epochs=epochs,
                callbacks=callbacks)
        net.load_weights(os.path.join(path, checkpoint))
    return net