예제 #1
0
def evaluate(config, datasets=None):
    """
    Evaluate the test data by given saved model.
    :param config: type dict: config parameter
    :param datasets: type list of str: list of dataset names
    :return: lists_loss_and_metrics: type list: evaluate result (losses and metrics)
    """
    lists_loss_and_metrics = []
    for dataset in datasets:
        #  Get test dataset_image_paths and dataset_label_path
        if not os.path.exists(config['dir_dataset_info'] + '/split_paths_' +
                              dataset + '.pickle'):
            raise FileNotFoundError(
                'Paths of dataset   `config[dir_dataset_info]/split_paths.pickle`are not found! '
            )
        with open(
                config['dir_dataset_info'] + '/split_paths_' + dataset +
                '.pickle', 'rb') as fp:
            split_path = pickle.load(fp)
            dataset_image_path = split_path['path_test_img']
            dataset_label_path = split_path['path_test_label']

        # Set the config files
        config = channel_config(config, dataset, evaluate=True)
        # create pipeline dataset
        ds_test = pipeline(config,
                           dataset_image_path,
                           dataset_label_path,
                           dataset=dataset)

        # Choose the training model.
        model = load_model_file(config, dataset)

        print('Now evaluating data ', dataset, ' ...')

        # Fit training & validation data into the model
        list_loss_and_metrics = model.evaluate(
            ds_test, verbose=config['evaluate_verbose_mode'])
        lists_loss_and_metrics.append(list_loss_and_metrics)

        path_pickle = config['result_rootdir'] + config[
            'exp_name'] + '/' + config['model'] + '/evaluate_loss_and_metrics/'
        if not os.path.exists(path_pickle): os.makedirs(path_pickle)

        dictionary = dict()
        # Save loss
        dictionary['evaluate_loss'] = list_loss_and_metrics[0]
        # Save metrics
        for i, item in enumerate(lists_loss_and_metrics):
            dictionary['evaluate_' + item] = list_loss_and_metrics[i + 1]

        with open(path_pickle + dataset + '.pickle', 'wb') as fp:
            pickle.dump(dictionary, fp, protocol=pickle.HIGHEST_PROTOCOL)
            print('Sucessfully save the evaluate loss and metrics of  ' +
                  dataset + '.')

        print('Evaluating data ', dataset, 'is finished.')

    return lists_loss_and_metrics
예제 #2
0
def k_fold_train_process(config, model, k_fold, paths, dataset, cp_callback,
                         init_epoch, saver1):
    """
    K-fold training process

    :param config: type dict: config parameter
    :param model:  type tf.keras.Model, training model
    :param paths:  type dict of str: tfrecords path loaded from pickle file.
    :param dataset: type str: name of dataset
    :param cp_callback: type tf.keras.callbacks.ModelCheckpoint, training check point
    :param init_epoch:  type int, initial epoch.
    :return: models:  type tf.keras.Model, trained model
    :return: history type  list of float, metrics evaluating value from each epoch.
    """
    # history = None
    list_1 = list(
        zip(paths['path_train_val_img'], paths['path_train_val_label']))
    random.shuffle(list_1)
    divided_datapath = len(list_1) // k_fold
    assert (divided_datapath > 0)
    history = []
    for k in range(k_fold):
        # Split train and eva
        list_val = list_1[k * divided_datapath:(k + 1) * divided_datapath]
        list_train = list_1[0:k * divided_datapath] + list_1[
            (k + 1) * divided_datapath:len(list_1)]

        print('k_fold', k, ' list_val:', list_val, ' list_train:', list_train)

        [paths_train_img, paths_train_label] = zip(*list_train)
        [paths_val_img, paths_val_label] = zip(*list_val)

        print('Now training data:', dataset, ', k fold: ', k, ' ...')
        if not config['k_fold_merge_model']:

            # train all k-fold on one model
            model, history = train_process(config,
                                           model,
                                           paths_train_img,
                                           paths_train_label,
                                           paths_val_img,
                                           paths_val_label,
                                           dataset,
                                           cp_callback,
                                           saver1,
                                           k_fold_index=k,
                                           init_epoch=k * config['epochs'] +
                                           init_epoch)

        else:
            # establish one new model at each fold.
            model, hist = train_process(config,
                                        model,
                                        paths_train_img,
                                        paths_train_label,
                                        paths_val_img,
                                        paths_val_label,
                                        dataset,
                                        cp_callback,
                                        saver1,
                                        k_fold_index=k,
                                        init_epoch=init_epoch)
            history.append(hist)
            # save model
            saved_model_path = config['saved_models_dir'] + '/' + config[
                'exp_name'] + '/' + config['model']
            if not os.path.exists(saved_model_path):
                os.makedirs(saved_model_path)
            # Save the model when training process is finished.
            model.save(saved_model_path + '/' + dataset + 'k_fold_' + str(k) +
                       '.h5')

            if k != k_fold - 1:
                # create a new model for next k-fold.
                if not config['train_premodel']:
                    call_model = getattr(ModelSet, config['model'])
                    model, list_metric_names = call_model(self=ModelSet,
                                                          config=config)
                else:
                    model = load_model_file(config, dataset, compile=True)
    return model, history
