Esempio n. 1
0
 def __init__(self, methodName='runTest'):
     super(TestCallbackList, self).__init__(methodName)
     self.callback_1 = MagicMock(spec=torchbearer.callbacks.printer.Tqdm())
     self.callback_2 = MagicMock(
         spec=torchbearer.callbacks.tensor_board.TensorBoard())
     callbacks = [self.callback_1, self.callback_2]
     self.list = CallbackList(callbacks)
Esempio n. 2
0
 def test_iter_copy(self):
     callback = 'test'
     clist = CallbackList([callback])
     cpy = clist.__copy__()
     self.assertTrue(cpy.callback_list[0] == 'test')
     self.assertTrue(cpy is not clist)
     cpy = clist.copy()
     self.assertTrue(cpy.callback_list[0] == 'test')
     self.assertTrue(cpy is not clist)
     for cback in clist:
         self.assertTrue(cback == 'test')
Esempio n. 3
0
class TestCallbackList(TestCase):
    def __init__(self, methodName='runTest'):
        super().__init__(methodName)
        self.callback_1 = MagicMock(spec=torchbearer.callbacks.printer.Tqdm())
        self.callback_2 = MagicMock(
            spec=torchbearer.callbacks.tensor_board.TensorBoard())
        callbacks = [self.callback_1, self.callback_2]
        self.list = CallbackList(callbacks)

    def test_state_dict(self):
        self.callback_1.state_dict = Mock(return_value='test_1')
        self.callback_2.state_dict = Mock(return_value='test_2')

        state = self.list.state_dict()

        self.assertEqual(self.callback_1.state_dict.call_count, 1)
        self.assertEqual(self.callback_2.state_dict.call_count, 1)
        self.assertEqual(state[CallbackList.CALLBACK_STATES][0], 'test_1')
        self.assertEqual(state[CallbackList.CALLBACK_STATES][1], 'test_2')
        self.assertEqual(state[CallbackList.CALLBACK_TYPES][0],
                         Tqdm().__class__)
        self.assertEqual(state[CallbackList.CALLBACK_TYPES][1],
                         TensorBoard().__class__)

    def test_load_state_dict(self):
        self.callback_1.load_state_dict = Mock(return_value='test_1')
        self.callback_2.load_state_dict = Mock(return_value='test_2')

        self.callback_1.state_dict = Mock(return_value='test_1')
        self.callback_2.state_dict = Mock(return_value='test_2')

        state = self.list.state_dict()
        self.list.load_state_dict(state)

        self.callback_1.load_state_dict.assert_called_once_with('test_1')
        self.callback_2.load_state_dict.assert_called_once_with('test_2')

        state = self.list.state_dict()
        state[CallbackList.CALLBACK_TYPES] = list(
            reversed(state[CallbackList.CALLBACK_TYPES]))

        with self.assertWarns(
                UserWarning,
                msg=
                'Callback classes did not match, expected: {\'TensorBoard\', \'Tqdm\'}'
        ):
            self.list.load_state_dict(state)

    def test_for_list(self):
        self.list.on_start({})
        self.assertTrue(self.callback_1.method_calls[0][0] == 'on_start')
        self.assertTrue(self.callback_2.method_calls[0][0] == 'on_start')

    def test_list_in_list(self):
        callback = 'test'
        clist = CallbackList([callback])
        clist2 = CallbackList([clist])
        self.assertTrue(clist2.callback_list[0] == 'test')
Esempio n. 4
0
    def __init__(self, model, optimizer=None, criterion=None, metrics=[], callbacks=[], verbose=2):
        if criterion is None:
            def criterion(_, __):
                return torch.zeros(1, device=self.state[torchbearer.DEVICE], dtype=self.state[torchbearer.DATA_TYPE], requires_grad=True)

        self.verbose = verbose

        self.closure = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, torchbearer.OPTIMIZER)
        self.state = State()
        self.state.update({
            torchbearer.MODEL: model,
            torchbearer.CRITERION: criterion,
            torchbearer.OPTIMIZER: optimizer if optimizer is not None else MockOptimizer(),
            torchbearer.METRIC_LIST: MetricList(metrics),
            torchbearer.CALLBACK_LIST: CallbackList(callbacks),
            torchbearer.DEVICE: 'cpu',
            torchbearer.DATA_TYPE: torch.float32,
            torchbearer.SELF: self,
            torchbearer.HISTORY: [],
            torchbearer.BACKWARD_ARGS: {},
            torchbearer.TRAIN_GENERATOR: None,
            torchbearer.VALIDATION_GENERATOR: None,
            torchbearer.TEST_GENERATOR: None,
            torchbearer.TRAIN_STEPS: None,
            torchbearer.VALIDATION_STEPS: None,
            torchbearer.TEST_STEPS: None,
            torchbearer.TRAIN_DATA: None,
            torchbearer.VALIDATION_DATA: None,
            torchbearer.TEST_DATA: None,
            torchbearer.INF_TRAIN_LOADING: False,
            torchbearer.LOADER: None
        })

        self.state[torchbearer.CALLBACK_LIST].on_init(self.state)
