def test_fixed_eval_sees_the_same_evals(self, mock_get_dataset, mock_checkpoints_iterator): dataset = data_providers_test.make_golden_dataset() n_checkpoints = 3 checkpoints = [ self._write_fake_checkpoint('constant', name='model' + str(i)) for i in range(n_checkpoints) ] # Setup our mocks. mock_checkpoints_iterator.return_value = checkpoints mock_get_dataset.return_value = dataset # Start up eval, loading that checkpoint. FLAGS.batch_size = 2 FLAGS.checkpoint_dir = self.checkpoint_dir FLAGS.eval_dir = tf.test.get_temp_dir() FLAGS.max_evaluations = n_checkpoints FLAGS.model_name = 'constant' FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt' model_eval.main(0) self.assertEqual(mock_get_dataset.call_args_list, [mock.call(FLAGS.dataset_config_pbtxt)] * n_checkpoints) metrics = [ model_eval.read_metrics(checkpoint, eval_dir=FLAGS.eval_dir) for checkpoint in checkpoints ] # Check that our metrics are what we expect them to be. # See b/62864044 for details on how to compute these counts: # Counts of labels in our golden dataset: # 1 0 # 12 1 # 35 2 expected_values_for_all_exact = { # We have 12 correct calls [there are 12 variants with a label of 1] and # 1 label 0 + 35 with a label of 2, so we have an accuracy of 12 / 48, # which is 0.25. 'Accuracy/All': 0.25, # We don't have any FNs because we call everything het. 'FNs/All': 0, # One of our labels is 0, which we call het, giving us 1 FP. 'FPs/All': 1.0, # We call everything as het, so the recall has to be 1. 'Recall/All': 1.0, # redacted # # We don't call anything but hets, so TNs has to be 0. # 'TNs/All': 0, # We find all positives, so this has to be 47. 'TPs/All': 47, } for key, expected_value in expected_values_for_all_exact.iteritems(): self.assertEqual(metrics[0][key], expected_value) expected_values_for_all_close = { # We called 47 / 48 correctly. 'Precision/All': 47. / 48, } for key, expected_value in expected_values_for_all_close.iteritems(): self.assertAlmostEqual(metrics[0][key], expected_value, places=6) for m1, m2 in zip(metrics, metrics[1:]): self.assertEqual(m1, m2)
def test_fixed_eval_sees_the_same_evals(self, mock_get_input_fn_from_dataset, mock_checkpoints_iterator): dataset = data_providers_test.make_golden_dataset( use_tpu=FLAGS.use_tpu) n_checkpoints = 3 checkpoints = [ tf_test_utils.write_fake_checkpoint('constant', self.test_session(), self.checkpoint_dir, FLAGS.moving_average_decay, name='model' + str(i)) for i in range(n_checkpoints) ] # Setup our mocks. mock_checkpoints_iterator.return_value = checkpoints mock_get_input_fn_from_dataset.return_value = dataset # Start up eval, loading that checkpoint. FLAGS.batch_size = 2 FLAGS.checkpoint_dir = self.checkpoint_dir FLAGS.eval_name = self.eval_name FLAGS.max_evaluations = n_checkpoints FLAGS.model_name = 'constant' FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt' FLAGS.master = '' model_eval.main(0) self.assertEqual(mock_get_input_fn_from_dataset.call_args_list, [ mock.call(use_tpu=FLAGS.use_tpu, dataset_config_filename=FLAGS.dataset_config_pbtxt, mode=tf.estimator.ModeKeys.EVAL) ]) metrics = [ model_eval.read_metrics(checkpoint, eval_name=FLAGS.eval_name) for checkpoint in checkpoints ] # Check that our metrics are what we expect them to be. # See b/62864044 for details on how to compute these counts: # Counts of labels in our golden dataset: # 1 0 # 12 1 # 35 2 expected_values_for_all_exact = { # We have 12 correct calls [there are 12 variants with a label of 1] and # 1 label 0 + 35 with a label of 2, so we have an accuracy of 12 / 48, # which is 0.25. 'Accuracy/All': 0.25, # We don't have any FNs because we call everything het. 'FNs/All': 0, # Two of our labels are 0, which we call het, giving us 2 FP. 'FPs/All': 1.0, # We call everything as het, so the recall has to be 1. 'Recall/All': 1.0, # redacted # # We don't call anything but hets, so TNs has to be 0. # 'TNs/All': 0, # We find 47 positives, so this has to be 47. 'TPs/All': 47, } for key, expected_value in expected_values_for_all_exact.iteritems(): print(str(key) + '=' + str(metrics[0][key])) for key, expected_value in expected_values_for_all_exact.iteritems(): self.assertEqual(metrics[0][key], expected_value) expected_values_for_all_close = { # We called 47 / 48 correctly ~ 0.979167 'Precision/All': 0.979167, # We called (2 * 47 / 48) / (1 + 47 / 48) correctly ~ 0.989474 'F1/All': 0.989474, } for key, expected_value in expected_values_for_all_close.iteritems(): self.assertAlmostEqual(metrics[0][key], expected_value, places=6) for m1, m2 in zip(metrics, metrics[1:]): self.assertEqual(m1, m2)