class Statistics(object):
    """Trainer extension to report the accumulated results to ABEJA Platform.
    This extension uses the log accumulated by a :class:`LogReport` extension
    to print specified entries of the log in a human-readable format.
    Args:
        entries (list of str): List of keys of observations to print.
        log_report (str or LogReport): Log report to accumulate the
            observations. This is either the name of a LogReport extensions
            registered to the trainer, or a LogReport instance to use
            internally.
    """
    def __init__(self, total_epochs):
        self._total_epochs = total_epochs
        self.client = Client()

    def __call__(self, epoch, train_loss, train_acc, val_loss, val_acc):
        statistics = ABEJAStatistics(num_epochs=self._total_epochs,
                                     epoch=epoch)

        statistics.add_stage(ABEJAStatistics.STAGE_TRAIN, train_acc,
                             train_loss)
        statistics.add_stage(ABEJAStatistics.STAGE_VALIDATION, val_acc,
                             val_loss)

        try:
            self.client.update_statistics(statistics)
        except Exception:
            logger.warning('failed to update statistics.')
 def __init__(self,
              entries,
              total_epochs,
              obs_key='epoch',
              log_report='LogReport'):
     self._entries = entries
     self._log_report = log_report
     self._total_epochs = total_epochs
     self._obs_key = obs_key
     self.client = Client()
Beispiel #3
0
class Statistics(Callback):
    """cf. https://keras.io/callbacks/"""
    def __init__(self):
        super(Statistics, self).__init__()
        self.client = Client()

    def on_epoch_end(self, epoch, logs=None):
        epochs = self.params['epochs']
        statistics = ABEJAStatistics(num_epochs=epochs, epoch=epoch + 1)
        statistics.add_stage(ABEJAStatistics.STAGE_TRAIN, logs['acc'], logs['loss'])
        statistics.add_stage(ABEJAStatistics.STAGE_VALIDATION, logs['val_acc'], logs['val_loss'])
        try:
            self.client.update_statistics(statistics)
        except Exception:
            logger.warning('failed to update statistics.')
Beispiel #4
0
class Statistics(Callback):
    """A Keras callback for reporting statistics to ABEJA Platform"""
    def __init__(self, **kwargs):
        super(Statistics, self).__init__()
        self.client = Client(**kwargs)

    def on_epoch_end(self, epoch, logs=None):
        epochs = self.params['epochs']
        statistics = ABEJAStatistics(num_epochs=epochs, epoch=epoch + 1)
        statistics.add_stage(ABEJAStatistics.STAGE_TRAIN, logs['acc'],
                             logs['loss'])
        statistics.add_stage(ABEJAStatistics.STAGE_VALIDATION, logs['val_acc'],
                             logs['val_loss'])
        try:
            self.client.update_statistics(statistics)
        except Exception:
            logger.warning('Failed to update statistics.')
Beispiel #5
0
 def test_update_statistics_override_organization_id(self, m):
     organization_id = '2222222222222'
     client = Client(organization_id=organization_id)
     statistics = Statistics(progress_percentage=0.5, key1='value1')
     client.update_statistics(statistics)
     self.assertEqual(m.call_count, 1)
     url = '{}/organizations/{}/training/definitions/{}/jobs/{}/statistics'.format(
         ABEJA_API_URL, organization_id, TRAINING_JON_DEFINITION_NAME, TRAINING_JOB_ID)
     m.assert_called_with(
         'POST',
         url,
         params=None,
         headers={
             'User-Agent': 'abeja-platform-sdk/{}'.format(VERSION)},
         timeout=30,
         data=None,
         json={
             'statistics': {
                 'progress_percentage': 0.5,
                 'key1': 'value1'}})
Beispiel #6
0
 def __init__(self, **kwargs):
     super(Statistics, self).__init__()
     self.client = Client(**kwargs)
 def __init__(self, total_epochs):
     self._total_epochs = total_epochs
     self.client = Client()
Beispiel #8
0
 def __init__(self):
     super(Statistics, self).__init__()
     self.client = Client()
Beispiel #9
0
 def setUp(self):
     Connection.BASE_URL = ABEJA_API_URL
     self.client = Client()
     self.client.logger.setLevel(logging.FATAL)
