Ejemplo n.º 1
0
def main():
    options = parse_inputs()
    c = color_codes()

    # Prepare the net architecture parameters
    sequential = options['sequential']
    dfactor = options['dfactor']
    # Prepare the net hyperparameters
    num_classes = 5
    epochs = options['epochs']
    padding = options['padding']
    patch_width = options['patch_width']
    patch_size = (patch_width, patch_width, patch_width)
    batch_size = options['batch_size']
    dense_size = options['dense_size']
    conv_blocks = options['conv_blocks']
    n_filters = options['n_filters']
    filters_list = n_filters if len(n_filters) > 1 else n_filters * conv_blocks
    conv_width = options['conv_width']
    kernel_size_list = conv_width if isinstance(
        conv_width, list) else [conv_width] * conv_blocks
    balanced = options['balanced']
    # Data loading parameters
    preload = options['preload']
    queue = options['queue']

    # Prepare the sufix that will be added to the results for the net and images
    path = options['dir_name']
    filters_s = 'n'.join(['%d' % nf for nf in filters_list])
    conv_s = 'c'.join(['%d' % cs for cs in kernel_size_list])
    s_s = '.s' if sequential else '.f'
    ub_s = '.ub' if not balanced else ''
    params_s = (ub_s, dfactor, s_s, patch_width, conv_s, filters_s, dense_size,
                epochs, padding)
    sufix = '%s.D%d%s.p%d.c%s.n%s.d%d.e%d.pad_%s.' % params_s
    n_channels = np.count_nonzero([
        options['use_flair'], options['use_t2'], options['use_t1'],
        options['use_t1ce']
    ])

    print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' +
          'Starting cross-validation' + c['nc'])
    # N-fold cross validation main loop (we'll do 2 training iterations with testing for each patient)
    data_names, label_names = get_names_from_path(options)
    folds = options['folds']
    fold_generator = izip(
        nfold_cross_validation(data_names, label_names, n=folds,
                               val_data=0.25), xrange(folds))
    dsc_results = list()
    for (train_data, train_labels, val_data, val_labels, test_data,
         test_labels), i in fold_generator:
        print(
            c['c'] + '[' + strftime("%H:%M:%S") + ']  ' + c['nc'] +
            'Fold %d/%d: ' % (i + 1, folds) + c['g'] +
            'Number of training/validation/testing images (%d=%d/%d=%d/%d)' %
            (len(train_data), len(train_labels), len(val_data),
             len(val_labels), len(test_data)) + c['nc'])
        # Prepare the data relevant to the leave-one-out (subtract the patient from the dataset and set the path)
        # Also, prepare the network
        net_name = os.path.join(
            path, 'baseline-brats2017.fold%d' % i + sufix + 'mdl')

        # First we check that we did not train for that patient, in order to save time
        try:
            # net_name_before =  os.path.join(path,'baseline-brats2017.fold0.D500.f.p13.c3c3c3c3c3.n32n32n32n32n32.d256.e1.pad_valid.mdl')
            net = keras.models.load_model(net_name)
        except IOError:
            print '==============================================================='
            # NET definition using Keras
            train_centers = get_cnn_centers(train_data[:, 0],
                                            train_labels,
                                            balanced=balanced)
            val_centers = get_cnn_centers(val_data[:, 0],
                                          val_labels,
                                          balanced=balanced)
            train_samples = len(train_centers) / dfactor
            val_samples = len(val_centers) / dfactor
            print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
                  'Creating and compiling the model ' + c['b'] +
                  '(%d samples)' % train_samples + c['nc'])
            train_steps_per_epoch = -(-train_samples / batch_size)
            val_steps_per_epoch = -(-val_samples / batch_size)
            input_shape = (n_channels, ) + patch_size

            # This architecture is based on the functional Keras API to introduce 3 output paths:
            # - Whole tumor segmentation
            # - Core segmentation (including whole tumor)
            # - Whole segmentation (tumor, core and enhancing parts)
            # The idea is to let the network work on the three parts to improve the multiclass segmentation.
            # merged_inputs = Input(shape=(4,) + patch_size, name='merged_inputs')
            # flair = merged_inputs

            model = Sequential()
            model.add(
                Conv3D(64, (3, 3, 3),
                       strides=1,
                       padding='same',
                       activation='relu',
                       data_format='channels_first',
                       input_shape=(4, options['patch_width'],
                                    options['patch_width'],
                                    options['patch_width'])))
            model.add(
                Conv3D(64, (3, 3, 3),
                       strides=1,
                       padding='same',
                       activation='relu',
                       data_format='channels_first'))
            model.add(
                MaxPooling3D(pool_size=(3, 3, 3),
                             strides=2,
                             data_format='channels_first'))

            model.add(
                Conv3D(128, (3, 3, 3),
                       strides=1,
                       padding='same',
                       activation='relu',
                       data_format='channels_first'))
            model.add(
                Conv3D(128, (3, 3, 3),
                       strides=1,
                       padding='same',
                       activation='relu',
                       data_format='channels_first'))
            model.add(
                MaxPooling3D(pool_size=(3, 3, 3),
                             strides=2,
                             data_format='channels_first'))

            model.add(Flatten())

            model.add(Dense(256, activation='relu'))

            model.add(Dropout(0.5))

            model.add(Dense(num_classes, activation='softmax'))

            net = model

            # net_name_before =  os.path.join(path,'baseline-brats2017.fold0.D500.f.p13.c3c3c3c3c3.n32n32n32n32n32.d256.e1.pad_valid.mdl')
            # net = keras.models.load_model(net_name_before)
            net.compile(optimizer='sgd',
                        loss='categorical_crossentropy',
                        metrics=['accuracy'])

            print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
                  'Training the model with a generator for ' + c['b'] +
                  '(%d parameters)' % net.count_params() + c['nc'])
            print(net.summary())

            net.fit_generator(
                generator=load_patch_batch_train(
                    image_names=train_data,
                    label_names=train_labels,
                    centers=train_centers,
                    batch_size=batch_size,
                    size=patch_size,
                    # fc_shape = patch_size,
                    nlabels=num_classes,
                    dfactor=dfactor,
                    preload=preload,
                    split=not sequential,
                    datatype=np.float32),
                validation_data=load_patch_batch_train(
                    image_names=val_data,
                    label_names=val_labels,
                    centers=val_centers,
                    batch_size=batch_size,
                    size=patch_size,
                    # fc_shape = patch_size,
                    nlabels=num_classes,
                    dfactor=dfactor,
                    preload=preload,
                    split=not sequential,
                    datatype=np.float32),
                # workers=queue,
                steps_per_epoch=train_steps_per_epoch,
                validation_steps=val_steps_per_epoch,
                max_q_size=queue,
                epochs=epochs)
            net.save(net_name)

        # Then we test the net.
        for p, gt_name in zip(test_data, test_labels):
            p_name = p[0].rsplit('/')[-2]
            patient_path = '/'.join(p[0].rsplit('/')[:-1])
            outputname = os.path.join(patient_path,
                                      'deep-brats17' + sufix + 'test.nii.gz')
            gt_nii = load_nii(gt_name)
            gt = np.copy(gt_nii.get_data()).astype(dtype=np.uint8)
            try:
                load_nii(outputname)
            except IOError:
                roi_nii = load_nii(p[0])
                roi = roi_nii.get_data().astype(dtype=np.bool)
                centers = get_mask_voxels(roi)
                test_samples = np.count_nonzero(roi)
                image = np.zeros_like(roi).astype(dtype=np.uint8)
                print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
                      '<Creating the probability map ' + c['b'] + p_name +
                      c['nc'] + c['g'] + ' (%d samples)>' % test_samples +
                      c['nc'])
                test_steps_per_epoch = -(-test_samples / batch_size)
                y_pr_pred = net.predict_generator(
                    generator=load_patch_batch_generator_test(
                        image_names=p,
                        centers=centers,
                        batch_size=batch_size,
                        size=patch_size,
                        preload=preload,
                    ),
                    steps=test_steps_per_epoch,
                    max_q_size=queue)
                [x, y, z] = np.stack(centers, axis=1)

                if not sequential:
                    tumor = np.argmax(y_pr_pred[0], axis=1)
                    y_pr_pred = y_pr_pred[-1]
                    roi = np.zeros_like(roi).astype(dtype=np.uint8)
                    roi[x, y, z] = tumor
                    roi_nii.get_data()[:] = roi
                    roiname = os.path.join(
                        patient_path,
                        'deep-brats17' + sufix + 'test.roi.nii.gz')
                    roi_nii.to_filename(roiname)

                y_pred = np.argmax(y_pr_pred, axis=1)

                image[x, y, z] = y_pred
                # Post-processing (Basically keep the biggest connected region)
                image = get_biggest_region(image)
                labels = np.unique(gt.flatten())
                results = (p_name, ) + tuple(
                    [dsc_seg(gt == l, image == l) for l in labels[1:]])
                text = 'Subject %s DSC: ' + '/'.join(
                    ['%f' for _ in labels[1:]])
                print(text % results)
                dsc_results.append(results)

                print(c['g'] + '                   -- Saving image ' + c['b'] +
                      outputname + c['nc'])
                roi_nii.get_data()[:] = image
                roi_nii.to_filename(outputname)
