Ejemplo n.º 1
0
def main(argv):
    parser = ParserCreator.createArgumentParser("./train.yml")
    if len(argv) == 1:
        parser.print_help(sys.stderr)
        sys.exit(1)
    args = parser.parse_args(argv[1:])
    print(args)
    N2VNetwork.reportGPUAccess()
    N2VNetwork.reportDevices()
    N2VNetwork.reportTensorFlowAccess()
    n2v = N2VNetwork(args.name)
    n2v.setTrainingSource(args.dataPath)
    n2v.setPath(args.baseDir)
    n2v.setNumberOfEpochs(args.epochs)
    n2v.setPatchSize(args.patchSizeXY)
    n2v.setBatchSize(args.batchSize)
    n2v.setPercentValidation(args.validationFraction)
    n2v.setNumberOfSteps(args.stepsPerEpoch)
    n2v.setInitialLearningRate(args.learningRate)
    n2v.setNetDepth(args.netDepth)
    n2v.setKernelSize(args.netKernelSize)
    n2v.setUNetNFirst(args.unetNFirst)
    n2v.setPercentPixel(args.n2vPercPix)
    if args.dataAugment:
        n2v.activateDataAugmentation()
    else:
        n2v.deactivateDataAugmentation()

    n2v.train()
    n2v.saveHistory()
    n2v.printElapsedTime()
    print("---training done---")
def main(argv):
    parser = ParserCreator.createArgumentParser("./predict.yml")
    if len(argv) == 1:
        parser.print_help(sys.stderr)
        sys.exit(1)
    args = parser.parse_args(argv[1:])
    print(args)
    model = models.Cellpose(gpu=args.gpu, model_type=args.modelType)
    files = [
        os.path.join(args.dataPath, image)
        for image in os.listdir(args.dataPath)
        if (image.lower().endswith(".tif") or image.lower().endswith(".jpg")
            or image.lower().endswith(".png"))
        and not image.lower().endswith("_cp_masks.png")
    ]
    channels = [[args.segChannel, args.nucleiChannel]] * len(files)
    diameter = args.diameter
    if diameter == 0:
        diameter = None
    for channel, filename in zip(channels, files):
        img = io.imread(filename)
        masks, flows, styles, diams = model.eval(img,
                                                 diameter=diameter,
                                                 channels=channel)
        newFilename = filename.split(".")[0] + "_c" + str(
            args.segChannel) + "." + filename.split(".")[1]
        io.save_to_png(img, masks, flows, newFilename)
Ejemplo n.º 3
0
def main(argv):
    prediction_prefix = ""
    parser = ParserCreator.createArgumentParser("./evaluate.yml")
    args = parser.parse_args(argv[1:])

    full_QC_model_path = os.path.join(args.baseDir, args.name)
    model_weight_path = getValidModel(full_QC_model_path)

    # Create a quality control/Prediction Folder

    prediction_QC_folder = cleanAndGetFolderQC(full_QC_model_path)

    model = createNetwork()

    model.keras_model.load_weights(model_weight_path)

    source_dir_list = prepareDataset()

    predictions = runPredictions(source_dir_list, model)

    saveResult(prediction_QC_folder,
               predictions,
               source_dir_list,
               prefix=prediction_prefix,
               threshold=None)

    write_QC(full_QC_model_path)

    print("---evaluation done---")
Ejemplo n.º 4
0
def main(argv):
    parser = ParserCreator.createArgumentParser("./predict.yml")
    if len(argv) == 1:
        parser.print_help(sys.stderr)
        sys.exit(1)
    args = parser.parse_args(argv[1:])
    print(args)
    n2v = N2VNetwork(args.name)
    n2v.setPath(args.baseDir)
    n2v.setTile((args.tileY, args.tileX))
    n2v.predict(args.dataPath, args.output)
    print("---predictions done---")
