Beispiel #1
0
def prediction_fn(model_dir: str,
                  input_dir: str,
                  output_dir: str = None) -> None:
    """
    Given a model directory this function will load the model and apply it to the files (.jpg, .png) found in input_dir.
    The predictions will be saved in output_dir as .npy files (values ranging [0,255])
    :param model_dir: Directory containing the saved model
    :param input_dir: input directory where the images to predict are
    :param output_dir: output directory to save the predictions (probability images)
    :return:
    """
    if not output_dir:
        # For model_dir of style model_name/export/timestamp/ this will create a folder model_name/predictions'
        output_dir = '{}'.format(os.path.sep).join(
            model_dir.split(os.path.sep)[:-3] + ['predictions'])

    os.makedirs(output_dir, exist_ok=True)
    filenames_to_predict = glob(os.path.join(input_dir, '*.jpg')) + glob(
        os.path.join(input_dir, '*.png'))

    with tf.Session():
        m = LoadedModel(model_dir, predict_mode='filename_original_shape')
        for filename in tqdm(filenames_to_predict, desc='Prediction'):
            pred = m.predict(filename)['probs'][0]
            np.save(
                os.path.join(output_dir,
                             os.path.basename(filename).split('.')[0]),
                np.uint8(255 * pred))
def predict_dhSegment(images,
                      model_dir,
                      save=False,
                      save_dir='./',
                      force_refresh=False):
    prediction_data = {}

    with tf.Session():  # Start a tensorflow session
        # Load the model
        m = LoadedModel(model_dir, predict_mode='filename')
        for filename in tqdm(images):
            basename = os.path.basename(filename).split('.')[0].replace(
                '.jpg', '')
            save_path = os.path.join(save_dir,
                                     basename + '_dhSegment_preds.npy')
            if save and not force_refresh and os.path.exists(save_path):
                prediction_data[basename] = tuple(np.load(save_path))
                continue
            prediction_outputs = m.predict(filename)
            probs = prediction_outputs['probs'][0]
            original_shape = prediction_outputs['original_shape']
            save_data = (probs, original_shape)
            if save:
                np.save(save_path, save_data)
            prediction_data[basename] = save_data
    return prediction_data
Beispiel #3
0
def predict_on_set(filenames_to_predict, model_dir, output_dir):
    """

    :param filenames_to_predict:
    :param model_dir:
    :param output_dir:
    :return:
    """
    with tf.Session():
        m = LoadedModel(model_dir, 'filename')
        for filename in tqdm(filenames_to_predict, desc='Prediction'):
            pred = m.predict(filename)['probs'][0]
            np.save(
                os.path.join(output_dir,
                             os.path.basename(filename).split('.')[0]),
                np.uint8(255 * pred))
Beispiel #4
0
def prediction_fn(model_dir: str, input_dir: str, output_dir: str = None):

    if not output_dir:
        # For model_dir of style model_name/export/timestamp/ this will create a folder model_name/predictions'
        output_dir = '{}'.format(os.path.sep).join(
            model_dir.split(os.path.sep)[:-3] + ['predictions'])

    os.makedirs(output_dir, exist_ok=True)
    filenames_to_predict = glob(os.path.join(input_dir, '*.jpg'))
    # Load model
    with tf.Session():
        m = LoadedModel(model_dir, 'filename_original_shape')
        for filename in tqdm(filenames_to_predict, desc='Prediction'):
            pred = m.predict(filename)['probs'][0]
            np.save(
                os.path.join(output_dir,
                             os.path.basename(filename).split('.')[0]),
                np.uint8(255 * pred))
Beispiel #5
0
def run(testDir, modelDir, modelName, outDir, _config):

    # Create output directory
    os.makedirs(outDir, exist_ok=True)

    # I/O
    files = glob(testDir + '/*')

    # Store coordinates of page in a .txt file
    txt_coordinates = ''

    models = list()

    with tf.Session():  # Start a tensorflow session
        # Load the model
        model = LoadedModel(modelDir, modelName, predict_mode='filename')

        for filename in tqdm(files, desc='Processed files'):

            basename = os.path.basename(filename).split('.')[0]

            if os.path.exists(os.path.join(outDir, basename + '.xml')):
                print(basename + " skipped...")

            prediction_outputs = model.predict(filename)

            probs = prediction_outputs['probs'][0]
            probs = probs.astype(float)
            # probs = probs / np.max(probs)

            imgPath = os.path.join(testDir, filename)

            # print("loading:" + imgPath)
            img = imread(imgPath)

            pageSeparatorsToXml(probs, img.shape, filename, outDir)
Beispiel #6
0
    output_dir = 'demo/processed_images'
    os.makedirs(output_dir, exist_ok=True)
    # PAGE XML format output
    output_pagexml_dir = os.path.join(output_dir, PAGE_XML_DIR)
    os.makedirs(output_pagexml_dir, exist_ok=True)

    # Store coordinates of page in a .txt file
    txt_coordinates = ''

    with tf.Session():  # Start a tensorflow session
        # Load the model
        m = LoadedModel(model_dir, predict_mode='filename')

        for filename in tqdm(input_files, desc='Processed files'):
            # For each image, predict each pixel's label
            prediction_outputs = m.predict(filename)
            probs = prediction_outputs['probs'][0]
            original_shape = prediction_outputs['original_shape']
            probs = probs[:, :,
                          1]  # Take only class '1' (class 0 is the background, class 1 is the page)
            probs = probs / np.max(probs)  # Normalize to be in [0, 1]

            # Binarize the predictions
            page_bin = page_make_binary_mask(probs)

            # Upscale to have full resolution image (cv2 uses (w,h) and not (h,w) for giving shapes)
            bin_upscaled = cv2.resize(page_bin.astype(np.uint8, copy=False),
                                      tuple(original_shape[::-1]),
                                      interpolation=cv2.INTER_NEAREST)

            # Find quadrilateral enclosing the page