示例#1
0
def predict_multiclass(img_file, out_path):

    print("[INFO] opening image...")
    input_img = load_img_by_gdal(img_file)
    if im_type == UINT8:
        input_img = input_img / 255.0
    elif im_type == UINT10:
        input_img = input_img / 1024.0
    elif im_type == UINT16:
        input_img = input_img / 65535.0
    input_img = np.clip(input_img, 0.0, 1.0)
    """checke model file"""
    print("model file: {}".format(model_file))
    if not os.path.isfile(model_file):
        print("model does not exist:{}".format(model_file))
        sys.exit(-2)

    model = load_model(model_file)
    abs_filename = os.path.split(img_file)[1]
    abs_filename = abs_filename.split(".")[0]
    print(abs_filename)

    if FLAG_APPROACH_PREDICT == 0:
        print("[INFO] predict image by orignal approach\n")
        result = orignal_predict_onehot(input_img, im_bands, model,
                                        window_size)
        output_file = ''.join(
            [out_path, '/original_predict_', abs_filename, '_multiclass.png'])
        print("result save as to: {}".format(output_file))
        cv2.imwrite(output_file, result * 128)

    elif FLAG_APPROACH_PREDICT == 1:
        print("[INFO] predict image by smooth approach\n")
        result = predict_img_with_smooth_windowing_multiclassbands(
            input_img,
            model,
            window_size=window_size,
            subdivisions=2,
            real_classes=target_class,  # output channels = 是真的类别,总类别-背景
            pred_func=smooth_predict_for_multiclass,
            PLOT_PROGRESS=False)

        for b in range(target_class):
            output_file = ''.join([
                out_path, '/mask_multiclass_', abs_filename, '_',
                dict_target[b], '.png'
            ])
            cv2.imwrite(output_file, result[:, :, b])
            print("Saved to: {}".format(output_file))
    gc.collect()
