Exemple #1
0
 def test_update_statistics_raise_ConnectionError(self, m):
     # check: don't raise Exception when model-api returns 500
     # Internal-Server-Error
     logger_mock = mock.MagicMock()
     self.client.logger = logger_mock
     try:
         statistics = Statistics(progress_percentage=0.5, key1='value1')
         self.client.update_statistics(statistics)
         self.assertEqual(m.call_count, 1)
         url = '{}/organizations/{}/training/definitions/{}/jobs/{}/statistics'.format(
             ABEJA_API_URL, ORGANIZATION_ID, TRAINING_JON_DEFINITION_NAME, TRAINING_JOB_ID)
         m.assert_called_with(
             'POST',
             url,
             params=None,
             headers={
                 'User-Agent': 'abeja-platform-sdk/{}'.format(VERSION)},
             data=None,
             json={
                 'statistics': {
                     'progress_percentage': 0.5,
                     'key1': 'value1'}})
         self.assertEqual(logger_mock.warning.call_count, 0)
         self.assertEqual(logger_mock.exception.call_count, 1)
     except Exception:
         self.fail()
Exemple #2
0
 def test_update_statistics_progress_within_statistics(self, m):
     statistics = Statistics(progress_percentage=0.5)
     statistics.add_stage(name='other_stage', key1='value1')
     self.client.update_statistics(statistics)
     self.assertEqual(m.call_count, 1)
     url = '{}/organizations/{}/training/definitions/{}/jobs/{}/statistics'.format(
         ABEJA_API_URL, ORGANIZATION_ID, TRAINING_JON_DEFINITION_NAME, TRAINING_JOB_ID)
     expected_data = {
         'statistics': {
             'progress_percentage': 0.5,
             'stages': {
                 'other_stage': {
                     'key1': 'value1'
                 }
             }
         }
     }
     m.assert_called_with(
         'POST',
         url,
         params=None,
         headers={
             'User-Agent': 'abeja-platform-sdk/{}'.format(VERSION)},
         timeout=30,
         data=None,
         json=expected_data)
Exemple #3
0
 def test_update_statistics_with_empty_statistics(self, m):
     # check: don't raise Exception
     logger_mock = mock.MagicMock()
     self.client.logger = logger_mock
     try:
         self.client.update_statistics(Statistics())
         m.assert_not_called()
         self.assertEqual(logger_mock.warning.call_count, 1)
         self.assertEqual(logger_mock.exception.call_count, 0)
     except Exception:
         self.fail()
Exemple #4
0
 def test_update_statistics_override_organization_id(self, m):
     organization_id = '2222222222222'
     client = Client(organization_id=organization_id)
     statistics = Statistics(progress_percentage=0.5, key1='value1')
     client.update_statistics(statistics)
     self.assertEqual(m.call_count, 1)
     url = '{}/organizations/{}/training/definitions/{}/jobs/{}/statistics'.format(
         ABEJA_API_URL, organization_id, TRAINING_JON_DEFINITION_NAME, TRAINING_JOB_ID)
     m.assert_called_with(
         'POST',
         url,
         params=None,
         headers={
             'User-Agent': 'abeja-platform-sdk/{}'.format(VERSION)},
         timeout=30,
         data=None,
         json={
             'statistics': {
                 'progress_percentage': 0.5,
                 'key1': 'value1'}})
Exemple #5
0
 def test_update_statistics(self, m):
     statistics = Statistics(progress_percentage=0.5, epoch=1,
                             num_epochs=5, key1='value1')
     statistics.add_stage(
         name=Statistics.STAGE_TRAIN,
         accuracy=0.9,
         loss=0.05)
     statistics.add_stage(name=Statistics.STAGE_VALIDATION,
                          accuracy=0.8, loss=0.1, key2=2)
     self.client.update_statistics(statistics)
     self.assertEqual(m.call_count, 1)
     url = '{}/organizations/{}/training/definitions/{}/jobs/{}/statistics'.format(
         ABEJA_API_URL, ORGANIZATION_ID, TRAINING_JON_DEFINITION_NAME, TRAINING_JOB_ID)
     expected_data = {
         'statistics': {
             'num_epochs': 5,
             'epoch': 1,
             'progress_percentage': 0.5,
             'stages': {
                 'train': {
                     'accuracy': 0.9,
                     'loss': 0.05
                 },
                 'validation': {
                     'accuracy': 0.8,
                     'loss': 0.1,
                     'key2': 2
                 }
             },
             'key1': 'value1'
         }
     }
     m.assert_called_with(
         'POST',
         url,
         params=None,
         headers={
             'User-Agent': 'abeja-platform-sdk/{}'.format(VERSION)},
         timeout=30,
         data=None,
         json=expected_data)