예제 #3
0
def train(config, restore=False):
    """
    Train the dataset from given paths of dataset.
    :param config: type dict: config parameter
    :param restore: type bool, True if resume training from the last checkpoint
    :return: models:  type list of model, trained model
    :return: histories type list of of list of float, metrics evaluating value from each epoch.
    """
    models, histories = [], []
    for pickle_path, pickle_max_shape, dataset in zip(
            config['filename_tfrec_pickle'],
            config['filename_max_shape_pickle'], config['dataset']):
        if restore:
            # Resume dataset.
            with open(
                    config['dir_model_checkpoint'] + '/' + config['exp_name'] +
                    '/' + 'training_info.pickle', 'rb') as fp:
                restore_dataset = pickle.load(fp)['dataset']

                while True:
                    if restore_dataset != dataset:
                        command = input(
                            'Warning! The stored resuming dataset name last time is not coincident with the dataset this time,'
                            ' do you want to overwrite? (y/n)')
                        if command == 'y':
                            break
                        elif command == 'n':
                            dataset = restore_dataset
                            break
                        else:
                            print('Invalid command.')
                    else:
                        print('Resume training dataset: ', dataset, '...')
                        break

        config = train_config_setting(config, dataset)

        # Load split (training, validation, test) tfrecord paths.
        if config['read_body_identification']:
            split_filename = config[
                'dir_dataset_info'] + '/split_paths_' + dataset + '_bi.pickle'
        else:
            split_filename = config[
                'dir_dataset_info'] + '/split_paths_' + dataset + '.pickle'
        with open(split_filename, 'rb') as fp:
            paths = pickle.load(fp)

        # Choose the training model.
        if not config['train_premodel']:
            call_model = getattr(ModelSet, config['model'])
            model, list_metric_names = call_model(self=ModelSet, config=config)
        else:  # load pre-trained model
            model = load_model_file(config, dataset, compile=True)

        print(model.summary())
        # Create checkpoint for saving model during training.
        if not os.path.exists(config['dir_model_checkpoint'] + '/' +
                              config['exp_name']):
            os.makedirs(config['dir_model_checkpoint'] + '/' +
                        config['exp_name'])
        checkpoint_path = config['dir_model_checkpoint'] + '/' + config[
            'exp_name'] + '/cp_' + dataset + '_' + config['model'] + '.hdf5'

        tb_tool = TensorBoardTool(config['dir_model_checkpoint'] + '/' +
                                  config['exp_name'])  # start the Tensorboard

        # Create a callback that saves the model's weights every X epochs.
        cp_callback = [
            tf.keras.callbacks.ModelCheckpoint(
                filepath=checkpoint_path,
                verbose=1,
                save_weights_only=False,
                period=config['save_training_model_period']),
            tf.keras.callbacks.TensorBoard(os.path.dirname(checkpoint_path),
                                           histogram_freq=1)
        ]

        # Initial epoch of training data
        init_epoch = 0
        if restore:
            # Resume saved epoch.
            model.load_weights(checkpoint_path)
            with open(
                    config['dir_model_checkpoint'] + '/' + config['exp_name'] +
                    '/' + 'training_info.pickle', 'rb') as fp:
                init_epoch = pickle.load(fp)['epoch'] + 1
        restore = False

        # Log data at end of training epoch
        class Additional_Saver(tf.keras.callbacks.Callback):
            """
            The program on the end of each epoch,
            Add here if any progress are processed on the end of each epoch
            """
            def on_epoch_end(self, epoch, logs={}):

                if epoch % config[
                        'save_training_model_period'] == 0 and epoch != 0:
                    with open(
                            config['dir_model_checkpoint'] + '/' +
                            config['exp_name'] + '/training_info.pickle',
                            'wb') as fp:
                        pickle.dump(
                            {
                                'epoch': epoch,
                                'dataset': dataset,
                                'model': config['model']
                            },
                            fp,
                            protocol=pickle.HIGHEST_PROTOCOL)
                if not os.path.exists('train_record'):
                    os.makedirs('train_record')
                file1 = open(
                    'train_record/' + config['model'] + '_' + dataset + ".txt",
                    "a+")
                now = datetime.datetime.now()
                file1.write('dataset: ' + dataset + ', Epoch: ' + str(epoch) +
                            ', Model: ' + config['model'] + ', time: ' +
                            now.strftime("%Y-%m-%d %H:%M:%S") + ', pid:' +
                            str(os.getpid()))
                file1.write("\n")
                file1.close()

        saver1 = Additional_Saver()
        print('Now training data: ', dataset)
        k_fold = config['k_fold'][dataset]
        history_dataset = []
        if k_fold is not None:
            model, history = k_fold_train_process(config, model, k_fold, paths,
                                                  dataset, cp_callback,
                                                  init_epoch, saver1)

        else:
            # without k-fold
            model, history = train_process(config,
                                           model,
                                           paths['path_train_img'],
                                           paths['path_train_label'],
                                           paths['path_val_img'],
                                           paths['path_val_label'],
                                           dataset,
                                           cp_callback,
                                           saver1,
                                           k_fold_index=0,
                                           init_epoch=init_epoch)
            history_dataset.append(history)
        saved_model_path = config['saved_models_dir'] + '/' + config[
            'exp_name'] + '/' + config['model']
        if not os.path.exists(saved_model_path): os.makedirs(saved_model_path)
        # Save the model when training process is finished.
        model.save(saved_model_path + '/' + dataset + '.h5')
        print('Training data ', dataset, 'is finished')

        models.append(model)
        histories.append(history_dataset)
    return models, histories
