Exemplo n.º 1
0
def test_train_and_compare():
    weights_file = artifact_dir + "/weights.h5"
    print(weights_file)
    model_v4, test_func_v4 = anmccv4.train_model(dataset_dir)
    try:
        os.makedirs(artifact_dir)
    except:
        print("Error creating artifact_dir")

    model_v4.save_weights(weights_file)
    model_v5 = anmccv5.ComponentModelCTC_v5(weights_file)
    test_func_v5 = model_v5._predictor
    size = 32
    tests = 20
    for i in range(tests):
        X_data = np.random.ranf([size, 1, model_v5.img_h, model_v5.img_w])
        v4_out = test_func_v4([X_data])
        v5_out = test_func_v5([X_data])
        assert len(v4_out) == len(
            v5_out), "Output lengths differ: \nv4: {}\nv5: {}".format(
                len(v4_out), len(v5_out))
        print("V4:\n{}".format(v4_out))
        print("V5:\n{}".format(v5_out))
        for i, v4v in enumerate(v4_out):
            v5v = v5_out[i]
            assert np.all(v4v == v5v)
Exemplo n.º 2
0
def test_train_for_a_single_word():
    model = anmccv5.ComponentModelCTC_v5(None)
    model.train_model(dataset_dir)
    save_file = "/tmp/small-classifier.h5"
    model.save_weights(save_file)
    model_reloaded = anmccv5.ComponentModelCTC_v5(save_file)
    image_files_dir = os.path.expandvars(
        "$HOME/Annex/Arabic/arabic-component-dataset-test-small/")
    image_files = aui.recursive_file_list(image_files_dir)
    assert len(image_files) > 0
    for image_file in image_files:
        assert os.path.exists(image_file)
        img = cv2.imread(image_file, 0)
        # cls = model.classify_image(img)
        cls_reloaded = model_reloaded.classify_image(img)
        mm = re.match(image_files_dir + r'([a-z]+)/cc.*png', image_file)
        correct_cls = mm.group(1)
        yield check_cls, correct_cls, cls_reloaded
Exemplo n.º 3
0
def test_train_and_classify():
    current_weights_file = artifact_dir + "/weights.h5"
    if os.path.exists(current_weights_file):
        model = anmccv5.ComponentModelCTC_v5(current_weights_file)
    else:
        model = anmccv5.ComponentModelCTC_v5(None)
        model.train_model(dataset_dir=dataset_dir, nb_epoch=100)
        if not os.path.exists(artifact_dir):
            os.mkdir(artifact_dir)
        model.save_weights(current_weights_file)

    image_files_dir = dataset_dir
    image_files = aui.recursive_file_list(image_files_dir)
    image_files = np.random.choice(image_files, 200)
    assert len(image_files) > 0
    for image_file in image_files:
        assert os.path.exists(image_file)
        img = cv2.imread(image_file, 0)
        cls = model.classify_image(img)
        mm = re.match(image_files_dir + r'([a-z]+)/.*png', image_file)
        correct_cls = mm.group(1)
        yield check_cls_among, correct_cls, cls
Exemplo n.º 4
0
def test_classify_image_with_pretrained():
    model = anmccv5.ComponentModelCTC_v5(weights_file)
    image_files_dir = os.path.expandvars(
        "$HOME/Annex/Arabic/Avians/arabic-component-dataset-4170/")
    image_files = aui.recursive_file_list(image_files_dir)
    image_files = np.random.choice(image_files, 200)
    assert len(image_files) > 0
    for image_file in image_files:
        assert os.path.exists(image_file)
        img = cv2.imread(image_file, 0)
        cls = model.classify_image(img)
        mm = re.match(image_files_dir + r'([a-z]+)/cc.*png', image_file)
        correct_cls = mm.group(1)
        yield check_cls, correct_cls, cls
Exemplo n.º 5
0
def test_init():
    # Is it able to create a model
    model = anmccv5.ComponentModelCTC_v5(weights_file)
    assert model is not None