Esempio n. 5
0
    def __init__(self, model, optimizer=None, criterion=None, metrics=[], callbacks=[], pass_state=False):
        if criterion is None:
            def criterion(_, y_true):
                return torch.zeros(1, device=y_true.device)

        self.pass_state = pass_state

        self.state = State()
        self.state.update({
            torchbearer.MODEL: model,
            torchbearer.CRITERION: criterion,
            torchbearer.OPTIMIZER: optimizer if optimizer is not None else MockOptimizer(),
            torchbearer.METRIC_LIST: MetricList(metrics),
            torchbearer.CALLBACK_LIST: CallbackList(callbacks),
            torchbearer.DEVICE: 'cpu',
            torchbearer.DATA_TYPE: torch.float32,
            torchbearer.SELF: self,
            torchbearer.HISTORY: [],
            torchbearer.BACKWARD_ARGS: {},
            torchbearer.TRAIN_GENERATOR: None,
            torchbearer.VALIDATION_GENERATOR: None,
            torchbearer.TEST_GENERATOR: None,
            torchbearer.TRAIN_STEPS: None,
            torchbearer.VALIDATION_STEPS: None,
            torchbearer.TEST_STEPS: None
        })
Esempio n. 6
0
class TestCallbackList(TestCase):
    def __init__(self, methodName='runTest'):
        super().__init__(methodName)
        self.callback_1 = MagicMock()
        self.callback_2 = MagicMock()
        callbacks = [self.callback_1, self.callback_2]
        self.list = CallbackList(callbacks)

    def test_for_list(self):
        self.list.on_start({})
        self.assertTrue(self.callback_1.method_calls[0][0] == 'on_start')
        self.assertTrue(self.callback_2.method_calls[0][0] == 'on_start')

    def test_list_in_list(self):
        callback = 'test'
        clist = CallbackList([callback])
        clist2 = CallbackList([clist])
        self.assertTrue(clist2.callback_list[0] == 'test')
Esempio n. 7
0
    def replay(self, callbacks=[], verbose=2, one_batch=False):  # TODO: Should we track if testing passes have happened?
        """ Replay the fit passes stored in history with given callbacks, useful when reloading a saved Trial. Note that only progress and metric information is populated in state during a replay.

        Args:
            callbacks (list): List of callbacks to be run during the replay
            verbose (int): If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress
            one_batch (bool): If True, only one batch per epoch is replayed. If False, all batches are replayed

        Returns:
            Trial: self
        """
        history = self.state[torchbearer.HISTORY]
        callbacks.append(get_printer(verbose=verbose, validation_label_letter='v'))
        callbacks = CallbackList(callbacks)

        state = State()
        state.update(self.state)
        state[torchbearer.STOP_TRAINING] = False
        state[torchbearer.MAX_EPOCHS] = len(history)

        callbacks.on_start(state)
        for i in range(len(history)):
            state[torchbearer.EPOCH] = i
            if not one_batch:
                state[torchbearer.TRAIN_STEPS], state[torchbearer.VALIDATION_STEPS] = history[i][0]
            else:
                state[torchbearer.TRAIN_STEPS], state[torchbearer.VALIDATION_STEPS] = 1, 1
            state[torchbearer.METRICS] = history[i][1]

            self._replay_pass(state, callbacks)
        callbacks.on_end(state)

        return self
