Пример #1
0
    def _test_fitting(self, params, logs, has_valid=True):
        self.assertEqual(len(logs), params['epochs'])
        train_dict = dict(zip(self.metrics_names, self.metrics_values), loss=ANY)
        if has_valid:
            val_metrics_names = ['val_' + metric_name for metric_name in self.metrics_names]
            val_dict = dict(zip(val_metrics_names, self.metrics_values), val_loss=ANY)
            log_dict = {**train_dict, **val_dict}
        else:
            log_dict = train_dict

        for epoch, log in enumerate(logs, 1):
            self.assertEqual(log, dict(log_dict, epoch=epoch))

        call_list = []
        call_list.append(call.on_train_begin({}))
        for epoch in range(1, params['epochs']+1):
            call_list.append(call.on_epoch_begin(epoch, {}))
            for step in range(1, params['steps']+1):
                call_list.append(call.on_batch_begin(step, {}))
                call_list.append(call.on_backward_end(step))
                call_list.append(call.on_batch_end(step, {'batch': step, 'size': ANY, **train_dict}))
            call_list.append(call.on_epoch_end(epoch, {'epoch': epoch, **log_dict}))
        call_list.append(call.on_train_end({}))

        method_calls = self.mock_callback.method_calls
        self.assertIn(call.set_model(self.model), method_calls[:2])
        self.assertIn(call.set_params(params), method_calls[:2])

        self.assertEqual(len(method_calls), len(call_list) + 2)
        self.assertEqual(method_calls[2:], call_list)
Пример #2
0
    def _test_batch_delay(self, epoch_delay, batch_in_epoch_delay):
        batch_delay = epoch_delay*DelayCallbackTest.steps_per_epoch + batch_in_epoch_delay
        delay_callback = DelayCallback(self.mock_callback, batch_delay=batch_delay)
        train_generator = some_data_generator(DelayCallbackTest.batch_size)
        valid_generator = some_data_generator(DelayCallbackTest.batch_size)
        logs = self.model.fit_generator(train_generator, valid_generator, epochs=DelayCallbackTest.epochs, steps_per_epoch=DelayCallbackTest.steps_per_epoch, validation_steps=DelayCallbackTest.steps_per_epoch, callbacks=[delay_callback])
        params = {'epochs': DelayCallbackTest.epochs, 'steps': DelayCallbackTest.steps_per_epoch}

        call_list = []
        call_list.append(call.on_train_begin({}))
        for epoch in range(epoch_delay + 1, DelayCallbackTest.epochs+1):
            call_list.append(call.on_epoch_begin(epoch, {}))
            start_step = batch_in_epoch_delay + 1 if epoch == epoch_delay + 1 else 1
            for step in range(start_step, params['steps']+1):
                call_list.append(call.on_batch_begin(step, {}))
                call_list.append(call.on_backward_end(step))
                call_list.append(call.on_batch_end(step, {'batch': step, 'size': DelayCallbackTest.batch_size, **self.train_dict}))
            call_list.append(call.on_epoch_end(epoch, {'epoch': epoch, **self.log_dict}))
        call_list.append(call.on_train_end({}))


        method_calls = self.mock_callback.method_calls
        self.assertIn(call.set_model(self.model), method_calls[:2])
        self.assertIn(call.set_params(params), method_calls[:2])

        self.assertEqual(len(method_calls), len(call_list) + 2)
        self.assertEqual(method_calls[2:], call_list)