示例#2
0
def predict_binary_jaccard(img_file, output_file):

    print("[INFO] opening image...")
    input_img = load_img_by_gdal(img_file)
    if im_type == UINT8:
        input_img = input_img / 255.0
    elif im_type == UINT10:
        input_img = input_img / 1024.0
    elif im_type == UINT16:
        input_img = input_img / 65535.0

    input_img = np.clip(input_img, 0.0, 1.0)
    input_img = input_img.astype(np.float16)

    model = load_model(model_file)

    if FLAG_APPROACH_PREDICT==0:
        print("[INFO] predict image by orignal approach\n")
        result = orignal_predict_notonehot(input_img,im_bands, model, window_size)
        abs_filename = os.path.split(img_file)[1]
        abs_filename = abs_filename.split(".")[0]
        # output_file = ''.join([output_path, '/original_pred_',
        #                        abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_jaccard.png'])
        output_file = ''.join([output_path, '/mask_binary_',
                            abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_jaccard_original.png'])
        print("result save as to: {}".format(output_file))
        cv2.imwrite(output_file, result*128)

    elif FLAG_APPROACH_PREDICT==1:
        print("[INFO] predict image by smooth approach\n")
        result = predict_img_with_smooth_windowing_multiclassbands(
            input_img,
            model,
            window_size=window_size,
            subdivisions=2,
            real_classes=target_class,  # output channels = 是真的类别,总类别-背景
            pred_func=smooth_predict_for_binary_notonehot,
            PLOT_PROGRESS=False
        )

        cv2.imwrite(output_file, result)
        print("Saved to: {}".format(output_file))

    gc.collect()
示例#3
0
        result = orignal_predict_notonehot(input_img, im_bands, model,
                                           window_size)
        output_file = ''.join([
            '../../data/predict/', dict_network[FLAG_USING_NETWORK],
            '/sat_4bands/original_pred_', abs_filename, '_',
            dict_target[FLAG_TARGET_CLASS], '_onlyjaccard.png'
        ])
        print("result save as to: {}".format(output_file))
        cv2.imwrite(output_file, result * 128)

    elif FLAG_APPROACH_PREDICT == 1:
        print("[INFO] predict image by smooth approach\n")
        result = predict_img_with_smooth_windowing_multiclassbands(
            input_img,
            model,
            window_size=window_size,
            subdivisions=2,
            real_classes=target_class,  # output channels = 是真的类别,总类别-背景
            pred_func=smooth_predict_for_binary_notonehot)
        # output_file = ''.join(['../../data/predict/', dict_network[FLAG_USING_NETWORK],'/sat_4bands/mask_binary_',
        #                        abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_onlyjaccard.png'])

        output_file = ''.join([
            '../../data/test/paper/pred/mask_binary_', abs_filename, '_',
            dict_target[FLAG_TARGET_CLASS], '_onlyjaccard.png'
        ])
        print("result save as to: {}".format(output_file))

        cv2.imwrite(output_file, result)

    gc.collect()
def predict_binary_for_single_image(input_dict={}):
    gup_id = input_dict['GPUID']
    os.environ["CUDA_VISIBLE_DEVICES"] = gup_id
    window_size = input_dict['windsize']
    im_bands = input_dict['im_bands']
    im_type = input_dict['dtype']
    FLAG_APPROACH_PREDICT = input_dict['strategy']
    img_file = input_dict['image_file']
    model_file = input_dict['model_file']
    output_file = input_dict['mask_path']

    out_bands = 1
    FLAG_ONEHOT = 0
    if input_dict['onehot']:
        FLAG_ONEHOT = 1


    input_img = load_img_by_gdal(img_file)
    if im_type == UINT8:
        input_img = input_img / 255.0
    elif im_type == UINT10:
        input_img = input_img / 1024.0
    elif im_type == UINT16:
        input_img = input_img / 65535.0

    input_img = np.clip(input_img, 0.0, 1.0)
    input_img = input_img.astype(np.float16)  # test accuracy

    """checke model file"""
    print("model file: {}".format(model_file))
    if not os.path.isfile(model_file):
        print("model does not exist:{}".format(model_file))
        sys.exit(-2)

    model = load_model(model_file)

    if FLAG_APPROACH_PREDICT==0:
        print("[INFO] predict image by orignal approach\n")
        if FLAG_ONEHOT:
            result = orignal_predict_onehot(input_img, im_bands, model, window_size)
        else:
            result = orignal_predict_notonehot(input_img,im_bands, model, window_size)
        print("result save as to: {}".format(output_file))
        cv2.imwrite(output_file, result*128)

    elif FLAG_APPROACH_PREDICT==1:
        print("[INFO] predict image by smooth approach\n")
        if FLAG_ONEHOT:
            result = predict_img_with_smooth_windowing_multiclassbands(
                input_img,
                model,
                window_size=window_size,
                subdivisions=2,
                real_classes=out_bands,  # output channels = 是真的类别,总类别-背景
                pred_func=smooth_predict_for_binary_onehot
            )
        else:
            result = predict_img_with_smooth_windowing_multiclassbands(
                input_img,
                model,
                window_size=window_size,
                subdivisions=2,
                real_classes=out_bands,  # output channels = 是真的类别,总类别-背景
                pred_func=smooth_predict_for_binary_notonehot
            )


        print("result save as to: {}".format(output_file))

        cv2.imwrite(output_file, result)
        print("Saved to {}".format(output_file))

    gc.collect()

    return 0
def predict_multiclass_for_batch_image(input_dict={}):
    gup_id = input_dict['GPUID']
    os.environ["CUDA_VISIBLE_DEVICES"] = gup_id
    window_size = input_dict['windsize']
    im_bands = input_dict['im_bands']
    im_type = input_dict['dtype']
    FLAG_APPROACH_PREDICT = input_dict['strategy']
    input_path = input_dict['image_dir']
    model_file = input_dict['model_file']
    output_path = input_dict['mask_dir']

    out_bands = input_dict['target_num']


    all_files, num = get_file(input_path)
    if num == 0:
        print("There is no file in path:{}".format(input_path))
        sys.exit(-1)

    for img_file in all_files:
        print("[INFO] opening image...".format(img_file))
        input_img = load_img_by_gdal(img_file)
        if im_type == UINT8:
            input_img = input_img / 255.0
        elif im_type == UINT10:
            input_img = input_img / 1024.0
        elif im_type == UINT16:
            input_img = input_img / 65535.0

        input_img = np.clip(input_img, 0.0, 1.0)
        input_img = input_img.astype(np.float16)

        model = load_model(model_file)

        abs_filename = os.path.split(img_file)[1]
        abs_filename = abs_filename.split(".")[0]

        if FLAG_APPROACH_PREDICT == 0:
            print("[INFO] predict image by orignal approach\n")
            result = orignal_predict_onehot(input_img, im_bands, model, window_size)
            output_file = ''.join([output_path, '/', abs_filename, '.png'])
            print("result save as to: {}".format(output_file))
            cv2.imwrite(output_file, result * 128)

        elif FLAG_APPROACH_PREDICT == 1:
            print("[INFO] predict image by smooth approach\n")
            result = predict_img_with_smooth_windowing_multiclassbands(
                input_img,
                model,
                window_size=window_size,
                subdivisions=2,
                real_classes=out_bands,  # output channels = 是真的类别,总类别-背景
                pred_func=smooth_predict_for_multiclass,
                PLOT_PROGRESS=False
            )

            H, W, C = np.array(input_img).shape
            output_file = ''.join([output_path, '/', abs_filename, '_smooth_pred.png'])
            output_mask = np.zeros((H, W), np.uint8)
            for i in range(out_bands):
                indx = np.where(result[:, :, i] >= 127)
                output_mask[indx] = i + 1
            print(np.unique(result))
            cv2.imwrite(output_file, output_mask)
            print("Saved to:{}".format(output_file))

            # for b in range(out_bands):
            #     output_file = ''.join([output_path, '/', abs_filename, '_', dict_target[b],'smooth.png'])
            #     cv2.imwrite(output_file, result[:,:,b])
            #     print("Saved to: {}".format(output_file))

        gc.collect()

    return 0
def predict_binary_for_batch_image(input_dict={}):
    gup_id = input_dict['GPUID']
    os.environ["CUDA_VISIBLE_DEVICES"] = gup_id
    window_size = input_dict['windsize']
    im_bands = input_dict['im_bands']
    im_type = input_dict['dtype']
    FLAG_APPROACH_PREDICT = input_dict['strategy']
    input_path = input_dict['image_dir']
    model_file = input_dict['model_file']
    output_path = input_dict['mask_dir']

    out_bands = 1
    FLAG_ONEHOT = 0
    if input_dict['onehot']:
        FLAG_ONEHOT = 1

    all_files, num = get_file(input_path)
    if num == 0:
        print("There is no file in path:{}".format(input_path))
        sys.exit(-1)

    for img_file in all_files:
        print("[INFO] opening image...")
        print("FileName:{}".format(img_file))
        input_img = load_img_by_gdal(img_file)
        if im_type == UINT8:
            input_img = input_img / 255.0
        elif im_type == UINT10:
            input_img = input_img / 1024.0
        elif im_type == UINT16:
            input_img = input_img / 65535.0

        input_img = np.clip(input_img, 0.0, 1.0)
        input_img = input_img.astype(np.float16)

        model = load_model(model_file)

        abs_filename = os.path.split(img_file)[1]
        abs_filename = abs_filename.split(".")[0]

        if FLAG_APPROACH_PREDICT == 0:
            print("[INFO] predict image by orignal approach\n")
            if FLAG_ONEHOT:
                result = orignal_predict_onehot(input_img, im_bands, model, window_size)
            else:
                result = orignal_predict_notonehot(input_img, im_bands, model, window_size)

            output_file = ''.join([output_path, '/', abs_filename, '.png'])
            print("result save as to: {}".format(output_file))
            cv2.imwrite(output_file, result * 128)

        elif FLAG_APPROACH_PREDICT == 1:
            print("[INFO] predict image by smooth approach\n")
            if FLAG_ONEHOT:
                result = predict_img_with_smooth_windowing_multiclassbands(
                    input_img,
                    model,
                    window_size=window_size,
                    subdivisions=2,
                    real_classes=out_bands,  # output channels = 是真的类别,总类别-背景
                    pred_func=smooth_predict_for_binary_onehot,
                    PLOT_PROGRESS=False
                )
            else:
                result = predict_img_with_smooth_windowing_multiclassbands(
                    input_img,
                    model,
                    window_size=window_size,
                    subdivisions=2,
                    real_classes=out_bands,  # output channels = 是真的类别,总类别-背景
                    pred_func=smooth_predict_for_binary_notonehot,
                    PLOT_PROGRESS=False
                )

            output_file = ''.join([output_path, '/', abs_filename, 'smooth.png'])
            cv2.imwrite(output_file, result)
            print("Saved to: {}".format(output_file))

        gc.collect()

    return 0
                                                    result_channels)
            for key, val in multi_dict.items():
                output_file = output_mask + key + '.png'
                cv2.imwrite(
                    output_file,
                    result_test[:, :,
                                val - 1])  # achieve the integer automatically
    else:
        print("[INFO]sooth predict")
        if FLAG_USING_MODEL == 0:
            print("unet or fcnnet binary")
            predictions_smooth = predict_img_with_smooth_windowing_multiclassbands(
                input_img,
                model,
                window_size=window_size,
                subdivisions=2,
                real_classes=result_channels,  # output channels = 是真的类别,总类别-背景
                pred_func=smooth_predict_for_unet_binary
                # labelencoder=labelencoder
            )

            cv2.imwrite(output_mask, predictions_smooth)
        elif FLAG_USING_MODEL == 1:
            print("unet or fcn multiclass")
            predictions_smooth = predict_img_with_smooth_windowing_multiclassbands(
                input_img,
                model,
                window_size=window_size,
                subdivisions=2,
                real_classes=result_channels,  # output channels = 是真的类别,总类别-背景
                pred_func=smooth_predict_for_unet_multiclass