def test_score_estimator_cpu(self): download_model("wmt-large-da-estimator-1719", DATA_PATH) args = [ "--model", DATA_PATH + "wmt-large-da-estimator-1719/_ckpt_epoch_1.ckpt", "-s", DATA_PATH + "src.en", "-h", DATA_PATH + "mt.de", "-r", DATA_PATH + "ref.de", "--to_json", DATA_PATH + "results.json", "--cpu", "--batch_size", "32", ] result = self.runner.invoke(score, args, catch_exceptions=False) self.assertEqual(result.exit_code, 0) self.assertIn( "Predictions saved in: {}".format(DATA_PATH + "results.json"), result.stdout) self.assertTrue(os.path.exists(DATA_PATH + "results.json")) with open(DATA_PATH + "results.json") as json_file: samples = json.load(json_file) for s in samples: self.assertIn("predicted_score", s) self.assertIn("COMET system score: ", result.stdout)
def test_download_model(self, mock_stdout): model = download_model("wmt-base-da-estimator-1719", DATA_PATH) self.assertIsInstance(model, CometEstimator) self.assertIn("Download succeeded. Loading model...", mock_stdout.getvalue()) download_model("wmt-base-da-estimator-1719", DATA_PATH) self.assertIn("is already in cache.", mock_stdout.getvalue())
def test_score_ranker_cpu(self): download_model("emnlp-base-da-ranker", DATA_PATH) args = [ "--model", DATA_PATH + "emnlp-base-da-ranker/_ckpt_epoch_0.ckpt", "-s", DATA_PATH + "src.en", "-h", DATA_PATH + "mt.de", "-r", DATA_PATH + "ref.de", "--cpu", ] result = self.runner.invoke(score, args, catch_exceptions=False) self.assertEqual(result.exit_code, 0) self.assertIn("COMET system score: ", result.stdout)
def test_score_estimator_cpu(self): download_model("wmt-large-da-estimator-1719", DATA_PATH) args = [ "--model", DATA_PATH + "wmt-large-da-estimator-1719/_ckpt_epoch_1.ckpt", "-s", DATA_PATH + "src.en", "-h", DATA_PATH + "mt.de", "-r", DATA_PATH + "ref.de", "--to_json", DATA_PATH + "results.json" "--cpu", ] result = self.runner.invoke(score, args, catch_exceptions=False) self.assertEqual(result.exit_code, 0) self.assertIn( "Predictions saved in: {}".format(DATA_PATH + "results.json"), result.stdout) self.assertIn("COMET system score: ", result.stdout)
def comet(l2, sources, outputs, references, LOG): from comet.models import download_model model = download_model("wmt-large-da-estimator-1719", "comet_models/") #references = [references data = {"src": sources, "mt": outputs, "ref": references} print(data['src'][:10]) print(data['mt'][:10]) print(data['ref'][:10]) data = [dict(zip(data, t)) for t in zip(*data.values())] temp = model.predict(data, cuda=False, show_progress=True) comet_score = np.mean(temp[1]) print(f"All comment values: {temp[1]}\n") print(f"COMET score: {comet_score}\n") with open(LOG, 'a') as op: op.write(f"All comment values: {temp[1]}\n") op.write(f"COMET score: {comet_score}\n")
def score(model, source, hypothesis, reference, cuda, to_json): source = [s.strip() for s in source.readlines()] hypothesis = [s.strip() for s in hypothesis.readlines()] reference = [s.strip() for s in reference.readlines()] data = {"src": source, "mt": hypothesis, "ref": reference} data = [dict(zip(data, t)) for t in zip(*data.values())] model = load_checkpoint(model) if os.path.exists( model) else download_model(model) data, scores = model.predict(data, cuda, show_progress=True) if isinstance(to_json, str): with open(to_json, "w") as outfile: json.dump(data, outfile, ensure_ascii=False, indent=4) click.secho(f"Predictions saved in: {to_json}.", fg="yellow") click.secho("COMET system score: {}.".format(sum(scores) / len(scores)), fg="yellow")
def download(data, model, saving_path): for d in data: download_corpus(d, saving_path) for m in model: download_model(m, saving_path)
def test_download_wrong_model(self): with self.assertRaises(Exception) as context: download_model("WrongModel", DATA_PATH) self.assertEqual(str(context.exception), "WrongModel is not a valid COMET model!")
def _download_and_prepare(self, dl_manager): if self.config_name == "default": self.scorer = download_model("wmt-large-da-estimator-1719") else: self.scorer = download_model(self.config_name)
def __init__(self, model_name="wmt-large-da-estimator-1719"): import torch from comet.models import download_model self.cuda = torch.cuda.is_available() self.model = download_model(model_name)