Пример #1
0
def evaluate_and_save_to_csv(version, n_estimators,
                             n_patients_train, data_dir, dataset, models_dir, predictions_dir,
                             patients, result_file, result_file_per_patient, label_load_name,
                             washed):
    """

    :param version: The Unet version name. String.
    :param n_estimators: Number of trees in forest. Positive integer.
    :param n_patients_train: The number of patient used for training. Positive integer.
    :param data_dir: The path to data dirs. String.
    :param dataset: The dataset. E.g. train/val/set
    :param models_dir: Directory where trained models are saved. String.
    :param predictions_dir: Directory where predictions are saved. String.
    :param patients: The names of patients in the given dataset. List of strings.
    :param result_file: The path to the csv file where to store the results of the performance calculation calculated
    as average over all patient in dataset. String.
    :param result_file_per_patient: The path to the csv file where to store the results of the performance calculation
    per patient. String.
    :param label_load_name: The file name of the ground-truth label. String.
    :param washed: True for evaluation of predictions washed with vessel segments.
    :return:
    """
    print('model version', version)
    print('number of estimators', n_estimators)
    print('label load name', label_load_name)
    print('washed', washed)

    starttime_row = helper.start_time_measuring(what_to_measure='model evaluation')

    # create the name of current run
    run_name = rf_config.get_run_name(version=version, n_estimators=n_estimators, n_patients_train=n_patients_train)

    # -----------------------------------------------------------
    # TRAINING PARAMS
    # -----------------------------------------------------------
    try:
        train_metadata_filepath = rf_config.get_train_metadata_filepath(models_dir, run_name)
        with open(train_metadata_filepath, 'rb') as handle:
            train_metadata = pickle.load(handle)
        print('Train params:')
        print(train_metadata['params'])
    except FileNotFoundError:
        print('Unexpected error by reading train metadata:', sys.exc_info()[0])

    # -----------------------------------------------------------
    # DATASET RESULTS (TRAIN / VAL / TEST)
    # -----------------------------------------------------------
    # initialize empty list for saving measures for each patient
    acc_list = []
    f1_macro_list = []
    f1_macro_wo_0_list = []
    f1_macro_wo_0_and_1_list = []
    avg_acc_list = []
    avg_acc_wo_0_list = []
    avg_acc_wo_0_and_1_list = []
    f1_bin_list = []
    f1_class_dict = {}  # for saving f1 for each class
    for cls in range(rf_config.NUM_CLASSES):
        f1_class_dict[cls] = []

    # calculate the performance per patient
    for patient in patients:
        # load patient ground truth and prediction
        print('-----------------------------------------------------------------')
        print(patient)
        print('> Loading label...')
        label = helper.load_nifti_mat_from_file(os.path.join(data_dir, patient + label_load_name))
        print('> Loading prediction...')
        if washed:
            prediction_path = rf_config.get_washed_annotation_filepath(predictions_dir, run_name, patient, dataset)
        else:
            prediction_path = rf_config.get_annotation_filepath(predictions_dir, run_name, patient, dataset)
        if not os.path.exists(prediction_path) and prediction_path.endswith('.gz'):
            prediction_path = prediction_path[:-3]
        prediction = helper.load_nifti_mat_from_file(prediction_path)

        # flatten the 3d volumes to 1d vector, necessary for performance calculation with sklearn functions
        label_f = label.flatten()
        prediction_f = prediction.flatten().astype(np.int16)
        label_f_bin = np.clip(label_f, 0, 1)
        prediction_f_bin = np.clip(prediction_f, 0, 1)

        print('Computing performance measures...')
        # classification accuracy
        acc = accuracy_score(label_f, prediction_f)
        # dice score for multiclass classification
        f1_per_classes = f1_score(label_f, prediction_f,
                                  average=None)  # f1 for each present class label in ground truth and prediction
        f1_macro = np.mean(f1_per_classes)  # averaged f1 over all present class labels
        f1_macro_wo_0 = np.mean(
            f1_per_classes[1:])  # averaged f1 over all present class labels except background class label
        f1_macro_wo_0_and_1 = np.mean(
            f1_per_classes[
            2:])  # averaged f1 over all present class labels except background  and not-annotated class label
        # average class accuracy for multiclass classification
        avg_acc, avg_acc_per_classes = balanced_accuracy_score(label_f, prediction_f)
        avg_acc_wo_0 = np.mean(avg_acc_per_classes[1:])
        avg_acc_wo_0_and_1 = np.mean(avg_acc_per_classes[2:])
        # dice for binary prediction -> labels converted to 1 for annotated vessels and 0 for background
        f1_bin = f1_score(label_f_bin, prediction_f_bin)
        patient_f1_class_list = ['-'] * rf_config.NUM_CLASSES  # for saving f1 for each class

        # print out to console
        print('acc:', acc)
        print('f1 all classes:', f1_macro)
        print('f1 without background class:', f1_macro_wo_0)
        print('f1 without background and not-annotated class:', f1_macro_wo_0_and_1)
        print('avg_acc all classes:', avg_acc)
        print('avg_acc without background class:', avg_acc_wo_0)
        print('avg_acc without background and not-annotated class:', avg_acc_wo_0_and_1)
        print('f1 binary predictions:', f1_bin)
        print('f1 per classes', f1_per_classes, 'size', len(f1_per_classes))

        # find what labels are in ground-truth and what labels were predicted
        unique_labels = np.unique(label_f)
        unique_prediction = np.unique(prediction_f)
        print('label unique', unique_labels, 'size', len(unique_labels))
        print('predicted annotation unique', unique_prediction, 'size', len(unique_prediction))

        # save results for patient to the lists
        acc_list.append(acc)
        f1_macro_list.append(f1_macro)
        f1_macro_wo_0_list.append(f1_macro_wo_0)
        f1_macro_wo_0_and_1_list.append(f1_macro_wo_0_and_1)
        avg_acc_list.append(avg_acc)
        avg_acc_wo_0_list.append(avg_acc_wo_0)
        avg_acc_wo_0_and_1_list.append(avg_acc_wo_0_and_1)
        f1_bin_list.append(f1_bin)
        all_classes = np.concatenate((unique_labels, unique_prediction))
        unique_classes = np.unique(all_classes)
        for i, cls in enumerate(unique_classes):
            f1_class_dict[cls].append(f1_per_classes[i])
            patient_f1_class_list[cls] = f1_per_classes[i]

        # create row for saving to csv file with details for each patient and write to the csv file
        row_per_patient = [n_estimators, patient,
                           acc,
                           f1_macro,
                           f1_macro_wo_0,
                           f1_macro_wo_0_and_1,
                           avg_acc,
                           avg_acc_wo_0,
                           avg_acc_wo_0_and_1,
                           f1_bin] + patient_f1_class_list
        helper.write_to_csv(result_file_per_patient, [row_per_patient])

    # calculate patient mean and std for each class
    f1_class_list_mean = []
    f1_class_list_std = []
    for cls, class_f1_list in f1_class_dict.items():
        f1_class_list_mean.append(np.mean(class_f1_list))
        f1_class_list_std.append(np.std(class_f1_list))
    # create row for saving to csv file with averages over whole set and write to the csv file
    row_avg = [n_estimators, len(patients),
               'AVG',
               np.mean(acc_list),
               np.mean(f1_macro_list),
               np.mean(f1_macro_wo_0_list),
               np.mean(f1_macro_wo_0_and_1_list),
               np.mean(avg_acc_list),
               np.mean(avg_acc_wo_0_list),
               np.mean(avg_acc_wo_0_and_1_list),
               np.mean(f1_bin_list)] + f1_class_list_mean
    row_std = [n_estimators, len(patients),
               'STD',
               np.std(acc_list),
               np.std(f1_macro_list),
               np.std(f1_macro_wo_0_list),
               np.std(f1_macro_wo_0_and_1_list),
               np.std(avg_acc_list),
               np.std(avg_acc_wo_0_list),
               np.std(avg_acc_wo_0_and_1_list),
               np.std(f1_bin_list)] + f1_class_list_std
    print('AVG:', row_avg)
    print('STD:', row_std)
    helper.write_to_csv(result_file, [row_avg, row_std])

    # print out how long did the calculations take
    helper.end_time_measuring(starttime_row, what_to_measure='model evaluation')
