Esempio n. 1
0
 def testGetModelAndOutputNamesEmptyPredictions(self):
     eval_config = config_pb2.EvalConfig(
         model_specs=[config_pb2.ModelSpec()])
     self.assertEmpty(
         util.StandardExtracts({
             constants.PREDICTIONS_KEY: {}
         }).get_model_and_output_names(eval_config))
Esempio n. 2
0
 def testSetLabelsModelNameOutputNameOverwrite(self):
     extracts = util.StandardExtracts({
         'labels': {
             'baseline': {
                 'output1': 3
             },
             'candidate': {
                 'output1': 3
             }
         }
     })
     extracts.set_labels(7, model_name='candidate', output_name='output1')
     self.assertEqual(
         7,
         extracts.get_labels(model_name='candidate', output_name='output1'))
     self.assertEqual(
         {
             'labels': {
                 'baseline': {
                     'output1': 3
                 },
                 'candidate': {
                     'output1': 7
                 }
             }
         }, dict(extracts))
Esempio n. 3
0
 def testGetLabelsNone(self):
     self.assertIsNone(
         util.StandardExtracts({
             'labels': {
                 'candidate': None,
                 'baseline': None
             },
         }).get_labels('candidate'))
Esempio n. 4
0
 def testGetLabelsMultiOutput(self):
     self.assertEqual(
         7,
         util.StandardExtracts({
             'labels': {
                 'output1': 7
             },
         }).get_labels(output_name='output1'))
Esempio n. 5
0
 def testGetLabelsMultiModel(self):
     self.assertEqual(
         7,
         util.StandardExtracts({
             'labels': {
                 'candidate': 7,
                 'baseline': 8
             },
         }).get_labels('candidate'))
Esempio n. 6
0
 def testSetLabelsNonMappingRaisesException(self):
     with self.assertRaisesRegex(
             RuntimeError, 'set_by_keys failed with arguments:') as ctx:
         util.StandardExtracts({
             'labels': np.array([7]),
         }).set_labels(7, model_name='candidate')
     self.assertIsInstance(ctx.exception.__cause__, ValueError)
     self.assertRegex(
         str(ctx.exception.__cause__),
         r'Cannot set keys \(\[\'candidate\'\]\) on a non-mapping root.*')
Esempio n. 7
0
 def testGetModelAndOutputNamesMultiOutput(self):
     eval_config = config_pb2.EvalConfig(
         model_specs=[config_pb2.ModelSpec()])
     self.assertEqual([(None, 'output1'), (None, 'output2')],
                      util.StandardExtracts({
                          constants.PREDICTIONS_KEY: {
                              'output1': np.array([]),
                              'output2': np.array([])
                          }
                      }).get_model_and_output_names(eval_config))
Esempio n. 8
0
 def testGetLabelsMultiModelMultiOutput(self):
     self.assertEqual(
         7,
         util.StandardExtracts({
             'labels': {
                 'candidate': {
                     'output1': 7
                 },
                 'baseline': {
                     'output1': 8
                 },
             },
         }).get_labels('candidate', 'output1'))
Esempio n. 9
0
 def testGetModelAndOutputNamesMultiModel(self):
     eval_config = config_pb2.EvalConfig(model_specs=[
         config_pb2.ModelSpec(name=constants.BASELINE_KEY),
         config_pb2.ModelSpec(name=constants.CANDIDATE_KEY)
     ])
     self.assertEqual([(constants.BASELINE_KEY, None),
                       (constants.CANDIDATE_KEY, None)],
                      util.StandardExtracts({
                          constants.PREDICTIONS_KEY: {
                              constants.BASELINE_KEY: np.array([]),
                              constants.CANDIDATE_KEY: np.array([])
                          }
                      }).get_model_and_output_names(eval_config))
Esempio n. 10
0
 def testSetLabelsEmptyLabels(self):
     extracts = util.StandardExtracts({'labels': {}})
     extracts.set_labels(7, model_name='candidate')
     self.assertEqual(7, extracts.get_labels(model_name='candidate'))
Esempio n. 11
0
 def testSetLabelsModelNameOutputNameEmptyExtracts(self):
     extracts = util.StandardExtracts({})
     extracts.set_labels(7, model_name='candidate', output_name='output1')
     self.assertEqual(
         7,
         extracts.get_labels(model_name='candidate', output_name='output1'))
Esempio n. 12
0
 def testGetLabels(self):
     self.assertEqual(7, util.StandardExtracts({'labels': 7}).get_labels())
Esempio n. 13
0
 def testGetModelAndOutputNamesEmptyExtracts(self):
     eval_config = config_pb2.EvalConfig(
         model_specs=[config_pb2.ModelSpec()])
     self.assertEmpty(
         util.StandardExtracts({}).get_model_and_output_names(eval_config))