示例#1
0
    def test_dispatches_to_group_handler(self, mocker):
        logger = TrainingLogger()
        handler_mock = mocker.Mock()
        logger.register_handler('handler', handler_mock)

        logger.register_value_group(r'good_group/.*', ['handler'])

        logger.log_value('good_group/value', 1)
        assert handler_mock.write.call_count == 1

        with pytest.raises(Exception):
            logger.log_value('bad_group/value', 1)
示例#2
0
    def test_dispatches_to_handler(self, mocker):
        logger = TrainingLogger()
        handler_mock = mocker.Mock()
        logger.register_handler('handler', handler_mock)
        logger.register_value('value', ['handler'])

        logger.log_value('value', 1)
        assert handler_mock.write.call_count == 1
示例#3
0
    def test_correct_scope_handled(self, periodic_handler_mock):
        logger = TrainingLogger()
        logger.register_handler('handler', periodic_handler_mock)
        logger.register_value('value', ['handler'])

        with logger.scope(scope='wrong_scope'):
            pass

        periodic_handler_mock.flush.assert_not_called()

        with logger.scope(scope='test_scope'):
            pass

        assert periodic_handler_mock.flush.call_count == 1
示例#4
0
    def test_instant_values_displayed(self, mocker):
        logger = TrainingLogger()
        stream = mocker.Mock()
        stream_handler = StreamHandler('batch',
                                       stream=stream,
                                       fmt='* {values}')
        logger.register_handler('handler', stream_handler)
        logger.register_value('value', ['handler'], average=True)

        for i, x in logger.scope_enumerate(range(1, 4)):
            logger.log_value('value', x)

        assert stream.write.call_count == 3
        stream.write.assert_has_calls([
            mocker.call('* value 1.0000 (1.0000)\n'),
            mocker.call('* value 2.0000 (1.5000)\n'),
            mocker.call('* value 3.0000 (2.0000)\n'),
        ])
def setup_logging(args):
    logger = TrainingLogger()

    # Create handlers
    logger.register_handler("val_batch",
                            StreamHandler(prefix='Val: ', scope='batch'))
    logger.register_handler(
        "val_epoch",
        StreamHandler(fmt="* {epoch} epochs done:  {values}",
                      scope='epoch',
                      display_instant=False))

    logger.register_handler(
        "test_std",
        StreamHandler(fmt="Testing: [{step}/{total}]\t{values}",
                      scope='batch'))
    logger.register_handler(
        "test_end",
        StreamHandler(fmt="* Testing results: {values}",
                      scope='epoch',
                      display_instant=False))
    logger.register_handler(
        "train_epoch",
        StreamHandler(fmt="* Train epoch {epoch} done:  {values}",
                      scope='epoch'))
    logger.register_handler("train_batch",
                            StreamHandler(prefix='Train: ', scope='batch'))

    logger.register_handler(
        "tb", TensorboardHandler(scope='epoch', summary_writer=args.writer))
    logger.register_handler(
        "tb_global",
        TensorboardHandler(scope='global', summary_writer=args.writer))

    logger.register_handler(
        "val_csv",
        CSVHandler(scope='epoch',
                   csv_path=(args.result_path / 'val.csv'),
                   index_col='epoch'))
    logger.register_handler(
        "train_csv",
        CSVHandler(scope='epoch',
                   csv_path=(args.result_path / 'train.csv'),
                   index_col='epoch'))

    # Create logged values
    logger.register_value("train/acc", ['train_batch', 'tb_global'],
                          average=True,
                          display_name='clip')
    logger.register_value("train/loss", ['train_batch', 'tb_global'],
                          average=True,
                          display_name='loss')
    logger.register_value("train/kd_loss", ['train_batch', 'tb_global'],
                          average=True,
                          display_name='loss')
    logger.register_value("train/epoch_acc",
                          ['train_epoch', 'tb', 'train_csv'],
                          display_name='clip')
    logger.register_value("train/epoch_loss",
                          ['train_epoch', 'tb', 'train_csv'],
                          display_name='loss')
    logger.register_value_group("lr/.*", ['tb'])

    logger.register_value("time/train_data", ['train_batch'],
                          average=True,
                          display_name='data time')
    logger.register_value("time/train_step", ['train_batch'],
                          average=True,
                          display_name='time')
    logger.register_value("time/train_epoch", ['train_epoch'],
                          display_name='Train epoch time')

    logger.register_value("val/acc",
                          ['val_batch', 'val_epoch', 'tb', 'val_csv'],
                          average=True,
                          display_name='clip')
    logger.register_value("val/video",
                          ['val_batch', 'val_epoch', 'tb', 'val_csv'],
                          average=False,
                          display_name='video')
    logger.register_value("val/loss", ['val_batch', 'tb', 'val_csv'],
                          average=True,
                          display_name='loss')
    logger.register_value("val/generalization_error",
                          ['val_epoch', 'tb', 'val_csv'],
                          display_name='Train Val accuracy gap')

    logger.register_value("time/val_data", ['val_batch'],
                          average=True,
                          display_name='data time')
    logger.register_value("time/val_step", ['val_batch'],
                          average=True,
                          display_name='time')
    logger.register_value("time/val_epoch", ['val_epoch'],
                          average=False,
                          display_name='Validation time')

    logger.register_value("test/acc", ['test_std', 'test_end', 'tb'],
                          average=True,
                          display_name='clip')
    logger.register_value("test/video", ['test_std', 'test_end', 'tb'],
                          average=False,
                          display_name='video')

    return logger
