def test_tracking_without_org_jobdef_job(self): mock_logger = mock.MagicMock() Tracking(logger=mock_logger) mock_logger.warning.assert_any_call( 'WARNING: No params/metrics/artifact will be uploaded to ABEJA Platform. ' 'Please specify "ABEJA_ORGANIZATION_ID", "TRAINING_JOB_DEFINITION_NAME" ' 'and "TRAINING_JOB_ID" for uploading.')
def test_tracking_statistics_not_work(self, m_update_statistics): for i in range(2): with Tracking(total_steps=10) as tk: tk.log_step(i + 1) tk.log_metric(key='main/acc', value=0.5) tk.log_metric(key='main/loss', value=0.5) tk.log_metric(key='test/acc', value=0.5) tk.log_metric(key='test/loss', value=0.5) self.assertEqual(0, m_update_statistics.call_count)
def test_tracking_summary_writer(self, m_add_scalar): m_add_scalar.return_value = None with Tracking() as tk: tk.log_step(1) tk.log_metric(key='main/acc', value=0.5) tk.log_metric(key='main/loss', value=0.5) tk.log_metric(key='test/acc', value=0.5) tk.log_metric(key='test/loss', value=0.5) tk.log_metric(key='hoge/fuga', value=0.5) self.assertEqual(4, m_add_scalar.call_count)
def test_tracking_error_on_create_training_model( self, m_create_training_model, m_get_training_job): m_get_training_job.return_value = {} with self.assertRaises(BadRequest): with tempfile.NamedTemporaryFile(suffix='.zip') as zipfile: tracking = Tracking() tracking.log_artifact(filepath=zipfile.name) tracking.flush()
def test_tracking_delete_after_flush(self, m, m_get_training_job): m_get_training_job.return_value = {} zipfile = tempfile.NamedTemporaryFile(suffix='.zip') tracking = Tracking() tracking.log_artifact(filepath=zipfile.name, delete_flag=True) tracking.flush() self.assertEqual(1, m.call_count) self.assertFalse(Path(zipfile.name).exists())
def test_tracking_error_on_update_statistics( self, m_update_statistics, m_create_training_model, m_get_training_job): m_get_training_job.return_value = {} tracking = Tracking(total_steps=10) tracking.log_step(1) tracking.flush() self.assertEqual(1, m_update_statistics.call_count)
def test_tracking_2(self, m, m_get_training_job, m_flush): m_get_training_job.return_value = {} m_flush.return_value = None with tempfile.NamedTemporaryFile(suffix='.zip') as zipfile: url = '{}/organizations/{}/training/definitions/{}/models'.format( ABEJA_API_URL, ORGANIZATION_ID, JOB_DEFINITION_NAME) parameters = { 'training_job_id': TRAINING_JOB_ID, 'description': 'STEP 1. {}'.format(DESCRIPTION), 'user_parameters': USER_PARAMETERS, 'metrics': METRICS, } params = json.dumps(parameters).encode() with Tracking() as tracking: for k, v in USER_PARAMETERS.items(): tracking.log_param(k, v) for k, v in METRICS.items(): tracking.log_metric(k, v) tracking.log_description(description=DESCRIPTION) tracking.log_step(step=1) tracking.log_artifact(filepath=zipfile.name) m_method, m_url = m.call_args[0] self.assertEqual('POST', m_method) self.assertEqual(url, m_url) body = m.call_args[1] self.assertIsNone(body['data']) self.assertIsNone(body['json']) self.assertIsNone(body['params']) self.assertDictEqual( {'User-Agent': 'abeja-platform-sdk/{}'.format(SDK_VERSION)}, body['headers']) self.assertEqual(30, body['timeout']) self.assertIn('model_data', body['files']) self.assertIn('parameters', body['files']) self.assertEqual(params, body['files']['parameters'][1].read()) self.assertEqual(1, m_flush.call_count)
def test_tracking_statistics( self, m_update_statistics, m_create_training_model, m_get_training_job): for i in range(2): with Tracking(total_steps=10) as tk: tk.log_step(i + 1) tk.log_metric(key='main/acc', value=0.5) tk.log_metric(key='main/loss', value=0.5) tk.log_metric(key='test/acc', value=0.5) tk.log_metric(key='test/loss', value=0.5) tk.log_param(key='dummy', value='dummy') self.assertEqual(2, m_update_statistics.call_count) expect = { 'dummy': 'dummy', 'accuracy': 0.5, 'loss': 0.5 } self.assertDictEqual( expect, m_update_statistics.call_args[1]['statistics']['stages']['train']) self.assertDictEqual( expect, m_update_statistics.call_args[1]['statistics']['stages']['validation'])
def test_tracking_error_on_get_training_job(self, m_get_training_job): with self.assertRaises(BadRequest): Tracking()
def test_tracking_without_artifact(self): mock_logger = mock.MagicMock() tracking = Tracking(logger=mock_logger) for k, v in USER_PARAMETERS.items(): tracking.log_param(k, v) for k, v in METRICS.items(): tracking.log_metric(k, v) tracking.log_description(description=DESCRIPTION) tracking.log_step(step=1) tracking.flush() mock_logger.warning.assert_any_call( 'No output. Need to add "artifact" by "log_artifact()".')
def test_tracking_invalid(self): tracking = Tracking() with self.assertRaises(TypeError): tracking.log_step('val') with self.assertRaises(TypeError): tracking.log_step(None) with self.assertRaises(TypeError): tracking.log_description(None) with self.assertRaises(TypeError): tracking.log_param('key', None) with self.assertRaises(TypeError): tracking.log_params({'key': None}) with self.assertRaises(TypeError): tracking.log_metric('key', 'val') with self.assertRaises(TypeError): tracking.log_metric('key', None) with self.assertRaises(TypeError): tracking.log_metrics({'key': 'val'}) with self.assertRaises(TypeError): tracking.log_metrics({'key': None}) with self.assertRaises(InvalidPathException): tracking.log_artifact('dummy')