def testEmptyMergeResults(self): result1 = results._Result(properties={}, property_names=[]) result2 = results._Result(properties={}, property_names=[]) merge_result = results._merge_results([result1, result2]) self.assertEqual( results._Result(properties={}, property_names=[]), merge_result)
def testOneEmptyMergeResults(self): result1 = results._Result( properties={'key': { 'nkey': 'val' }}, property_names=['test']) result2 = results._Result(properties={}, property_names=[]) merge_result = results._merge_results([result1, result2]) self.assertEqual( results._Result( properties={'key': { 'nkey': 'val' }}, property_names=['test']), merge_result)
def testMergeResults(self): result1 = results._Result( properties={ 'key1': { 'hparam': 'val' }, 'key2': { 'hparam': 'val' } }, property_names=['hparam_names']) result2 = results._Result( { 'key1': { 'metrics': 'val' }, 'key3': { 'metrics': 'val' } }, property_names=['metric_names']) merge_result = results._merge_results([result1, result2]) want_result = results._Result( properties={ 'key1': { 'hparam': 'val', 'metrics': 'val' }, 'key2': { 'hparam': 'val' }, 'key3': { 'metrics': 'val' } }, property_names=['hparam_names', 'metric_names']) self.assertEqual(want_result, merge_result)
def testGetHparams(self): hparam = "['batch_size=256', 'learning_rate=0.05', 'decay_rate=0.95']" run_id = '0' trainer_name = results._TRAINER_PREFIX + '.Test' self._put_execution(run_id, trainer_name, hparam) result = results._get_hparams(self.test_mlmd.store) want_result = results._Result( properties={ '0.Test': { 'batch_size': 256, 'learning_rate': 0.05, 'decay_rate': 0.95, results.RUN_ID_KEY: '0', results.BENCHMARK_KEY: 'Test', results.STARTED_AT: datetime.datetime.fromtimestamp(0) } }, property_names=['batch_size', 'decay_rate', 'learning_rate']) self.assertEqual(want_result, result)
def testGetBenchmarkResults(self): run_id = '0' artifact_id = self.test_mlmd.put_artifact({ 'accuracy': '0.25', 'average_loss': '2.40', results.BENCHMARK_KEY: 'Test' }) execution_id = self.test_mlmd.put_execution(run_id) self.test_mlmd.put_event(artifact_id, execution_id) result = results._get_benchmark_results(self.test_mlmd.store) want_result = results._Result( properties={ '0.Test': { 'accuracy': 0.25, 'average_loss': 2.40, results.RUN_ID_KEY: '0', results.BENCHMARK_KEY: 'Test', results.STARTED_AT: datetime.datetime.fromtimestamp(0) } }, property_names=['accuracy', 'average_loss']) self.assertEqual(want_result, result)