Пример #2
0
def train_and_save(version,
                   models_dir,
                   n_estimators=10,
                   n_train_patients=0,
                   train_X=None,
                   train_y=None,
                   scaler=None,
                   run_name=None,
                   train_metadata_filepath=None,
                   model_filepath=None):
    """
    :param version: The name describes the Unet version. String.
    :param models_dir: Directory where trained models are saved. String.
    :param n_estimators: Number of trees in forest. Positive integer.
    :param n_train_patients: The number of training patients. Positive integer.
    :param train_X: The features in train set. Ndarray.
    :param train_y: The labels in train set. Ndarray.
    :param scaler: Scikit scaler used for train set scaling.
    """
    print(
        '________________________________________________________________________________'
    )
    print('network version', version)
    print('number of estimators', n_estimators)
    print('num train samples', len(train_X))

    # -----------------------------------------------------------
    # CREATING NAME OF CURRENT RUN
    # -----------------------------------------------------------
    if not run_name:
        run_name = rf_config.get_run_name(version=version,
                                          n_estimators=n_estimators,
                                          n_patients_train=n_train_patients)
    # file paths
    if not os.path.exists(models_dir):
        os.makedirs(models_dir)
    if not train_metadata_filepath:
        train_metadata_filepath = rf_config.get_train_metadata_filepath(
            models_dir, run_name)
    if not model_filepath:
        model_filepath = rf_config.get_model_filepath(models_dir, run_name)

    # -----------------------------------------------------------
    # CREATING MODEL
    # -----------------------------------------------------------
    print('Creating new model in', model_filepath)
    print('Training Random forest...')
    model = RandomForestClassifier(n_estimators=n_estimators)

    # -----------------------------------------------------------
    # TRAINING MODEL
    # -----------------------------------------------------------
    starttime_train = helper.start_time_measuring(what_to_measure='training')
    model.fit(train_X, train_y)
    print(model)
    endtime_train, duration_train = helper.end_time_measuring(
        starttime_train, what_to_measure='training')

    # -----------------------------------------------------------
    # SAVING MODEL
    # -----------------------------------------------------------
    print('Saving the final model to:', model_filepath)
    pickle.dump(model, open(model_filepath, 'wb'))

    print('Saving params to ', train_metadata_filepath)
    params = {
        'version': version,
        'n_estimators': n_estimators,
        'n_patients': n_train_patients,
        'samples': len(train_X),
        'total_time': duration_train,
        'scaler': scaler
    }
    results = {'params': params}
    with open(train_metadata_filepath, 'wb') as handle:
        pickle.dump(results, handle)
