import digit_detector.show as show
import digit_detector.region_proposal as rp

N_IMAGES = None
DIR = '../datasets/svhn/train'
ANNOTATION_FILE = "../datasets/svhn/train/digitStruct.json"
NEG_OVERLAP_THD = 0.05
POS_OVERLAP_THD = 0.6
PATCH_SIZE = (32, 32)

if __name__ == "__main__":

    # 1. file 을 load
    files = file_io.list_files(directory=DIR,
                               pattern="*.png",
                               recursive_option=False,
                               n_files_to_sample=N_IMAGES,
                               random_order=False)
    n_files = len(files)
    n_train_files = int(n_files * 0.8)
    print n_train_files

    extractor = extractor_.Extractor(rp.MserRegionProposer(),
                                     ann.SvhnAnnotation(ANNOTATION_FILE),
                                     rp.OverlapCalculator())
    train_samples, train_labels = extractor.extract_patch(
        files[:n_train_files], PATCH_SIZE, POS_OVERLAP_THD, NEG_OVERLAP_THD)

    extractor = extractor_.Extractor(rp.MserRegionProposer(),
                                     ann.SvhnAnnotation(ANNOTATION_FILE),
                                     rp.OverlapCalculator())
detect_model = "detector_model.hdf5"
recognize_model = "recognize_model.hdf5"

mean_value_for_detector = 107.524
mean_value_for_recognizer = 112.833

model_input_shape = (32, 32, 1)
# DIR = '../datasets/svhn/train'
DIR = './extra_examples'

if __name__ == "__main__":
    # 1. image files
    img_files = file_io.list_files(directory=DIR,
                                   pattern="*.jpg",
                                   recursive_option=False,
                                   n_files_to_sample=None,
                                   random_order=False)

    preproc_for_detector = preproc.GrayImgPreprocessor(mean_value_for_detector)
    preproc_for_recognizer = preproc.GrayImgPreprocessor(
        mean_value_for_recognizer)

    char_detector = cls.CnnClassifier(detect_model, preproc_for_detector,
                                      model_input_shape)
    char_recognizer = cls.CnnClassifier(recognize_model,
                                        preproc_for_recognizer,
                                        model_input_shape)

    digit_spotter = detector.DigitSpotter(char_detector, char_recognizer,
                                          rp.MserRegionProposer())