示例#1
0
def get_validation_split(data_file,
                         training_file,
                         validation_file,
                         data_split=0.8,
                         change_validation=False,
                         split_list=()):
    """
    Splits the data into the training and validation indices list.
    :param data_file: pytables hdf5 data file
    :param training_file:
    :param validation_file:
    :param data_split:
    :param overwrite:
    :return:
    """
    if change_validation or not os.path.exists(training_file):
        nb_samples = data_file.root.data.shape[0]
        sample_list = list(range(nb_samples))
        training_list, validation_list = split_a_list(sample_list,
                                                      split=data_split)
        print("Training list: ", training_list)
        print("Validation list: ", validation_list)
        training_list = split_list[0]
        validation_list = split_list[1]
        pickle_dump(training_list, training_file)
        pickle_dump(validation_list, validation_file)
        return training_list, validation_list
    else:
        return pickle_load(training_file), pickle_load(validation_file)
示例#2
0
 def run_validation_cases(self, validation_keys_file, training_modalities, labels, hdf5_file,
                          output_label_map=False, output_dir=".", threshold=0.5, overlap=16, permute=False,
                          save_image = False):
     '''
     For each patient of the testing set we run a validation.
     :param validation_keys_file:
     :param training_modalities:
     :param labels:
     :param hdf5_file:
     :param output_label_map:
     :param output_dir:
     :param threshold:
     :param overlap:
     :param permute:
     :param save_image:
     :return:
     '''
     validation_indices = pickle_load(validation_keys_file)
     model = self.model
     data_file = tables.open_file(hdf5_file, "r")
     for i, index in enumerate(validation_indices):
         actual = round(i/len(validation_indices)*100, 2)
         print("Running validation case: ", actual,"%")
         if 'subject_ids' in data_file.root:
             case_directory = os.path.join(output_dir, data_file.root.subject_ids[index].decode('utf-8'))
         else:
             case_directory = os.path.join(output_dir, "validation_case_{}".format(index))
         run_validation_case(data_index=index, output_dir=case_directory, model=model, data_file=data_file,
                             training_modalities=training_modalities, output_label_map=output_label_map, labels=labels,
                             threshold=threshold, overlap=overlap, permute=permute, save_image=save_image)
     data_file.close()
示例#3
0
def run_validation_cases(validation_keys_file,
                         model_file,
                         training_modalities,
                         labels,
                         hdf5_file,
                         output_label_map=False,
                         output_dir=".",
                         threshold=0.5,
                         overlap=16,
                         permute=False):
    validation_indices = pickle_load(validation_keys_file)
    model = load_old_model(model_file)
    data_file = tables.open_file(hdf5_file, "r")
    for i, index in enumerate(validation_indices):
        actual = round(i / len(validation_indices) * 100, 2)
        print("Running validation case: ", actual, "%")
        if 'subject_ids' in data_file.root:
            case_directory = os.path.join(
                output_dir, data_file.root.subject_ids[index].decode('utf-8'))
        else:
            case_directory = os.path.join(output_dir,
                                          "validation_case_{}".format(index))
        run_validation_case(data_index=index,
                            output_dir=case_directory,
                            model=model,
                            data_file=data_file,
                            training_modalities=training_modalities,
                            output_label_map=output_label_map,
                            labels=labels,
                            threshold=threshold,
                            overlap=overlap,
                            permute=permute)
    data_file.close()
示例#4
0
def load_index_patches_with_ceil(file_name):
    '''
    Load the array indices with lesions (if it had already been computed)
    :param file_name:
    :return:
    '''
    return pickle_load(file_name)
示例#5
0
def is_all_cases_predicted(prediction_folder, testing_file):
    data_file = pickle_load(config["testing_file"])
    num_cases = len(data_file)
    if not os.path.exists(prediction_folder):
        return False
    else:
        num_predicted = len(
            glob.glob(os.path.join(config["prediction_folder"], "*")))
        return num_cases == num_predicted
示例#6
0
def main():
    prediction_dir = os.path.abspath("prediction")
    validation_indices = pickle_load(config["validation_file"])
    for i in range(len(validation_indices)):
        run_validation_case(test_index=i,
                            out_dir=os.path.join(
                                prediction_dir,
                                "validation_case_{}".format(i)),
                            model_file=config["model_file"],
                            validation_keys_file=config["validation_file"],
                            training_modalities=config["training_modalities"],
                            output_label_map=True,
                            labels=config["labels"],
                            hdf5_file=config["data_file"])
示例#7
0
def main(config=None):
    model = load_old_model(config, re_compile=False)
    data_file_opened = open_data_file(config["data_file"])
    validation_idxs = pickle_load(config['validation_file'])
    validation_generator = data_generator(data_file_opened, validation_idxs, 
                                        batch_size=config['validation_batch_size'], 
                                        n_labels=config['n_labels'], labels=config['labels'],
                                        skip_blank=config['skip_blank'], shuffle_index_list=False)
    steps = math.ceil(len(validation_idxs) / config['validation_batch_size'])
    results = model.evaluate(validation_generator, steps=steps, verbose=1)
    metrics_names = model.metrics_names
    for i, x in enumerate(metrics_names):
        print('{}: {}'.format(x, results[i]))
        
    data_file_opened.close()
示例#8
0
def main():
    prediction_dir = os.path.abspath("../data/prediction_isensee2017")
    if not os.path.exists(prediction_dir):
        subprocess.call('mkdir' + '-p' + prediction_dir, shell=True)

    validation_indices = pickle_load(config["validation_file"])
    for i in range(len(validation_indices)):
        run_validation_case(test_index=i,
                            out_dir=os.path.join(
                                prediction_dir,
                                "validation_case_{}".format(i)),
                            model_file=config["model_file"],
                            validation_keys_file=config["validation_file"],
                            training_modalities=config["training_modalities"],
                            output_label_map=True,
                            labels=config["labels"],
                            hdf5_file=config["data_file"])
示例#9
0
def get_test_indices(testing_file):
    return pickle_load(testing_file)