Пример #3
0
def predict_and_save(version,
                     n_estimators,
                     n_patients_train,
                     patient,
                     data_dir,
                     dataset,
                     feature_filenames,
                     models_dir,
                     predictions_dir,
                     run_name=None,
                     model=None,
                     train_metadata=None):
    """
    :param version: The name describes the SVM version. String.
    :param n_estimators: Number of trees in forest. Positive integer.
    :param n_patients_train: The number of patient used for training. Positive integer.
    :param patient: The patient name. String.
    :param data_dir: The path to data dirs. String.
    :param dataset: The dataset. E.g. train/val/set
    :param feature_filenames: List of file names of the feature inputs to the network. List of strings.
    :param models_dir: Directory where trained models are saved. String.
    :param predictions_dir: Directory where predictions are saved. String.
    """
    print('model version', version)
    print('number of estimators', n_estimators)
    print('patient:', patient)
    print('feature file names', feature_filenames)

    # create the name of current run
    if not run_name:
        run_name = rf_config.get_run_name(version=version,
                                          n_estimators=n_estimators,
                                          n_patients_train=n_patients_train)

    # -----------------------------------------------------------
    # LOADING MODEL, RESULTS AND WHOLE BRAIN MATRICES
    # -----------------------------------------------------------
    try:
        if not model:
            model_filepath = rf_config.get_model_filepath(models_dir, run_name)
            print('Model path:', model_filepath)
            model = pickle.load(open(model_filepath, 'rb'))
        # -----------------------------------------------------------
        # TRAINING PARAMS AND FEATURES
        # -----------------------------------------------------------
        if not train_metadata:
            try:
                train_metadata_filepath = rf_config.get_train_metadata_filepath(
                    models_dir, run_name)
                with open(train_metadata_filepath, 'rb') as handle:
                    train_metadata = pickle.load(handle)
                print('Train params:')
                print(train_metadata['params'])
            except FileNotFoundError:
                print('Unexpected error by reading train metadata:',
                      sys.exc_info()[0])

        print('> Loading features...')
        loaded_feature_list = [
            helper.load_nifti_mat_from_file(os.path.join(
                data_dir, patient + f),
                                            printout=True)
            for f in feature_filenames
        ]

        # -----------------------------------------------------------
        # PREDICTION
        # -----------------------------------------------------------
        print('Predicting...')
        starttime_predict = helper.start_time_measuring(
            what_to_measure='patient prediction')

        # find all vessel voxel indices
        vessel_inds = np.where(loaded_feature_list[0] > 0)

        # extract features and labels per voxel and predict
        prediction_3D = np.zeros(loaded_feature_list[0].shape, dtype='uint8')
        for voxel in range(len(vessel_inds[0])):
            x, y, z = vessel_inds[0][voxel], vessel_inds[1][
                voxel], vessel_inds[2][voxel]  # vessel voxel coordinates
            features = [[x, y, z]]
            for i, feature in enumerate(loaded_feature_list):
                features[0].append(feature[x, y, z])

            # scale data
            scaler = train_metadata['params']['scaler']
            features = scaler.transform(features)

            # predict
            prediction = model.predict(features)

            # rebuild the 3D volume
            prediction_3D[x, y, z] = prediction

        # how long does the prediction take for a patient
        helper.end_time_measuring(starttime_predict,
                                  what_to_measure='patient prediction')

        # -----------------------------------------------------------
        # SAVE AS NIFTI
        # -----------------------------------------------------------
        print(predictions_dir)
        save_path = rf_config.get_annotation_filepath(predictions_dir,
                                                      run_name, patient,
                                                      dataset)
        helper.create_and_save_nifti(prediction_3D, save_path)
    except FileNotFoundError:
        print('Unexpected error by reading model:', sys.exc_info()[0])