def test_seg(net, p, outputname, nlabels, mask=None, verbose=True):

    c = color_codes()
    options = parse_inputs()
    p_name = p[0].rsplit('/')[-2]
    patient_path = '/'.join(p[0].rsplit('/')[:-1])
    outputname_path = os.path.join(patient_path, outputname + '.nii.gz')
    try:
        roi_nii = load_nii(outputname_path)
        if verbose:
            print('%s%s<%s%s%s%s - probability map loaded>%s' % (''.join(
                [' '] * 14), c['g'], c['b'], p_name, c['nc'], c['g'], c['nc']))
    except IOError:
        roi_nii = load_nii(p[0])
        # Image loading
        if mask is None:
            # This is the unet path
            x = np.expand_dims(np.stack(load_images(p), axis=0), axis=0)
            # Network parameters
            conv_blocks = options['conv_blocks']
            n_filters = options['n_filters']
            filters_list = n_filters if len(
                n_filters) > 1 else n_filters * conv_blocks
            conv_width = options['conv_width']
            kernel_size_list = conv_width if isinstance(
                conv_width, list) else [conv_width] * conv_blocks

            image_net = options['net'](x.shape[1:], filters_list,
                                       kernel_size_list, nlabels)
            # We should copy the weights here (if not using roinet)
            for l_new, l_orig in zip(image_net.layers[1:], net.layers[1:]):
                l_new.set_weights(l_orig.get_weights())

            # Now we can test
            if verbose:
                print('%s[%s] %sTesting the network%s' %
                      (c['c'], strftime("%H:%M:%S"), c['g'], c['nc']))
            # Load only the patient images
            if verbose:
                print(
                    '%s%s<Creating the probability map for %s%s%s%s - %s%s%s %s>%s'
                    % (''.join(
                        [' '] * 14), c['g'], c['b'], p_name, c['nc'], c['g'],
                       c['b'], outputname_path, c['nc'], c['g'], c['nc']))
            pr_maps = image_net.predict(x, batch_size=options['test_size'])
            image = np.argmax(pr_maps, axis=-1).reshape(x.shape[2:])
            image = get_biggest_region(image)
        else:
            # This is the ensemble path
            image = np.zeros_like(mask, dtype=np.int8)
            options = parse_inputs()
            conv_blocks = options['conv_blocks_seg']
            test_centers = get_mask_blocks(mask)
            x = get_data(
                image_names=[p],
                list_of_centers=[test_centers],
                patch_size=(conv_blocks * 2 + 3, ) * 3,
                verbose=verbose,
            )
            if verbose:
                print('%s- Concatenating the data x' % ' '.join([''] * 12))
            x = np.concatenate(x)
            pr_maps = net.predict(x, batch_size=options['test_size'])
            [x, y, z] = np.stack(test_centers, axis=1)
            image[x, y, z] = np.argmax(pr_maps, axis=1).astype(dtype=np.int8)

        roi_nii.get_data()[:] = image
        roi_nii.to_filename(outputname_path)
    return roi_nii
