def train_net(net, net_name, nlabels): options = parse_inputs() c = color_codes() # Data stuff train_data, train_labels = get_names_from_path(options) # Prepare the net architecture parameters dfactor = options['dfactor'] # Prepare the net hyperparameters epochs = options['epochs'] patch_width = options['patch_width'] patch_size = (patch_width, patch_width, patch_width) batch_size = options['batch_size'] conv_blocks = options['conv_blocks'] conv_width = options['conv_width'] kernel_size_list = conv_width if isinstance( conv_width, list) else [conv_width] * conv_blocks balanced = options['balanced'] val_rate = options['val_rate'] preload = options['preload'] fc_width = patch_width - sum(kernel_size_list) + conv_blocks fc_shape = (fc_width, ) * 3 try: net = load_model(net_name + '.md') except IOError: centers = np.random.permutation( get_cnn_centers(train_data[:, 0], train_labels, balanced=balanced)) print(' '.join([''] * 15) + c['g'] + 'Total number of centers = ' + c['b'] + '(%d centers)' % (len(centers)) + c['nc']) for i in range(dfactor): print(' '.join([''] * 16) + c['g'] + 'Round ' + c['b'] + '%d' % (i + 1) + c['nc'] + c['g'] + '/%d' % dfactor + c['nc']) batch_centers = centers[i::dfactor] print(' '.join([''] * 16) + c['g'] + 'Loading data ' + c['b'] + '(%d centers)' % (len(batch_centers)) + c['nc']) x, y = load_patches_train( image_names=train_data, label_names=train_labels, batch_centers=batch_centers, size=patch_size, fc_shape=fc_shape, nlabels=nlabels, preload=preload, ) print(' '.join([''] * 16) + c['g'] + 'Training the model for ' + c['b'] + '(%d parameters)' % net.count_trainable_parameters() + c['nc']) net.fit(x, y, batch_size=batch_size, validation_split=val_rate, epochs=epochs) net.save(net_name + '.mod')
def main(): options = parse_inputs() c = color_codes() # Prepare the net architecture parameters sequential = options['sequential'] dfactor = options['dfactor'] # Prepare the net hyperparameters num_classes = 5 epochs = options['epochs'] padding = options['padding'] patch_width = options['patch_width'] patch_size = (patch_width, patch_width, patch_width) batch_size = options['batch_size'] dense_size = options['dense_size'] conv_blocks = options['conv_blocks'] n_filters = options['n_filters'] filters_list = n_filters if len(n_filters) > 1 else n_filters * conv_blocks conv_width = options['conv_width'] kernel_size_list = conv_width if isinstance( conv_width, list) else [conv_width] * conv_blocks balanced = options['balanced'] # Data loading parameters preload = options['preload'] queue = options['queue'] # Prepare the sufix that will be added to the results for the net and images path = options['dir_name'] filters_s = 'n'.join(['%d' % nf for nf in filters_list]) conv_s = 'c'.join(['%d' % cs for cs in kernel_size_list]) s_s = '.s' if sequential else '.f' ub_s = '.ub' if not balanced else '' params_s = (ub_s, dfactor, s_s, patch_width, conv_s, filters_s, dense_size, epochs, padding) sufix = '%s.D%d%s.p%d.c%s.n%s.d%d.e%d.pad_%s.' % params_s n_channels = np.count_nonzero([ options['use_flair'], options['use_t2'], options['use_t1'], options['use_t1ce'] ]) print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + 'Starting cross-validation' + c['nc']) # N-fold cross validation main loop (we'll do 2 training iterations with testing for each patient) data_names, label_names = get_names_from_path(options) folds = options['folds'] fold_generator = izip( nfold_cross_validation(data_names, label_names, n=folds, val_data=0.25), xrange(folds)) dsc_results = list() for (train_data, train_labels, val_data, val_labels, test_data, test_labels), i in fold_generator: print( c['c'] + '[' + strftime("%H:%M:%S") + '] ' + c['nc'] + 'Fold %d/%d: ' % (i + 1, folds) + c['g'] + 'Number of training/validation/testing images (%d=%d/%d=%d/%d)' % (len(train_data), len(train_labels), len(val_data), len(val_labels), len(test_data)) + c['nc']) # Prepare the data relevant to the leave-one-out (subtract the patient from the dataset and set the path) # Also, prepare the network net_name = os.path.join( path, 'baseline-brats2017.fold%d' % i + sufix + 'mdl') # First we check that we did not train for that patient, in order to save time try: # net_name_before = os.path.join(path,'baseline-brats2017.fold0.D500.f.p13.c3c3c3c3c3.n32n32n32n32n32.d256.e1.pad_valid.mdl') net = keras.models.load_model(net_name) except IOError: print '===============================================================' # NET definition using Keras train_centers = get_cnn_centers(train_data[:, 0], train_labels, balanced=balanced) val_centers = get_cnn_centers(val_data[:, 0], val_labels, balanced=balanced) train_samples = len(train_centers) / dfactor val_samples = len(val_centers) / dfactor print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + c['g'] + 'Creating and compiling the model ' + c['b'] + '(%d samples)' % train_samples + c['nc']) train_steps_per_epoch = -(-train_samples / batch_size) val_steps_per_epoch = -(-val_samples / batch_size) input_shape = (n_channels, ) + patch_size # This architecture is based on the functional Keras API to introduce 3 output paths: # - Whole tumor segmentation # - Core segmentation (including whole tumor) # - Whole segmentation (tumor, core and enhancing parts) # The idea is to let the network work on the three parts to improve the multiclass segmentation. # merged_inputs = Input(shape=(4,) + patch_size, name='merged_inputs') # flair = merged_inputs model = Sequential() model.add( Conv3D(64, (3, 3, 3), strides=1, padding='same', activation='relu', data_format='channels_first', input_shape=(4, options['patch_width'], options['patch_width'], options['patch_width']))) model.add( Conv3D(64, (3, 3, 3), strides=1, padding='same', activation='relu', data_format='channels_first')) model.add( MaxPooling3D(pool_size=(3, 3, 3), strides=2, data_format='channels_first')) model.add( Conv3D(128, (3, 3, 3), strides=1, padding='same', activation='relu', data_format='channels_first')) model.add( Conv3D(128, (3, 3, 3), strides=1, padding='same', activation='relu', data_format='channels_first')) model.add( MaxPooling3D(pool_size=(3, 3, 3), strides=2, data_format='channels_first')) model.add(Flatten()) model.add(Dense(256, activation='relu')) model.add(Dropout(0.5)) model.add(Dense(num_classes, activation='softmax')) net = model # net_name_before = os.path.join(path,'baseline-brats2017.fold0.D500.f.p13.c3c3c3c3c3.n32n32n32n32n32.d256.e1.pad_valid.mdl') # net = keras.models.load_model(net_name_before) net.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy']) print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + c['g'] + 'Training the model with a generator for ' + c['b'] + '(%d parameters)' % net.count_params() + c['nc']) print(net.summary()) net.fit_generator( generator=load_patch_batch_train( image_names=train_data, label_names=train_labels, centers=train_centers, batch_size=batch_size, size=patch_size, # fc_shape = patch_size, nlabels=num_classes, dfactor=dfactor, preload=preload, split=not sequential, datatype=np.float32), validation_data=load_patch_batch_train( image_names=val_data, label_names=val_labels, centers=val_centers, batch_size=batch_size, size=patch_size, # fc_shape = patch_size, nlabels=num_classes, dfactor=dfactor, preload=preload, split=not sequential, datatype=np.float32), # workers=queue, steps_per_epoch=train_steps_per_epoch, validation_steps=val_steps_per_epoch, max_q_size=queue, epochs=epochs) net.save(net_name) # Then we test the net. for p, gt_name in zip(test_data, test_labels): p_name = p[0].rsplit('/')[-2] patient_path = '/'.join(p[0].rsplit('/')[:-1]) outputname = os.path.join(patient_path, 'deep-brats17' + sufix + 'test.nii.gz') gt_nii = load_nii(gt_name) gt = np.copy(gt_nii.get_data()).astype(dtype=np.uint8) try: load_nii(outputname) except IOError: roi_nii = load_nii(p[0]) roi = roi_nii.get_data().astype(dtype=np.bool) centers = get_mask_voxels(roi) test_samples = np.count_nonzero(roi) image = np.zeros_like(roi).astype(dtype=np.uint8) print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + c['g'] + '<Creating the probability map ' + c['b'] + p_name + c['nc'] + c['g'] + ' (%d samples)>' % test_samples + c['nc']) test_steps_per_epoch = -(-test_samples / batch_size) y_pr_pred = net.predict_generator( generator=load_patch_batch_generator_test( image_names=p, centers=centers, batch_size=batch_size, size=patch_size, preload=preload, ), steps=test_steps_per_epoch, max_q_size=queue) [x, y, z] = np.stack(centers, axis=1) if not sequential: tumor = np.argmax(y_pr_pred[0], axis=1) y_pr_pred = y_pr_pred[-1] roi = np.zeros_like(roi).astype(dtype=np.uint8) roi[x, y, z] = tumor roi_nii.get_data()[:] = roi roiname = os.path.join( patient_path, 'deep-brats17' + sufix + 'test.roi.nii.gz') roi_nii.to_filename(roiname) y_pred = np.argmax(y_pr_pred, axis=1) image[x, y, z] = y_pred # Post-processing (Basically keep the biggest connected region) image = get_biggest_region(image) labels = np.unique(gt.flatten()) results = (p_name, ) + tuple( [dsc_seg(gt == l, image == l) for l in labels[1:]]) text = 'Subject %s DSC: ' + '/'.join( ['%f' for _ in labels[1:]]) print(text % results) dsc_results.append(results) print(c['g'] + ' -- Saving image ' + c['b'] + outputname + c['nc']) roi_nii.get_data()[:] = image roi_nii.to_filename(outputname)
def main(): options = parse_inputs() c = color_codes() # Prepare the net hyperparameters epochs = options['epochs'] patch_width = options['patch_width'] patch_size = (patch_width, patch_width, patch_width) dense_size = options['dense_size'] conv_blocks = options['conv_blocks'] n_filters = options['n_filters'] filters_list = n_filters if len(n_filters) > 1 else n_filters * conv_blocks conv_width = options['conv_width'] kernel_size_list = conv_width if isinstance(conv_width, list) else [conv_width] * conv_blocks balanced = options['balanced'] # Data loading parameters downsample = options['downsample'] preload = options['preload'] shuffle = options['shuffle'] # Prepare the sufix that will be added to the results for the net and images filters_s = 'n'.join(['%d' % nf for nf in filters_list]) conv_s = 'c'.join(['%d' % cs for cs in kernel_size_list]) unbalanced_s = '.ub' if not balanced else '' shuffle_s = '.s' if shuffle else '' params_s = (unbalanced_s, shuffle_s, patch_width, conv_s, filters_s, dense_size, downsample) sufix = '%s%s.p%d.c%s.n%s.d%d.D%d' % params_s preload_s = ' (with %spreloading%s%s)' % (c['b'], c['nc'], c['c']) if preload else '' print('%s[%s] Starting training%s%s' % (c['c'], strftime("%H:%M:%S"), preload_s, c['nc'])) train_data, _ = get_names_from_path(options) test_data, test_labels = get_names_from_path(options, False) input_shape = (train_data.shape[1],) + patch_size dsc_results = list() dsc_results_pr = list() train_data, train_labels = get_names_from_path(options) centers_s = np.random.permutation( get_cnn_centers(train_data[:, 0], train_labels, balanced=balanced) )[::downsample] x_seg, y_seg = load_patches_ganseg_by_batches( image_names=train_data, label_names=train_labels, source_centers=centers_s, size=patch_size, nlabels=2, preload=preload, ) for i, (p, gt_name) in enumerate(zip(test_data, test_labels)): p_name = p[0].rsplit('/')[-3] patient_path = '/'.join(p[0].rsplit('/')[:-1]) print('%s[%s] %sCase %s%s%s%s%s (%d/%d):%s' % ( c['c'], strftime("%H:%M:%S"), c['nc'], c['c'], c['b'], p_name, c['nc'], c['c'], i + 1, len(test_data), c['nc'] )) # NO DSC objective image_cnn_name = os.path.join(patient_path, p_name + '.cnn.test%s.e%d' % (shuffle_s, epochs)) image_gan_name = os.path.join(patient_path, p_name + '.gan.test%s.e%d' % (shuffle_s, epochs)) # DSC objective image_cnn_dsc_name = os.path.join(patient_path, p_name + '.dsc-cnn.test%s.e%d' % (shuffle_s, epochs)) image_gan_dsc_name = os.path.join(patient_path, p_name + '.dsc-gan.test%s.e%d' % (shuffle_s, epochs)) try: # NO DSC objective image_cnn = load_nii(image_cnn_name + '.nii.gz').get_data() image_cnn_pr = load_nii(image_cnn_name + '.pr.nii.gz').get_data() image_gan = load_nii(image_gan_name + '.nii.gz').get_data() image_gan_pr = load_nii(image_gan_name + '.pr.nii.gz').get_data() # DSC objective image_cnn_dsc = load_nii(image_cnn_dsc_name + '.nii.gz').get_data() image_cnn_dsc_pr = load_nii(image_cnn_dsc_name + '.pr.nii.gz').get_data() image_gan_dsc = load_nii(image_gan_dsc_name + '.nii.gz').get_data() image_gan_dsc_pr = load_nii(image_gan_dsc_name + '.pr.nii.gz').get_data() except IOError: # Lesion segmentation adversarial_w = K.variable(0) # NO DSC objective cnn, gan, gan_test = get_wmh_nets( input_shape=input_shape, filters_list=filters_list, kernel_size_list=kernel_size_list, dense_size=dense_size, lambda_var=adversarial_w ) # DSC objective cnn_dsc, gan_dsc, gan_dsc_test = get_wmh_nets( input_shape=input_shape, filters_list=filters_list, kernel_size_list=kernel_size_list, dense_size=dense_size, lambda_var=adversarial_w, dsc_obj=True ) train_nets( gan=gan, gan_dsc=gan_dsc, cnn=cnn, cnn_dsc=cnn_dsc, p=p, x=x_seg, y=y_seg, name='wmh2017' + sufix, adversarial_w=adversarial_w ) # NO DSC objective image_cnn = test_net(cnn, p, image_cnn_name) image_cnn_pr = load_nii(image_cnn_name + '.pr.nii.gz').get_data() image_gan = test_net(gan_test, p, image_gan_name) image_gan_pr = load_nii(image_gan_name + '.pr.nii.gz').get_data() # DSC objective image_cnn_dsc = test_net(cnn_dsc, p, image_cnn_dsc_name) image_cnn_dsc_pr = load_nii(image_cnn_dsc_name + '.pr.nii.gz').get_data() image_gan_dsc = test_net(gan_dsc_test, p, image_gan_dsc_name) image_gan_dsc_pr = load_nii(image_gan_dsc_name + '.pr.nii.gz').get_data() # NO DSC objective seg_cnn = image_cnn.astype(np.bool) seg_gan = image_gan.astype(np.bool) # DSC objective seg_cnn_dsc = image_cnn_dsc.astype(np.bool) seg_gan_dsc = image_gan_dsc.astype(np.bool) seg_gt = load_nii(gt_name).get_data() not_roi = np.logical_not(seg_gt == 2) results_cnn_dsc = dsc_seg(seg_gt == 1, np.logical_and(seg_cnn_dsc, not_roi)) results_cnn_dsc_pr = probabilistic_dsc_seg(seg_gt == 1, image_cnn_dsc_pr * not_roi) results_cnn = dsc_seg(seg_gt == 1, np.logical_and(seg_cnn, not_roi)) results_cnn_pr = probabilistic_dsc_seg(seg_gt == 1, image_cnn_pr * not_roi) results_gan_dsc = dsc_seg(seg_gt == 1, np.logical_and(seg_gan_dsc, not_roi)) results_gan_dsc_pr = probabilistic_dsc_seg(seg_gt == 1, image_gan_dsc_pr * not_roi) results_gan = dsc_seg(seg_gt == 1, np.logical_and(seg_gan, not_roi)) results_gan_pr = probabilistic_dsc_seg(seg_gt == 1, image_gan_pr * not_roi) whites = ''.join([' '] * 14) print('%sCase %s%s%s%s %sCNN%s vs %sGAN%s DSC: %s%f%s (%s%f%s) vs %s%f%s (%s%f%s)' % ( whites, c['c'], c['b'], p_name, c['nc'], c['lgy'], c['nc'], c['y'], c['nc'], c['lgy'], results_cnn_dsc, c['nc'], c['lgy'], results_cnn, c['nc'], c['y'], results_gan_dsc, c['nc'], c['y'], results_gan, c['nc'] )) print('%sCase %s%s%s%s %sCNN%s vs %sGAN%s DSC Pr: %s%f%s (%s%f%s) vs %s%f%s (%s%f%s)' % ( whites, c['c'], c['b'], p_name, c['nc'], c['lgy'], c['nc'], c['y'], c['nc'], c['lgy'], results_cnn_dsc_pr, c['nc'], c['lgy'], results_cnn_pr, c['nc'], c['y'], results_gan_dsc_pr, c['nc'], c['y'], results_gan_pr, c['nc'] )) dsc_results.append((results_cnn_dsc, results_cnn, results_gan_dsc, results_gan)) dsc_results_pr.append((results_cnn_dsc_pr, results_cnn_pr, results_gan_dsc_pr, results_gan_pr)) final_dsc = tuple(np.mean(dsc_results, axis=0)) final_dsc_pr = tuple(np.mean(dsc_results_pr, axis=0)) print('Final results DSC: %s%f%s (%s%f%s) vs %s%f%s (%s%f%s)' % ( c['lgy'], final_dsc[0], c['nc'], c['lgy'], final_dsc[1], c['nc'], c['y'], final_dsc[2], c['nc'], c['y'], final_dsc[3], c['nc'] )) print('Final results DSC Pr: %s%f%s (%s%f%s) vs %s%f%s (%s%f%s)' % ( c['lgy'], final_dsc_pr[0], c['nc'], c['lgy'], final_dsc_pr[1], c['nc'], c['y'], final_dsc_pr[2], c['nc'], c['y'], final_dsc_pr[3], c['nc'] ))
def main(): options = parse_inputs() c = color_codes() # Prepare the net architecture parameters sequential = options['sequential'] dfactor = options['dfactor'] # Prepare the net hyperparameters num_classes = 5 epochs = options['epochs'] padding = options['padding'] patch_width = options['patch_width'] patch_size = (patch_width, patch_width, patch_width) batch_size = options['batch_size'] dense_size = options['dense_size'] conv_blocks = options['conv_blocks'] n_filters = options['n_filters'] filters_list = n_filters if len(n_filters) > 1 else n_filters * conv_blocks conv_width = options['conv_width'] kernel_size_list = conv_width if isinstance( conv_width, list) else [conv_width] * conv_blocks balanced = options['balanced'] recurrent = options['recurrent'] # Data loading parameters preload = options['preload'] queue = options['queue'] # Prepare the sufix that will be added to the results for the net and images path = options['dir_name'] filters_s = 'n'.join(['%d' % nf for nf in filters_list]) conv_s = 'c'.join(['%d' % cs for cs in kernel_size_list]) s_s = '.s' if sequential else '.f' ub_s = '.ub' if not balanced else '' params_s = (ub_s, dfactor, s_s, patch_width, conv_s, filters_s, dense_size, epochs, padding) sufix = '%s.D%d%s.p%d.c%s.n%s.d%d.e%d.pad_%s.' % params_s n_channels = np.count_nonzero([ options['use_flair'], options['use_t2'], options['use_t1'], options['use_t1ce'] ]) print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + 'Starting cross-validation' + c['nc']) # N-fold cross validation main loop (we'll do 2 training iterations with testing for each patient) data_names, label_names = get_names_from_path(options) folds = options['folds'] fold_generator = izip( nfold_cross_validation(data_names, label_names, n=folds, val_data=0.25), xrange(folds)) dsc_results = list() for (train_data, train_labels, val_data, val_labels, test_data, test_labels), i in fold_generator: print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + c['nc'] + 'Fold %d/%d: ' % (i + 1, folds) + c['g'] + 'Number of training/validation/testing images (%d=%d/%d=%d/%d)' % (len(train_data), len(train_labels), len(val_data), len(val_labels), len(test_data)) + c['nc']) # Prepare the data relevant to the leave-one-out (subtract the patient from the dataset and set the path) # Also, prepare the network net_name = os.path.join( path, 'baseline-brats2017.fold%d' % i + sufix + 'mdl') # First we check that we did not train for that patient, in order to save time try: net = keras.models.load_model(net_name) except IOError: # NET definition using Keras train_centers = get_cnn_centers(train_data[:, 0], train_labels, balanced=balanced) val_centers = get_cnn_centers(val_data[:, 0], val_labels, balanced=balanced) train_samples = len(train_centers) / dfactor val_samples = len(val_centers) / dfactor print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + c['g'] + 'Creating and compiling the model ' + c['b'] + '(%d samples)' % train_samples + c['nc']) train_steps_per_epoch = -(-train_samples / batch_size) val_steps_per_epoch = -(-val_samples / batch_size) input_shape = (n_channels, ) + patch_size if sequential: # Sequential model that merges all 4 images. This architecture is just a set of convolutional blocks # that end in a dense layer. This is supposed to be an original baseline. net = Sequential() net.add( Conv3D(filters_list[0], kernel_size=kernel_size_list[0], input_shape=input_shape, activation='relu', data_format='channels_first')) for filters, kernel_size in zip(filters_list[1:], kernel_size_list[1:]): net.add(Dropout(0.5)) net.add( Conv3D(filters, kernel_size=kernel_size, activation='relu', data_format='channels_first')) net.add(Dropout(0.5)) net.add(Flatten()) net.add(Dense(dense_size, activation='relu')) net.add(Dropout(0.5)) net.add(Dense(num_classes, activation='softmax')) else: # This architecture is based on the functional Keras API to introduce 3 output paths: # - Whole tumor segmentation # - Core segmentation (including whole tumor) # - Whole segmentation (tumor, core and enhancing parts) # The idea is to let the network work on the three parts to improve the multiclass segmentation. merged_inputs = Input(shape=(4, ) + patch_size, name='merged_inputs') flair = Reshape((1, ) + patch_size)(Lambda( lambda l: l[:, 0, :, :, :], output_shape=(1, ) + patch_size)(merged_inputs), ) t2 = Reshape((1, ) + patch_size)(Lambda( lambda l: l[:, 1, :, :, :], output_shape=(1, ) + patch_size)(merged_inputs)) t1 = Lambda(lambda l: l[:, 2:, :, :, :], output_shape=(2, ) + patch_size)(merged_inputs) for filters, kernel_size in zip(filters_list, kernel_size_list): flair = Conv3D(filters, kernel_size=kernel_size, activation='relu', data_format='channels_first')(flair) t2 = Conv3D(filters, kernel_size=kernel_size, activation='relu', data_format='channels_first')(t2) t1 = Conv3D(filters, kernel_size=kernel_size, activation='relu', data_format='channels_first')(t1) flair = Dropout(0.5)(flair) t2 = Dropout(0.5)(t2) t1 = Dropout(0.5)(t1) # We only apply the RCNN to the multioutput approach (we keep the simple one, simple) if recurrent: flair = Conv3D(dense_size, kernel_size=(1, 1, 1), activation='relu', data_format='channels_first', name='fcn_flair')(flair) flair = Dropout(0.5)(flair) t2 = concatenate([flair, t2], axis=1) t2 = Conv3D(dense_size, kernel_size=(1, 1, 1), activation='relu', data_format='channels_first', name='fcn_t2')(t2) t2 = Dropout(0.5)(t2) t1 = concatenate([t2, t1], axis=1) t1 = Conv3D(dense_size, kernel_size=(1, 1, 1), activation='relu', data_format='channels_first', name='fcn_t1')(t1) t1 = Dropout(0.5)(t1) flair = Dropout(0.5)(flair) t2 = Dropout(0.5)(t2) t1 = Dropout(0.5)(t1) lstm_instance = LSTM(dense_size, implementation=1, name='rf_layer') flair = lstm_instance( Permute((2, 1))(Reshape((dense_size, -1))(flair))) t2 = lstm_instance( Permute((2, 1))(Reshape((dense_size, -1))(t2))) t1 = lstm_instance( Permute((2, 1))(Reshape((dense_size, -1))(t1))) else: flair = Flatten()(flair) t2 = Flatten()(t2) t1 = Flatten()(t1) flair = Dense(dense_size, activation='relu')(flair) flair = Dropout(0.5)(flair) t2 = concatenate([flair, t2]) t2 = Dense(dense_size, activation='relu')(t2) t2 = Dropout(0.5)(t2) t1 = concatenate([t2, t1]) t1 = Dense(dense_size, activation='relu')(t1) t1 = Dropout(0.5)(t1) tumor = Dense(2, activation='softmax', name='tumor')(flair) core = Dense(3, activation='softmax', name='core')(t2) enhancing = Dense(num_classes, activation='softmax', name='enhancing')(t1) net = Model(inputs=merged_inputs, outputs=[tumor, core, enhancing]) net.compile(optimizer='adadelta', loss='categorical_crossentropy', metrics=['accuracy']) print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + c['g'] + 'Training the model with a generator for ' + c['b'] + '(%d parameters)' % net.count_params() + c['nc']) print(net.summary()) net.fit_generator( generator=load_patch_batch_train(image_names=train_data, label_names=train_labels, centers=train_centers, batch_size=batch_size, size=patch_size, nlabels=num_classes, dfactor=dfactor, preload=preload, split=not sequential, datatype=np.float32), validation_data=load_patch_batch_train(image_names=val_data, label_names=val_labels, centers=val_centers, batch_size=batch_size, size=patch_size, nlabels=num_classes, dfactor=dfactor, preload=preload, split=not sequential, datatype=np.float32), steps_per_epoch=train_steps_per_epoch, validation_steps=val_steps_per_epoch, max_q_size=queue, epochs=epochs) net.save(net_name) # Then we test the net. use_gt = options['use_gt'] for p, gt_name in zip(test_data, test_labels): p_name = p[0].rsplit('/')[-2] patient_path = '/'.join(p[0].rsplit('/')[:-1]) outputname = os.path.join(patient_path, 'deep-brats17' + sufix + 'test.nii.gz') try: load_nii(outputname) except IOError: roi_nii = load_nii(p[0]) roi = roi_nii.get_data().astype(dtype=np.bool) centers = get_mask_voxels(roi) test_samples = np.count_nonzero(roi) image = np.zeros_like(roi).astype(dtype=np.uint8) print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + c['g'] + '<Creating the probability map ' + c['b'] + p_name + c['nc'] + c['g'] + ' (%d samples)>' % test_samples + c['nc']) test_steps_per_epoch = -(-test_samples / batch_size) y_pr_pred = net.predict_generator( generator=load_patch_batch_generator_test( image_names=p, centers=centers, batch_size=batch_size, size=patch_size, preload=preload, ), steps=test_steps_per_epoch, max_q_size=queue) [x, y, z] = np.stack(centers, axis=1) if not sequential: tumor = np.argmax(y_pr_pred[0], axis=1) y_pr_pred = y_pr_pred[-1] roi = np.zeros_like(roi).astype(dtype=np.uint8) roi[x, y, z] = tumor roi_nii.get_data()[:] = roi roiname = os.path.join( patient_path, 'deep-brats17' + sufix + 'test.roi.nii.gz') roi_nii.to_filename(roiname) y_pred = np.argmax(y_pr_pred, axis=1) image[x, y, z] = y_pred # Post-processing (Basically keep the biggest connected region) image = get_biggest_region(image) if use_gt: gt_nii = load_nii(gt_name) gt = np.copy(gt_nii.get_data()).astype(dtype=np.uint8) labels = np.unique(gt.flatten()) results = (p_name, ) + tuple( [dsc_seg(gt == l, image == l) for l in labels[1:]]) text = 'Subject %s DSC: ' + '/'.join( ['%f' for _ in labels[1:]]) print(text % results) dsc_results.append(results) print(c['g'] + ' -- Saving image ' + c['b'] + outputname + c['nc']) roi_nii.get_data()[:] = image roi_nii.to_filename(outputname)
def main(): options = parse_inputs() c = color_codes() # Prepare the net hyperparameters epochs = options['epochs'] patch_width = options['patch_width'] patch_size = (patch_width, patch_width, patch_width) dense_size = options['dense_size'] conv_blocks = options['conv_blocks'] n_filters = options['n_filters'] filters_list = n_filters if len(n_filters) > 1 else n_filters * conv_blocks conv_width = options['conv_width'] kernel_size_list = conv_width if isinstance(conv_width, list) else [conv_width] * conv_blocks balanced = options['balanced'] # Data loading parameters preload = options['preload'] # Prepare the sufix that will be added to the results for the net and images filters_s = 'n'.join(['%d' % nf for nf in filters_list]) conv_s = 'c'.join(['%d' % cs for cs in kernel_size_list]) ub_s = '.ub' if not balanced else '' params_s = (ub_s, patch_width, conv_s, filters_s, dense_size, epochs) sufix = '%s.p%d.c%s.n%s.d%d.e%d' % params_s preload_s = ' (with ' + c['b'] + 'preloading' + c['nc'] + c['c'] + ')' if preload else '' print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + 'Starting training' + preload_s + c['nc']) train_data, _ = get_names_from_path(options) test_data, test_labels = get_names_from_path(options, False) input_shape = (train_data.shape[1],) + patch_size dsc_results_gan = list() dsc_results_cnn = list() dsc_results_caps = list() train_data, train_labels = get_names_from_path(options) centers_s = np.random.permutation( get_cnn_centers(train_data[:, 0], train_labels, balanced=balanced) )[::options['down_sampling']] x_seg, y_seg = load_patches_ganseg_by_batches( image_names=train_data, label_names=train_labels, source_centers=centers_s, size=patch_size, nlabels=5, preload=preload, batch_size=51200 ) y_seg_roi = np.empty((len(y_seg), 2), dtype=np.bool) y_seg_roi[:, 0] = y_seg[:, 0] y_seg_roi[:, 1] = np.sum(y_seg[:, 1:], axis=1) for i, (p, gt_name) in enumerate(zip(test_data, test_labels)): p_name = p[0].rsplit('/')[-2] patient_path = '/'.join(p[0].rsplit('/')[:-1]) print('%s[%s] %sCase %s%s%s%s%s (%d/%d):%s' % ( c['c'], strftime("%H:%M:%S"), c['nc'], c['c'], c['b'], p_name, c['nc'], c['c'], i + 1, len(test_data), c['nc'] )) # ROI segmentation adversarial_w = K.variable(0) roi_cnn = get_brats_fc(input_shape, filters_list, kernel_size_list, dense_size, 2) roi_caps = get_brats_caps(input_shape, filters_list, kernel_size_list, 8, 2) roi_gan, _ = get_brats_gan_fc( input_shape, filters_list, kernel_size_list, dense_size, 2, lambda_var=adversarial_w ) train_nets( x=x_seg, y=y_seg_roi, gan=roi_gan, cnn=roi_cnn, caps=roi_caps, p=p, name='brats2017-roi' + sufix, adversarial_w=adversarial_w ) # Tumor substructures net adversarial_w = K.variable(0) seg_cnn = get_brats_fc(input_shape, filters_list, kernel_size_list, dense_size, 5) seg_caps = get_brats_caps(input_shape, filters_list, kernel_size_list, 8, 5) seg_gan_tr, seg_gan_tst = get_brats_gan_fc( input_shape, filters_list, kernel_size_list, dense_size, 5, lambda_var=adversarial_w ) roi_net_conv_layers = [l for l in roi_gan.layers if 'conv' in l.name] seg_net_conv_layers = [l for l in seg_gan_tr.layers if 'conv' in l.name] for lr, ls in zip(roi_net_conv_layers[:conv_blocks], seg_net_conv_layers[:conv_blocks]): ls.set_weights(lr.get_weights()) train_nets( x=x_seg, y=y_seg, gan=seg_gan_tr, cnn=seg_cnn, caps=seg_caps, p=p, name='brats2017-full' + sufix, adversarial_w=adversarial_w ) image_cnn_name = os.path.join(patient_path, p_name + '.cnn.test') try: image_cnn = load_nii(image_cnn_name + '.nii.gz').get_data() except IOError: image_cnn = test_net(seg_cnn, p, image_cnn_name) image_caps_name = os.path.join(patient_path, p_name + '.caps.test') try: image_caps = load_nii(image_caps_name + '.nii.gz').get_data() except IOError: image_caps = test_net(seg_caps, p, image_caps_name) image_gan_name = os.path.join(patient_path, p_name + '.gan.test') try: image_gan = load_nii(image_gan_name + '.nii.gz').get_data() except IOError: image_gan = test_net(seg_gan_tst, p, image_gan_name) results_cnn = check_dsc(gt_name, image_cnn) dsc_string = c['g'] + '/'.join(['%f'] * len(results_cnn)) + c['nc'] print(''.join([' '] * 14) + c['c'] + c['b'] + p_name + c['nc'] + ' CNN DSC: ' + dsc_string % tuple(results_cnn)) results_caps = check_dsc(gt_name, image_caps) dsc_string = c['g'] + '/'.join(['%f'] * len(results_caps)) + c['nc'] print(''.join([' '] * 14) + c['c'] + c['b'] + p_name + c['nc'] + ' CAPS DSC: ' + dsc_string % tuple(results_caps)) results_gan = check_dsc(gt_name, image_gan) dsc_string = c['g'] + '/'.join(['%f'] * len(results_gan)) + c['nc'] print(''.join([' '] * 14) + c['c'] + c['b'] + p_name + c['nc'] + ' GAN DSC: ' + dsc_string % tuple(results_gan)) dsc_results_cnn.append(results_cnn) dsc_results_caps.append(results_caps) dsc_results_gan.append(results_gan) f_dsc = tuple( [np.array([dsc[i] for dsc in dsc_results_cnn if len(dsc) > i]).mean() for i in range(3)] ) + tuple( [np.array([dsc[i] for dsc in dsc_results_caps if len(dsc) > i]).mean() for i in range(3)] ) + tuple( [np.array([dsc[i] for dsc in dsc_results_gan if len(dsc) > i]).mean() for i in range(3)] ) print('Final results DSC: (%f/%f/%f) vs (%f/%f/%f) vs (%f/%f/%f)' % f_dsc)
def main(): options = parse_inputs() c = color_codes() # Prepare the net architecture parameters dfactor = options['dfactor'] # Prepare the net hyperparameters num_classes = 4 epochs = options['epochs'] patch_width = options['patch_width'] patch_size = (patch_width, patch_width, patch_width) batch_size = options['batch_size'] dense_size = options['dense_size'] conv_blocks = options['conv_blocks'] n_filters = options['n_filters'] filters_list = n_filters if len(n_filters) > 1 else n_filters * conv_blocks conv_width = options['conv_width'] kernel_size_list = conv_width if isinstance( conv_width, list) else [conv_width] * conv_blocks balanced = options['balanced'] val_rate = options['val_rate'] # Data loading parameters preload = options['preload'] queue = options['queue'] # Prepare the sufix that will be added to the results for the net and images path = options['dir_name'] filters_s = 'n'.join(['%d' % nf for nf in filters_list]) conv_s = 'c'.join(['%d' % cs for cs in kernel_size_list]) ub_s = '.ub' if not balanced else '' params_s = (ub_s, dfactor, patch_width, conv_s, filters_s, dense_size, epochs) sufix = '%s.D%d.p%d.c%s.n%s.d%d.e%d.' % params_s n_channels = 4 preload_s = ' (with ' + c['b'] + 'preloading' + c['nc'] + c[ 'c'] + ')' if preload else '' print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + 'Starting training' + preload_s + c['nc']) # N-fold cross validation main loop (we'll do 2 training iterations with testing for each patient) train_data, train_labels = get_names_from_path(options) print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + c['nc'] + c['g'] + 'Number of training images (%d=%d)' % (len(train_data), len(train_labels)) + c['nc']) # Also, prepare the network net_name = os.path.join(path, 'CBICA-brats2017' + sufix) print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + c['g'] + 'Creating and compiling the model ' + c['nc']) input_shape = (train_data.shape[1], ) + patch_size # Sequential model that merges all 4 images. This architecture is just a set of convolutional blocks # that end in a dense layer. This is supposed to be an original baseline. inputs = Input(shape=input_shape, name='merged_inputs') conv = inputs for filters, kernel_size in zip(filters_list, kernel_size_list): conv = Conv3D(filters, kernel_size=kernel_size, activation='relu', data_format='channels_first')(conv) conv = Dropout(0.5)(conv) full = Conv3D(dense_size, kernel_size=(1, 1, 1), data_format='channels_first')(conv) full = PReLU()(full) full = Conv3D(2, kernel_size=(1, 1, 1), data_format='channels_first')(full) rf = concatenate([conv, full], axis=1) while np.product(K.int_shape(rf)[2:]) > 1: rf = Conv3D(dense_size, kernel_size=(3, 3, 3), data_format='channels_first')(rf) rf = Dropout(0.5)(rf) full = Reshape((2, -1))(full) full = Permute((2, 1))(full) full_out = Activation('softmax', name='fc_out')(full) tumor = Dense(2, activation='softmax', name='tumor')(rf) outputs = [tumor, full_out] net = Model(inputs=inputs, outputs=outputs) net.compile(optimizer='adadelta', loss='categorical_crossentropy', loss_weights=[0.8, 1.0], metrics=['accuracy']) fc_width = patch_width - sum(kernel_size_list) + conv_blocks fc_shape = (fc_width, ) * 3 checkpoint = net_name + '{epoch:02d}.{val_tumor_acc:.2f}.hdf5' callbacks = [ EarlyStopping(monitor='val_tumor_loss', patience=options['patience']), ModelCheckpoint(os.path.join(path, checkpoint), monitor='val_tumor_loss', save_best_only=True) ] for i in range(options['r_epochs']): try: net = load_model(net_name + ('e%d.' % i) + 'mdl') except IOError: train_centers = get_cnn_centers(train_data[:, 0], train_labels, balanced=balanced) train_samples = len(train_centers) / dfactor print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + c['g'] + 'Loading data ' + c['b'] + '(%d centers)' % (len(train_centers) / dfactor) + c['nc']) x, y = load_patches_train(image_names=train_data, label_names=train_labels, centers=train_centers, size=patch_size, fc_shape=fc_shape, nlabels=2, dfactor=dfactor, preload=preload, split=True, iseg=False, experimental=1, datatype=np.float32) print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + c['g'] + 'Training the model for ' + c['b'] + '(%d parameters)' % net.count_params() + c['nc']) print(net.summary()) net.fit(x, y, batch_size=batch_size, validation_split=val_rate, epochs=epochs, callbacks=callbacks) net.save(net_name + ('e%d.' % i) + 'mdl')
def train_net(fold_n, train_data, train_labels, options): # Prepare the net architecture parameters dfactor = options['dfactor'] # Prepare the net hyperparameters epochs = options['epochs'] patch_width = options['patch_width'] patch_size = (patch_width, ) * 3 batch_size = options['batch_size'] dense_size = options['dense_size'] conv_blocks = options['conv_blocks'] nfilters = options['n_filters'] filters_list = nfilters if len(nfilters) > 1 else nfilters * conv_blocks conv_width = options['conv_width'] kernel_size_list = conv_width if isinstance( conv_width, list) else [conv_width] * conv_blocks experimental = options['experimental'] fc_width = patch_width - sum(kernel_size_list) + conv_blocks fc_shape = (fc_width, ) * 3 # Data loading parameters preload = options['preload'] # Prepare the sufix that will be added to the results for the net and images path = options['dir_name'] sufix = get_sufix(options) net_name = os.path.join(path, 'iseg2017.fold%d' % fold_n + sufix + 'mdl') checkpoint = 'iseg2017.fold%d' % fold_n + sufix + '{epoch:02d}.{val_brain_acc:.2f}.hdf5' c = color_codes() try: net = load_model(net_name) net.load_weights(os.path.join(path, checkpoint)) except IOError: # Data loading train_centers = get_cnn_centers(train_data[:, 0], train_labels) train_samples = len(train_centers) / dfactor print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + c['g'] + 'Loading data ' + c['b'] + '(%d centers)' % len(train_centers) + c['nc']) x, y = load_patches_train(image_names=train_data, label_names=train_labels, centers=train_centers, size=patch_size, fc_shape=fc_shape, nlabels=4, dfactor=dfactor, preload=preload, split=True, iseg=True, experimental=experimental, datatype=np.float32) # NET definition using Keras print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + c['g'] + 'Creating and compiling the model ' + c['b'] + '(%d samples)' % train_samples + c['nc']) input_shape = (2, ) + patch_size # This architecture is based on the functional Keras API to introduce 3 output paths: # - Whole tumor segmentation # - Core segmentation (including whole tumor) # - Whole segmentation (tumor, core and enhancing parts) # The idea is to let the network work on the three parts to improve the multiclass segmentation. network_func = [ get_iseg_baseline, get_iseg_experimental1, get_iseg_experimental2, get_iseg_experimental3, get_iseg_experimental4 ] net = network_func[experimental](input_shape, filters_list, kernel_size_list, dense_size) print(c['c'] + '[' + strftime("%H:%M:%S") + '] ' + c['g'] + 'Training the model ' + c['b'] + '(%d parameters)' % net.count_params() + c['nc']) print(net.summary()) callbacks = [ EarlyStopping(monitor='val_brain_loss', patience=options['patience']), ModelCheckpoint(os.path.join(path, checkpoint), monitor='val_brain_loss', save_best_only=True) ] net.save(net_name) net.fit(x, y, batch_size=batch_size, validation_split=0.25, epochs=epochs, callbacks=callbacks) net.load_weights(os.path.join(path, checkpoint)) return net