Esempio n. 1
0
    def test_prediction_extended_and_positions(self):
        # With actual model to evaluate correct positions
        args = predict_args()
        args.checkpoint = [
            os.path.join(this_dir, "models",
                         f"version{SavedCalamariModel.VERSION}", "0.ckpt")
        ]
        args.extended_prediction_data = True
        run(args)
        jsons = [
            os.path.join(this_dir, "data", "uw3_50lines", "test", "*.json")
        ]
        run_compute_avg_pred(ExtendedPredictionDataParams(files=jsons))

        def assert_pos_in_interval(p, start, end):
            self.assertGreaterEqual(p.global_start, start)
            self.assertGreaterEqual(p.global_end, start)
            self.assertLessEqual(p.global_start, end)
            self.assertLessEqual(p.global_end, end)

        with open(sorted(glob_all(jsons[0]))[0]) as f:
            first_pred: Predictions = Predictions.from_json(f.read())
            for p in first_pred.predictions:
                # Check for correct prediction string (models is trained!)
                self.assertEqual(
                    p.sentence,
                    "The problem, simplified for our purposes, is set up as")
                # Check for correct character positions
                assert_pos_in_interval(p.positions[0], 0, 24)  # T
                assert_pos_in_interval(p.positions[1], 24, 43)  # h
                assert_pos_in_interval(p.positions[2], 45, 63)  # e
                # ...
                assert_pos_in_interval(p.positions[-2], 1062, 1081)  # a
                assert_pos_in_interval(p.positions[-1], 1084, 1099)  # s
Esempio n. 2
0
 def test_prediction_extended(self):
     args = predict_args()
     args.extended_prediction_data = True
     run(args)
     jsons = [
         os.path.join(this_dir, "data", "uw3_50lines", "test", "*.json")
     ]
     run_compute_avg_pred(ExtendedPredictionDataParams(files=jsons))
Esempio n. 3
0
 def test_prediction_abbyy(self):
     run(
         predict_args(data=Abbyy(images=[
             os.path.join(
                 this_dir,
                 "data",
                 "hiltl_die_bank_des_verderbens_abbyyxml",
                 "*.jpg",
             )
         ], )))
Esempio n. 4
0
 def test_prediction_extended_pagexml_with_voting(self):
     # With actual model to evaluate correct positions
     args = predict_args(data=pagexml_dataset())
     args.checkpoint = [
         os.path.join(this_dir, "models",
                      f"version{SavedCalamariModel.VERSION}", "0.ckpt")
     ] * 2
     args.extended_prediction_data = True
     run(args)
     jsons = [
         os.path.join(this_dir, "data", "uw3_50lines", "test", "*.json")
     ]
     run_compute_avg_pred(ExtendedPredictionDataParams(files=jsons))
Esempio n. 5
0
 def test_prediction_voter(self):
     args = PredictionAttrs()
     run(args)
Esempio n. 6
0
 def test_prediction(self):
     args = PredictionAttrs()
     args.checkpoint = args.checkpoint[0:1]
     run(args)
Esempio n. 7
0
 def test_prediction_files(self):
     run(predict_args())
Esempio n. 8
0
 def test_prediction_hdf5(self):
     run(
         predict_args(data=Hdf5(files=[
             os.path.join(this_dir, "data", "uw3_50lines", "uw3-50lines.h5")
         ], )))
Esempio n. 9
0
 def test_prediction_pagexml(self):
     run(
         predict_args(data=PageXML(images=[
             os.path.join(this_dir, "data", "avicanon_pagexml",
                          "008.nrm.png")
         ], )))
Esempio n. 10
0
 def test_prediction_voter_files(self):
     run(predict_args(n_models=3))
Esempio n. 11
0
 def test_prediction_voter(self):
     args = PredictionAttrs()
     self.checkpoint = [os.path.join(this_dir, "test_models", "uw3_50lines_best.ckpt")] * 2
     run(args)
Esempio n. 12
0
 def test_prediction(self):
     args = PredictionAttrs()
     run(args)