def main():
    options = parse_inputs()
    c = color_codes()

    # Prepare the net architecture parameters
    sequential = options['sequential']
    dfactor = options['dfactor']
    # Prepare the net hyperparameters
    num_classes = 5
    epochs = options['epochs']
    padding = options['padding']
    patch_width = options['patch_width']
    patch_size = (patch_width, patch_width, patch_width)
    batch_size = options['batch_size']
    dense_size = options['dense_size']
    conv_blocks = options['conv_blocks']
    n_filters = options['n_filters']
    filters_list = n_filters if len(n_filters) > 1 else n_filters * conv_blocks
    conv_width = options['conv_width']
    kernel_size_list = conv_width if isinstance(
        conv_width, list) else [conv_width] * conv_blocks
    balanced = options['balanced']
    recurrent = options['recurrent']
    # Data loading parameters
    preload = options['preload']
    queue = options['queue']

    # Prepare the sufix that will be added to the results for the net and images
    path = options['dir_name']
    filters_s = 'n'.join(['%d' % nf for nf in filters_list])
    conv_s = 'c'.join(['%d' % cs for cs in kernel_size_list])
    s_s = '.s' if sequential else '.f'
    ub_s = '.ub' if not balanced else ''
    params_s = (ub_s, dfactor, s_s, patch_width, conv_s, filters_s, dense_size,
                epochs, padding)
    sufix = '%s.D%d%s.p%d.c%s.n%s.d%d.e%d.pad_%s.' % params_s
    n_channels = np.count_nonzero([
        options['use_flair'], options['use_t2'], options['use_t1'],
        options['use_t1ce']
    ])

    print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' +
          'Starting cross-validation' + c['nc'])
    # N-fold cross validation main loop (we'll do 2 training iterations with testing for each patient)
    data_names, label_names = get_names_from_path(options)
    folds = options['folds']
    fold_generator = izip(
        nfold_cross_validation(data_names, label_names, n=folds,
                               val_data=0.25), xrange(folds))
    dsc_results = list()
    for (train_data, train_labels, val_data, val_labels, test_data,
         test_labels), i in fold_generator:
        print(c['c'] + '[' + strftime("%H:%M:%S") + ']  ' + c['nc'] +
              'Fold %d/%d: ' % (i + 1, folds) + c['g'] +
              'Number of training/validation/testing images (%d=%d/%d=%d/%d)' %
              (len(train_data), len(train_labels), len(val_data),
               len(val_labels), len(test_data)) + c['nc'])
        # Prepare the data relevant to the leave-one-out (subtract the patient from the dataset and set the path)
        # Also, prepare the network
        net_name = os.path.join(
            path, 'baseline-brats2017.fold%d' % i + sufix + 'mdl')

        # First we check that we did not train for that patient, in order to save time
        try:
            net = keras.models.load_model(net_name)
        except IOError:
            # NET definition using Keras
            train_centers = get_cnn_centers(train_data[:, 0],
                                            train_labels,
                                            balanced=balanced)
            val_centers = get_cnn_centers(val_data[:, 0],
                                          val_labels,
                                          balanced=balanced)
            train_samples = len(train_centers) / dfactor
            val_samples = len(val_centers) / dfactor
            print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
                  'Creating and compiling the model ' + c['b'] +
                  '(%d samples)' % train_samples + c['nc'])
            train_steps_per_epoch = -(-train_samples / batch_size)
            val_steps_per_epoch = -(-val_samples / batch_size)
            input_shape = (n_channels, ) + patch_size
            if sequential:
                # Sequential model that merges all 4 images. This architecture is just a set of convolutional blocks
                # that end in a dense layer. This is supposed to be an original baseline.
                net = Sequential()
                net.add(
                    Conv3D(filters_list[0],
                           kernel_size=kernel_size_list[0],
                           input_shape=input_shape,
                           activation='relu',
                           data_format='channels_first'))
                for filters, kernel_size in zip(filters_list[1:],
                                                kernel_size_list[1:]):
                    net.add(Dropout(0.5))
                    net.add(
                        Conv3D(filters,
                               kernel_size=kernel_size,
                               activation='relu',
                               data_format='channels_first'))
                net.add(Dropout(0.5))
                net.add(Flatten())
                net.add(Dense(dense_size, activation='relu'))
                net.add(Dropout(0.5))
                net.add(Dense(num_classes, activation='softmax'))
            else:
                # This architecture is based on the functional Keras API to introduce 3 output paths:
                # - Whole tumor segmentation
                # - Core segmentation (including whole tumor)
                # - Whole segmentation (tumor, core and enhancing parts)
                # The idea is to let the network work on the three parts to improve the multiclass segmentation.
                merged_inputs = Input(shape=(4, ) + patch_size,
                                      name='merged_inputs')
                flair = Reshape((1, ) + patch_size)(Lambda(
                    lambda l: l[:, 0, :, :, :],
                    output_shape=(1, ) + patch_size)(merged_inputs), )
                t2 = Reshape((1, ) + patch_size)(Lambda(
                    lambda l: l[:, 1, :, :, :],
                    output_shape=(1, ) + patch_size)(merged_inputs))
                t1 = Lambda(lambda l: l[:, 2:, :, :, :],
                            output_shape=(2, ) + patch_size)(merged_inputs)
                for filters, kernel_size in zip(filters_list,
                                                kernel_size_list):
                    flair = Conv3D(filters,
                                   kernel_size=kernel_size,
                                   activation='relu',
                                   data_format='channels_first')(flair)
                    t2 = Conv3D(filters,
                                kernel_size=kernel_size,
                                activation='relu',
                                data_format='channels_first')(t2)
                    t1 = Conv3D(filters,
                                kernel_size=kernel_size,
                                activation='relu',
                                data_format='channels_first')(t1)
                    flair = Dropout(0.5)(flair)
                    t2 = Dropout(0.5)(t2)
                    t1 = Dropout(0.5)(t1)

                # We only apply the RCNN to the multioutput approach (we keep the simple one, simple)
                if recurrent:
                    flair = Conv3D(dense_size,
                                   kernel_size=(1, 1, 1),
                                   activation='relu',
                                   data_format='channels_first',
                                   name='fcn_flair')(flair)
                    flair = Dropout(0.5)(flair)
                    t2 = concatenate([flair, t2], axis=1)
                    t2 = Conv3D(dense_size,
                                kernel_size=(1, 1, 1),
                                activation='relu',
                                data_format='channels_first',
                                name='fcn_t2')(t2)
                    t2 = Dropout(0.5)(t2)
                    t1 = concatenate([t2, t1], axis=1)
                    t1 = Conv3D(dense_size,
                                kernel_size=(1, 1, 1),
                                activation='relu',
                                data_format='channels_first',
                                name='fcn_t1')(t1)
                    t1 = Dropout(0.5)(t1)
                    flair = Dropout(0.5)(flair)
                    t2 = Dropout(0.5)(t2)
                    t1 = Dropout(0.5)(t1)
                    lstm_instance = LSTM(dense_size,
                                         implementation=1,
                                         name='rf_layer')
                    flair = lstm_instance(
                        Permute((2, 1))(Reshape((dense_size, -1))(flair)))
                    t2 = lstm_instance(
                        Permute((2, 1))(Reshape((dense_size, -1))(t2)))
                    t1 = lstm_instance(
                        Permute((2, 1))(Reshape((dense_size, -1))(t1)))

                else:
                    flair = Flatten()(flair)
                    t2 = Flatten()(t2)
                    t1 = Flatten()(t1)
                    flair = Dense(dense_size, activation='relu')(flair)
                    flair = Dropout(0.5)(flair)
                    t2 = concatenate([flair, t2])
                    t2 = Dense(dense_size, activation='relu')(t2)
                    t2 = Dropout(0.5)(t2)
                    t1 = concatenate([t2, t1])
                    t1 = Dense(dense_size, activation='relu')(t1)
                    t1 = Dropout(0.5)(t1)

                tumor = Dense(2, activation='softmax', name='tumor')(flair)
                core = Dense(3, activation='softmax', name='core')(t2)
                enhancing = Dense(num_classes,
                                  activation='softmax',
                                  name='enhancing')(t1)

                net = Model(inputs=merged_inputs,
                            outputs=[tumor, core, enhancing])

            net.compile(optimizer='adadelta',
                        loss='categorical_crossentropy',
                        metrics=['accuracy'])

            print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
                  'Training the model with a generator for ' + c['b'] +
                  '(%d parameters)' % net.count_params() + c['nc'])
            print(net.summary())
            net.fit_generator(
                generator=load_patch_batch_train(image_names=train_data,
                                                 label_names=train_labels,
                                                 centers=train_centers,
                                                 batch_size=batch_size,
                                                 size=patch_size,
                                                 nlabels=num_classes,
                                                 dfactor=dfactor,
                                                 preload=preload,
                                                 split=not sequential,
                                                 datatype=np.float32),
                validation_data=load_patch_batch_train(image_names=val_data,
                                                       label_names=val_labels,
                                                       centers=val_centers,
                                                       batch_size=batch_size,
                                                       size=patch_size,
                                                       nlabels=num_classes,
                                                       dfactor=dfactor,
                                                       preload=preload,
                                                       split=not sequential,
                                                       datatype=np.float32),
                steps_per_epoch=train_steps_per_epoch,
                validation_steps=val_steps_per_epoch,
                max_q_size=queue,
                epochs=epochs)
            net.save(net_name)

        # Then we test the net.
        use_gt = options['use_gt']
        for p, gt_name in zip(test_data, test_labels):
            p_name = p[0].rsplit('/')[-2]
            patient_path = '/'.join(p[0].rsplit('/')[:-1])
            outputname = os.path.join(patient_path,
                                      'deep-brats17' + sufix + 'test.nii.gz')
            try:
                load_nii(outputname)
            except IOError:
                roi_nii = load_nii(p[0])
                roi = roi_nii.get_data().astype(dtype=np.bool)
                centers = get_mask_voxels(roi)
                test_samples = np.count_nonzero(roi)
                image = np.zeros_like(roi).astype(dtype=np.uint8)
                print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
                      '<Creating the probability map ' + c['b'] + p_name +
                      c['nc'] + c['g'] + ' (%d samples)>' % test_samples +
                      c['nc'])
                test_steps_per_epoch = -(-test_samples / batch_size)
                y_pr_pred = net.predict_generator(
                    generator=load_patch_batch_generator_test(
                        image_names=p,
                        centers=centers,
                        batch_size=batch_size,
                        size=patch_size,
                        preload=preload,
                    ),
                    steps=test_steps_per_epoch,
                    max_q_size=queue)
                [x, y, z] = np.stack(centers, axis=1)

                if not sequential:
                    tumor = np.argmax(y_pr_pred[0], axis=1)
                    y_pr_pred = y_pr_pred[-1]
                    roi = np.zeros_like(roi).astype(dtype=np.uint8)
                    roi[x, y, z] = tumor
                    roi_nii.get_data()[:] = roi
                    roiname = os.path.join(
                        patient_path,
                        'deep-brats17' + sufix + 'test.roi.nii.gz')
                    roi_nii.to_filename(roiname)

                y_pred = np.argmax(y_pr_pred, axis=1)

                image[x, y, z] = y_pred
                # Post-processing (Basically keep the biggest connected region)
                image = get_biggest_region(image)
                if use_gt:
                    gt_nii = load_nii(gt_name)
                    gt = np.copy(gt_nii.get_data()).astype(dtype=np.uint8)
                    labels = np.unique(gt.flatten())
                    results = (p_name, ) + tuple(
                        [dsc_seg(gt == l, image == l) for l in labels[1:]])
                    text = 'Subject %s DSC: ' + '/'.join(
                        ['%f' for _ in labels[1:]])
                    print(text % results)
                    dsc_results.append(results)

                print(c['g'] + '                   -- Saving image ' + c['b'] +
                      outputname + c['nc'])
                roi_nii.get_data()[:] = image
                roi_nii.to_filename(outputname)