def evaluate_set(version,
                 n_epochs,
                 batch_size,
                 learning_rate,
                 dropout_rate,
                 l1,
                 l2,
                 batch_normalization,
                 deconvolution,
                 n_base_filters,
                 n_patients_train,
                 n_patients_val,
                 patients,
                 data_dir,
                 predictions_dir,
                 models_dir,
                 dataset,
                 result_file,
                 result_file_per_patient,
                 label_load_name,
                 washed=False,
                 washing_version='score',
                 mode='voxel-wise',
                 segment_load_name=''):
    """
    :param version: The net version name. String.
    :param n_epochs: The number of epochs. Positive integer.
    :param batch_size: The size of one mini-batch. Positive integer.
    :param learning_rate: The learning rate. Positive float.
    :param dropout_rate: The dropout rate. Positive float or None.
    :param l1: The L1 regularization. Positive float or None.
    :param l2: The L2 regularization. Positive float or None.
    :param batch_normalization: Whether to train with batch normalization. Boolean.
    :param deconvolution: Whether to use deconvolution instead of up-sampling layer. Boolean.
    :param n_base_filters: The number of filters in the first convolutional layer of the net. Positive integer.
    :param n_patients_train: The number of patient used for training. Positive integer.
    :param n_patients_val: The number of patient used for validation. Positive integer.
    :param patients: The names of patients in the given dataset. List of strings.
    :param data_dir: The path to data dir. String.
    :param predictions_dir: Path to directory where to save predictions and results. String.
    :param models_dir: Path to directory where to save models. String.
    :param dataset: The dataset. E.g. train/val/set
    :param result_file: The path to the csv file where to store the results of the performance calculation calculated
    as average over all patient in dataset. String.
    :param result_file_per_patient: The path to the csv file where to store the results of the performance calculation
    per patient. String.
    :param label_load_name: The file name of the ground-truth label. String.
    :param washed: True for washed predictions with vessels segments.
    :param washing_version: The name of the washing version. Can be empty. String.
    :param mode: One of the values ['voxel-wise', 'segment-wise']. Voxel-wise: voxel-wise scores are calculated.
    Segment-wise: segment-wise scores are calculated. String.
    :param segment_load_name: Filename regex for files containing the skeletons with vessel segments.
    :return:
    """
    print('label load name', label_load_name)
    print('washed', washed)
    print('mode', mode)

    starttime_set = helper.start_time_measuring(
        'performance assessment in set')

    # create the name of current run
    run_name = bravenet_config.get_run_name(
        version=version,
        n_epochs=n_epochs,
        batch_size=batch_size,
        learning_rate=learning_rate,
        dropout_rate=dropout_rate,
        l1=l1,
        l2=l2,
        batch_normalization=batch_normalization,
        deconvolution=deconvolution,
        n_base_filters=n_base_filters,
        n_patients_train=n_patients_train,
        n_patients_val=n_patients_val)

    # -----------------------------------------------------------
    # TRAINING PARAMS
    # -----------------------------------------------------------
    try:
        train_metadata_filepath = bravenet_config.get_train_metadata_filepath(
            models_dir, run_name)
        with open(train_metadata_filepath, 'rb') as handle:
            train_metadata = pickle.load(handle)
        print('Train params:', train_metadata['params'])
    except FileNotFoundError:
        train_metadata = None
        print('Unexpected error by reading train metadata:', sys.exc_info()[0])

    # -----------------------------------------------------------
    # DATASET RESULTS (TRAIN / VAL / TEST)
    # -----------------------------------------------------------
    num_epochs = train_metadata['params'][
        'epochs'] if train_metadata else n_epochs
    # initialize empty list for saving measures for each patient
    acc_list = []
    f1_macro_list = []
    f1_macro_wo_0_list = []
    f1_macro_wo_0_and_1_list = []
    avg_acc_list = []
    avg_acc_wo_0_list = []
    avg_acc_wo_0_and_1_list = []
    f1_bin_list = []
    f1_class_dict = {}  # for saving f1 for each class
    for cls in range(bravenet_config.NUM_CLASSES):
        f1_class_dict[cls] = []

    # calculate the performance per patient
    for p, patient in enumerate(patients):
        # load patient ground truth and prediction
        print(
            '-----------------------------------------------------------------'
        )
        print('PATIENT:', patient, ',', str(p + 1) + '/' + str(len(patients)))
        print('> Loading label...')
        label = helper.load_nifti_mat_from_file(
            os.path.join(data_dir, patient + label_load_name))
        print('> Loading prediction...')
        if washed:
            prediction_path = bravenet_config.get_washed_annotation_filepath(
                predictions_dir,
                run_name,
                patient,
                dataset,
                washing_version=washing_version)
        else:
            prediction_path = bravenet_config.get_annotation_filepath(
                predictions_dir, run_name, patient, dataset)
        if not os.path.exists(prediction_path) and prediction_path.endswith(
                '.gz'):
            prediction_path = prediction_path[:-3]
        prediction = helper.load_nifti_mat_from_file(prediction_path)

        # If mode is segment-wise, prepare the segment-wise data points.
        if mode == 'segment-wise':
            print('> Loading segments...')
            skel_seg = helper.load_nifti_mat_from_file(
                os.path.join(data_dir, patient + segment_load_name))

            # Get label and prediction segment datapoints.
            label_segment_datapoints = []
            prediction_segment_datapoints = []
            unique_segments = np.unique(skel_seg)
            for seg in unique_segments:
                segment_inds = np.where(skel_seg == seg)

                # Get label segment data points.
                label_segment_classes = label[segment_inds]
                label_segment_unique_classes, label_segment_classes_counts = np.unique(
                    label_segment_classes, return_counts=True)
                if len(label_segment_unique_classes) > 1:
                    label_segment_class = label_segment_unique_classes[
                        np.argmax(label_segment_classes_counts)]
                else:
                    label_segment_class = label_segment_unique_classes[0]
                label_segment_datapoints.append(label_segment_class)

                # Get prediction segment data points.
                prediction_segment_classes = prediction[segment_inds]
                prediction_segment_unique_classes, prediction_segment_classes_counts = np.unique(
                    prediction_segment_classes, return_counts=True)
                if len(prediction_segment_unique_classes) > 1:
                    prediction_segment_class = prediction_segment_unique_classes[
                        np.argmax(prediction_segment_classes_counts)]
                else:
                    prediction_segment_class = prediction_segment_unique_classes[
                        0]
                prediction_segment_datapoints.append(prediction_segment_class)

            # assign the segment datapoint lists to flatten variables
            label_f = label_segment_datapoints
            prediction_f = prediction_segment_datapoints
        elif mode == 'voxel-wise':
            # flatten the 3d volumes to 1d vector, necessary for performance calculation with sklearn functions
            label_f = label.flatten()
            prediction_f = prediction.flatten()
        else:
            raise ValueError(
                'Mode can be only one of the values ["voxel-wise", "segment-wise"].'
            )

        # Calculate scores.
        scores = evaluate_volume(label_f, prediction_f)

        # find what labels are in ground-truth and what labels were predicted
        unique_labels = np.unique(label_f)
        unique_prediction = np.unique(prediction_f)
        print('label unique', unique_labels, 'size', len(unique_labels))
        print('predicted annotation unique', unique_prediction, 'size',
              len(unique_prediction))

        # save results for patient to the lists
        acc_list.append(scores['acc'])
        f1_macro_list.append(scores['f1_macro'])
        f1_macro_wo_0_list.append(scores['f1_macro_wo_0'])
        f1_macro_wo_0_and_1_list.append(scores['f1_macro_wo_0_and_1'])
        avg_acc_list.append(scores['avg_acc'])
        avg_acc_wo_0_list.append(scores['avg_acc_wo_0'])
        avg_acc_wo_0_and_1_list.append(scores['avg_acc_wo_0_and_1'])
        f1_bin_list.append(scores['f1_bin'])
        all_classes = np.concatenate((unique_labels, unique_prediction))
        unique_classes = np.unique(all_classes)
        f1_per_classes = scores['f1_per_classes']
        patient_f1_class_list = [
            '-'
        ] * bravenet_config.NUM_CLASSES  # for saving f1 for each class
        for i, cls in enumerate(unique_classes):
            f1_class_dict[cls].append(f1_per_classes[i])
            patient_f1_class_list[cls] = f1_per_classes[i]

        # create row for saving to csv file with details for each patient and write to the csv file
        row_per_patient = [
            num_epochs, batch_size, learning_rate, dropout_rate, l1, l2,
            batch_normalization, deconvolution, n_base_filters, patient,
            scores['acc'], scores['f1_macro'], scores['f1_macro_wo_0'],
            scores['f1_macro_wo_0_and_1'], scores['avg_acc'],
            scores['avg_acc_wo_0'], scores['avg_acc_wo_0_and_1'],
            scores['f1_bin']
        ] + patient_f1_class_list
        print('Writing to per patient csv...')
        helper.write_to_csv(result_file_per_patient, [row_per_patient])

    # calculate patient mean and std for each class
    f1_class_list_mean = []
    f1_class_list_std = []
    for cls, class_f1_list in f1_class_dict.items():
        f1_class_list_mean.append(np.mean(class_f1_list))
        f1_class_list_std.append(np.std(class_f1_list))
    # create row for saving to csv file with averages over whole set and write to the csv file
    row_avg = [
        num_epochs, batch_size, learning_rate, dropout_rate, l1, l2,
        batch_normalization, deconvolution, n_base_filters,
        len(patients), 'AVG',
        np.mean(acc_list),
        np.mean(f1_macro_list),
        np.mean(f1_macro_wo_0_list),
        np.mean(f1_macro_wo_0_and_1_list),
        np.mean(avg_acc_list),
        np.mean(avg_acc_wo_0_list),
        np.mean(avg_acc_wo_0_and_1_list),
        np.mean(f1_bin_list)
    ] + f1_class_list_mean
    row_std = [
        num_epochs, batch_size, learning_rate, dropout_rate, l1, l2,
        batch_normalization, deconvolution, n_base_filters,
        len(patients), 'STD',
        np.std(acc_list),
        np.std(f1_macro_list),
        np.std(f1_macro_wo_0_list),
        np.std(f1_macro_wo_0_and_1_list),
        np.std(avg_acc_list),
        np.std(avg_acc_wo_0_list),
        np.std(avg_acc_wo_0_and_1_list),
        np.std(f1_bin_list)
    ] + f1_class_list_std
    print('AVG:', row_avg)
    print('STD:', row_std)
    print('Writing to csv...')
    helper.write_to_csv(result_file, [row_avg, row_std])

    # print out how long did the calculations take
    helper.end_time_measuring(starttime_set,
                              what_to_measure='performance assessment in set')