Beispiel #10
0
class TestClient(unittest.TestCase):

    @mock.patch.dict('os.environ', PATCHED_ENVIRON)
    def setUp(self):
        Connection.BASE_URL = ABEJA_API_URL
        self.client = Client()
        self.client.logger.setLevel(logging.FATAL)

    def test_init(self):
        self.assertIsInstance(self.client.api, APIClient)

    @mock.patch('abeja.train.client.extract_zipfile')
    @mock.patch('abeja.train.client.Client._get_content')
    @mock.patch('requests.Session.request')
    def test_download_training_result(
            self, m, m_get_content, m_extract_zipfile):
        dummy_binary = b'dummy'
        m_get_content.return_value = dummy_binary
        self.client.download_training_result(TRAINING_JOB_ID)
        url = '{}/organizations/{}/training/definitions/{}/jobs/{}/result'.format(
            ABEJA_API_URL, ORGANIZATION_ID, TRAINING_JON_DEFINITION_NAME, TRAINING_JOB_ID)
        m.assert_called_with(
            'GET',
            url,
            data=None,
            headers={
                'User-Agent': 'abeja-platform-sdk/{}'.format(VERSION)},
            json=None,
            params=None,
            timeout=30)
        m_extract_zipfile.assert_called_once_with(dummy_binary, path=None)

    @mock.patch('requests.Session.request')
    def test_update_statistics(self, m):
        statistics = Statistics(progress_percentage=0.5, epoch=1,
                                num_epochs=5, key1='value1')
        statistics.add_stage(
            name=Statistics.STAGE_TRAIN,
            accuracy=0.9,
            loss=0.05)
        statistics.add_stage(name=Statistics.STAGE_VALIDATION,
                             accuracy=0.8, loss=0.1, key2=2)
        self.client.update_statistics(statistics)
        self.assertEqual(m.call_count, 1)
        url = '{}/organizations/{}/training/definitions/{}/jobs/{}/statistics'.format(
            ABEJA_API_URL, ORGANIZATION_ID, TRAINING_JON_DEFINITION_NAME, TRAINING_JOB_ID)
        expected_data = {
            'statistics': {
                'num_epochs': 5,
                'epoch': 1,
                'progress_percentage': 0.5,
                'stages': {
                    'train': {
                        'accuracy': 0.9,
                        'loss': 0.05
                    },
                    'validation': {
                        'accuracy': 0.8,
                        'loss': 0.1,
                        'key2': 2
                    }
                },
                'key1': 'value1'
            }
        }
        m.assert_called_with(
            'POST',
            url,
            params=None,
            headers={
                'User-Agent': 'abeja-platform-sdk/{}'.format(VERSION)},
            timeout=30,
            data=None,
            json=expected_data)

    @mock.patch('requests.Session.request')
    def test_update_statistics_without_statistics(self, m):
        statistics = Statistics(progress_percentage=0.5)
        self.client.update_statistics(statistics)
        self.assertEqual(m.call_count, 1)
        url = '{}/organizations/{}/training/definitions/{}/jobs/{}/statistics'.format(
            ABEJA_API_URL, ORGANIZATION_ID, TRAINING_JON_DEFINITION_NAME, TRAINING_JOB_ID)
        m.assert_called_with(
            'POST',
            url,
            params=None,
            headers={
                'User-Agent': 'abeja-platform-sdk/{}'.format(VERSION)},
            timeout=30,
            data=None,
            json={
                'statistics': {
                    'progress_percentage': 0.5}})

    @mock.patch('requests.Session.request')
    def test_update_statistics_progress_within_statistics(self, m):
        statistics = Statistics(progress_percentage=0.5)
        statistics.add_stage(name='other_stage', key1='value1')
        self.client.update_statistics(statistics)
        self.assertEqual(m.call_count, 1)
        url = '{}/organizations/{}/training/definitions/{}/jobs/{}/statistics'.format(
            ABEJA_API_URL, ORGANIZATION_ID, TRAINING_JON_DEFINITION_NAME, TRAINING_JOB_ID)
        expected_data = {
            'statistics': {
                'progress_percentage': 0.5,
                'stages': {
                    'other_stage': {
                        'key1': 'value1'
                    }
                }
            }
        }
        m.assert_called_with(
            'POST',
            url,
            params=None,
            headers={
                'User-Agent': 'abeja-platform-sdk/{}'.format(VERSION)},
            timeout=30,
            data=None,
            json=expected_data)

    @mock.patch('requests.Session.request')
    def test_update_statistics_override_organization_id(self, m):
        organization_id = '2222222222222'
        client = Client(organization_id=organization_id)
        statistics = Statistics(progress_percentage=0.5, key1='value1')
        client.update_statistics(statistics)
        self.assertEqual(m.call_count, 1)
        url = '{}/organizations/{}/training/definitions/{}/jobs/{}/statistics'.format(
            ABEJA_API_URL, organization_id, TRAINING_JON_DEFINITION_NAME, TRAINING_JOB_ID)
        m.assert_called_with(
            'POST',
            url,
            params=None,
            headers={
                'User-Agent': 'abeja-platform-sdk/{}'.format(VERSION)},
            timeout=30,
            data=None,
            json={
                'statistics': {
                    'progress_percentage': 0.5,
                    'key1': 'value1'}})

    @mock.patch(
        'abeja.common.connection.Connection.request',
        side_effect=BadRequest(
            'foo',
            'bar',
            400,
            'https://api.abeja.io/'))
    def test_update_statistics_raise_BadRequest(self, m):
        # check: don't raise Exception when model-api returns 400 Bad-Request
        logger_mock = mock.MagicMock()
        self.client.logger = logger_mock
        try:
            statistics = Statistics(progress_percentage=0.5, key1='value1')
            self.client.update_statistics(statistics)
            self.assertEqual(m.call_count, 1)
            url = '{}/organizations/{}/training/definitions/{}/jobs/{}/statistics'.format(
                ABEJA_API_URL, ORGANIZATION_ID, TRAINING_JON_DEFINITION_NAME, TRAINING_JOB_ID)
            m.assert_called_with(
                'POST',
                url,
                params=None,
                headers={
                    'User-Agent': 'abeja-platform-sdk/{}'.format(VERSION)},
                data=None,
                json={
                    'statistics': {
                        'progress_percentage': 0.5,
                        'key1': 'value1'}})
            self.assertEqual(logger_mock.warning.call_count, 1)
            self.assertEqual(logger_mock.exception.call_count, 0)
        except Exception:
            self.fail()

    @mock.patch(
        'abeja.common.connection.Connection.request',
        side_effect=InternalServerError(
            'foo',
            'bar',
            500,
            'https://api.abeja.io/'))
    def test_update_statistics_raise_InternalServerError(self, m):
        # check: don't raise Exception when model-api returns 500
        # Internal-Server-Error
        logger_mock = mock.MagicMock()
        self.client.logger = logger_mock
        try:
            statistics = Statistics(progress_percentage=0.5, key1='value1')
            self.client.update_statistics(statistics)
            self.assertEqual(m.call_count, 1)
            url = '{}/organizations/{}/training/definitions/{}/jobs/{}/statistics'.format(
                ABEJA_API_URL, ORGANIZATION_ID, TRAINING_JON_DEFINITION_NAME, TRAINING_JOB_ID)
            m.assert_called_with(
                'POST',
                url,
                params=None,
                headers={
                    'User-Agent': 'abeja-platform-sdk/{}'.format(VERSION)},
                data=None,
                json={
                    'statistics': {
                        'progress_percentage': 0.5,
                        'key1': 'value1'}})
            self.assertEqual(logger_mock.warning.call_count, 0)
            self.assertEqual(logger_mock.exception.call_count, 1)
        except Exception:
            self.fail()

    @mock.patch('abeja.common.connection.Connection.request',
                side_effect=requests.exceptions.ConnectionError())
    def test_update_statistics_raise_ConnectionError(self, m):
        # check: don't raise Exception when model-api returns 500
        # Internal-Server-Error
        logger_mock = mock.MagicMock()
        self.client.logger = logger_mock
        try:
            statistics = Statistics(progress_percentage=0.5, key1='value1')
            self.client.update_statistics(statistics)
            self.assertEqual(m.call_count, 1)
            url = '{}/organizations/{}/training/definitions/{}/jobs/{}/statistics'.format(
                ABEJA_API_URL, ORGANIZATION_ID, TRAINING_JON_DEFINITION_NAME, TRAINING_JOB_ID)
            m.assert_called_with(
                'POST',
                url,
                params=None,
                headers={
                    'User-Agent': 'abeja-platform-sdk/{}'.format(VERSION)},
                data=None,
                json={
                    'statistics': {
                        'progress_percentage': 0.5,
                        'key1': 'value1'}})
            self.assertEqual(logger_mock.warning.call_count, 0)
            self.assertEqual(logger_mock.exception.call_count, 1)
        except Exception:
            self.fail()

    @mock.patch('abeja.common.connection.Connection.request')
    def test_update_statistics_statistics_none(self, m):
        # check: don't raise Exception
        logger_mock = mock.MagicMock()
        self.client.logger = logger_mock
        try:
            self.client.update_statistics(None)
            m.assert_not_called()
            self.assertEqual(logger_mock.warning.call_count, 1)
            self.assertEqual(logger_mock.exception.call_count, 0)
        except Exception:
            self.fail()

    @mock.patch('abeja.common.connection.Connection.request')
    def test_update_statistics_with_empty_statistics(self, m):
        # check: don't raise Exception
        logger_mock = mock.MagicMock()
        self.client.logger = logger_mock
        try:
            self.client.update_statistics(Statistics())
            m.assert_not_called()
            self.assertEqual(logger_mock.warning.call_count, 1)
            self.assertEqual(logger_mock.exception.call_count, 0)
        except Exception:
            self.fail()