Ejemplo n.º 5
0
def main(argv):
    prediction_prefix = ""
    W = '\033[0m'  # white (normal)
    R = '\033[31m'  # red
    parser = ParserCreator.createArgumentParser("./predict.yml")
    args = parser.parse_args(argv[1:])
    full_Prediction_model_path = os.path.join(args.baseDir, args.name)
    if os.path.exists(
            os.path.join(full_Prediction_model_path, 'weights_best.hdf5')):
        print("The " + args.name + " network will be used.")
    else:
        print(R + '!! WARNING: The chosen model does not exist !!' + W)
        print(
            'Please make sure you provide a valid model path and model name before proceeding further.'
        )

    unet = load_model(os.path.join(args.baseDir, args.name,
                                   'weights_best.hdf5'),
                      custom_objects={
                          '_weighted_binary_crossentropy':
                          weighted_binary_crossentropy(np.ones(2))
                      })
    Input_size = unet.layers[0].output_shape[1:3]
    print('Model input size: ' + str(Input_size[0]) + 'x' + str(Input_size[1]))

    source_dir_list = os.listdir(args.dataPath)
    number_of_dataset = len(source_dir_list)
    print('Number of dataset found in the folder: ' + str(number_of_dataset))

    predictions = []
    for i in range(number_of_dataset):
        print("processing dataset " + str(i + 1) + ", file: " +
              source_dir_list[i])
        predictions.append(
            predict_as_tiles(os.path.join(args.dataPath, source_dir_list[i]),
                             unet))

    # Save the results in the folder along with the masks according to the set threshold
    saveResult(args.output,
               predictions,
               source_dir_list,
               prefix=prediction_prefix,
               threshold=args.threshold)

    print("---predictions done---")
Ejemplo n.º 6
0
def main(argv):
    prediction_prefix = ""
    parser = ParserCreator.createArgumentParser("./predict.yml")
    args = parser.parse_args(argv[1:])
    full_Prediction_model_path = os.path.join(args.baseDir, args.name)
    if os.path.exists(os.path.join(full_Prediction_model_path, 'weights_best.h5')):
        print("The " + args.name + " network will be used.")
    else:
        print('!! WARNING: The chosen model does not exist !!')
        print('Please make sure you provide a valid model path and model name before proceeding further.')

    conf = Config2D(
        n_rays=args.nRays,
        grid=(2,2),
    )
    model = stardist(conf,name=args.name, basedir=args.baseDir)

    model.keras_model.summary();
    model.keras_model.load_weights(os.path.join(full_Prediction_model_path, 'weights_best.h5'))

    source_dir_list = os.listdir(args.dataPath)
    number_of_dataset = len(source_dir_list)
    print('Number of dataset found in the folder: ' + str(number_of_dataset))

    im=imread(args.dataPath +"/"+ source_dir_list[0])
    n_channel = 1 if im.ndim == 2 else im.shape[-1]
    if n_channel == 1:
        axis_norm = (0, 1)  # normalize channels independently
    if n_channel > 1:
        axis_norm = (0, 1, 2)  # normalize channels jointly

    predictions = []
    polygons= []
    for i in range(number_of_dataset):
        print("processing dataset " + str(i+1) + ", file: " + source_dir_list[i])
        image = imread(args.dataPath +"/"+ source_dir_list[i])
        image = normalize(image, 1, 99.8, axis=axis_norm)
        labels, poly =model.predict_instances(image)
        predictions.append(labels)
        polygons.append(poly)

    # Save the results in the folder
    saveResult(args.output, predictions,polygons, source_dir_list, prefix=prediction_prefix)

    print("---predictions done---")
def main(argv):
    prediction_prefix = ""
    parser = ParserCreator.createArgumentParser("./predict.yml")
    args = parser.parse_args(argv[1:])

    model_weight_path = getValidModel(os.path.join(args.baseDir, args.name))

    model = createNetwork()

    model.keras_model.load_weights(model_weight_path)

    source_dir_list = prepareDataset()

    predictions = runPredictions(source_dir_list, model)

    saveResult(args.output,
               predictions,
               source_dir_list,
               prefix=prediction_prefix,
               threshold=args.threshold)

    print("---predictions done---")
