Beispiel #1
0
    def test_tqdm_custom_args(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.MAX_EPOCHS: 10,
            torchbearer.TRAIN_STEPS: 100,
            torchbearer.VALIDATION_STEPS: 101,
            torchbearer.METRICS: {
                'test': 10
            }
        }
        state[torchbearer.HISTORY] = [
            dict(state[torchbearer.METRICS],
                 train_steps=None,
                 validation_steps=None)
        ]
        tqdm = Tqdm(ascii=True)
        state[torchbearer.STEPS] = state[torchbearer.TRAIN_STEPS]

        tqdm.tqdm_module = MagicMock()
        mock_tqdm = tqdm.tqdm_module

        tqdm.on_start_training(state)
        mock_tqdm.assert_called_once_with(total=100,
                                          desc='1/10(t)',
                                          ascii=True)

        tqdm = Tqdm(on_epoch=True, ascii=True)

        tqdm.tqdm_module = MagicMock()
        mock_tqdm = tqdm.tqdm_module

        tqdm.on_start(state)
        mock_tqdm.assert_called_once_with(initial=1, total=10, ascii=True)
Beispiel #2
0
    def test_tqdm_custom_args(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.MAX_EPOCHS: 10,
            torchbearer.TRAIN_STEPS: 100,
            torchbearer.VALIDATION_STEPS: 101,
            torchbearer.METRICS: 'test'
        }
        tqdm = Tqdm(ascii=True)

        tqdm.tqdm_module = MagicMock()
        mock_tqdm = tqdm.tqdm_module

        tqdm.on_start_training(state)
        mock_tqdm.assert_called_once_with(total=100,
                                          desc='1/10(t)',
                                          ascii=True)

        tqdm = Tqdm(on_epoch=True, ascii=True)

        tqdm.tqdm_module = MagicMock()
        mock_tqdm = tqdm.tqdm_module

        tqdm.on_start(state)
        mock_tqdm.assert_called_once_with(total=10, ascii=True)
Beispiel #3
0
def get_printer(verbose, validation_label_letter):
    if verbose >= 2:
        printer = Tqdm(validation_label_letter=validation_label_letter)
    elif verbose >= 1:
        printer = Tqdm(validation_label_letter=validation_label_letter, on_epoch=True)
    else:
        printer = Callback()

    return printer
Beispiel #4
0
    def test_tqdm_on_epoch(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.MAX_EPOCHS: 10,
            torchbearer.HISTORY: [0, (1, {
                'test': 0.99456
            })],
            torchbearer.METRICS: {
                'test': 0.99456
            }
        }
        tqdm = Tqdm(validation_label_letter='e', on_epoch=True)
        tqdm.tqdm_module = MagicMock()
        mock_tqdm = tqdm.tqdm_module

        tqdm.on_start(state)
        mock_tqdm.assert_called_once_with(initial=2, total=10)
        mock_tqdm.return_value.set_postfix_str.assert_called_once_with(
            'test=0.9946')
        mock_tqdm.return_value.update.assert_called_once_with(1)
        mock_tqdm.return_value.set_postfix_str.reset_mock()
        mock_tqdm.return_value.update.reset_mock()

        tqdm.on_end_epoch(state)
        mock_tqdm.return_value.set_postfix_str.assert_called_once_with(
            'test=0.9946')
        mock_tqdm.return_value.update.assert_called_once_with(1)
        mock_tqdm.return_value.set_postfix_str.reset_mock()

        tqdm.on_end(state)
        mock_tqdm.return_value.set_postfix_str.assert_called_once_with(
            'test=0.9946')
        self.assertEqual(mock_tqdm.return_value.close.call_count, 1)
    def test_tqdm_keys(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.MAX_EPOCHS: 10,
            torchbearer.HISTORY: [0, {
                'test': 0.99456
            }],
            torchbearer.METRICS: {
                'test': 0.99456
            }
        }
        tqdm = Tqdm(validation_label_letter='e', on_epoch=True)
        tqdm.tqdm_module = MagicMock()

        tqdm.on_start(state)
    def test_state_dict(self):
        self.callback_1.state_dict = Mock(return_value='test_1')
        self.callback_2.state_dict = Mock(return_value='test_2')

        state = self.list.state_dict()

        self.assertEqual(self.callback_1.state_dict.call_count, 1)
        self.assertEqual(self.callback_2.state_dict.call_count, 1)
        self.assertEqual(state[CallbackList.CALLBACK_STATES][0], 'test_1')
        self.assertEqual(state[CallbackList.CALLBACK_STATES][1], 'test_2')
        self.assertEqual(state[CallbackList.CALLBACK_TYPES][0],
                         Tqdm().__class__)
        self.assertEqual(state[CallbackList.CALLBACK_TYPES][1],
                         TensorBoard().__class__)