Beispiel #11
0
import os
import catboost
from sklearn import datasets, model_selection

from abeja.train.client import Client
from abeja.train.statistics import Statistics as ABEJAStatistics

ABEJA_TRAINING_RESULT_DIR = os.environ.get('ABEJA_TRAINING_RESULT_DIR', '.')
client = Client()


def handler(context):
    iris = datasets.load_iris()
    cls = catboost.CatBoostClassifier(loss_function='MultiClass')

    X = iris.data
    y = iris.target

    data_train, data_test, label_train, label_test = model_selection.train_test_split(
        X, y)

    cls.fit(data_train, label_train)

    train_acc = cls.score(data_train, label_train)
    test_acc = cls.score(data_test, label_test)

    statistics = ABEJAStatistics(num_epochs=1, epoch=1)
    statistics.add_stage(ABEJAStatistics.STAGE_TRAIN, train_acc, None)
    statistics.add_stage(ABEJAStatistics.STAGE_VALIDATION, test_acc, None)
    print(train_acc, test_acc)
Beispiel #12
0
class Statistics(extension.Extension):

    """Trainer extension to report the accumulated results to ABEJA Platform.
    This extension uses the log accumulated by a :class:`LogReport` extension
    to print specified entries of the log in a human-readable format.
    Args:
        entries (list of str): List of keys of observations to print.
        log_report (str or LogReport): Log report to accumulate the
            observations. This is either the name of a LogReport extensions
            registered to the trainer, or a LogReport instance to use
            internally.
    """

    def __init__(self, entries, total_epochs, log_report='LogReport'):
        self._entries = entries
        self._log_report = log_report
        self._total_epochs = total_epochs
        self.client = Client()

    def __call__(self, trainer):
        log_report = self._log_report
        if isinstance(log_report, str):
            log_report = trainer.get_extension(log_report)
        elif isinstance(log_report, log_report_module.LogReport):
            log_report(trainer)  # update the log report
        else:
            raise TypeError('log report has a wrong type %s' %
                            type(log_report))

        log = log_report.log
        if len(log) > 0:
            self._print(log[-1])

    def serialize(self, serializer):
        log_report = self._log_report
        if isinstance(log_report, log_report_module.LogReport):
            log_report.serialize(serializer['_log_report'])

    def _print(self, observation):
        train_loss = None
        train_acc = None
        val_loss = None
        val_acc = None

        train_list = {}
        val_list = {}

        epoch = observation['epoch']
        statistics = ABEJAStatistics(num_epochs=self._total_epochs, epoch=epoch)

        for key, value in observation.items():
            keys = key.split('/')
            if len(keys) > 1 and keys[0] == 'main':
                name = '/'.join(keys[1:])
                if name == 'loss':
                    train_loss = value
                elif name == 'accuracy':
                    train_acc = value
                else:
                    train_list[name] = value
            elif len(keys) > 2 and keys[0] == 'validation' and keys[1] == 'main':
                name = '/'.join(keys[2:])
                if name == 'loss':
                    val_loss = value
                elif name == 'accuracy':
                    val_acc = value
                else:
                    val_list[name] = value

        statistics.add_stage(ABEJAStatistics.STAGE_TRAIN,
                             train_acc, train_loss, **train_list)
        statistics.add_stage(ABEJAStatistics.STAGE_VALIDATION,
                             val_acc, val_loss, **val_list)

        try:
            self.client.update_statistics(statistics)
        except Exception:
            logger.warning('failed to update statistics.')