Ejemplo n.º 8
0
def main(argv):
    W = '\033[0m'  # white (normal)
    R = '\033[31m'  # red
    parser = ParserCreator.createArgumentParser("./evaluate.yml")
    args = parser.parse_args(argv[1:])
    full_QC_model_path = os.path.join(args.baseDir, args.name)
    if os.path.exists(os.path.join(full_QC_model_path, 'weights_best.hdf5')):
        print("The " + args.name + " network will be evaluated")
    else:
        print(R + '!! WARNING: The chosen model does not exist !!' + W)
        print(
            'Please make sure you provide a valid model path and model name before proceeding further.'
        )

    # Create a quality control/Prediction Folder
    prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control',
                                        'Prediction')
    if os.path.exists(prediction_QC_folder):
        shutil.rmtree(prediction_QC_folder)

    os.makedirs(prediction_QC_folder)

    # Load the model
    denseUnet = load_model(os.path.join(full_QC_model_path,
                                        'weights_best.hdf5'),
                           custom_objects={
                               '_weighted_binary_crossentropy':
                               weighted_binary_crossentropy(np.ones(2))
                           })
    Input_size = denseUnet.layers[0].output_shape[1:3]
    print('Model input size: ' + str(Input_size[0]) + 'x' + str(Input_size[1]))

    # Create a list of sources
    source_dir_list = os.listdir(args.testInputPath)
    number_of_dataset = len(source_dir_list)
    print('Number of dataset found in the folder: ' + str(number_of_dataset))

    predictions = []
    for i in range(number_of_dataset):
        print("processing dataset " + str(i + 1) + ", file: " +
              source_dir_list[i])
        predictions.append(
            predict_as_tiles(
                os.path.join(args.testInputPath, source_dir_list[i]),
                denseUnet))

    # Save the results in the folder along with the masks according to the set threshold
    saveResult(prediction_QC_folder,
               predictions,
               source_dir_list,
               prefix=prediction_prefix,
               threshold=None)

    with open(os.path.join(full_QC_model_path, 'Quality Control',
                           'QC_metrics_' + args.name + '.csv'),
              "w",
              newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["File name", "IoU", "IoU-optimised threshold"])

        # Initialise the lists
        filename_list = []
        best_threshold_list = []
        best_IoU_score_list = []

        for filename in os.listdir(args.testInputPath):

            if not os.path.isdir(os.path.join(args.testInputPath, filename)):
                print('Running QC on: ' + filename)
                test_input = io.imread(os.path.join(args.testInputPath,
                                                    filename),
                                       as_gray=True)
                test_ground_truth_image = io.imread(os.path.join(
                    args.testGroundTruthPath, filename),
                                                    as_gray=True)

                (threshold_list, iou_scores_per_threshold) = getIoUvsThreshold(
                    os.path.join(prediction_QC_folder,
                                 prediction_prefix + filename),
                    os.path.join(args.testGroundTruthPath, filename))

                # Here we find which threshold yielded the highest IoU score for image n.
                best_IoU_score = max(iou_scores_per_threshold)
                best_threshold = iou_scores_per_threshold.index(best_IoU_score)

                # Write the results in the CSV file
                writer.writerow(
                    [filename,
                     str(best_IoU_score),
                     str(best_threshold)])

                # Here we append the best threshold and score to the lists
                filename_list.append(filename)
                best_IoU_score_list.append(best_IoU_score)
                best_threshold_list.append(best_threshold)

    print("---evaluation done---")
