def test_tracking_invalid(self): tracking = Tracking() with self.assertRaises(TypeError): tracking.log_step('val') with self.assertRaises(TypeError): tracking.log_step(None) with self.assertRaises(TypeError): tracking.log_description(None) with self.assertRaises(TypeError): tracking.log_param('key', None) with self.assertRaises(TypeError): tracking.log_params({'key': None}) with self.assertRaises(TypeError): tracking.log_metric('key', 'val') with self.assertRaises(TypeError): tracking.log_metric('key', None) with self.assertRaises(TypeError): tracking.log_metrics({'key': 'val'}) with self.assertRaises(TypeError): tracking.log_metrics({'key': None}) with self.assertRaises(InvalidPathException): tracking.log_artifact('dummy')
def test_tracking(self, m, m_get_training_job, m_flush): m_get_training_job.return_value = {} m_flush.return_value = None with tempfile.NamedTemporaryFile(suffix='.zip') as zipfile: url = '{}/organizations/{}/training/definitions/{}/models'.format( ABEJA_API_URL, ORGANIZATION_ID, JOB_DEFINITION_NAME) parameters = { 'training_job_id': TRAINING_JOB_ID, 'description': 'STEP 1. {}'.format(DESCRIPTION), 'user_parameters': USER_PARAMETERS, 'metrics': METRICS, } params = json.dumps(parameters).encode() tracking = Tracking() tracking.log_params(USER_PARAMETERS) tracking.log_metrics(METRICS) tracking.log_description(description=DESCRIPTION) tracking.log_step(step=1) tracking.log_artifact(filepath=zipfile.name) tracking.flush() m_method, m_url = m.call_args[0] self.assertEqual('POST', m_method) self.assertEqual(url, m_url) body = m.call_args[1] self.assertIsNone(body['data']) self.assertIsNone(body['json']) self.assertIsNone(body['params']) self.assertDictEqual( {'User-Agent': 'abeja-platform-sdk/{}'.format(SDK_VERSION)}, body['headers']) self.assertEqual(30, body['timeout']) self.assertIn('model_data', body['files']) self.assertIn('parameters', body['files']) self.assertEqual(params, body['files']['parameters'][1].read()) self.assertEqual(1, m_flush.call_count)