def test_forward_x(self): opt = Mock() def model_forward(x): return None model = create_autospec(model_forward) x = 'test' state = { torchbearer.X: x, torchbearer.MODEL: model, torchbearer.Y_TRUE: None, torchbearer.CRITERION: lambda x, y: MagicMock(), torchbearer.LOSS: None, torchbearer.OPTIMIZER: opt, torchbearer.CALLBACK_LIST: MagicMock(), torchbearer.BACKWARD_ARGS: {} } closure = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, torchbearer.OPTIMIZER) closure(state) self.assertTrue(model.call_args[0][0] == x)
def test_callback_list(self): opt = Mock() callback_list = Mock() callback_list.on_forward = Mock() callback_list.on_criterion = Mock() callback_list.on_backward = Mock() state = { torchbearer.X: None, torchbearer.MODEL: lambda x: None, torchbearer.Y_TRUE: None, torchbearer.CRITERION: lambda state: MagicMock(), torchbearer.LOSS: None, torchbearer.OPTIMIZER: opt, torchbearer.CALLBACK_LIST: callback_list, torchbearer.BACKWARD_ARGS: {} } closure = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, torchbearer.OPTIMIZER) closure(state) self.assertTrue(callback_list.on_forward.call_count == 1) self.assertTrue(callback_list.on_criterion.call_count == 1) self.assertTrue(callback_list.on_backward.call_count == 1)
def test_loss_state(self): opt = Mock() y_pred = 'yp' y_true = 'yt' def loss_sig(state): return None crit = create_autospec(loss_sig) state = { torchbearer.X: None, torchbearer.MODEL: lambda x: y_pred, torchbearer.Y_TRUE: y_true, torchbearer.CRITERION: crit, torchbearer.LOSS: None, torchbearer.OPTIMIZER: opt, torchbearer.CALLBACK_LIST: MagicMock(), torchbearer.BACKWARD_ARGS: {} } closure = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, torchbearer.OPTIMIZER) closure(state) self.assertDictEqual(crit.call_args[0][0], state)
def __init__(self, model, optimizer=None, criterion=None, metrics=[], callbacks=[], verbose=2): if criterion is None: def criterion(_, __): return torch.zeros(1, device=self.state[torchbearer.DEVICE], dtype=self.state[torchbearer.DATA_TYPE], requires_grad=True) self.verbose = verbose self.closure = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, torchbearer.OPTIMIZER) self.state = State() self.state.update({ torchbearer.MODEL: model, torchbearer.CRITERION: criterion, torchbearer.OPTIMIZER: optimizer if optimizer is not None else MockOptimizer(), torchbearer.METRIC_LIST: MetricList(metrics), torchbearer.CALLBACK_LIST: CallbackList(callbacks), torchbearer.DEVICE: 'cpu', torchbearer.DATA_TYPE: torch.float32, torchbearer.SELF: self, torchbearer.HISTORY: [], torchbearer.BACKWARD_ARGS: {}, torchbearer.TRAIN_GENERATOR: None, torchbearer.VALIDATION_GENERATOR: None, torchbearer.TEST_GENERATOR: None, torchbearer.TRAIN_STEPS: None, torchbearer.VALIDATION_STEPS: None, torchbearer.TEST_STEPS: None, torchbearer.TRAIN_DATA: None, torchbearer.VALIDATION_DATA: None, torchbearer.TEST_DATA: None, torchbearer.INF_TRAIN_LOADING: False, torchbearer.LOADER: None }) self.state[torchbearer.CALLBACK_LIST].on_init(self.state)
def test_loss_multiple_output_no_state(self): opt = Mock() y_pred1 = 'yp1' y_pred2 = 'yp2' y_true1 = 'yt1' y_true2 = 'yt2' def loss_sig(y_pred1, y_pred2, y_true1, y_true2): return None crit = create_autospec(loss_sig) state = { torchbearer.X: None, torchbearer.MODEL: lambda x: (y_pred1, y_pred2), torchbearer.Y_TRUE: (y_true1, y_true2), torchbearer.CRITERION: crit, torchbearer.LOSS: None, torchbearer.OPTIMIZER: opt, torchbearer.CALLBACK_LIST: MagicMock(), torchbearer.BACKWARD_ARGS: {} } closure = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, torchbearer.OPTIMIZER) closure(state) self.assertTrue(crit.call_args[0] == (y_pred1, y_pred2, y_true1, y_true2))
def test_forward_multiple_x_and_state(self): opt = Mock() def model_forward(x1, x2, state): return None model = create_autospec(model_forward) x1 = 'test1' x2 = 'test2' state = { torchbearer.X: [x1, x2], torchbearer.MODEL: model, torchbearer.Y_TRUE: None, torchbearer.CRITERION: lambda x, y: MagicMock(), torchbearer.LOSS: None, torchbearer.OPTIMIZER: opt, torchbearer.CALLBACK_LIST: MagicMock(), torchbearer.BACKWARD_ARGS: {} } closure = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, torchbearer.OPTIMIZER) closure(state) self.assertTrue(model.call_args[0][0] == x1) self.assertTrue(model.call_args[0][1] == x2) self.assertDictEqual(model.call_args[1]['state'], state)
def test_state_type_error(self): state = { torchbearer.X: None, torchbearer.MODEL: self.state_model_with_e(TypeError('test')), torchbearer.CRITERION: lambda state: MagicMock(), torchbearer.OPTIMIZER: Mock(), torchbearer.CALLBACK_LIST: Mock(), torchbearer.BACKWARD_ARGS: {} } closure = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, torchbearer.OPTIMIZER) with self.assertRaises(Exception) as context: closure(state) self.assertTrue(len(context.exception.args[0]) == 2)
def test_opt(self): opt = Mock() opt.zero_grad = Mock() state = { torchbearer.X: None, torchbearer.MODEL: lambda x: None, torchbearer.Y_TRUE: None, torchbearer.CRITERION: lambda x, y: MagicMock(), torchbearer.LOSS: None, torchbearer.OPTIMIZER: opt, torchbearer.CALLBACK_LIST: MagicMock(), torchbearer.BACKWARD_ARGS: {} } closure = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, torchbearer.OPTIMIZER) closure(state) self.assertTrue(opt.zero_grad.call_count == 1)
def test_backward(self): opt = Mock() loss = Mock() loss.backward = Mock() state = { torchbearer.X: None, torchbearer.MODEL: lambda x: None, torchbearer.Y_TRUE: None, torchbearer.CRITERION: lambda state: loss, torchbearer.LOSS: None, torchbearer.OPTIMIZER: opt, torchbearer.CALLBACK_LIST: MagicMock(), torchbearer.BACKWARD_ARGS: {} } closure = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, torchbearer.OPTIMIZER) closure(state) self.assertTrue(loss.backward.call_count == 1)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) # Model and optimizer generator = Generator() discriminator = Discriminator() optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999)) closure_gen = base_closure(tb.X, tb.MODEL, tb.Y_PRED, tb.Y_TRUE, tb.CRITERION, tb.LOSS, GEN_OPT) closure_disc = base_closure(tb.Y_PRED, DISC_MODEL, None, DISC_IMGS, DISC_CRIT, tb.LOSS, DISC_OPT) def closure(state): closure_gen(state) state[GEN_OPT].step() closure_disc(state) state[DISC_OPT].step() from torchbearer.metrics import mean, running_mean metrics = ['loss', mean(running_mean(D_LOSS)), mean(running_mean(G_LOSS))] trial = tb.Trial(generator,