Beispiel #1
0
    def test_forward_x(self):
        opt = Mock()

        def model_forward(x):
            return None

        model = create_autospec(model_forward)

        x = 'test'

        state = {
            torchbearer.X: x,
            torchbearer.MODEL: model,
            torchbearer.Y_TRUE: None,
            torchbearer.CRITERION: lambda x, y: MagicMock(),
            torchbearer.LOSS: None,
            torchbearer.OPTIMIZER: opt,
            torchbearer.CALLBACK_LIST: MagicMock(),
            torchbearer.BACKWARD_ARGS: {}
        }

        closure = base_closure(torchbearer.X, torchbearer.MODEL,
                               torchbearer.Y_PRED, torchbearer.Y_TRUE,
                               torchbearer.CRITERION, torchbearer.LOSS,
                               torchbearer.OPTIMIZER)
        closure(state)
        self.assertTrue(model.call_args[0][0] == x)
Beispiel #2
0
    def test_callback_list(self):
        opt = Mock()
        callback_list = Mock()
        callback_list.on_forward = Mock()
        callback_list.on_criterion = Mock()
        callback_list.on_backward = Mock()

        state = {
            torchbearer.X: None,
            torchbearer.MODEL: lambda x: None,
            torchbearer.Y_TRUE: None,
            torchbearer.CRITERION: lambda state: MagicMock(),
            torchbearer.LOSS: None,
            torchbearer.OPTIMIZER: opt,
            torchbearer.CALLBACK_LIST: callback_list,
            torchbearer.BACKWARD_ARGS: {}
        }

        closure = base_closure(torchbearer.X, torchbearer.MODEL,
                               torchbearer.Y_PRED, torchbearer.Y_TRUE,
                               torchbearer.CRITERION, torchbearer.LOSS,
                               torchbearer.OPTIMIZER)
        closure(state)
        self.assertTrue(callback_list.on_forward.call_count == 1)
        self.assertTrue(callback_list.on_criterion.call_count == 1)
        self.assertTrue(callback_list.on_backward.call_count == 1)
Beispiel #3
0
    def test_loss_state(self):
        opt = Mock()

        y_pred = 'yp'
        y_true = 'yt'

        def loss_sig(state):
            return None

        crit = create_autospec(loss_sig)

        state = {
            torchbearer.X: None,
            torchbearer.MODEL: lambda x: y_pred,
            torchbearer.Y_TRUE: y_true,
            torchbearer.CRITERION: crit,
            torchbearer.LOSS: None,
            torchbearer.OPTIMIZER: opt,
            torchbearer.CALLBACK_LIST: MagicMock(),
            torchbearer.BACKWARD_ARGS: {}
        }

        closure = base_closure(torchbearer.X, torchbearer.MODEL,
                               torchbearer.Y_PRED, torchbearer.Y_TRUE,
                               torchbearer.CRITERION, torchbearer.LOSS,
                               torchbearer.OPTIMIZER)
        closure(state)
        self.assertDictEqual(crit.call_args[0][0], state)
Beispiel #4
0
    def __init__(self, model, optimizer=None, criterion=None, metrics=[], callbacks=[], verbose=2):
        if criterion is None:
            def criterion(_, __):
                return torch.zeros(1, device=self.state[torchbearer.DEVICE], dtype=self.state[torchbearer.DATA_TYPE], requires_grad=True)

        self.verbose = verbose

        self.closure = base_closure(torchbearer.X, torchbearer.MODEL, torchbearer.Y_PRED, torchbearer.Y_TRUE, torchbearer.CRITERION, torchbearer.LOSS, torchbearer.OPTIMIZER)
        self.state = State()
        self.state.update({
            torchbearer.MODEL: model,
            torchbearer.CRITERION: criterion,
            torchbearer.OPTIMIZER: optimizer if optimizer is not None else MockOptimizer(),
            torchbearer.METRIC_LIST: MetricList(metrics),
            torchbearer.CALLBACK_LIST: CallbackList(callbacks),
            torchbearer.DEVICE: 'cpu',
            torchbearer.DATA_TYPE: torch.float32,
            torchbearer.SELF: self,
            torchbearer.HISTORY: [],
            torchbearer.BACKWARD_ARGS: {},
            torchbearer.TRAIN_GENERATOR: None,
            torchbearer.VALIDATION_GENERATOR: None,
            torchbearer.TEST_GENERATOR: None,
            torchbearer.TRAIN_STEPS: None,
            torchbearer.VALIDATION_STEPS: None,
            torchbearer.TEST_STEPS: None,
            torchbearer.TRAIN_DATA: None,
            torchbearer.VALIDATION_DATA: None,
            torchbearer.TEST_DATA: None,
            torchbearer.INF_TRAIN_LOADING: False,
            torchbearer.LOADER: None
        })

        self.state[torchbearer.CALLBACK_LIST].on_init(self.state)