Ejemplo n.º 9
0
def main(argv):
    parser = ParserCreator.createArgumentParser("./evaluate.yml")
    if len(argv) == 1:
        parser.print_help(sys.stderr)
        sys.exit(1)
    args = parser.parse_args(argv[1:])
    print(args)

    Source_QC_folder = args.testInputPath
    Target_QC_folder = args.testGroundTruthPath
    QC_model_path = args.baseDir
    QC_model_name = args.name

    # Create a quality control/Prediction Folder
    if os.path.exists(QC_model_path + "/" + QC_model_name +
                      "/Quality Control/Prediction"):
        shutil.rmtree(QC_model_path + "/" + QC_model_name +
                      "/Quality Control/Prediction")

    os.makedirs(QC_model_path + "/" + QC_model_name +
                "/Quality Control/Prediction")

    # Activate the pretrained model.
    model_training = N2V(config=None,
                         name=QC_model_name,
                         basedir=QC_model_path)

    # List Tif images in Source_QC_folder
    Source_QC_folder_tif = Source_QC_folder + "/*.tif"
    Z = sorted(glob(Source_QC_folder_tif))
    Z = list(map(imread, Z))

    print('Number of test dataset found in the folder: ' + str(len(Z)))

    # Perform prediction on all datasets in the Source_QC folder
    for filename in os.listdir(Source_QC_folder):
        img = imread(os.path.join(Source_QC_folder, filename))
        predicted = model_training.predict(img, axes='YX', n_tiles=(2, 1))
        #        os.chdir(QC_model_path+"/"+QC_model_name+"/Quality Control/Prediction")
        imsave(
            os.path.join(
                QC_model_path + "/" + QC_model_name +
                "/Quality Control/Prediction", filename), predicted)

    # Open and create the csv file that will contain all the QC metrics
    with open(QC_model_path+"/"+QC_model_name+"/Quality Control/QC_metrics_"+QC_model_name+".csv", "w", newline='') \
            as file:
        writer = csv.writer(file)

        # Write the header in the csv file
        writer.writerow([
            "image #", "Prediction v. GT mSSIM", "Input v. GT mSSIM",
            "Prediction v. GT NRMSE", "Input v. GT NRMSE"
        ])

        # Let's loop through the provided dataset in the QC folders
        for i in os.listdir(Source_QC_folder):
            if not os.path.isdir(os.path.join(Source_QC_folder, i)):
                print('Running QC on: ' + i)
                # -------------------------------- Target test data (Ground truth) --------------------------------
                test_GT = io.imread(os.path.join(Target_QC_folder, i))
                test_GT_norm = normalizeImageWithPercentile(
                    test_GT)  # For normalisation between 0 and 1.

                # -------------------------------- Source test data --------------------------------
                test_source = io.imread(os.path.join(Source_QC_folder, i))
                test_source_norm = normalizeImageWithPercentile(
                    test_source)  # For normalisation between 0 and 1.
                # Normalize the image further via linear regression wrt the normalised GT image
                test_source_norm = normalizeByLinearRegression(
                    test_source_norm, test_GT_norm)

                # -------------------------------- Prediction --------------------------------
                test_prediction = io.imread(
                    os.path.join(
                        QC_model_path + "/" + QC_model_name +
                        "/Quality Control/Prediction", i))
                test_prediction_norm = normalizeImageWithPercentile(
                    test_prediction)
                # For normalisation between 0 and 1.
                # Normalize the image further via linear regression wrt the normalised GT image
                test_prediction_norm = normalizeByLinearRegression(
                    test_prediction_norm, test_GT_norm)

                # -------------------------------- Calculate the metric maps and save them ----------------------------

                # Calculate the SSIM images based on the default window parameters defined in the function
                GTforSSIM = img_as_uint(clipImageMinAndMax(test_GT_norm, 0, 1),
                                        force_copy=True)
                PredictionForSSIM = img_as_uint(clipImageMinAndMax(
                    test_prediction_norm, 0, 1),
                                                force_copy=True)
                SourceForSSIM = img_as_uint(clipImageMinAndMax(
                    test_source_norm, 0, 1),
                                            force_copy=True)

                # Calculate the SSIM maps
                img_SSIM_GTvsPrediction = ssim(GTforSSIM, PredictionForSSIM)
                img_SSIM_GTvsSource = ssim(GTforSSIM, SourceForSSIM)

                # Save ssim_maps
                img_SSIM_GTvsPrediction_32bit = np.float32(
                    img_SSIM_GTvsPrediction)
                io.imsave(
                    QC_model_path + '/' + QC_model_name +
                    '/Quality Control/SSIM_GTvsPrediction_' + i,
                    img_SSIM_GTvsPrediction_32bit)
                img_SSIM_GTvsSource_32bit = np.float32(img_SSIM_GTvsSource)
                io.imsave(
                    QC_model_path + '/' + QC_model_name +
                    '/Quality Control/SSIM_GTvsSource_' + i,
                    img_SSIM_GTvsSource_32bit)

                # Calculate the Root Squared Error (RSE) maps
                img_RSE_GTvsPrediction = np.sqrt(
                    np.square(test_GT_norm - test_prediction_norm))
                img_RSE_GTvsSource = np.sqrt(
                    np.square(test_GT_norm - test_source_norm))

                # Save SE maps
                img_RSE_GTvsPrediction_32bit = np.float32(
                    img_RSE_GTvsPrediction)
                img_RSE_GTvsSource_32bit = np.float32(img_RSE_GTvsSource)
                io.imsave(
                    QC_model_path + '/' + QC_model_name +
                    '/Quality Control/RSE_GTvsPrediction_' + i,
                    img_RSE_GTvsPrediction_32bit)
                io.imsave(
                    QC_model_path + '/' + QC_model_name +
                    '/Quality Control/RSE_GTvsSource_' + i,
                    img_RSE_GTvsSource_32bit)

                # SAVE THE METRIC MAPS HERE #########
                # -------------------------------- Calculate the metrics and save them --------------------------------

                # Calculate the mean SSIM metric
                # SSIM_GTvsPrediction_metrics = np.mean(img_SSIM_GTvsPrediction) # THIS IS WRONG, please compute the
                # SSIM over the whole image and not in patches.
                # SSIM_GTvsSource_metrics = np.mean(img_SSIM_GTvsSource) # THIS IS WRONG, please compute the SSIM over
                # the whole image and not in patches.
                index_SSIM_GTvsPrediction = ssim_index(GTforSSIM,
                                                       PredictionForSSIM)
                index_SSIM_GTvsSource = ssim_index(GTforSSIM, SourceForSSIM)

                # Normalised Root Mean Squared Error (here it's valid to take the mean of the image)
                NRMSE_GTvsPrediction = np.sqrt(np.mean(img_RSE_GTvsPrediction))
                NRMSE_GTvsSource = np.sqrt(np.mean(img_RSE_GTvsSource))

                writer.writerow([
                    i,
                    str(index_SSIM_GTvsPrediction),
                    str(index_SSIM_GTvsSource),
                    str(NRMSE_GTvsPrediction),
                    str(NRMSE_GTvsSource)
                ])

    # All data is now processed saved
    Test_FileList = os.listdir(Source_QC_folder)  # this assumes, as it should,
    #                                               that both source and target are named the same
    print("---evaluation done---")