Ejemplo n.º 4
0
def test_net(net, p, outputname):

    c = color_codes()
    options = parse_inputs()
    patch_width = options['patch_width']
    patch_size = (patch_width, patch_width, patch_width)
    batch_size = options['test_size']
    p_name = p[0].rsplit('/')[-2]
    patient_path = '/'.join(p[0].rsplit('/')[:-1])
    outputname_path = os.path.join(patient_path, outputname + '.nii.gz')
    roiname = os.path.join(patient_path, outputname + '.roi.nii.gz')
    try:
        image = load_nii(outputname_path).get_data()
        load_nii(roiname)
    except IOError:
        print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
              'Testing the network' + c['nc'])
        roi_nii = load_nii(p[0])
        roi = roi_nii.get_data().astype(dtype=np.bool)
        centers = get_mask_voxels(roi)
        test_samples = np.count_nonzero(roi)
        image = np.zeros_like(roi).astype(dtype=np.uint8)
        print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
              '<Creating the probability map ' + c['b'] + p_name + c['nc'] +
              c['g'] + ' - ' + c['b'] + outputname + c['nc'] + c['g'] +
              ' (%d samples)>' % test_samples + c['nc'])

        n_centers = len(centers)
        image_list = [load_norm_list(p)]
        is_roi = True
        roi = np.zeros_like(roi).astype(dtype=np.uint8)
        fcn_out = np.zeros_like(roi).astype(dtype=np.uint8)
        out = np.zeros_like(roi).astype(dtype=np.uint8)
        for i in range(0, n_centers, batch_size):
            print('%f%% tested (step %d/%d)' %
                  (100.0 * i / n_centers,
                   (i / batch_size) + 1, -(-n_centers / batch_size)),
                  end='\r')
            sys.stdout.flush()
            centers_i = [centers[i:i + batch_size]]
            x = get_patches_list(image_list, centers_i, patch_size, True)
            x = np.concatenate(x).astype(dtype=np.float32)
            y_pr_pred = net.predict(x, batch_size=options['batch_size'])

            [x, y, z] = np.stack(centers_i[0], axis=1)

            out[x, y, z] = np.argmax(y_pr_pred[0], axis=1)
            y_pr_pred = y_pr_pred[1][:, y_pr_pred[1].shape[1] / 2 + 1, :]
            y_pr_pred = np.squeeze(y_pr_pred)
            fcn_out[x, y, z] = np.argmax(y_pr_pred, axis=1)
            tumor = np.argmax(y_pr_pred, axis=1)

            # We store the ROI
            roi[x, y, z] = tumor.astype(dtype=np.bool)
            # We store the results
            y_pred = np.argmax(y_pr_pred, axis=1)
            image[x, y, z] = tumor if is_roi else y_pred

        print(' '.join([''] * 50), end='\r')
        sys.stdout.flush()

        # Post-processing (Basically keep the biggest connected region)
        image = get_biggest_region(image, is_roi)
        print(c['g'] + '                   -- Saving image ' + c['b'] +
              outputname_path + c['nc'])

        roi_nii.get_data()[:] = roi
        roi_nii.to_filename(roiname)
        roi_nii.get_data()[:] = get_biggest_region(fcn_out, is_roi)
        roi_nii.to_filename(
            os.path.join(patient_path, outputname + '.fcn.nii.gz'))

        roi_nii.get_data()[:] = get_biggest_region(out, is_roi)
        roi_nii.to_filename(
            os.path.join(patient_path, outputname + '.dense.nii.gz'))
        roi_nii.get_data()[:] = image
        roi_nii.to_filename(outputname_path)
    return image
