def evaluate(config, datasets=None): """ Evaluate the test data by given saved model. :param config: type dict: config parameter :param datasets: type list of str: list of dataset names :return: lists_loss_and_metrics: type list: evaluate result (losses and metrics) """ lists_loss_and_metrics = [] for dataset in datasets: # Get test dataset_image_paths and dataset_label_path if not os.path.exists(config['dir_dataset_info'] + '/split_paths_' + dataset + '.pickle'): raise FileNotFoundError( 'Paths of dataset `config[dir_dataset_info]/split_paths.pickle`are not found! ' ) with open( config['dir_dataset_info'] + '/split_paths_' + dataset + '.pickle', 'rb') as fp: split_path = pickle.load(fp) dataset_image_path = split_path['path_test_img'] dataset_label_path = split_path['path_test_label'] # Set the config files config = channel_config(config, dataset, evaluate=True) # create pipeline dataset ds_test = pipeline(config, dataset_image_path, dataset_label_path, dataset=dataset) # Choose the training model. model = load_model_file(config, dataset) print('Now evaluating data ', dataset, ' ...') # Fit training & validation data into the model list_loss_and_metrics = model.evaluate( ds_test, verbose=config['evaluate_verbose_mode']) lists_loss_and_metrics.append(list_loss_and_metrics) path_pickle = config['result_rootdir'] + config[ 'exp_name'] + '/' + config['model'] + '/evaluate_loss_and_metrics/' if not os.path.exists(path_pickle): os.makedirs(path_pickle) dictionary = dict() # Save loss dictionary['evaluate_loss'] = list_loss_and_metrics[0] # Save metrics for i, item in enumerate(lists_loss_and_metrics): dictionary['evaluate_' + item] = list_loss_and_metrics[i + 1] with open(path_pickle + dataset + '.pickle', 'wb') as fp: pickle.dump(dictionary, fp, protocol=pickle.HIGHEST_PROTOCOL) print('Sucessfully save the evaluate loss and metrics of ' + dataset + '.') print('Evaluating data ', dataset, 'is finished.') return lists_loss_and_metrics
def k_fold_train_process(config, model, k_fold, paths, dataset, cp_callback, init_epoch, saver1): """ K-fold training process :param config: type dict: config parameter :param model: type tf.keras.Model, training model :param paths: type dict of str: tfrecords path loaded from pickle file. :param dataset: type str: name of dataset :param cp_callback: type tf.keras.callbacks.ModelCheckpoint, training check point :param init_epoch: type int, initial epoch. :return: models: type tf.keras.Model, trained model :return: history type list of float, metrics evaluating value from each epoch. """ # history = None list_1 = list( zip(paths['path_train_val_img'], paths['path_train_val_label'])) random.shuffle(list_1) divided_datapath = len(list_1) // k_fold assert (divided_datapath > 0) history = [] for k in range(k_fold): # Split train and eva list_val = list_1[k * divided_datapath:(k + 1) * divided_datapath] list_train = list_1[0:k * divided_datapath] + list_1[ (k + 1) * divided_datapath:len(list_1)] print('k_fold', k, ' list_val:', list_val, ' list_train:', list_train) [paths_train_img, paths_train_label] = zip(*list_train) [paths_val_img, paths_val_label] = zip(*list_val) print('Now training data:', dataset, ', k fold: ', k, ' ...') if not config['k_fold_merge_model']: # train all k-fold on one model model, history = train_process(config, model, paths_train_img, paths_train_label, paths_val_img, paths_val_label, dataset, cp_callback, saver1, k_fold_index=k, init_epoch=k * config['epochs'] + init_epoch) else: # establish one new model at each fold. model, hist = train_process(config, model, paths_train_img, paths_train_label, paths_val_img, paths_val_label, dataset, cp_callback, saver1, k_fold_index=k, init_epoch=init_epoch) history.append(hist) # save model saved_model_path = config['saved_models_dir'] + '/' + config[ 'exp_name'] + '/' + config['model'] if not os.path.exists(saved_model_path): os.makedirs(saved_model_path) # Save the model when training process is finished. model.save(saved_model_path + '/' + dataset + 'k_fold_' + str(k) + '.h5') if k != k_fold - 1: # create a new model for next k-fold. if not config['train_premodel']: call_model = getattr(ModelSet, config['model']) model, list_metric_names = call_model(self=ModelSet, config=config) else: model = load_model_file(config, dataset, compile=True) return model, history
def train(config, restore=False): """ Train the dataset from given paths of dataset. :param config: type dict: config parameter :param restore: type bool, True if resume training from the last checkpoint :return: models: type list of model, trained model :return: histories type list of of list of float, metrics evaluating value from each epoch. """ models, histories = [], [] for pickle_path, pickle_max_shape, dataset in zip( config['filename_tfrec_pickle'], config['filename_max_shape_pickle'], config['dataset']): if restore: # Resume dataset. with open( config['dir_model_checkpoint'] + '/' + config['exp_name'] + '/' + 'training_info.pickle', 'rb') as fp: restore_dataset = pickle.load(fp)['dataset'] while True: if restore_dataset != dataset: command = input( 'Warning! The stored resuming dataset name last time is not coincident with the dataset this time,' ' do you want to overwrite? (y/n)') if command == 'y': break elif command == 'n': dataset = restore_dataset break else: print('Invalid command.') else: print('Resume training dataset: ', dataset, '...') break config = train_config_setting(config, dataset) # Load split (training, validation, test) tfrecord paths. if config['read_body_identification']: split_filename = config[ 'dir_dataset_info'] + '/split_paths_' + dataset + '_bi.pickle' else: split_filename = config[ 'dir_dataset_info'] + '/split_paths_' + dataset + '.pickle' with open(split_filename, 'rb') as fp: paths = pickle.load(fp) # Choose the training model. if not config['train_premodel']: call_model = getattr(ModelSet, config['model']) model, list_metric_names = call_model(self=ModelSet, config=config) else: # load pre-trained model model = load_model_file(config, dataset, compile=True) print(model.summary()) # Create checkpoint for saving model during training. if not os.path.exists(config['dir_model_checkpoint'] + '/' + config['exp_name']): os.makedirs(config['dir_model_checkpoint'] + '/' + config['exp_name']) checkpoint_path = config['dir_model_checkpoint'] + '/' + config[ 'exp_name'] + '/cp_' + dataset + '_' + config['model'] + '.hdf5' tb_tool = TensorBoardTool(config['dir_model_checkpoint'] + '/' + config['exp_name']) # start the Tensorboard # Create a callback that saves the model's weights every X epochs. cp_callback = [ tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_path, verbose=1, save_weights_only=False, period=config['save_training_model_period']), tf.keras.callbacks.TensorBoard(os.path.dirname(checkpoint_path), histogram_freq=1) ] # Initial epoch of training data init_epoch = 0 if restore: # Resume saved epoch. model.load_weights(checkpoint_path) with open( config['dir_model_checkpoint'] + '/' + config['exp_name'] + '/' + 'training_info.pickle', 'rb') as fp: init_epoch = pickle.load(fp)['epoch'] + 1 restore = False # Log data at end of training epoch class Additional_Saver(tf.keras.callbacks.Callback): """ The program on the end of each epoch, Add here if any progress are processed on the end of each epoch """ def on_epoch_end(self, epoch, logs={}): if epoch % config[ 'save_training_model_period'] == 0 and epoch != 0: with open( config['dir_model_checkpoint'] + '/' + config['exp_name'] + '/training_info.pickle', 'wb') as fp: pickle.dump( { 'epoch': epoch, 'dataset': dataset, 'model': config['model'] }, fp, protocol=pickle.HIGHEST_PROTOCOL) if not os.path.exists('train_record'): os.makedirs('train_record') file1 = open( 'train_record/' + config['model'] + '_' + dataset + ".txt", "a+") now = datetime.datetime.now() file1.write('dataset: ' + dataset + ', Epoch: ' + str(epoch) + ', Model: ' + config['model'] + ', time: ' + now.strftime("%Y-%m-%d %H:%M:%S") + ', pid:' + str(os.getpid())) file1.write("\n") file1.close() saver1 = Additional_Saver() print('Now training data: ', dataset) k_fold = config['k_fold'][dataset] history_dataset = [] if k_fold is not None: model, history = k_fold_train_process(config, model, k_fold, paths, dataset, cp_callback, init_epoch, saver1) else: # without k-fold model, history = train_process(config, model, paths['path_train_img'], paths['path_train_label'], paths['path_val_img'], paths['path_val_label'], dataset, cp_callback, saver1, k_fold_index=0, init_epoch=init_epoch) history_dataset.append(history) saved_model_path = config['saved_models_dir'] + '/' + config[ 'exp_name'] + '/' + config['model'] if not os.path.exists(saved_model_path): os.makedirs(saved_model_path) # Save the model when training process is finished. model.save(saved_model_path + '/' + dataset + '.h5') print('Training data ', dataset, 'is finished') models.append(model) histories.append(history_dataset) return models, histories
def predict_image(config, dataset, model, patch_imgs, indice_list, img_data_shape=None): """ Predict the image from the model :param config: type dict: config parameter :param dataset: type str, name of dataset :param model: type tf.keras.Model, the Model of prediction :param patch_imgs: type list of ndarray, the output :param indice_list: :param img_data_shape: :return:predict_img """ # Select the input channels, which are correspondent to model input channels if config['input_channel'][dataset] is not None: patch_imgs = patch_imgs[..., config['input_channel'][dataset]] indice_list_model = indice_list if config['regularize_indice_list']['max_shape']: indice_list_model = np.float32(np.array(indice_list)) / np.array(config['max_shape']['image'])[..., 0] elif config['regularize_indice_list']['image_shape']: indice_list_model = np.float32(np.array(indice_list)) / np.array( img_data_shape)[:-1] elif config['regularize_indice_list']['custom_specified']: indice_list_model = np.float32(np.array(indice_list)) / np.array( config['regularize_indice_list']['custom_specified']) # Predict the test data by given trained model print(config['saved_models_dir'] + '/' + config['exp_name'] + '/' + config['model'] + '/' + dataset + '.h5') try: if not config['read_body_identification']: predict_patch_imgs = model.predict(x=(patch_imgs, indice_list_model), batch_size=1, verbose=1) print('predict_patch_imgs_shape', predict_patch_imgs.shape) else: if config['model'] == 'model_body_identification_hybrid': predict_patch_imgs, predict_reg = model.predict(x=(patch_imgs[..., 0], indice_list_model), batch_size=1, verbose=1) else: predict_patch_imgs = model.predict(x=(patch_imgs[..., 0]), batch_size=1, verbose=1) except: try: print( 'Predict by model with load_weights_only=True Failed, Try rebuild model with load_weights_only=False...') config['load_weights_only'] = False model = load_model_file(config, dataset) predict_patch_imgs = model.predict(x=(patch_imgs, indice_list_model), batch_size=1, verbose=1) except: return -1 # patch images-> whole image if not config['read_body_identification']: predict_img = unpatch_predict_image(predict_patch_imgs, indice_list, config['patch_size'], output_patch_size=config['model_output_size'], set_zero_by_threshold=config['set_zero_by_threshold'], threshold=config['unpatch_start_threshold']) # Adjust model output channel order if config['predict_output_channel_order']: channel_stack = [] for j in config['predict_output_channel_order']: channel_stack.append(predict_img[..., j]) predict_img = np.stack(tuple(channel_stack), axis=-1) return predict_img else: # Especially for network body identification # prob_map,decision_map(191, 256, 110, 6) (191, 256, 110, 6) prob_map, decision_map = prediction_prob(config, predict_patch_imgs, indice_list) print("prob_map, decision_map", prob_map.shape, decision_map.shape) return prob_map, decision_map
def predict(config, datasets=None, save_predict_data=False, name_ID=None): """ Predict the test data based on the saved model :param config: type dict: config parameter :param datasets: type list of str, names of dataset :param save_predict_data: type bool, names of dataset :param save_predict_data: type bool, names of dataset :return: """ if datasets is None: datasets = config['dataset'] if config['load_predict_from_tfrecords']: # If predict data is in tfrecords format for dataset in datasets: if not name_ID: # Load path of predict dataset. if not config['read_body_identification']: split_path = config['dir_dataset_info'] + '/split_paths_' + dataset + '.pickle' else: split_path = config['dir_dataset_info'] + '/split_paths_' + dataset + '_bi.pickle' with open(split_path, 'rb') as fp: split_path = pickle.load(fp) dataset_image_path = split_path['path_test_img'] dataset_label_path = split_path['path_test_label'] else: # load tfrecords by a single name_ID dataset_image_path = [[config['rootdir_tfrec'][dataset] + '/' + name_ID + '/image/image.tfrecords']] dataset_label_path = [[config['rootdir_tfrec'][dataset] + '/' + name_ID + '/label/label.tfrecords']] config = channel_config(config, dataset) # Reformat data path list: [[path1],[path2], ...] ->[[path1, path2, ...]] data_path_image_list = [t[i] for t in dataset_image_path for i in range(len(dataset_image_path[0]))] data_path_label_list = [t[i] for t in dataset_label_path for i in range(len(dataset_label_path[0]))] list_image_TFRecordDataset = [tf.data.TFRecordDataset(i) for i in data_path_image_list] list_label_TFRecordDataset = [tf.data.TFRecordDataset(i) for i in data_path_label_list] # Choose and create the model which is the same with the saved model. print('load_model now...') model = load_model_file(config, dataset) collect_predict, collect_label = [], [] for index, (image_TFRecordDataset, label_TFRecordDataset, data_path_image) in \ enumerate(zip(list_image_TFRecordDataset, list_label_TFRecordDataset, data_path_image_list)): dataset_image = image_TFRecordDataset.map(parser) dataset_label = label_TFRecordDataset.map(parser) # Get the image data from tfrecords # elem[0]= data, elem[1]= data shape img_data = [elem[0].numpy() for elem in dataset_image][0] label_data_onehot = [elem[0].numpy() for elem in dataset_label][0] if not config['read_body_identification']: img_data, label_data_onehot = image_transform(config, img_data, label_data_onehot) # Patch the image patch_imgs, indice_list = patch_image(config, img_data) predict_img = predict_image(config, dataset, model, patch_imgs, indice_list) # Get name_ID from the data path # The data path must have the specified format which is generated from med_io/preprocess_raw_dataset.py name_ID = data_path_image.replace('\\', '/').split('/')[-3] if isinstance(predict_img, int) and predict_img == -1: print('failed loading: ' + name_ID) continue if config['output_label_tfrecords']: predict_img, label_data_onehot = select_output_channel(config, dataset, predict_img, label_data_onehot) predict_img_integers, predict_img_onehot, label_data_integers, label_data_onehot = convert_result( config, predict_img, label_data_onehot) else: # bypass label_data_onehot, label_data_integers = None, None predict_img = select_output_channel(config, dataset, predict_img) predict_img_integers, predict_img_onehot = convert_result(config, predict_img) # Get data of one patient for plot dict_data = {'predict_integers': predict_img_integers, 'predict_onehot': predict_img_onehot, 'label_integers': label_data_integers, 'label_onehot': label_data_onehot, 'original_image': img_data, 'without_mask': np.zeros(predict_img_integers.shape)} if config['plot_figure']: # Plot figure based on the single patient plot_figures_single(config, dict_data, dataset=dataset, name_ID=name_ID) if save_predict_data: save_img_mat(config, dataset, name_ID, 'predict_image', predict_img_integers) # Collect the image for plot. collect_predict.append(predict_img_onehot) collect_label.append(label_data_onehot) else: patch_imgs, indice_list = patch_image(config, img_data) predict_img = predict_image(config, dataset, model, patch_imgs, indice_list) if config['model'] == 'model_body_identification_hybrid' or config[ 'model'] == 'model_body_identification_classification': prob_map, decision_map = predict_img[0], predict_img[1] select_slice = config['select_body_identification_predict_slice'] prob_map = convert_onehot_to_integers(prob_map[select_slice, ...], axis=-1) decision_map = convert_onehot_to_integers(decision_map[select_slice, ...], axis=-1).astype( np.int32) pred_thresholds = get_thresholds(decision_map, n_classes=6) real_thresholds = label_data_onehot name_ID = data_path_image.replace('\\', '/').split('/')[ -3] # extract name_ID from the tfrecords path dict_data = {'predict_integers': decision_map.T, 'original_image': img_data[select_slice, :, :, 0].T * 100, 'pred_thresholds': pred_thresholds, 'real_thresholds': real_thresholds, 'predict_integers_probability': prob_map.T, } if config['plot_figure']: # Plot figure based on the single patient plot_figures_single(config, dict_data, dataset=dataset, name_ID=name_ID) if config['load_predict_from_tfrecords']: # list_images_series: collect predict data and label data list_images_series = {'predict': collect_predict, 'label': collect_label} if config['plot_figure']: plot_figures_dataset(config, list_images_series, dataset=dataset) print('Predict data ', dataset, 'is finished.') else: # Load dataset not from tfrecords. e.g. from nifti if datasets is None or datasets == []: datasets = ['New_predict_image'] for dataset in datasets: config = channel_config(config, dataset) # Choose and create the model which is the same with the saved model. model = load_model_file(config, dataset) collect_predict, collect_label = [], [] data_dir_img = config['predict_data_dir_img'] # get the image dir # If predict dataset has label if config['predict_load_label']: data_dir_label = config['predict_data_dir_label'] for index, (dir_name_img, dir_name_label) in \ enumerate((os.listdir(data_dir_img), os.listdir(data_dir_label))): data_path_img = os.path.join(config['predict_data_dir_img'], dir_name_img).replace('\\', '/') data_path_label = os.path.join(config['predict_data_dir_label'], dir_name_label).replace('\\', '/') img_data, label_data = read_predict_file(config, data_path_img, data_path_label) img_data, label_data_onehot = image_transform(config, img_data, label_data_onehot=label_data) # Patch the image patch_imgs, indice_list = patch_image(config, img_data) predict_img = predict_image(config, dataset, model, patch_imgs, indice_list_model) predict_img, label_data_onehot = select_output_channel(config, dataset, predict_img, label_data_onehot) predict_img_integers, predict_img_onehot, label_data_integers, label_data_onehot \ = convert_result(config, predict_img, label_data_onehot=label_data) name_ID = dir_name_img # Get data of one patient for plot dict_data = {'predict_integers': predict_img_integers, 'predict_onehot': predict_img_onehot, 'label_integers': label_data_integers, 'label_onehot': label_data_onehot, 'original_image': img_data, 'without_mask': np.zeros(predict_img_integers.shape)} if config['plot_figure']: plot_figures_single(config, dict_data, dataset=dataset, name_ID=name_ID) if save_predict_data: save_img_mat(config, dataset, name_ID, 'predict_image', predict_img_integers) # Collect the image for plot. collect_predict.append(predict_img_onehot) collect_label.append(label_data_onehot) else: # Load dataset not from tfrecords. e.g. from nifti, and have no labels for name_ID in os.listdir(data_dir_img): data_path_img = os.path.join(data_dir_img, name_ID).replace('\\', '/') img_data = read_predict_file(config, data_path_img, name_ID=name_ID) img_data = image_transform(config, img_data) patch_imgs, indice_list = patch_image(config, img_data) predict_img = predict_image(config, dataset, model, patch_imgs, indice_list, img_data_shape=img_data.shape) predict_img = select_output_channel(config, dataset, predict_img) predict_img_integers, predict_img_onehot = convert_result(config, predict_img) dict_data = {'predict_integers': predict_img_integers, 'predict_onehot': predict_img_onehot, 'label_integers': None, 'label_onehot': None, 'original_image': img_data, 'without_mask': np.zeros(predict_img_integers.shape)} # Plot single figure if config['plot_figure']: plot_figures_single(config, dict_data, dataset=dataset, name_ID=name_ID) if save_predict_data: save_img_mat(config, dataset, name_ID, 'predict_image', predict_img_integers) # Collect the image for plot. collect_predict.append(predict_img_onehot) collect_label.append(None) # dict collect list_images_series = {'predict': collect_predict, 'label': collect_label} if config['plot_figure']: plot_figures_dataset(config, list_images_series, dataset=dataset) print('Predict data ', dataset, 'is finished.')