def test_prediction_extended_and_positions(self): # With actual model to evaluate correct positions args = predict_args() args.checkpoint = [ os.path.join(this_dir, "models", f"version{SavedCalamariModel.VERSION}", "0.ckpt") ] args.extended_prediction_data = True run(args) jsons = [ os.path.join(this_dir, "data", "uw3_50lines", "test", "*.json") ] run_compute_avg_pred(ExtendedPredictionDataParams(files=jsons)) def assert_pos_in_interval(p, start, end): self.assertGreaterEqual(p.global_start, start) self.assertGreaterEqual(p.global_end, start) self.assertLessEqual(p.global_start, end) self.assertLessEqual(p.global_end, end) with open(sorted(glob_all(jsons[0]))[0]) as f: first_pred: Predictions = Predictions.from_json(f.read()) for p in first_pred.predictions: # Check for correct prediction string (models is trained!) self.assertEqual( p.sentence, "The problem, simplified for our purposes, is set up as") # Check for correct character positions assert_pos_in_interval(p.positions[0], 0, 24) # T assert_pos_in_interval(p.positions[1], 24, 43) # h assert_pos_in_interval(p.positions[2], 45, 63) # e # ... assert_pos_in_interval(p.positions[-2], 1062, 1081) # a assert_pos_in_interval(p.positions[-1], 1084, 1099) # s
def test_prediction_extended(self): args = predict_args() args.extended_prediction_data = True run(args) jsons = [ os.path.join(this_dir, "data", "uw3_50lines", "test", "*.json") ] run_compute_avg_pred(ExtendedPredictionDataParams(files=jsons))
def test_prediction_abbyy(self): run( predict_args(data=Abbyy(images=[ os.path.join( this_dir, "data", "hiltl_die_bank_des_verderbens_abbyyxml", "*.jpg", ) ], )))
def test_prediction_extended_pagexml_with_voting(self): # With actual model to evaluate correct positions args = predict_args(data=pagexml_dataset()) args.checkpoint = [ os.path.join(this_dir, "models", f"version{SavedCalamariModel.VERSION}", "0.ckpt") ] * 2 args.extended_prediction_data = True run(args) jsons = [ os.path.join(this_dir, "data", "uw3_50lines", "test", "*.json") ] run_compute_avg_pred(ExtendedPredictionDataParams(files=jsons))
def test_prediction_voter(self): args = PredictionAttrs() run(args)
def test_prediction(self): args = PredictionAttrs() args.checkpoint = args.checkpoint[0:1] run(args)
def test_prediction_files(self): run(predict_args())
def test_prediction_hdf5(self): run( predict_args(data=Hdf5(files=[ os.path.join(this_dir, "data", "uw3_50lines", "uw3-50lines.h5") ], )))
def test_prediction_pagexml(self): run( predict_args(data=PageXML(images=[ os.path.join(this_dir, "data", "avicanon_pagexml", "008.nrm.png") ], )))
def test_prediction_voter_files(self): run(predict_args(n_models=3))
def test_prediction_voter(self): args = PredictionAttrs() self.checkpoint = [os.path.join(this_dir, "test_models", "uw3_50lines_best.ckpt")] * 2 run(args)
def test_prediction(self): args = PredictionAttrs() run(args)