def test_update_statistics_raise_ConnectionError(self, m): # check: don't raise Exception when model-api returns 500 # Internal-Server-Error logger_mock = mock.MagicMock() self.client.logger = logger_mock try: statistics = Statistics(progress_percentage=0.5, key1='value1') self.client.update_statistics(statistics) self.assertEqual(m.call_count, 1) url = '{}/organizations/{}/training/definitions/{}/jobs/{}/statistics'.format( ABEJA_API_URL, ORGANIZATION_ID, TRAINING_JON_DEFINITION_NAME, TRAINING_JOB_ID) m.assert_called_with( 'POST', url, params=None, headers={ 'User-Agent': 'abeja-platform-sdk/{}'.format(VERSION)}, data=None, json={ 'statistics': { 'progress_percentage': 0.5, 'key1': 'value1'}}) self.assertEqual(logger_mock.warning.call_count, 0) self.assertEqual(logger_mock.exception.call_count, 1) except Exception: self.fail()
def test_update_statistics_progress_within_statistics(self, m): statistics = Statistics(progress_percentage=0.5) statistics.add_stage(name='other_stage', key1='value1') self.client.update_statistics(statistics) self.assertEqual(m.call_count, 1) url = '{}/organizations/{}/training/definitions/{}/jobs/{}/statistics'.format( ABEJA_API_URL, ORGANIZATION_ID, TRAINING_JON_DEFINITION_NAME, TRAINING_JOB_ID) expected_data = { 'statistics': { 'progress_percentage': 0.5, 'stages': { 'other_stage': { 'key1': 'value1' } } } } m.assert_called_with( 'POST', url, params=None, headers={ 'User-Agent': 'abeja-platform-sdk/{}'.format(VERSION)}, timeout=30, data=None, json=expected_data)
def test_update_statistics_with_empty_statistics(self, m): # check: don't raise Exception logger_mock = mock.MagicMock() self.client.logger = logger_mock try: self.client.update_statistics(Statistics()) m.assert_not_called() self.assertEqual(logger_mock.warning.call_count, 1) self.assertEqual(logger_mock.exception.call_count, 0) except Exception: self.fail()
def test_update_statistics_override_organization_id(self, m): organization_id = '2222222222222' client = Client(organization_id=organization_id) statistics = Statistics(progress_percentage=0.5, key1='value1') client.update_statistics(statistics) self.assertEqual(m.call_count, 1) url = '{}/organizations/{}/training/definitions/{}/jobs/{}/statistics'.format( ABEJA_API_URL, organization_id, TRAINING_JON_DEFINITION_NAME, TRAINING_JOB_ID) m.assert_called_with( 'POST', url, params=None, headers={ 'User-Agent': 'abeja-platform-sdk/{}'.format(VERSION)}, timeout=30, data=None, json={ 'statistics': { 'progress_percentage': 0.5, 'key1': 'value1'}})
def test_update_statistics(self, m): statistics = Statistics(progress_percentage=0.5, epoch=1, num_epochs=5, key1='value1') statistics.add_stage( name=Statistics.STAGE_TRAIN, accuracy=0.9, loss=0.05) statistics.add_stage(name=Statistics.STAGE_VALIDATION, accuracy=0.8, loss=0.1, key2=2) self.client.update_statistics(statistics) self.assertEqual(m.call_count, 1) url = '{}/organizations/{}/training/definitions/{}/jobs/{}/statistics'.format( ABEJA_API_URL, ORGANIZATION_ID, TRAINING_JON_DEFINITION_NAME, TRAINING_JOB_ID) expected_data = { 'statistics': { 'num_epochs': 5, 'epoch': 1, 'progress_percentage': 0.5, 'stages': { 'train': { 'accuracy': 0.9, 'loss': 0.05 }, 'validation': { 'accuracy': 0.8, 'loss': 0.1, 'key2': 2 } }, 'key1': 'value1' } } m.assert_called_with( 'POST', url, params=None, headers={ 'User-Agent': 'abeja-platform-sdk/{}'.format(VERSION)}, timeout=30, data=None, json=expected_data)