Ejemplo n.º 5
0
def test_network(net, p, batch_size, patch_size, queue=50, sufix='', centers=None, filename=None):

    c = color_codes()
    p_name = p[0].rsplit('/')[-2]
    patient_path = '/'.join(p[0].rsplit('/')[:-1])
    outputname = filename if filename is not None else 'deep-brats17.test.' + sufix
    outputname_path = os.path.join(patient_path, outputname + '.nii.gz')
    roiname = os.path.join(patient_path, outputname + '.roi.nii.gz')
    try:
        image = load_nii(outputname_path).get_data()
        load_nii(roiname)
    except IOError:
        print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] + 'Testing ' +
              c['b'] + sufix + c['nc'] + c['g'] + ' network' + c['nc'])
        roi_nii = load_nii(p[0])
        roi = roi_nii.get_data().astype(dtype=np.bool)
        centers = get_mask_voxels(roi) if centers is None else centers
        test_samples = np.count_nonzero(roi)
        image = np.zeros_like(roi).astype(dtype=np.uint8)
        print(c['c'] + '[' + strftime("%H:%M:%S") + ']    ' + c['g'] +
              '<Creating the probability map ' + c['b'] + p_name + c['nc'] + c['g'] +
              ' (%d samples)>' % test_samples + c['nc'])
        test_steps_per_epoch = -(-test_samples / batch_size)
        y_pr_pred = net.predict_generator(
            generator=load_patch_batch_generator_test(
                image_names=p,
                centers=centers,
                batch_size=batch_size,
                size=patch_size,
                preload=True,
            ),
            steps=test_steps_per_epoch,
            max_q_size=queue
        )
        print(' '.join([''] * 50), end='\r')
        sys.stdout.flush()
        [x, y, z] = np.stack(centers, axis=1)

        if isinstance(y_pr_pred, list):
            tumor = np.argmax(y_pr_pred[0], axis=1)
            y_pr_pred = y_pr_pred[-1]
            is_roi = False
        else:
            tumor = np.argmax(y_pr_pred, axis=1)
            is_roi = True

        # We save the ROI
        roi = np.zeros_like(roi).astype(dtype=np.uint8)
        roi[x, y, z] = tumor
        roi_nii.get_data()[:] = roi
        roi_nii.to_filename(roiname)

        y_pred = np.argmax(y_pr_pred, axis=1)

        # We save the results
        image[x, y, z] = tumor if is_roi else y_pred
        # Post-processing (Basically keep the biggest connected region)
        image = get_biggest_region(image, is_roi)
        print(c['g'] + '                   -- Saving image ' + c['b'] + outputname_path + c['nc'])
        roi_nii.get_data()[:] = image
        roi_nii.to_filename(outputname_path)
    return image