예제 #4
0
def predict_image(config, dataset, model, patch_imgs, indice_list, img_data_shape=None):
    """
    Predict the image from the model
    :param config: type dict: config parameter
    :param dataset: type str, name of dataset
    :param model: type tf.keras.Model, the Model of prediction
    :param patch_imgs: type list of ndarray, the output
    :param indice_list:
    :param img_data_shape:
    :return:predict_img
    """
    # Select the input channels, which are correspondent to model input channels
    if config['input_channel'][dataset] is not None:
        patch_imgs = patch_imgs[..., config['input_channel'][dataset]]

    indice_list_model = indice_list
    if config['regularize_indice_list']['max_shape']:

        indice_list_model = np.float32(np.array(indice_list)) / np.array(config['max_shape']['image'])[..., 0]

    elif config['regularize_indice_list']['image_shape']:
        indice_list_model = np.float32(np.array(indice_list)) / np.array(
            img_data_shape)[:-1]

    elif config['regularize_indice_list']['custom_specified']:
        indice_list_model = np.float32(np.array(indice_list)) / np.array(
            config['regularize_indice_list']['custom_specified'])

    # Predict the test data by given trained model

    print(config['saved_models_dir'] + '/' + config['exp_name'] + '/' + config['model'] + '/' + dataset + '.h5')

    try:
        if not config['read_body_identification']:
            predict_patch_imgs = model.predict(x=(patch_imgs, indice_list_model), batch_size=1, verbose=1)

            print('predict_patch_imgs_shape', predict_patch_imgs.shape)
        else:

            if config['model'] == 'model_body_identification_hybrid':
                predict_patch_imgs, predict_reg = model.predict(x=(patch_imgs[..., 0], indice_list_model), batch_size=1,
                                                                verbose=1)
            else:
                predict_patch_imgs = model.predict(x=(patch_imgs[..., 0]), batch_size=1, verbose=1)
    except:
        try:
            print(
                'Predict by model with load_weights_only=True Failed, Try rebuild model with load_weights_only=False...')
            config['load_weights_only'] = False
            model = load_model_file(config, dataset)

            predict_patch_imgs = model.predict(x=(patch_imgs, indice_list_model), batch_size=1, verbose=1)
        except:
            return -1

    # patch images-> whole image
    if not config['read_body_identification']:

        predict_img = unpatch_predict_image(predict_patch_imgs, indice_list, config['patch_size'],
                                            output_patch_size=config['model_output_size'],
                                            set_zero_by_threshold=config['set_zero_by_threshold'],
                                            threshold=config['unpatch_start_threshold'])
        # Adjust model output channel order
        if config['predict_output_channel_order']:
            channel_stack = []
            for j in config['predict_output_channel_order']:
                channel_stack.append(predict_img[..., j])
            predict_img = np.stack(tuple(channel_stack), axis=-1)

        return predict_img

    else:

        # Especially for network body identification
        # prob_map,decision_map(191, 256, 110, 6) (191, 256, 110, 6)
        prob_map, decision_map = prediction_prob(config, predict_patch_imgs, indice_list)
        print("prob_map, decision_map", prob_map.shape, decision_map.shape)

        return prob_map, decision_map