Пример #5
0
def train(version, n_epochs, batch_size, lr, dr, l1, l2, bn, deconvolution, n_base_filters,
          depth, filter_size, activation, final_activation, n_classes, optimizer, loss_function, metrics,
          n_train_patients, n_val_patients, checkpoint_model, models_dir,
          train_feature_files, train_label_files, val_feature_files, val_label_files,
          train_feature_files_big=None, train_label_files_big=None, val_feature_files_big=None,
          val_label_files_big=None, normalize_features=True,
          max_values_for_normalization=None, transfer_learning=False, transfer_weights_path=None):
    """
    Trains one model with given samples and given parameters and saves it.

    :param version: The name describes the net version. String.
    :param n_epochs: The number of epochs. Positive integer.
    :param batch_size: The size of one mini-batch. Positive integer.
    :param lr: The learning rate. Positive float.
    :param dr: The dropout rate. Positive float or None.
    :param l1: The L1 regularization. Positive float or None.
    :param l2: The L2 regularization. Positive float or None.
    :param bn: True for training with batch normalization. Boolean.
    :param deconvolution: True for using deconvolution instead of up-sampling layer. Boolean.
    :param n_base_filters: The number of filters in the first convolutional layer of the net. Positive integer.
    :param depth: The number of levels of the net. Positive integer.
    :param filter_size: The size of the 3D convolutional filters. Tuple of three positive integers.
    :param activation: The activation after the convolutional layers. String.
    :param final_activation: The activation in the final layer. String.
    :param n_classes: The number of class labels to be predicted. Positive integer.
    :param optimizer: The optimization algorithm used for training. E.g. Adam. String.
    :param loss_function: The loss function. String.
    :param metrics: List of metrics (i.e. performance measures). List of strings.
    :param n_train_patients: The number of training samples. Positive integer.
    :param n_val_patients: The number of validation samples. Positive integer.
    :param checkpoint_model: True for saving the model after each epochs during the training. Boolean.
    :param models_dir: String. Directory path where the model will be stored.
    :param train_feature_files: List of file names containing features from training set. List of strings.
    :param train_label_files: List of file names containing labels from training set. List of strings.
    :param val_feature_files: List of file names containing features from validation set. List of strings.
    :param val_label_files: List of file names containing labels from validation set. List of strings.
    :param train_feature_files_big: List of file names containing features in double-sized volume from training set.
    List of strings.
    :param train_label_files_big: List of file names containing labels in double-sized volume from training set.
    List of strings.
    :param val_feature_files_big: List of file names containing features in double-sized volume from validation set.
    List of strings.
    :param val_label_files_big: List of file names containing labels in double-sized volume from validation set.
    List of strings.
    :param normalize_features: True for scale input data between 0 an 1.
    :param max_values_for_normalization: Max values for scaling.
    :param transfer_learning: True for initialize network with pretrained weights.
    :param transfer_weights_path: Path to pretrained weights.
    """
    print('network version', version)
    print('number of epochs', n_epochs)
    print('batch size', batch_size)
    print('learning rate', lr)
    print('dropout rate', dr)
    print('L1', l1)
    print('L2', l2)
    print('batch normalization', bn)
    print('deconvolution', deconvolution)

    # Get number of training and validation samples.
    n_train_samples = len(train_feature_files) if train_feature_files else len(train_feature_files_big)
    n_val_samples = len(val_feature_files) if val_feature_files else len(val_feature_files_big)

    # -----------------------------------------------------------
    # CREATING NAME OF CURRENT RUN
    # -----------------------------------------------------------
    run_name = bravenet_config.get_run_name(version=version, n_epochs=n_epochs, batch_size=batch_size, learning_rate=lr,
                                            dropout_rate=dr, l1=l1, l2=l2, batch_normalization=bn,
                                            deconvolution=deconvolution, n_base_filters=n_base_filters,
                                            n_patients_train=n_train_patients, n_patients_val=n_val_patients)
    formatting_run_name = bravenet_config.get_run_name(version=version, n_epochs=n_epochs, batch_size=batch_size,
                                                       learning_rate=lr, dropout_rate=dr, l1=l1, l2=l2,
                                                       batch_normalization=bn, deconvolution=deconvolution,
                                                       n_base_filters=n_base_filters, n_patients_train=n_train_patients,
                                                       n_patients_val=n_val_patients, formatting_epoch=True)
    # File paths.
    if not os.path.exists(models_dir):
        os.makedirs(models_dir)
    model_filepath = bravenet_config.get_model_filepath(models_dir, run_name)
    train_metadata_filepath = bravenet_config.get_train_metadata_filepath(models_dir, run_name)
    train_history_filepath = bravenet_config.get_train_history_filepath(models_dir, run_name)
    logdir = bravenet_config.LOG_DIR

    # -----------------------------------------------------------
    # CREATING MODEL
    # -----------------------------------------------------------
    input_shape = (bravenet_config.PATCH_SIZE_X, bravenet_config.PATCH_SIZE_Y, bravenet_config.PATCH_SIZE_Z,
                   bravenet_config.NUM_FEATURES)
    # Double all dimensions except the last one because that is the number of feature channels.
    input_shape_big = tuple(v * 2 if i < len(input_shape) - 1 else v for i, v in enumerate(input_shape))
    num_outputs = depth - 1

    # Load specific architectures according to the model version.
    model = get_bravenet(input_shapes=[input_shape_big, input_shape], n_classes=n_classes,
                         activation=activation, final_activation=final_activation, n_base_filters=n_base_filters,
                         depth=depth, optimizer=optimizer, learning_rate=lr, dropout=dr, l1=l1, l2=l2,
                         batch_normalization=bn, loss_function=loss_function, metrics=metrics, filter_size=filter_size,
                         deconvolution=deconvolution)

    # -----------------------------------------------------------
    # TRAINING MODEL
    # -----------------------------------------------------------
    starttime_training = helper.start_time_measuring('training')

    # SET CALLBACKS
    callbacks = []
    # keras callback for tensorboard logging
    tb = TensorBoard(log_dir=logdir, histogram_freq=1)
    callbacks.append(tb)
    # keras callback for saving the training history to csv file
    csv_logger = CSVLogger(train_history_filepath)
    callbacks.append(csv_logger)
    # keras ModelCheckpoint callback saves the model after every epoch, monitors the val_dice and does not overwrite
    # if the val_dice gets worse
    if checkpoint_model:
        mc = ModelCheckpoint(bravenet_config.get_model_filepath(models_dir, formatting_run_name),
                             monitor='val_loss', verbose=1, save_best_only=False,
                             mode='min')
        callbacks.append(mc)

    # DATAGENERATORS
    train_generator = DataGenerator(train_feature_files, train_feature_files_big, train_label_files,
                                    train_label_files_big, bravenet_config.SAMPLES_PATH, batch_size=batch_size,
                                    dim=input_shape[:-1], dim_big=input_shape_big[:-1],
                                    n_channels=bravenet_config.NUM_FEATURES, num_outputs=num_outputs, shuffle=True,
                                    normalize_features=normalize_features,
                                    max_values_for_normalization=max_values_for_normalization)
    val_generator = DataGenerator(val_feature_files, val_feature_files_big, val_label_files, val_label_files_big,
                                  bravenet_config.SAMPLES_PATH, batch_size=batch_size, dim=input_shape[:-1],
                                  dim_big=input_shape_big[:-1], n_channels=bravenet_config.NUM_FEATURES,
                                  num_outputs=num_outputs, shuffle=True, normalize_features=normalize_features,
                                  max_values_for_normalization=max_values_for_normalization)

    # TRANSFER LEARNING
    if transfer_learning:
        # Load weights
        # untrained_model = clone_model(model)
        model.load_weights(transfer_weights_path, by_name=True)
        print('Weights loaded.')
        # Multiple loss
        loss, loss_weights = ds_loss(depth=depth, loss_function=loss_function)
        model.compile(optimizer=optimizer(lr=lr), loss=loss, metrics=metrics, loss_weights=loss_weights)
        # untrained_model.compile(optimizer=optimizer(lr=lr), loss=loss, metrics=metrics, loss_weights=loss_weights)
        print('model compiled.')

    # TRAIN
    history = None
    try:
        history = model.fit_generator(
            generator=train_generator,
            validation_data=val_generator,
            steps_per_epoch=n_train_samples // batch_size,
            validation_steps=n_val_samples // batch_size,
            epochs=n_epochs,
            verbose=2, shuffle=True, callbacks=callbacks)
    except KeyboardInterrupt:
        print("KeyboardInterrupt has been caught.")
        exit(0)
    finally:
        if history is not None:
            duration_training = helper.end_time_measuring(starttime_training, 'training')

            # SAVING MODEL AND PARAMS
            if checkpoint_model:
                print('Model was checkpointed -> not saving the model from last epoch.')
            else:
                print('Model was not checkpointed -> saving the model from last epoch to:', model_filepath)
                model.save(model_filepath)

            print('Saving params to ', train_metadata_filepath)
            history.params['version'] = version
            history.params['batchsize'] = batch_size
            history.params['learning_rate'] = lr
            history.params['dropout_rate'] = dr
            history.params['l1'] = l1
            history.params['l2'] = l2
            history.params['batch_norm'] = bn
            history.params['deconvolution'] = deconvolution
            history.params['num_base_filters'] = n_base_filters
            history.params['loss'] = loss_function
            history.params['samples'] = n_train_samples
            history.params['val_samples'] = n_val_samples
            history.params['total_time'] = duration_training
            history.params['normalize features'] = normalize_features
            history.params['max_values_for_normalization'] = max_values_for_normalization
            results = {'params': history.params, 'history': history.history}
            with open(train_metadata_filepath, 'wb') as handle:
                pickle.dump(results, handle)
    return history
