def __init__(self, methodName='runTest'): super(TestCallbackList, self).__init__(methodName) self.callback_1 = MagicMock(spec=torchbearer.callbacks.printer.Tqdm()) self.callback_2 = MagicMock( spec=torchbearer.callbacks.tensor_board.TensorBoard()) callbacks = [self.callback_1, self.callback_2] self.list = CallbackList(callbacks)
def test_iter_copy(self): callback = 'test' clist = CallbackList([callback]) cpy = clist.__copy__() self.assertTrue(cpy.callback_list[0] == 'test') self.assertTrue(cpy is not clist) cpy = clist.copy() self.assertTrue(cpy.callback_list[0] == 'test') self.assertTrue(cpy is not clist) for cback in clist: self.assertTrue(cback == 'test')
class TestCallbackList(TestCase): def __init__(self, methodName='runTest'): super().__init__(methodName) self.callback_1 = MagicMock(spec=torchbearer.callbacks.printer.Tqdm()) self.callback_2 = MagicMock( spec=torchbearer.callbacks.tensor_board.TensorBoard()) callbacks = [self.callback_1, self.callback_2] self.list = CallbackList(callbacks) def test_state_dict(self): self.callback_1.state_dict = Mock(return_value='test_1') self.callback_2.state_dict = Mock(return_value='test_2') state = self.list.state_dict() self.assertEqual(self.callback_1.state_dict.call_count, 1) self.assertEqual(self.callback_2.state_dict.call_count, 1) self.assertEqual(state[CallbackList.CALLBACK_STATES][0], 'test_1') self.assertEqual(state[CallbackList.CALLBACK_STATES][1], 'test_2') self.assertEqual(state[CallbackList.CALLBACK_TYPES][0], Tqdm().__class__) self.assertEqual(state[CallbackList.CALLBACK_TYPES][1], TensorBoard().__class__) def test_load_state_dict(self): self.callback_1.load_state_dict = Mock(return_value='test_1') self.callback_2.load_state_dict = Mock(return_value='test_2') self.callback_1.state_dict = Mock(return_value='test_1') self.callback_2.state_dict = Mock(return_value='test_2') state = self.list.state_dict() self.list.load_state_dict(state) self.callback_1.load_state_dict.assert_called_once_with('test_1') self.callback_2.load_state_dict.assert_called_once_with('test_2') state = self.list.state_dict() state[CallbackList.CALLBACK_TYPES] = list( reversed(state[CallbackList.CALLBACK_TYPES])) with self.assertWarns( UserWarning, msg= 'Callback classes did not match, expected: {\'TensorBoard\', \'Tqdm\'}' ): self.list.load_state_dict(state) def test_for_list(self): self.list.on_start({}) self.assertTrue(self.callback_1.method_calls[0][0] == 'on_start') self.assertTrue(self.callback_2.method_calls[0][0] == 'on_start') def test_list_in_list(self): callback = 'test' clist = CallbackList([callback]) clist2 = CallbackList([clist]) self.assertTrue(clist2.callback_list[0] == 'test')
def __init__(self, model, optimizer=None, criterion=None, metrics=[], callbacks=[], verbose=2): if criterion is None: def criterion(_, __): return torch.zeros(1, device=self.state[torchbearer.DEVICE], dtype=self.state[torchbearer.DATA_TYPE], requires_grad=True) self.verbose = verbose self.closure = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, torchbearer.OPTIMIZER) self.state = State() self.state.update({ torchbearer.MODEL: model, torchbearer.CRITERION: criterion, torchbearer.OPTIMIZER: optimizer if optimizer is not None else MockOptimizer(), torchbearer.METRIC_LIST: MetricList(metrics), torchbearer.CALLBACK_LIST: CallbackList(callbacks), torchbearer.DEVICE: 'cpu', torchbearer.DATA_TYPE: torch.float32, torchbearer.SELF: self, torchbearer.HISTORY: [], torchbearer.BACKWARD_ARGS: {}, torchbearer.TRAIN_GENERATOR: None, torchbearer.VALIDATION_GENERATOR: None, torchbearer.TEST_GENERATOR: None, torchbearer.TRAIN_STEPS: None, torchbearer.VALIDATION_STEPS: None, torchbearer.TEST_STEPS: None, torchbearer.TRAIN_DATA: None, torchbearer.VALIDATION_DATA: None, torchbearer.TEST_DATA: None, torchbearer.INF_TRAIN_LOADING: False, torchbearer.LOADER: None }) self.state[torchbearer.CALLBACK_LIST].on_init(self.state)
def __init__(self, model, optimizer=None, criterion=None, metrics=[], callbacks=[], pass_state=False): if criterion is None: def criterion(_, y_true): return torch.zeros(1, device=y_true.device) self.pass_state = pass_state self.state = State() self.state.update({ torchbearer.MODEL: model, torchbearer.CRITERION: criterion, torchbearer.OPTIMIZER: optimizer if optimizer is not None else MockOptimizer(), torchbearer.METRIC_LIST: MetricList(metrics), torchbearer.CALLBACK_LIST: CallbackList(callbacks), torchbearer.DEVICE: 'cpu', torchbearer.DATA_TYPE: torch.float32, torchbearer.SELF: self, torchbearer.HISTORY: [], torchbearer.BACKWARD_ARGS: {}, torchbearer.TRAIN_GENERATOR: None, torchbearer.VALIDATION_GENERATOR: None, torchbearer.TEST_GENERATOR: None, torchbearer.TRAIN_STEPS: None, torchbearer.VALIDATION_STEPS: None, torchbearer.TEST_STEPS: None })
class TestCallbackList(TestCase): def __init__(self, methodName='runTest'): super().__init__(methodName) self.callback_1 = MagicMock() self.callback_2 = MagicMock() callbacks = [self.callback_1, self.callback_2] self.list = CallbackList(callbacks) def test_for_list(self): self.list.on_start({}) self.assertTrue(self.callback_1.method_calls[0][0] == 'on_start') self.assertTrue(self.callback_2.method_calls[0][0] == 'on_start') def test_list_in_list(self): callback = 'test' clist = CallbackList([callback]) clist2 = CallbackList([clist]) self.assertTrue(clist2.callback_list[0] == 'test')
def replay(self, callbacks=[], verbose=2, one_batch=False): # TODO: Should we track if testing passes have happened? """ Replay the fit passes stored in history with given callbacks, useful when reloading a saved Trial. Note that only progress and metric information is populated in state during a replay. Args: callbacks (list): List of callbacks to be run during the replay verbose (int): If 2: use tqdm on batch, If 1: use tqdm on epoch, If 0: display no training progress one_batch (bool): If True, only one batch per epoch is replayed. If False, all batches are replayed Returns: Trial: self """ history = self.state[torchbearer.HISTORY] callbacks.append(get_printer(verbose=verbose, validation_label_letter='v')) callbacks = CallbackList(callbacks) state = State() state.update(self.state) state[torchbearer.STOP_TRAINING] = False state[torchbearer.MAX_EPOCHS] = len(history) callbacks.on_start(state) for i in range(len(history)): state[torchbearer.EPOCH] = i if not one_batch: state[torchbearer.TRAIN_STEPS], state[torchbearer.VALIDATION_STEPS] = history[i][0] else: state[torchbearer.TRAIN_STEPS], state[torchbearer.VALIDATION_STEPS] = 1, 1 state[torchbearer.METRICS] = history[i][1] self._replay_pass(state, callbacks) callbacks.on_end(state) return self
def replay(self, callbacks=[], verbose=2 ): # TODO: Should we track if testing passes have happened? """ Replay the fit passes stored in history with given callbacks, useful when reloading a saved Trial. Note that only progress and metric information is populated in state during a replay. :param callbacks: List of callbacks to be run during the replay :type callbacks: list :param verbose: If 2: use tqdm on batch, If 1: use tqdm on epoch, Else: display no training progress :type verbose: int :return: self :rtype: Trial """ history = self.state[torchbearer.HISTORY] callbacks.append( get_printer(verbose=verbose, validation_label_letter='v')) callbacks = CallbackList(callbacks) state = State() state.update(self.state) state[torchbearer.STOP_TRAINING] = False state[torchbearer.MAX_EPOCHS] = len(history) callbacks.on_start(state) for i in range(len(history)): state[torchbearer.EPOCH] = i state[torchbearer.TRAIN_STEPS], state[ torchbearer.VALIDATION_STEPS] = history[i][0] state[torchbearer.METRICS] = history[i][1] self._replay_pass(state, callbacks) callbacks.on_end(state)
class TestCallbackList(TestCase): def __init__(self, methodName='runTest'): super(TestCallbackList, self).__init__(methodName) self.callback_1 = MagicMock(spec=torchbearer.callbacks.printer.Tqdm()) self.callback_2 = MagicMock( spec=torchbearer.callbacks.tensor_board.TensorBoard()) callbacks = [self.callback_1, self.callback_2] self.list = CallbackList(callbacks) def test_state_dict(self): self.callback_1.state_dict = Mock(return_value='test_1') self.callback_2.state_dict = Mock(return_value='test_2') state = self.list.state_dict() self.assertEqual(self.callback_1.state_dict.call_count, 1) self.assertEqual(self.callback_2.state_dict.call_count, 1) self.assertEqual(state[CallbackList.CALLBACK_STATES][0], 'test_1') self.assertEqual(state[CallbackList.CALLBACK_STATES][1], 'test_2') self.assertEqual(state[CallbackList.CALLBACK_TYPES][0], Tqdm().__class__) self.assertEqual(state[CallbackList.CALLBACK_TYPES][1], TensorBoard().__class__) def test_load_state_dict(self): self.callback_1.load_state_dict = Mock(return_value='test_1') self.callback_2.load_state_dict = Mock(return_value='test_2') self.callback_1.state_dict = Mock(return_value='test_1') self.callback_2.state_dict = Mock(return_value='test_2') state = self.list.state_dict() self.list.load_state_dict(state) self.callback_1.load_state_dict.assert_called_once_with('test_1') self.callback_2.load_state_dict.assert_called_once_with('test_2') state = self.list.state_dict() state[CallbackList.CALLBACK_TYPES] = list( reversed(state[CallbackList.CALLBACK_TYPES])) with warnings.catch_warnings(record=True) as w: self.list.load_state_dict(state) self.assertTrue(len(w) == 1) self.assertTrue(issubclass(w[-1].category, UserWarning)) self.assertTrue( 'Callback classes did not match, expected: [\'TensorBoard\', \'Tqdm\']' in str(w[-1].message)) def test_for_list(self): self.list.on_start({}) self.assertTrue(self.callback_1.method_calls[0][0] == 'on_start') self.assertTrue(self.callback_2.method_calls[0][0] == 'on_start') def test_list_in_list(self): callback = 'test' clist = CallbackList([callback]) clist2 = CallbackList([clist]) self.assertTrue(clist2.callback_list[0] == 'test') def test_iter_copy(self): callback = 'test' clist = CallbackList([callback]) cpy = clist.__copy__() self.assertTrue(cpy.callback_list[0] == 'test') self.assertTrue(cpy is not clist) cpy = clist.copy() self.assertTrue(cpy.callback_list[0] == 'test') self.assertTrue(cpy is not clist) for cback in clist: self.assertTrue(cback == 'test')
def test_list_in_list(self): callback = 'test' clist = CallbackList([callback]) clist2 = CallbackList([clist]) self.assertTrue(clist2.callback_list[0] == 'test')
def __init__(self, methodName='runTest'): super().__init__(methodName) self.callback_1 = MagicMock() self.callback_2 = MagicMock() callbacks = [self.callback_1, self.callback_2] self.list = CallbackList(callbacks)