Ejemplo n.º 10
0
def main(argv):
    parser = ParserCreator.createArgumentParser("./train.yml")
    if len(argv) == 1:
        parser.print_help(sys.stderr)
        sys.exit(1)
    args = parser.parse_args(argv[1:])
    print(args)

    if tf.test.gpu_device_name() == '':
        print('You do not have GPU access.')
        print('Expect slow performance.')
    else:
        print('You have GPU access')

    print('Tensorflow version is ' + str(tf.__version__))

    #Prepare Data - TODO

    #Create Model - TODO

    #Display Model and Training Parameters - OK
    model.keras_model.summary()

    # ------------------ Display ------------------
    print(
        '---------------------------- Main training parameters ----------------------------'
    )
    print('Number of epochs: ' + str(args.epochs))
    print('Batch size: ' + str(args.batchSize))
    print('Number of training dataset: ' + str(len(X)))
    print('Number of training steps: ' + str(number_of_steps - n_val))
    print('Number of validation steps: ' + str(n_val))
    print(
        '---------------------------- ------------------------ ----------------------------'
    )

    start = time.time()
    #Train Model - TODO

    #history = model.train(X_trn, Y_trn, validation_data=(X_val, Y_val), augmenter=augment,
    #                      epochs=args.epochs,
    #                      steps_per_epoch=number_of_steps)

    #Write to CSV - OK
    lossDataCSVPath = os.path.join(full_model_path,
                                   'Quality Control/training_evaluation.csv')

    with open(lossDataCSVPath, 'w', newline="") as f:
        writer = csv.writer(f)
        writer.writerow(history.history.keys())
        values = list(history.history.values())
        #writer.writerows(values)
        for i in range(args.epochs):
            v = [values[j][i] for j in range(len(values))]
            writer.writerow(v)

    #Print End - OK
    print("------------------------------------------")
    dt = time.time() - start
    mins, sec = divmod(dt, 60)
    hour, mins = divmod(mins, 60)
    print("Time elapsed:", hour, "hour(s)", mins, "min(s)", round(sec),
          "sec(s)")
    print("------------------------------------------")

    print("---training done---")