Beispiel #7
0
    def test_tqdm(self):
        state = {torchbearer.EPOCH: 1, torchbearer.MAX_EPOCHS: 10, torchbearer.TRAIN_STEPS: 100, torchbearer.VALIDATION_STEPS: 101, torchbearer.METRICS: 'test'}
        tqdm = Tqdm(validation_label_letter='e')
        tqdm.tqdm_module = MagicMock()
        mock_tqdm = tqdm.tqdm_module

        tqdm.on_start_training(state)
        mock_tqdm.assert_called_once_with(total=100, desc='1/10(t)')

        tqdm.on_step_training(state)
        mock_tqdm.return_value.set_postfix.assert_called_once_with('test')
        mock_tqdm.return_value.update.assert_called_once_with(1)
        mock_tqdm.return_value.set_postfix.reset_mock()

        tqdm.on_end_training(state)
        mock_tqdm.return_value.set_postfix.assert_called_once_with('test')
        mock_tqdm.return_value.close.assert_called_once()

        mock_tqdm.reset_mock()
        mock_tqdm.return_value.set_postfix.reset_mock()
        mock_tqdm.return_value.update.reset_mock()
        mock_tqdm.return_value.close.reset_mock()

        tqdm.on_start_validation(state)
        mock_tqdm.assert_called_once_with(total=101, desc='1/10(e)')

        tqdm.on_step_validation(state)
        mock_tqdm.return_value.set_postfix.assert_called_once_with('test')
        mock_tqdm.return_value.update.assert_called_once_with(1)
        mock_tqdm.return_value.set_postfix.reset_mock()

        tqdm.on_end_validation(state)
        mock_tqdm.return_value.set_postfix.assert_called_once_with('test')
        mock_tqdm.return_value.close.assert_called_once()
Beispiel #8
0
    def test_tqdm(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.MAX_EPOCHS: 10,
            torchbearer.TRAIN_STEPS: 100,
            torchbearer.VALIDATION_STEPS: 101,
            torchbearer.METRICS: {
                'test': 0.99456
            }
        }
        tqdm = Tqdm(validation_label_letter='e')
        tqdm.tqdm_module = MagicMock()
        mock_tqdm = tqdm.tqdm_module
        state[torchbearer.STEPS] = state[torchbearer.TRAIN_STEPS]

        tqdm.on_start_training(state)
        mock_tqdm.assert_called_once_with(total=100, desc='1/10(t)')

        tqdm.on_step_training(state)
        mock_tqdm.return_value.set_postfix_str.assert_called_once_with(
            'test=0.9946')
        mock_tqdm.return_value.update.assert_called_once_with(1)
        mock_tqdm.return_value.set_postfix_str.reset_mock()

        tqdm.on_end_training(state)
        mock_tqdm.return_value.set_postfix_str.assert_called_once_with(
            'test=0.9946')
        self.assertEqual(mock_tqdm.return_value.close.call_count, 1)

        mock_tqdm.reset_mock()
        mock_tqdm.return_value.set_postfix_str.reset_mock()
        mock_tqdm.return_value.update.reset_mock()
        mock_tqdm.return_value.close.reset_mock()

        state[torchbearer.STEPS] = state[torchbearer.VALIDATION_STEPS]
        tqdm.on_start_validation(state)
        mock_tqdm.assert_called_once_with(total=101, desc='1/10(e)')

        tqdm.on_step_validation(state)
        mock_tqdm.return_value.set_postfix_str.assert_called_once_with(
            'test=0.9946')
        mock_tqdm.return_value.update.assert_called_once_with(1)
        mock_tqdm.return_value.set_postfix_str.reset_mock()

        tqdm.on_end_validation(state)
        mock_tqdm.return_value.set_postfix_str.assert_called_once_with(
            'test=0.9946')
        self.assertEqual(mock_tqdm.return_value.close.call_count, 1)
Beispiel #9
0
 def test_tqdm_module_init_not_notebook(self, mock_is_notebook):
     from tqdm import tqdm as base_tqdm
     mock_is_notebook.return_value = False
     tqdm = Tqdm(validation_label_letter='e', on_epoch=True)
     self.assertTrue(tqdm.tqdm_module == base_tqdm)
Beispiel #10
0
    def test_tqdm(self, mock_tqdm):
        mock_tqdm.return_value = Mock()
        mock_tqdm.return_value.set_postfix = Mock()
        mock_tqdm.return_value.close = Mock()
        mock_tqdm.return_value.update = Mock()

        state = {torchbearer.EPOCH: 1, torchbearer.MAX_EPOCHS: 10, torchbearer.TRAIN_STEPS: 100, torchbearer.VALIDATION_STEPS: 101, torchbearer.METRICS: 'test'}
        tqdm = Tqdm(validation_label_letter='e')

        tqdm.on_start_training(state)
        mock_tqdm.assert_called_once_with(total=100, desc='1/10(t)')

        tqdm.on_step_training(state)
        mock_tqdm.return_value.set_postfix.assert_called_once_with('test')
        mock_tqdm.return_value.update.assert_called_once_with(1)
        mock_tqdm.return_value.set_postfix.reset_mock()

        tqdm.on_end_training(state)
        mock_tqdm.return_value.set_postfix.assert_called_once_with('test')
        mock_tqdm.return_value.close.assert_called_once()

        mock_tqdm.reset_mock()
        mock_tqdm.return_value.set_postfix.reset_mock()
        mock_tqdm.return_value.update.reset_mock()
        mock_tqdm.return_value.close.reset_mock()

        tqdm.on_start_validation(state)
        mock_tqdm.assert_called_once_with(total=101, desc='1/10(e)')

        tqdm.on_step_validation(state)
        mock_tqdm.return_value.set_postfix.assert_called_once_with('test')
        mock_tqdm.return_value.update.assert_called_once_with(1)
        mock_tqdm.return_value.set_postfix.reset_mock()

        tqdm.on_end_validation(state)
        mock_tqdm.return_value.set_postfix.assert_called_once_with('test')
        mock_tqdm.return_value.close.assert_called_once()