def testGetModelAndOutputNamesEmptyPredictions(self): eval_config = config_pb2.EvalConfig( model_specs=[config_pb2.ModelSpec()]) self.assertEmpty( util.StandardExtracts({ constants.PREDICTIONS_KEY: {} }).get_model_and_output_names(eval_config))
def testSetLabelsModelNameOutputNameOverwrite(self): extracts = util.StandardExtracts({ 'labels': { 'baseline': { 'output1': 3 }, 'candidate': { 'output1': 3 } } }) extracts.set_labels(7, model_name='candidate', output_name='output1') self.assertEqual( 7, extracts.get_labels(model_name='candidate', output_name='output1')) self.assertEqual( { 'labels': { 'baseline': { 'output1': 3 }, 'candidate': { 'output1': 7 } } }, dict(extracts))
def testGetLabelsNone(self): self.assertIsNone( util.StandardExtracts({ 'labels': { 'candidate': None, 'baseline': None }, }).get_labels('candidate'))
def testGetLabelsMultiOutput(self): self.assertEqual( 7, util.StandardExtracts({ 'labels': { 'output1': 7 }, }).get_labels(output_name='output1'))
def testGetLabelsMultiModel(self): self.assertEqual( 7, util.StandardExtracts({ 'labels': { 'candidate': 7, 'baseline': 8 }, }).get_labels('candidate'))
def testSetLabelsNonMappingRaisesException(self): with self.assertRaisesRegex( RuntimeError, 'set_by_keys failed with arguments:') as ctx: util.StandardExtracts({ 'labels': np.array([7]), }).set_labels(7, model_name='candidate') self.assertIsInstance(ctx.exception.__cause__, ValueError) self.assertRegex( str(ctx.exception.__cause__), r'Cannot set keys \(\[\'candidate\'\]\) on a non-mapping root.*')
def testGetModelAndOutputNamesMultiOutput(self): eval_config = config_pb2.EvalConfig( model_specs=[config_pb2.ModelSpec()]) self.assertEqual([(None, 'output1'), (None, 'output2')], util.StandardExtracts({ constants.PREDICTIONS_KEY: { 'output1': np.array([]), 'output2': np.array([]) } }).get_model_and_output_names(eval_config))
def testGetLabelsMultiModelMultiOutput(self): self.assertEqual( 7, util.StandardExtracts({ 'labels': { 'candidate': { 'output1': 7 }, 'baseline': { 'output1': 8 }, }, }).get_labels('candidate', 'output1'))
def testGetModelAndOutputNamesMultiModel(self): eval_config = config_pb2.EvalConfig(model_specs=[ config_pb2.ModelSpec(name=constants.BASELINE_KEY), config_pb2.ModelSpec(name=constants.CANDIDATE_KEY) ]) self.assertEqual([(constants.BASELINE_KEY, None), (constants.CANDIDATE_KEY, None)], util.StandardExtracts({ constants.PREDICTIONS_KEY: { constants.BASELINE_KEY: np.array([]), constants.CANDIDATE_KEY: np.array([]) } }).get_model_and_output_names(eval_config))
def testSetLabelsEmptyLabels(self): extracts = util.StandardExtracts({'labels': {}}) extracts.set_labels(7, model_name='candidate') self.assertEqual(7, extracts.get_labels(model_name='candidate'))
def testSetLabelsModelNameOutputNameEmptyExtracts(self): extracts = util.StandardExtracts({}) extracts.set_labels(7, model_name='candidate', output_name='output1') self.assertEqual( 7, extracts.get_labels(model_name='candidate', output_name='output1'))
def testGetLabels(self): self.assertEqual(7, util.StandardExtracts({'labels': 7}).get_labels())
def testGetModelAndOutputNamesEmptyExtracts(self): eval_config = config_pb2.EvalConfig( model_specs=[config_pb2.ModelSpec()]) self.assertEmpty( util.StandardExtracts({}).get_model_and_output_names(eval_config))