Ejemplo n.º 11
0
def main(argv):
    print("py:entered Main")

    parser = ParserCreator.createArgumentParser("./train.yml")
    if len(argv) == 1:
        parser.print_help(sys.stderr)
        sys.exit(1)
    args = parser.parse_args(argv[1:])
    print(args)

    if tf.test.gpu_device_name() == '':
        print('You do not have GPU access.')
        print('Expect slow performance.')
    else:
        print('You have GPU access')

    print('Tensorflow version is ' + str(tf.__version__))

    if args.useDataAugmentation:
        data_gen_args = dict(width_shift_range=args.horizontalShift / 100.,
                             height_shift_range=args.verticalShift / 100.,
                             rotation_range=args.rotationRange,
                             zoom_range=args.zoomRange / 100.,
                             shear_range=args.shearRange / 100.,
                             horizontal_flip=args.horizontalFlip,
                             vertical_flip=args.verticalFlip,
                             validation_split=args.validationFraction / 100,
                             fill_mode='reflect')
    else:
        data_gen_args = dict(validation_split=args.validationFraction / 100,
                             fill_mode='reflect')
    print('Creating patches...')
    Patch_source, Patch_target = create_patches(args.dataSourcePath,
                                                args.dataTargetPath,
                                                args.patchSizeXY,
                                                args.patchSizeXY)

    (train_datagen,
     validation_datagen) = prepareGenerators(Patch_source,
                                             Patch_target,
                                             data_gen_args,
                                             args.batchSize,
                                             target_size=(args.patchSizeXY,
                                                          args.patchSizeXY))
    full_model_path = os.path.join(args.baseDir, args.name)
    W = '\033[0m'  # white (normal)
    R = '\033[31m'  # red
    if os.path.exists(full_model_path):
        print(R +
              '!! WARNING: Folder already exists and will be overwritten !!' +
              W)

    model_checkpoint = ModelCheckpoint(os.path.join(full_model_path,
                                                    'weights_best.hdf5'),
                                       monitor='val_loss',
                                       verbose=0,
                                       save_best_only=True)
    print('Getting class weights...')
    class_weights = getClassWeights(args.dataTargetPath)
    h5_file_path = None
    reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                                  factor=0.1,
                                  verbose=1,
                                  mode='auto',
                                  patience=10,
                                  min_lr=0)
    model = denseUnet(pretrained_weights=h5_file_path,
                      input_size=(args.patchSizeXY, args.patchSizeXY, 1),
                      pooling_steps=args.poolingSteps,
                      learning_rate=args.learningRate,
                      class_weights=class_weights)

    number_of_training_dataset = len(os.listdir(Patch_source))

    if args.stepsPerEpoch == 0:
        number_of_steps = math.ceil(
            (100 - args.validationFraction) / 100 *
            number_of_training_dataset / args.batchSize)
    else:
        number_of_steps = args.stepsPerEpoch

    validation_steps = max(
        1,
        math.ceil(args.validationFraction / 100 * number_of_training_dataset /
                  args.batchSize))
    config_model = model.optimizer.get_config()
    print(config_model)
    if os.path.exists(full_model_path):
        print(
            R +
            '!! WARNING: Model folder already existed and has been removed !!'
            + W)
        shutil.rmtree(full_model_path)

    os.makedirs(full_model_path)
    os.makedirs(os.path.join(full_model_path, 'Quality Control'))

    # ------------------ Display ------------------
    print(
        '---------------------------- Main training parameters ----------------------------'
    )
    print('Number of epochs: ' + str(args.epochs))
    print('Batch size: ' + str(args.batchSize))
    print('Number of training dataset: ' + str(number_of_training_dataset))
    print('Number of training steps: ' + str(number_of_steps))
    print('Number of validation steps: ' + str(validation_steps))
    print(
        '---------------------------- ------------------------ ----------------------------'
    )

    start = time.time()

    history = model.fit_generator(train_datagen,
                                  steps_per_epoch=number_of_steps,
                                  epochs=args.epochs,
                                  callbacks=[model_checkpoint, reduce_lr],
                                  validation_data=validation_datagen,
                                  validation_steps=validation_steps,
                                  shuffle=True,
                                  verbose=1)
    # Save the last model
    model.save(os.path.join(full_model_path, 'weights_last.hdf5'))

    lossDataCSVPath = os.path.join(full_model_path,
                                   'Quality Control/training_evaluation.csv')
    with open(lossDataCSVPath, 'w') as f:
        writer = csv.writer(f)
        writer.writerow(['loss', 'val_loss', 'learning rate'])
        for i in range(len(history.history['loss'])):
            writer.writerow([
                history.history['loss'][i], history.history['val_loss'][i],
                history.history['lr'][i]
            ])

    print("------------------------------------------")
    dt = time.time() - start
    mins, sec = divmod(dt, 60)
    hour, mins = divmod(mins, 60)
    print("Time elapsed:", hour, "hour(s)", mins, "min(s)", round(sec),
          "sec(s)")
    print("------------------------------------------")

    print("---training done---")
