コード例 #1
0
    def test_tracking_invalid(self):
        tracking = Tracking()
        with self.assertRaises(TypeError):
            tracking.log_step('val')
        with self.assertRaises(TypeError):
            tracking.log_step(None)

        with self.assertRaises(TypeError):
            tracking.log_description(None)

        with self.assertRaises(TypeError):
            tracking.log_param('key', None)

        with self.assertRaises(TypeError):
            tracking.log_params({'key': None})

        with self.assertRaises(TypeError):
            tracking.log_metric('key', 'val')
        with self.assertRaises(TypeError):
            tracking.log_metric('key', None)

        with self.assertRaises(TypeError):
            tracking.log_metrics({'key': 'val'})
        with self.assertRaises(TypeError):
            tracking.log_metrics({'key': None})

        with self.assertRaises(InvalidPathException):
            tracking.log_artifact('dummy')
コード例 #2
0
    def test_tracking(self, m, m_get_training_job, m_flush):
        m_get_training_job.return_value = {}
        m_flush.return_value = None

        with tempfile.NamedTemporaryFile(suffix='.zip') as zipfile:
            url = '{}/organizations/{}/training/definitions/{}/models'.format(
                ABEJA_API_URL, ORGANIZATION_ID, JOB_DEFINITION_NAME)
            parameters = {
                'training_job_id': TRAINING_JOB_ID,
                'description': 'STEP 1. {}'.format(DESCRIPTION),
                'user_parameters': USER_PARAMETERS,
                'metrics': METRICS,
            }
            params = json.dumps(parameters).encode()

            tracking = Tracking()
            tracking.log_params(USER_PARAMETERS)
            tracking.log_metrics(METRICS)
            tracking.log_description(description=DESCRIPTION)
            tracking.log_step(step=1)
            tracking.log_artifact(filepath=zipfile.name)
            tracking.flush()

            m_method, m_url = m.call_args[0]
            self.assertEqual('POST', m_method)
            self.assertEqual(url, m_url)

            body = m.call_args[1]
            self.assertIsNone(body['data'])
            self.assertIsNone(body['json'])
            self.assertIsNone(body['params'])
            self.assertDictEqual(
                {'User-Agent': 'abeja-platform-sdk/{}'.format(SDK_VERSION)}, body['headers'])
            self.assertEqual(30, body['timeout'])
            self.assertIn('model_data', body['files'])
            self.assertIn('parameters', body['files'])
            self.assertEqual(params, body['files']['parameters'][1].read())

            self.assertEqual(1, m_flush.call_count)