def main():
    # Init
    c = color_codes()
    nlabels = 5

    # Prepare the names
    flair_name = '/data/flair.nii.gz'
    t2_name = '/data/t2.nii.gz'
    t1_name = '/data/t1.nii.gz'
    t1ce_name = '/data/t1ce.nii.gz'

    image_names = [flair_name, t2_name, t1_name, t1ce_name]

    ''' Unet stuff '''
    # Data loading
    x = np.expand_dims(np.stack(load_images(image_names), axis=0), axis=0)

    # Network loading and testing
    print('%s[%s] %sTesting the Unet%s' % (c['c'], strftime("%H:%M:%S"), c['g'], c['nc']))
    net = get_brats_unet(x.shape[1:], [32] * 5, [3] * 5, nlabels)
    net.load_weights('/usr/local/models/brats18-unet.hdf5')
    pr_maps = net.predict(x)
    image = np.argmax(pr_maps, axis=-1).reshape(x.shape[2:])
    mask = get_biggest_region(image).astype(np.bool)

    ''' Ensemble stuff '''
    # Init
    image = np.zeros_like(mask, dtype=np.int8)

    # Data loading
    test_centers = get_mask_blocks(mask)
    x = get_data(
        image_names=[image_names],
        list_of_centers=[test_centers],
        patch_size=(9,) * 3,
        verbose=True,
    )
    print('%s- Concatenating the data x' % ' '.join([''] * 12))
    x = np.concatenate(x)

    # Networks loading
    nets, unet, cnn, fcnn, ucnn = get_brats_nets(
        n_channels=len(image_names),
        filters_list=[32] * 3,
        kernel_size_list=[3] * 3,
        nlabels=nlabels,
        dense_size=256
    )

    ensemble = get_brats_ensemble(
        n_channels=len(image_names),
        n_blocks=3,
        unet=unet,
        cnn=cnn,
        fcnn=fcnn,
        ucnn=ucnn,
        nlabels=nlabels
    )

    nets.load_weights('/usr/local/models/brats18-nets.hdf5')
    ensemble.load_weights('/usr/local/models/brats18-ensemble.hdf5')

    # Network testing and results saving
    pr_maps = ensemble.predict(x)
    [x, y, z] = np.stack(test_centers, axis=1)
    image[x, y, z] = np.argmax(pr_maps, axis=1).astype(dtype=np.int8)

    if not os.path.isdir('/data/results'):
        os.mkdir('/data/results')
    roi_nii = load_nii(flair_name)
    roi_nii.get_data()[:] = image
    roi_nii.to_filename('/data/results/tumor_NVICOROB_class.nii.gz')