Ejemplo n.º 12
0
def main(argv):
    parser = ParserCreator.createArgumentParser("./evaluate.yml")
    args = parser.parse_args(argv[1:])
    full_QC_model_path = os.path.join(args.baseDir, args.name)
    if os.path.exists(os.path.join(full_QC_model_path, 'weights_best.h5')):
        print("The " + args.name + " network will be evaluated")
    else:
        print('!! WARNING: The chosen model does not exist !!')
        print(
            'Please make sure you provide a valid model path and model name before proceeding further.'
        )

    # Create a quality control/Prediction Folder
    prediction_QC_folder = os.path.join(full_QC_model_path, 'Quality Control',
                                        'Prediction')
    if os.path.exists(prediction_QC_folder):
        shutil.rmtree(prediction_QC_folder)

    os.makedirs(prediction_QC_folder)

    # Load the model

    conf = Config2D(
        n_rays=args.nRays,
        grid=(args.gridParameter, args.gridParameter),
    )
    model = stardist(conf, name=args.name, basedir=args.baseDir)

    model.keras_model.summary()
    model.keras_model.load_weights(
        os.path.join(full_QC_model_path, 'weights_best.h5'))

    # Create a list of sources
    source_dir_list = os.listdir(args.testInputPath)
    number_of_dataset = len(source_dir_list)
    print('Number of dataset found in the folder: ' + str(number_of_dataset))
    im = imread(args.testInputPath + "/" + source_dir_list[0])
    n_channel = 1 if im.ndim == 2 else im.shape[-1]
    if n_channel == 1:
        axis_norm = (0, 1)  # normalize channels independently
    if n_channel > 1:
        axis_norm = (0, 1, 2)  # normalize channels jointly

    predictions = []
    polygons = []
    for i in range(number_of_dataset):
        print("processing dataset " + str(i + 1) + ", file: " +
              source_dir_list[i])
        image = imread(args.testInputPath + "/" + source_dir_list[i])
        image = normalize(image, 1, 99.8, axis=axis_norm)
        labels, poly = model.predict_instances(image)
        predictions.append(labels)
        polygons.append(poly)

    # Save the results in the folder
    saveResult(prediction_QC_folder,
               predictions,
               polygons,
               source_dir_list,
               prefix=prediction_prefix)

    with open(os.path.join(full_QC_model_path, 'Quality Control',
                           'QC_metrics_' + args.name + '.csv'),
              "w",
              newline='') as file:
        writer = csv.writer(file)
        writer.writerow([
            "File name", "IoU", "False Positives", "True Positives",
            "False Negatives", "precision", "recall", "Accuracy", "F1",
            "N True", "N Pred", "Mean True Score", "Mean Matched Score",
            "Panoptic Quality"
        ])

        # Corresponding objects of ground truth and prediction are counted as true positives, false positive, and false negatives whether their intersection over union (IoU) >= thresh
        # mean_true_score is the mean IoUs of matched true positives but normalized by the total number of GT objects
        # mean_matched_score is the mean IoUs of matched true positives
        # panoptic_quality defined as in Eq. 1 of Kirillov et al. "Panoptic Segmentation", CVPR 2019

        # Initialise the lists

        for filename in os.listdir(args.testInputPath):
            if not os.path.isdir(os.path.join(args.testInputPath, filename)):
                print('Running QC on: ' + filename)
                test_input = io.imread(
                    os.path.join(args.testInputPath, filename))
                test_prediction = io.imread(
                    os.path.join(full_QC_model_path, 'Quality Control',
                                 'Prediction', filename))
                test_ground_truth_image = io.imread(
                    os.path.join(args.testGroundTruthPath, filename))

                # Calculate the matching (with IoU threshold `thresh`) and all metrics

                stats = matching(test_ground_truth_image,
                                 test_prediction,
                                 thresh=0.5)

                # Convert pixel values to 0 or 255
                test_prediction_0_to_255 = test_prediction
                test_prediction_0_to_255[test_prediction_0_to_255 > 0] = 255

                # Convert pixel values to 0 or 255
                test_ground_truth_0_to_255 = test_ground_truth_image
                test_ground_truth_0_to_255[
                    test_ground_truth_0_to_255 > 0] = 255
                # Intersection over Union metric

                intersection = np.logical_and(test_ground_truth_image,
                                              test_prediction)
                union = np.logical_or(test_ground_truth_image, test_prediction)
                iou_score = np.sum(intersection) / np.sum(union)
                writer.writerow([
                    filename,
                    str(iou_score),
                    str(stats.fp),
                    str(stats.tp),
                    str(stats.fn),
                    str(stats.precision),
                    str(stats.recall),
                    str(stats.accuracy),
                    str(stats.f1),
                    str(stats.n_true),
                    str(stats.n_pred),
                    str(stats.mean_true_score),
                    str(stats.mean_matched_score),
                    str(stats.panoptic_quality)
                ])

    print("---evaluation done---")