コード例 #1
0
    def test_tqdm(self):
        state = {torchbearer.EPOCH: 1, torchbearer.MAX_EPOCHS: 10, torchbearer.TRAIN_STEPS: 100, torchbearer.VALIDATION_STEPS: 101, torchbearer.METRICS: 'test'}
        tqdm = Tqdm(validation_label_letter='e')
        tqdm.tqdm_module = MagicMock()
        mock_tqdm = tqdm.tqdm_module

        tqdm.on_start_training(state)
        mock_tqdm.assert_called_once_with(total=100, desc='1/10(t)')

        tqdm.on_step_training(state)
        mock_tqdm.return_value.set_postfix.assert_called_once_with('test')
        mock_tqdm.return_value.update.assert_called_once_with(1)
        mock_tqdm.return_value.set_postfix.reset_mock()

        tqdm.on_end_training(state)
        mock_tqdm.return_value.set_postfix.assert_called_once_with('test')
        mock_tqdm.return_value.close.assert_called_once()

        mock_tqdm.reset_mock()
        mock_tqdm.return_value.set_postfix.reset_mock()
        mock_tqdm.return_value.update.reset_mock()
        mock_tqdm.return_value.close.reset_mock()

        tqdm.on_start_validation(state)
        mock_tqdm.assert_called_once_with(total=101, desc='1/10(e)')

        tqdm.on_step_validation(state)
        mock_tqdm.return_value.set_postfix.assert_called_once_with('test')
        mock_tqdm.return_value.update.assert_called_once_with(1)
        mock_tqdm.return_value.set_postfix.reset_mock()

        tqdm.on_end_validation(state)
        mock_tqdm.return_value.set_postfix.assert_called_once_with('test')
        mock_tqdm.return_value.close.assert_called_once()
コード例 #2
0
    def test_tqdm(self):
        state = {
            torchbearer.EPOCH: 1,
            torchbearer.MAX_EPOCHS: 10,
            torchbearer.TRAIN_STEPS: 100,
            torchbearer.VALIDATION_STEPS: 101,
            torchbearer.METRICS: {
                'test': 0.99456
            }
        }
        tqdm = Tqdm(validation_label_letter='e')
        tqdm.tqdm_module = MagicMock()
        mock_tqdm = tqdm.tqdm_module
        state[torchbearer.STEPS] = state[torchbearer.TRAIN_STEPS]

        tqdm.on_start_training(state)
        mock_tqdm.assert_called_once_with(total=100, desc='1/10(t)')

        tqdm.on_step_training(state)
        mock_tqdm.return_value.set_postfix_str.assert_called_once_with(
            'test=0.9946')
        mock_tqdm.return_value.update.assert_called_once_with(1)
        mock_tqdm.return_value.set_postfix_str.reset_mock()

        tqdm.on_end_training(state)
        mock_tqdm.return_value.set_postfix_str.assert_called_once_with(
            'test=0.9946')
        self.assertEqual(mock_tqdm.return_value.close.call_count, 1)

        mock_tqdm.reset_mock()
        mock_tqdm.return_value.set_postfix_str.reset_mock()
        mock_tqdm.return_value.update.reset_mock()
        mock_tqdm.return_value.close.reset_mock()

        state[torchbearer.STEPS] = state[torchbearer.VALIDATION_STEPS]
        tqdm.on_start_validation(state)
        mock_tqdm.assert_called_once_with(total=101, desc='1/10(e)')

        tqdm.on_step_validation(state)
        mock_tqdm.return_value.set_postfix_str.assert_called_once_with(
            'test=0.9946')
        mock_tqdm.return_value.update.assert_called_once_with(1)
        mock_tqdm.return_value.set_postfix_str.reset_mock()

        tqdm.on_end_validation(state)
        mock_tqdm.return_value.set_postfix_str.assert_called_once_with(
            'test=0.9946')
        self.assertEqual(mock_tqdm.return_value.close.call_count, 1)
コード例 #3
0
    def test_tqdm(self, mock_tqdm):
        mock_tqdm.return_value = Mock()
        mock_tqdm.return_value.set_postfix = Mock()
        mock_tqdm.return_value.close = Mock()
        mock_tqdm.return_value.update = Mock()

        state = {torchbearer.EPOCH: 1, torchbearer.MAX_EPOCHS: 10, torchbearer.TRAIN_STEPS: 100, torchbearer.VALIDATION_STEPS: 101, torchbearer.METRICS: 'test'}
        tqdm = Tqdm(validation_label_letter='e')

        tqdm.on_start_training(state)
        mock_tqdm.assert_called_once_with(total=100, desc='1/10(t)')

        tqdm.on_step_training(state)
        mock_tqdm.return_value.set_postfix.assert_called_once_with('test')
        mock_tqdm.return_value.update.assert_called_once_with(1)
        mock_tqdm.return_value.set_postfix.reset_mock()

        tqdm.on_end_training(state)
        mock_tqdm.return_value.set_postfix.assert_called_once_with('test')
        mock_tqdm.return_value.close.assert_called_once()

        mock_tqdm.reset_mock()
        mock_tqdm.return_value.set_postfix.reset_mock()
        mock_tqdm.return_value.update.reset_mock()
        mock_tqdm.return_value.close.reset_mock()

        tqdm.on_start_validation(state)
        mock_tqdm.assert_called_once_with(total=101, desc='1/10(e)')

        tqdm.on_step_validation(state)
        mock_tqdm.return_value.set_postfix.assert_called_once_with('test')
        mock_tqdm.return_value.update.assert_called_once_with(1)
        mock_tqdm.return_value.set_postfix.reset_mock()

        tqdm.on_end_validation(state)
        mock_tqdm.return_value.set_postfix.assert_called_once_with('test')
        mock_tqdm.return_value.close.assert_called_once()