Esempio n. 8
0
    def replay(self,
               callbacks=[],
               verbose=2
               ):  # TODO: Should we track if testing passes have happened?
        """ Replay the fit passes stored in history with given callbacks, useful when reloading a saved Trial. Note that only progress and metric information is populated in state during a replay.

        :param callbacks: List of callbacks to be run during the replay
        :type callbacks: list
        :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress
        :type verbose: int
        :return: self
        :rtype: Trial
        """
        history = self.state[torchbearer.HISTORY]
        callbacks.append(
            get_printer(verbose=verbose, validation_label_letter='v'))
        callbacks = CallbackList(callbacks)

        state = State()
        state.update(self.state)
        state[torchbearer.STOP_TRAINING] = False
        state[torchbearer.MAX_EPOCHS] = len(history)

        callbacks.on_start(state)
        for i in range(len(history)):
            state[torchbearer.EPOCH] = i
            state[torchbearer.TRAIN_STEPS], state[
                torchbearer.VALIDATION_STEPS] = history[i][0]
            state[torchbearer.METRICS] = history[i][1]

            self._replay_pass(state, callbacks)
        callbacks.on_end(state)
Esempio n. 9
0
class TestCallbackList(TestCase):
    def __init__(self, methodName='runTest'):
        super(TestCallbackList, self).__init__(methodName)
        self.callback_1 = MagicMock(spec=torchbearer.callbacks.printer.Tqdm())
        self.callback_2 = MagicMock(
            spec=torchbearer.callbacks.tensor_board.TensorBoard())
        callbacks = [self.callback_1, self.callback_2]
        self.list = CallbackList(callbacks)

    def test_state_dict(self):
        self.callback_1.state_dict = Mock(return_value='test_1')
        self.callback_2.state_dict = Mock(return_value='test_2')

        state = self.list.state_dict()

        self.assertEqual(self.callback_1.state_dict.call_count, 1)
        self.assertEqual(self.callback_2.state_dict.call_count, 1)
        self.assertEqual(state[CallbackList.CALLBACK_STATES][0], 'test_1')
        self.assertEqual(state[CallbackList.CALLBACK_STATES][1], 'test_2')
        self.assertEqual(state[CallbackList.CALLBACK_TYPES][0],
                         Tqdm().__class__)
        self.assertEqual(state[CallbackList.CALLBACK_TYPES][1],
                         TensorBoard().__class__)

    def test_load_state_dict(self):
        self.callback_1.load_state_dict = Mock(return_value='test_1')
        self.callback_2.load_state_dict = Mock(return_value='test_2')

        self.callback_1.state_dict = Mock(return_value='test_1')
        self.callback_2.state_dict = Mock(return_value='test_2')

        state = self.list.state_dict()
        self.list.load_state_dict(state)

        self.callback_1.load_state_dict.assert_called_once_with('test_1')
        self.callback_2.load_state_dict.assert_called_once_with('test_2')

        state = self.list.state_dict()
        state[CallbackList.CALLBACK_TYPES] = list(
            reversed(state[CallbackList.CALLBACK_TYPES]))

        with warnings.catch_warnings(record=True) as w:
            self.list.load_state_dict(state)
            self.assertTrue(len(w) == 1)
            self.assertTrue(issubclass(w[-1].category, UserWarning))
            self.assertTrue(
                'Callback classes did not match, expected: [\'TensorBoard\', \'Tqdm\']'
                in str(w[-1].message))

    def test_for_list(self):
        self.list.on_start({})
        self.assertTrue(self.callback_1.method_calls[0][0] == 'on_start')
        self.assertTrue(self.callback_2.method_calls[0][0] == 'on_start')

    def test_list_in_list(self):
        callback = 'test'
        clist = CallbackList([callback])
        clist2 = CallbackList([clist])
        self.assertTrue(clist2.callback_list[0] == 'test')

    def test_iter_copy(self):
        callback = 'test'
        clist = CallbackList([callback])
        cpy = clist.__copy__()
        self.assertTrue(cpy.callback_list[0] == 'test')
        self.assertTrue(cpy is not clist)
        cpy = clist.copy()
        self.assertTrue(cpy.callback_list[0] == 'test')
        self.assertTrue(cpy is not clist)
        for cback in clist:
            self.assertTrue(cback == 'test')
Esempio n. 10
0
 def test_list_in_list(self):
     callback = 'test'
     clist = CallbackList([callback])
     clist2 = CallbackList([clist])
     self.assertTrue(clist2.callback_list[0] == 'test')
Esempio n. 11
0
 def __init__(self, methodName='runTest'):
     super().__init__(methodName)
     self.callback_1 = MagicMock()
     self.callback_2 = MagicMock()
     callbacks = [self.callback_1, self.callback_2]
     self.list = CallbackList(callbacks)