def testEvaluationMetric(self): features_file = os.path.join(self.get_temp_dir(), "features.txt") labels_file = os.path.join(self.get_temp_dir(), "labels.txt") model_dir = self.get_temp_dir() with open(features_file, "w") as features, open(labels_file, "w") as labels: features.write("1\n2\n") labels.write("1\n2\n") model = TestModel({"a": [2, 5, 8], "b": [3, 6, 9]}) model.initialize({}) early_stopping = evaluation.EarlyStopping(metric="loss", min_improvement=0, steps=1) evaluator = evaluation.Evaluator( model, features_file, labels_file, batch_size=1, early_stopping=early_stopping, model_dir=model_dir, export_on_best="loss", exporter=TestExporter()) self.assertSetEqual(evaluator.metrics_name, {"loss", "perplexity", "a", "b"}) model.next_loss.assign(1) metrics_5 = evaluator(5) self._assertMetricsEqual( metrics_5, {"loss": 1.0, "perplexity": math.exp(1.0), "a": 2, "b": 3}) self.assertFalse(evaluator.should_stop()) self.assertTrue(evaluator.is_best("loss")) self.assertTrue(os.path.isdir(os.path.join(evaluator.export_dir, str(5)))) model.next_loss.assign(4) metrics_10 = evaluator(10) self._assertMetricsEqual( metrics_10, {"loss": 4.0, "perplexity": math.exp(4.0), "a": 5, "b": 6}) self.assertTrue(evaluator.should_stop()) self.assertFalse(evaluator.is_best("loss")) self.assertFalse(os.path.isdir(os.path.join(evaluator.export_dir, str(10)))) self.assertLen(evaluator.metrics_history, 2) self._assertMetricsEqual(evaluator.metrics_history[0][1], metrics_5) self._assertMetricsEqual(evaluator.metrics_history[1][1], metrics_10) # Recreating the evaluator should load the metrics history from the eval directory. evaluator = evaluation.Evaluator( model, features_file, labels_file, batch_size=1, model_dir=model_dir, export_on_best="loss", exporter=TestExporter()) self.assertLen(evaluator.metrics_history, 2) self._assertMetricsEqual(evaluator.metrics_history[0][1], metrics_5) self._assertMetricsEqual(evaluator.metrics_history[1][1], metrics_10) # Evaluating previous steps should clear future steps in the history. model.next_loss.assign(7) self._assertMetricsEqual( evaluator(7), {"loss": 7.0, "perplexity": math.exp(7.0), "a": 8, "b": 9}) self.assertFalse(evaluator.is_best("loss")) self.assertFalse(os.path.isdir(os.path.join(evaluator.export_dir, str(10)))) recorded_steps = list(step for step, _ in evaluator.metrics_history) self.assertListEqual(recorded_steps, [5, 7])
def testExportsGarbageCollection(self): features_file = os.path.join(self.get_temp_dir(), "features.txt") labels_file = os.path.join(self.get_temp_dir(), "labels.txt") model_dir = self.get_temp_dir() with open(features_file, "w") as features, open(labels_file, "w") as labels: features.write("1\n2\n") labels.write("1\n2\n") model = _TestModel() exporter = TestExporter() evaluator = evaluation.Evaluator( model, features_file, labels_file, batch_size=1, model_dir=model_dir, export_on_best="loss", exporter=exporter, max_exports_to_keep=2, ) # Generate some pre-existing exports. for step in (5, 10, 15): exporter.export(model, os.path.join(evaluator.export_dir, str(step))) def _eval_step(step, loss, expected_exported_steps): model.next_loss.assign(loss) evaluator(step) exported_steps = list( sorted(map(int, os.listdir(evaluator.export_dir)))) self.assertListEqual(exported_steps, expected_exported_steps) _eval_step(20, 3, [15, 20]) # Exports 5 and 10 should be removed. _eval_step(25, 2, [20, 25]) # Export 15 should be removed.
def testEvaluationWithRougeScorer(self): features_file = os.path.join(self.get_temp_dir(), "features.txt") labels_file = os.path.join(self.get_temp_dir(), "labels.txt") model_dir = self.get_temp_dir() with open(features_file, "w") as features, open(labels_file, "w") as labels: features.write("1\n2\n") labels.write("1\n2\n") model = TestModel() evaluator = evaluation.Evaluator(model, features_file, labels_file, batch_size=1, scorers=[scorers.ROUGEScorer()], model_dir=model_dir) self.assertNotIn("rouge", evaluator.metrics_name) self.assertIn("rouge-1", evaluator.metrics_name) self.assertIn("rouge-2", evaluator.metrics_name) self.assertIn("rouge-l", evaluator.metrics_name)