patch_size_y = bravenet_config.PATCH_SIZE_Y
patch_size_z = bravenet_config.PATCH_SIZE_Z
num_features = bravenet_config.NUM_FEATURES

half_patch_size_x = patch_size_x // 2
half_patch_size_y = patch_size_y // 2
half_patch_size_z = patch_size_z // 2

# classes without background and not-annotated class
classes = [*class_dict][2:]
helper.write_to_csv(csv_file, [['patient'] + classes])

print('Number of patches to extract per class per patient:', nr_of_patches_around_each_class)

# extract patches from each data stack (patient)
starttime_total = helper.start_time_measuring(what_to_measure='total extraction')
for dataset in datasets:
    patients = patient_files[dataset]
    # check number of patients in datasets
    helper.check_num_patients(patients)
    all_patients = patients['working']
    num_patients = len(all_patients)
    print('Number of patients:', num_patients)
    patch_directory = bravenet_config.SAMPLES_PATH
    if not os.path.exists(patch_directory):
        os.makedirs(patch_directory)

    for patient in all_patients:
        print('DATA SET:', dataset)
        print('PATIENT:', patient)
        patient_class_list = []
def predict(feature_volumes,
            model,
            num_classes,
            patch_size_x,
            patch_size_y,
            patch_size_z,
            annotation_save_path,
            max_values_for_normalization=None,
            skeleton_segments=None,
            washed_annotation_save_path=None):
    """
    :param feature_volumes: List of input feature volumes. List of 3D ndarrays.
    :param model: Trained Keras model.
    :param num_classes: Number of classes to predict. Positive integer.
    :param patch_size_x: Size of the patch in x axis. Positive integer.
    :param patch_size_y: Size of the patch in y axis. Positive integer.
    :param patch_size_z: Size of the patch in z axis. Positive integer.
    :param annotation_save_path: Path (including the file name) where to save the annotation results as nifti. String.
    :param max_values_for_normalization: Max values for scaling between 0 and 1.
    :param skeleton_segments: Volume with vessel segments.
    :param washed_annotation_save_path: Path to predicted washed volume.
    :return: Predicted annotated volume. 3D ndarray.
    """
    normalize = max_values_for_normalization is not None

    num_features = len(feature_volumes)
    volume_dimensions = feature_volumes[0].shape
    min_x = 0
    min_y = 0
    min_z = 0
    max_x = volume_dimensions[0]
    max_y = volume_dimensions[1]
    max_z = volume_dimensions[2]
    num_x_patches = int(np.ceil((max_x - min_x) / float(patch_size_x)))
    num_y_patches = int(np.ceil((max_y - min_y) / float(patch_size_y)))
    num_z_patches = int(np.ceil((max_z - min_z) / float(patch_size_z)))

    print('volume dimensions', volume_dimensions)
    print('num x patches', num_x_patches)
    print('num y patches', num_y_patches)
    print('num z patches', num_z_patches)

    # the predicted annotation is going to be saved in this probability matrix
    predicted_probabilities = np.zeros(volume_dimensions + (num_classes, ),
                                       dtype='float32')

    # start cutting out and predicting the patches
    starttime_volume = helper.start_time_measuring(
        what_to_measure='volume prediction')
    p = 0
    for ix in range(num_x_patches):
        for iy in range(num_y_patches):
            for iz in range(num_z_patches):
                # find the starting and ending coordinates of the given patch
                patch_start_x = patch_size_x * ix
                patch_end_x = min(patch_size_x * (ix + 1), max_x)
                patch_start_y = patch_size_y * iy
                patch_end_y = min(patch_size_y * (iy + 1), max_y)
                patch_start_z = patch_size_z * iz
                patch_end_z = min(patch_size_z * (iz + 1), max_z)
                # extract patch for prediction
                patch = np.zeros((1, patch_size_x, patch_size_y, patch_size_z,
                                  num_features),
                                 dtype='float32')
                for f in range(num_features):
                    patch[0, :patch_end_x - patch_start_x, :patch_end_y -
                          patch_start_y, :patch_end_z - patch_start_z,
                          f] = feature_volumes[f][
                              patch_start_x:patch_end_x,
                              patch_start_y:patch_end_y,
                              patch_start_z:patch_end_z].astype('float32')
                # normalize patch
                if normalize:
                    assert isinstance(max_values_for_normalization, list)
                    for f in range(num_features):
                        patch[...,
                              f] = patch[...,
                                         f] / max_values_for_normalization[f]

                # find center location in the patch
                center_x = patch_start_x + patch_size_x // 2
                center_y = patch_start_y + patch_size_y // 2
                center_z = patch_start_z + patch_size_z // 2
                # find the starting and ending coordinates of the big patch
                big_patch_start_x = max(center_x - patch_size_x, min_x)
                big_patch_end_x = min(center_x + patch_size_x, max_x)
                big_patch_start_y = max(center_y - patch_size_y, min_y)
                big_patch_end_y = min(center_y + patch_size_y, max_y)
                big_patch_start_z = max(center_z - patch_size_z, min_z)
                big_patch_end_z = min(center_z + patch_size_z, max_z)
                # if the patch should reach outside the volume, prepare offset for zero padding
                offset_x = max(min_x - (center_x - patch_size_x), 0)
                offset_y = max(min_y - (center_y - patch_size_y), 0)
                offset_z = max(min_z - (center_z - patch_size_z), 0)
                # extract big patch for prediction
                big_patch = np.zeros((1, patch_size_x * 2, patch_size_y * 2,
                                      patch_size_z * 2, num_features),
                                     dtype='float32')
                for f in range(num_features):
                    big_patch[0, offset_x:offset_x +
                              (big_patch_end_x - big_patch_start_x),
                              offset_y:offset_y +
                              (big_patch_end_y - big_patch_start_y),
                              offset_z:offset_z +
                              (big_patch_end_z - big_patch_start_z),
                              f] = feature_volumes[f][
                                  big_patch_start_x:big_patch_end_x,
                                  big_patch_start_y:big_patch_end_y,
                                  big_patch_start_z:big_patch_end_z].astype(
                                      'float32')

                # normalize patch
                if normalize:
                    assert isinstance(max_values_for_normalization, list)
                    for f in range(num_features):
                        big_patch[..., f] = big_patch[
                            ..., f] / max_values_for_normalization[f]

                predicted_patch = model.predict([big_patch, patch],
                                                batch_size=1,
                                                verbose=0)[-1]

                # in case the last patch along a axis reached outside the volume, cut off the zero padding
                sliced_predicted_patch = predicted_patch[
                    0, :patch_end_x - patch_start_x, :patch_end_y -
                    patch_start_y, :patch_end_z - patch_start_z]
                predicted_probabilities[
                    patch_start_x:patch_end_x, patch_start_y:patch_end_y,
                    patch_start_z:patch_end_z] = sliced_predicted_patch
                p += 1

    # how long does the prediction take for the whole volume
    helper.end_time_measuring(starttime_volume,
                              what_to_measure='volume prediction',
                              print_end_time=False)

    # save annotated volume as nifti
    annotation = np.argmax(predicted_probabilities, axis=-1)
    annotation = np.asarray(annotation, dtype='uint8')
    helper.create_and_save_nifti(annotation, annotation_save_path)

    # wash with segments according to max softmax score in segment and save as nifti
    starttime_washing = helper.start_time_measuring(
        what_to_measure='washing with segments')
    washed_annotation = helper.wash_with_segments_max_scores(
        predicted_probabilities, skeleton_segments)
    helper.end_time_measuring(starttime_washing,
                              what_to_measure='washing with segments',
                              print_end_time=False)
    washed_annotation = np.asarray(washed_annotation, dtype='uint8')
    helper.create_and_save_nifti(washed_annotation,
                                 washed_annotation_save_path)

    return annotation, washed_annotation