def test_tqdm_custom_args(self): state = { torchbearer.EPOCH: 1, torchbearer.MAX_EPOCHS: 10, torchbearer.TRAIN_STEPS: 100, torchbearer.VALIDATION_STEPS: 101, torchbearer.METRICS: { 'test': 10 } } state[torchbearer.HISTORY] = [ dict(state[torchbearer.METRICS], train_steps=None, validation_steps=None) ] tqdm = Tqdm(ascii=True) state[torchbearer.STEPS] = state[torchbearer.TRAIN_STEPS] tqdm.tqdm_module = MagicMock() mock_tqdm = tqdm.tqdm_module tqdm.on_start_training(state) mock_tqdm.assert_called_once_with(total=100, desc='1/10(t)', ascii=True) tqdm = Tqdm(on_epoch=True, ascii=True) tqdm.tqdm_module = MagicMock() mock_tqdm = tqdm.tqdm_module tqdm.on_start(state) mock_tqdm.assert_called_once_with(initial=1, total=10, ascii=True)
def test_tqdm_custom_args(self): state = { torchbearer.EPOCH: 1, torchbearer.MAX_EPOCHS: 10, torchbearer.TRAIN_STEPS: 100, torchbearer.VALIDATION_STEPS: 101, torchbearer.METRICS: 'test' } tqdm = Tqdm(ascii=True) tqdm.tqdm_module = MagicMock() mock_tqdm = tqdm.tqdm_module tqdm.on_start_training(state) mock_tqdm.assert_called_once_with(total=100, desc='1/10(t)', ascii=True) tqdm = Tqdm(on_epoch=True, ascii=True) tqdm.tqdm_module = MagicMock() mock_tqdm = tqdm.tqdm_module tqdm.on_start(state) mock_tqdm.assert_called_once_with(total=10, ascii=True)
def get_printer(verbose, validation_label_letter): if verbose >= 2: printer = Tqdm(validation_label_letter=validation_label_letter) elif verbose >= 1: printer = Tqdm(validation_label_letter=validation_label_letter, on_epoch=True) else: printer = Callback() return printer
def test_tqdm_on_epoch(self): state = { torchbearer.EPOCH: 1, torchbearer.MAX_EPOCHS: 10, torchbearer.HISTORY: [0, (1, { 'test': 0.99456 })], torchbearer.METRICS: { 'test': 0.99456 } } tqdm = Tqdm(validation_label_letter='e', on_epoch=True) tqdm.tqdm_module = MagicMock() mock_tqdm = tqdm.tqdm_module tqdm.on_start(state) mock_tqdm.assert_called_once_with(initial=2, total=10) mock_tqdm.return_value.set_postfix_str.assert_called_once_with( 'test=0.9946') mock_tqdm.return_value.update.assert_called_once_with(1) mock_tqdm.return_value.set_postfix_str.reset_mock() mock_tqdm.return_value.update.reset_mock() tqdm.on_end_epoch(state) mock_tqdm.return_value.set_postfix_str.assert_called_once_with( 'test=0.9946') mock_tqdm.return_value.update.assert_called_once_with(1) mock_tqdm.return_value.set_postfix_str.reset_mock() tqdm.on_end(state) mock_tqdm.return_value.set_postfix_str.assert_called_once_with( 'test=0.9946') self.assertEqual(mock_tqdm.return_value.close.call_count, 1)
def test_tqdm_keys(self): state = { torchbearer.EPOCH: 1, torchbearer.MAX_EPOCHS: 10, torchbearer.HISTORY: [0, { 'test': 0.99456 }], torchbearer.METRICS: { 'test': 0.99456 } } tqdm = Tqdm(validation_label_letter='e', on_epoch=True) tqdm.tqdm_module = MagicMock() tqdm.on_start(state)
def test_state_dict(self): self.callback_1.state_dict = Mock(return_value='test_1') self.callback_2.state_dict = Mock(return_value='test_2') state = self.list.state_dict() self.assertEqual(self.callback_1.state_dict.call_count, 1) self.assertEqual(self.callback_2.state_dict.call_count, 1) self.assertEqual(state[CallbackList.CALLBACK_STATES][0], 'test_1') self.assertEqual(state[CallbackList.CALLBACK_STATES][1], 'test_2') self.assertEqual(state[CallbackList.CALLBACK_TYPES][0], Tqdm().__class__) self.assertEqual(state[CallbackList.CALLBACK_TYPES][1], TensorBoard().__class__)
def test_tqdm(self): state = {torchbearer.EPOCH: 1, torchbearer.MAX_EPOCHS: 10, torchbearer.TRAIN_STEPS: 100, torchbearer.VALIDATION_STEPS: 101, torchbearer.METRICS: 'test'} tqdm = Tqdm(validation_label_letter='e') tqdm.tqdm_module = MagicMock() mock_tqdm = tqdm.tqdm_module tqdm.on_start_training(state) mock_tqdm.assert_called_once_with(total=100, desc='1/10(t)') tqdm.on_step_training(state) mock_tqdm.return_value.set_postfix.assert_called_once_with('test') mock_tqdm.return_value.update.assert_called_once_with(1) mock_tqdm.return_value.set_postfix.reset_mock() tqdm.on_end_training(state) mock_tqdm.return_value.set_postfix.assert_called_once_with('test') mock_tqdm.return_value.close.assert_called_once() mock_tqdm.reset_mock() mock_tqdm.return_value.set_postfix.reset_mock() mock_tqdm.return_value.update.reset_mock() mock_tqdm.return_value.close.reset_mock() tqdm.on_start_validation(state) mock_tqdm.assert_called_once_with(total=101, desc='1/10(e)') tqdm.on_step_validation(state) mock_tqdm.return_value.set_postfix.assert_called_once_with('test') mock_tqdm.return_value.update.assert_called_once_with(1) mock_tqdm.return_value.set_postfix.reset_mock() tqdm.on_end_validation(state) mock_tqdm.return_value.set_postfix.assert_called_once_with('test') mock_tqdm.return_value.close.assert_called_once()
def test_tqdm(self): state = { torchbearer.EPOCH: 1, torchbearer.MAX_EPOCHS: 10, torchbearer.TRAIN_STEPS: 100, torchbearer.VALIDATION_STEPS: 101, torchbearer.METRICS: { 'test': 0.99456 } } tqdm = Tqdm(validation_label_letter='e') tqdm.tqdm_module = MagicMock() mock_tqdm = tqdm.tqdm_module state[torchbearer.STEPS] = state[torchbearer.TRAIN_STEPS] tqdm.on_start_training(state) mock_tqdm.assert_called_once_with(total=100, desc='1/10(t)') tqdm.on_step_training(state) mock_tqdm.return_value.set_postfix_str.assert_called_once_with( 'test=0.9946') mock_tqdm.return_value.update.assert_called_once_with(1) mock_tqdm.return_value.set_postfix_str.reset_mock() tqdm.on_end_training(state) mock_tqdm.return_value.set_postfix_str.assert_called_once_with( 'test=0.9946') self.assertEqual(mock_tqdm.return_value.close.call_count, 1) mock_tqdm.reset_mock() mock_tqdm.return_value.set_postfix_str.reset_mock() mock_tqdm.return_value.update.reset_mock() mock_tqdm.return_value.close.reset_mock() state[torchbearer.STEPS] = state[torchbearer.VALIDATION_STEPS] tqdm.on_start_validation(state) mock_tqdm.assert_called_once_with(total=101, desc='1/10(e)') tqdm.on_step_validation(state) mock_tqdm.return_value.set_postfix_str.assert_called_once_with( 'test=0.9946') mock_tqdm.return_value.update.assert_called_once_with(1) mock_tqdm.return_value.set_postfix_str.reset_mock() tqdm.on_end_validation(state) mock_tqdm.return_value.set_postfix_str.assert_called_once_with( 'test=0.9946') self.assertEqual(mock_tqdm.return_value.close.call_count, 1)
def test_tqdm_module_init_not_notebook(self, mock_is_notebook): from tqdm import tqdm as base_tqdm mock_is_notebook.return_value = False tqdm = Tqdm(validation_label_letter='e', on_epoch=True) self.assertTrue(tqdm.tqdm_module == base_tqdm)
def test_tqdm(self, mock_tqdm): mock_tqdm.return_value = Mock() mock_tqdm.return_value.set_postfix = Mock() mock_tqdm.return_value.close = Mock() mock_tqdm.return_value.update = Mock() state = {torchbearer.EPOCH: 1, torchbearer.MAX_EPOCHS: 10, torchbearer.TRAIN_STEPS: 100, torchbearer.VALIDATION_STEPS: 101, torchbearer.METRICS: 'test'} tqdm = Tqdm(validation_label_letter='e') tqdm.on_start_training(state) mock_tqdm.assert_called_once_with(total=100, desc='1/10(t)') tqdm.on_step_training(state) mock_tqdm.return_value.set_postfix.assert_called_once_with('test') mock_tqdm.return_value.update.assert_called_once_with(1) mock_tqdm.return_value.set_postfix.reset_mock() tqdm.on_end_training(state) mock_tqdm.return_value.set_postfix.assert_called_once_with('test') mock_tqdm.return_value.close.assert_called_once() mock_tqdm.reset_mock() mock_tqdm.return_value.set_postfix.reset_mock() mock_tqdm.return_value.update.reset_mock() mock_tqdm.return_value.close.reset_mock() tqdm.on_start_validation(state) mock_tqdm.assert_called_once_with(total=101, desc='1/10(e)') tqdm.on_step_validation(state) mock_tqdm.return_value.set_postfix.assert_called_once_with('test') mock_tqdm.return_value.update.assert_called_once_with(1) mock_tqdm.return_value.set_postfix.reset_mock() tqdm.on_end_validation(state) mock_tqdm.return_value.set_postfix.assert_called_once_with('test') mock_tqdm.return_value.close.assert_called_once()