示例#6
0
    def test_value_resets(self, periodic_handler_mock):
        logger = TrainingLogger()
        logger.register_handler('handler', periodic_handler_mock)
        logger.register_value('scope/value', ['handler'], average=True)

        with logger.scope(scope='test_scope'):
            logger.log_value('scope/value', 1)
            logger.log_value('scope/value', 2)
            logger.log_value('scope/value', 3)
        assert logger.get_value('scope/value') == 2

        logger.reset_values('scope')

        with logger.scope(scope='test_scope'):
            logger.log_value('scope/value', 5)
            logger.log_value('scope/value', 10)
            logger.log_value('scope/value', 15)
        assert logger.get_value('scope/value') == 10
示例#7
0
    def test_saves_instant_values(self, mocker):
        logger = TrainingLogger()
        handler_mock = mocker.Mock()
        logger.register_handler('handler', handler_mock)
        logger.register_value('value', ['handler'], average=True)

        logger.log_value('value', 1)
        assert logger.get_value('value') == 1
        assert handler_mock.write.call_args_list[0][0][0].instant_value == 1
        logger.log_value('value', 2)
        assert logger.get_value('value') == 1.5
        assert handler_mock.write.call_args_list[1][0][0].instant_value == 2
        logger.log_value('value', 3)
        assert logger.get_value('value') == 2
        assert handler_mock.write.call_args_list[2][0][0].instant_value == 3
示例#8
0
    def test_value_averaged_within_scope(self, mocker):
        logger = TrainingLogger()
        handler_mock = mocker.Mock()
        logger.register_handler('handler', handler_mock)
        logger.register_value('value', ['handler'], average=True)

        logger.log_value('value', 1)
        assert logger.get_value('value') == 1
        logger.log_value('value', 2)
        assert logger.get_value('value') == 1.5
        logger.log_value('value', 3)
        assert logger.get_value('value') == 2
示例#9
0
    def test_raises_with_unknown_value(self):
        logger = TrainingLogger()

        with pytest.raises(Exception):
            logger.log_value('value', 1)
示例#10
0
    def test_raises_with_unknown_handler(self):
        logger = TrainingLogger()

        with pytest.raises(Exception):
            logger.register_value('value', ['handler'])
示例#11
0
    def test_flush_writes_to_stream(self, mocker):
        logger = TrainingLogger()
        stream = mocker.Mock()
        stream_handler = StreamHandler('epoch',
                                       stream=stream,
                                       fmt='* {values}')
        logger.register_handler('handler', stream_handler)
        logger.register_value('value', ['handler'])

        with logger.scope():
            logger.log_value('value', 1)
            logger.log_value('value', 0.1)

        stream.write.assert_called_once_with('* value 0.1000\n')
        assert stream.flush.call_count == 1