Beispiel #5
0
    def test_loss_multiple_output_no_state(self):
        opt = Mock()

        y_pred1 = 'yp1'
        y_pred2 = 'yp2'
        y_true1 = 'yt1'
        y_true2 = 'yt2'

        def loss_sig(y_pred1, y_pred2, y_true1, y_true2):
            return None

        crit = create_autospec(loss_sig)

        state = {
            torchbearer.X: None,
            torchbearer.MODEL: lambda x: (y_pred1, y_pred2),
            torchbearer.Y_TRUE: (y_true1, y_true2),
            torchbearer.CRITERION: crit,
            torchbearer.LOSS: None,
            torchbearer.OPTIMIZER: opt,
            torchbearer.CALLBACK_LIST: MagicMock(),
            torchbearer.BACKWARD_ARGS: {}
        }

        closure = base_closure(torchbearer.X, torchbearer.MODEL,
                               torchbearer.Y_PRED, torchbearer.Y_TRUE,
                               torchbearer.CRITERION, torchbearer.LOSS,
                               torchbearer.OPTIMIZER)
        closure(state)
        self.assertTrue(crit.call_args[0] == (y_pred1, y_pred2, y_true1,
                                              y_true2))
Beispiel #6
0
    def test_forward_multiple_x_and_state(self):
        opt = Mock()

        def model_forward(x1, x2, state):
            return None

        model = create_autospec(model_forward)

        x1 = 'test1'
        x2 = 'test2'

        state = {
            torchbearer.X: [x1, x2],
            torchbearer.MODEL: model,
            torchbearer.Y_TRUE: None,
            torchbearer.CRITERION: lambda x, y: MagicMock(),
            torchbearer.LOSS: None,
            torchbearer.OPTIMIZER: opt,
            torchbearer.CALLBACK_LIST: MagicMock(),
            torchbearer.BACKWARD_ARGS: {}
        }

        closure = base_closure(torchbearer.X, torchbearer.MODEL,
                               torchbearer.Y_PRED, torchbearer.Y_TRUE,
                               torchbearer.CRITERION, torchbearer.LOSS,
                               torchbearer.OPTIMIZER)
        closure(state)
        self.assertTrue(model.call_args[0][0] == x1)
        self.assertTrue(model.call_args[0][1] == x2)
        self.assertDictEqual(model.call_args[1]['state'], state)
    def test_state_type_error(self):
        state = {
            torchbearer.X: None,
            torchbearer.MODEL: self.state_model_with_e(TypeError('test')),
            torchbearer.CRITERION: lambda state: MagicMock(),
            torchbearer.OPTIMIZER: Mock(),
            torchbearer.CALLBACK_LIST: Mock(),
            torchbearer.BACKWARD_ARGS: {}
        }

        closure = base_closure(torchbearer.X, torchbearer.MODEL,
                               torchbearer.Y_PRED, torchbearer.Y_TRUE,
                               torchbearer.CRITERION, torchbearer.LOSS,
                               torchbearer.OPTIMIZER)

        with self.assertRaises(Exception) as context:
            closure(state)

        self.assertTrue(len(context.exception.args[0]) == 2)
Beispiel #8
0
    def test_opt(self):
        opt = Mock()
        opt.zero_grad = Mock()
        state = {
            torchbearer.X: None,
            torchbearer.MODEL: lambda x: None,
            torchbearer.Y_TRUE: None,
            torchbearer.CRITERION: lambda x, y: MagicMock(),
            torchbearer.LOSS: None,
            torchbearer.OPTIMIZER: opt,
            torchbearer.CALLBACK_LIST: MagicMock(),
            torchbearer.BACKWARD_ARGS: {}
        }

        closure = base_closure(torchbearer.X, torchbearer.MODEL,
                               torchbearer.Y_PRED, torchbearer.Y_TRUE,
                               torchbearer.CRITERION, torchbearer.LOSS,
                               torchbearer.OPTIMIZER)
        closure(state)
        self.assertTrue(opt.zero_grad.call_count == 1)
Beispiel #9
0
    def test_backward(self):
        opt = Mock()
        loss = Mock()
        loss.backward = Mock()

        state = {
            torchbearer.X: None,
            torchbearer.MODEL: lambda x: None,
            torchbearer.Y_TRUE: None,
            torchbearer.CRITERION: lambda state: loss,
            torchbearer.LOSS: None,
            torchbearer.OPTIMIZER: opt,
            torchbearer.CALLBACK_LIST: MagicMock(),
            torchbearer.BACKWARD_ARGS: {}
        }

        closure = base_closure(torchbearer.X, torchbearer.MODEL,
                               torchbearer.Y_PRED, torchbearer.Y_TRUE,
                               torchbearer.CRITERION, torchbearer.LOSS,
                               torchbearer.OPTIMIZER)
        closure(state)
        self.assertTrue(loss.backward.call_count == 1)
Beispiel #10
0
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         drop_last=True)

# Model and optimizer
generator = Generator()
discriminator = Discriminator()
optimizer_G = torch.optim.Adam(generator.parameters(),
                               lr=lr,
                               betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(),
                               lr=lr,
                               betas=(0.5, 0.999))

closure_gen = base_closure(tb.X, tb.MODEL, tb.Y_PRED, tb.Y_TRUE, tb.CRITERION,
                           tb.LOSS, GEN_OPT)
closure_disc = base_closure(tb.Y_PRED, DISC_MODEL, None, DISC_IMGS, DISC_CRIT,
                            tb.LOSS, DISC_OPT)


def closure(state):
    closure_gen(state)
    state[GEN_OPT].step()
    closure_disc(state)
    state[DISC_OPT].step()


from torchbearer.metrics import mean, running_mean
metrics = ['loss', mean(running_mean(D_LOSS)), mean(running_mean(G_LOSS))]

trial = tb.Trial(generator,