예제 #5
0
def predict(config, datasets=None, save_predict_data=False, name_ID=None):
    """
    Predict the test data based on the saved model
    :param config: type dict: config parameter
    :param datasets: type list of str, names of dataset
    :param save_predict_data: type bool, names of dataset
    :param save_predict_data: type bool, names of dataset
    :return:
    """
    if datasets is None:
        datasets = config['dataset']
    if config['load_predict_from_tfrecords']:
        #  If predict data is in tfrecords format
        for dataset in datasets:
            if not name_ID:
                # Load path of predict dataset.
                if not config['read_body_identification']:
                    split_path = config['dir_dataset_info'] + '/split_paths_' + dataset + '.pickle'
                else:
                    split_path = config['dir_dataset_info'] + '/split_paths_' + dataset + '_bi.pickle'

                with open(split_path, 'rb') as fp:
                    split_path = pickle.load(fp)
                    dataset_image_path = split_path['path_test_img']
                    dataset_label_path = split_path['path_test_label']

            else:
                # load tfrecords by a single name_ID
                dataset_image_path = [[config['rootdir_tfrec'][dataset] + '/' + name_ID + '/image/image.tfrecords']]
                dataset_label_path = [[config['rootdir_tfrec'][dataset] + '/' + name_ID + '/label/label.tfrecords']]

            config = channel_config(config, dataset)
            # Reformat data path list: [[path1],[path2], ...] ->[[path1, path2, ...]]
            data_path_image_list = [t[i] for t in dataset_image_path for i in range(len(dataset_image_path[0]))]
            data_path_label_list = [t[i] for t in dataset_label_path for i in range(len(dataset_label_path[0]))]
            list_image_TFRecordDataset = [tf.data.TFRecordDataset(i) for i in data_path_image_list]
            list_label_TFRecordDataset = [tf.data.TFRecordDataset(i) for i in data_path_label_list]

            # Choose and create the model which is the same with the saved model.
            print('load_model now...')
            model = load_model_file(config, dataset)

            collect_predict, collect_label = [], []

            for index, (image_TFRecordDataset, label_TFRecordDataset, data_path_image) in \
                    enumerate(zip(list_image_TFRecordDataset, list_label_TFRecordDataset, data_path_image_list)):
                dataset_image = image_TFRecordDataset.map(parser)
                dataset_label = label_TFRecordDataset.map(parser)

                # Get the image data from tfrecords
                # elem[0]= data, elem[1]= data shape
                img_data = [elem[0].numpy() for elem in dataset_image][0]
                label_data_onehot = [elem[0].numpy() for elem in dataset_label][0]
                if not config['read_body_identification']:

                    img_data, label_data_onehot = image_transform(config, img_data, label_data_onehot)
                    # Patch the image
                    patch_imgs, indice_list = patch_image(config, img_data)
                    predict_img = predict_image(config, dataset, model, patch_imgs, indice_list)
                    # Get name_ID from the data path
                    # The data path must have the specified format which is generated from  med_io/preprocess_raw_dataset.py
                    name_ID = data_path_image.replace('\\', '/').split('/')[-3]
                    if isinstance(predict_img, int) and predict_img == -1:
                        print('failed loading: ' + name_ID)
                        continue

                    if config['output_label_tfrecords']:
                        predict_img, label_data_onehot = select_output_channel(config, dataset, predict_img,
                                                                               label_data_onehot)
                        predict_img_integers, predict_img_onehot, label_data_integers, label_data_onehot = convert_result(
                            config, predict_img, label_data_onehot)
                    else:
                        # bypass
                        label_data_onehot, label_data_integers = None, None

                        predict_img = select_output_channel(config, dataset, predict_img)
                        predict_img_integers, predict_img_onehot = convert_result(config, predict_img)

                    # Get data of one patient for plot
                    dict_data = {'predict_integers': predict_img_integers,
                                 'predict_onehot': predict_img_onehot,
                                 'label_integers': label_data_integers,
                                 'label_onehot': label_data_onehot,
                                 'original_image': img_data,
                                 'without_mask': np.zeros(predict_img_integers.shape)}
                    if config['plot_figure']:
                        # Plot figure based on the single patient
                        plot_figures_single(config, dict_data, dataset=dataset, name_ID=name_ID)
                    if save_predict_data:
                        save_img_mat(config, dataset, name_ID, 'predict_image', predict_img_integers)
                    # Collect the image for plot.
                    collect_predict.append(predict_img_onehot)
                    collect_label.append(label_data_onehot)
                else:

                    patch_imgs, indice_list = patch_image(config, img_data)
                    predict_img = predict_image(config, dataset, model, patch_imgs, indice_list)

                    if config['model'] == 'model_body_identification_hybrid' or config[
                        'model'] == 'model_body_identification_classification':
                        prob_map, decision_map = predict_img[0], predict_img[1]
                        select_slice = config['select_body_identification_predict_slice']

                        prob_map = convert_onehot_to_integers(prob_map[select_slice, ...], axis=-1)
                        decision_map = convert_onehot_to_integers(decision_map[select_slice, ...], axis=-1).astype(
                            np.int32)

                        pred_thresholds = get_thresholds(decision_map, n_classes=6)
                        real_thresholds = label_data_onehot

                        name_ID = data_path_image.replace('\\', '/').split('/')[
                            -3]  # extract name_ID from the tfrecords path
                        dict_data = {'predict_integers': decision_map.T,
                                     'original_image': img_data[select_slice, :, :, 0].T * 100,
                                     'pred_thresholds': pred_thresholds,
                                     'real_thresholds': real_thresholds,
                                     'predict_integers_probability': prob_map.T,
                                     }

                        if config['plot_figure']:
                            # Plot figure based on the single patient
                            plot_figures_single(config, dict_data, dataset=dataset, name_ID=name_ID)

            if config['load_predict_from_tfrecords']:
                # list_images_series: collect  predict data and label data
                list_images_series = {'predict': collect_predict, 'label': collect_label}
                if config['plot_figure']:
                    plot_figures_dataset(config, list_images_series, dataset=dataset)

            print('Predict data ', dataset, 'is finished.')
    else:
        # Load dataset not from tfrecords. e.g. from nifti

        if datasets is None or datasets == []:
            datasets = ['New_predict_image']
        for dataset in datasets:
            config = channel_config(config, dataset)

            # Choose and create the model which is the same with the saved model.
            model = load_model_file(config, dataset)

            collect_predict, collect_label = [], []
            data_dir_img = config['predict_data_dir_img']  # get the image dir

            # If predict dataset has label
            if config['predict_load_label']:
                data_dir_label = config['predict_data_dir_label']
                for index, (dir_name_img, dir_name_label) in \
                        enumerate((os.listdir(data_dir_img), os.listdir(data_dir_label))):

                    data_path_img = os.path.join(config['predict_data_dir_img'], dir_name_img).replace('\\', '/')
                    data_path_label = os.path.join(config['predict_data_dir_label'], dir_name_label).replace('\\', '/')
                    img_data, label_data = read_predict_file(config, data_path_img, data_path_label)
                    img_data, label_data_onehot = image_transform(config, img_data, label_data_onehot=label_data)
                    # Patch the image
                    patch_imgs, indice_list = patch_image(config, img_data)
                    predict_img = predict_image(config, dataset, model, patch_imgs, indice_list_model)
                    predict_img, label_data_onehot = select_output_channel(config, dataset, predict_img,
                                                                           label_data_onehot)
                    predict_img_integers, predict_img_onehot, label_data_integers, label_data_onehot \
                        = convert_result(config, predict_img, label_data_onehot=label_data)
                    name_ID = dir_name_img

                    # Get data of one patient for plot
                    dict_data = {'predict_integers': predict_img_integers,
                                 'predict_onehot': predict_img_onehot,
                                 'label_integers': label_data_integers,
                                 'label_onehot': label_data_onehot,
                                 'original_image': img_data,
                                 'without_mask': np.zeros(predict_img_integers.shape)}

                    if config['plot_figure']:
                        plot_figures_single(config, dict_data, dataset=dataset, name_ID=name_ID)
                    if save_predict_data:
                        save_img_mat(config, dataset, name_ID, 'predict_image', predict_img_integers)
                    # Collect the image for plot.
                    collect_predict.append(predict_img_onehot)
                    collect_label.append(label_data_onehot)

            else:
                # Load dataset not from tfrecords. e.g. from nifti, and have no labels
                for name_ID in os.listdir(data_dir_img):

                    data_path_img = os.path.join(data_dir_img, name_ID).replace('\\', '/')
                    img_data = read_predict_file(config, data_path_img, name_ID=name_ID)
                    img_data = image_transform(config, img_data)
                    patch_imgs, indice_list = patch_image(config, img_data)
                    predict_img = predict_image(config, dataset, model, patch_imgs, indice_list,
                                                img_data_shape=img_data.shape)
                    predict_img = select_output_channel(config, dataset, predict_img)
                    predict_img_integers, predict_img_onehot = convert_result(config, predict_img)
                    dict_data = {'predict_integers': predict_img_integers,
                                 'predict_onehot': predict_img_onehot,
                                 'label_integers': None,
                                 'label_onehot': None,
                                 'original_image': img_data,
                                 'without_mask': np.zeros(predict_img_integers.shape)}
                    # Plot single figure
                    if config['plot_figure']:
                        plot_figures_single(config, dict_data, dataset=dataset, name_ID=name_ID)
                    if save_predict_data:
                        save_img_mat(config, dataset, name_ID, 'predict_image', predict_img_integers)
                    # Collect the image for plot.
                    collect_predict.append(predict_img_onehot)
                    collect_label.append(None)
                # dict collect
                list_images_series = {'predict': collect_predict, 'label': collect_label}
                if config['plot_figure']:
                    plot_figures_dataset(config, list_images_series, dataset=dataset)

                print('Predict data ', dataset, 'is finished.')