コード例 #1
0
ファイル: test_linear_shift.py プロジェクト: Fhrozen/chainer
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.extension = extensions.LinearShift(
            'x', self.value_range, self.time_range, self.optimizer)

        self.interval = 2
        self.trigger = training.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
コード例 #2
0
    def __init__(self, *args, **kwargs):
        self.recognizer = kwargs.pop('model')
        self.tensorboard_handle = kwargs.pop('tensorboard_handle', None)
        tensorboard_log_interval = kwargs.pop('tensorboard_log_interval',
                                              (1, 'iteration'))
        self.tensorboard_trigger = get_trigger(tensorboard_log_interval)

        self.mocked_trainer = Mock()
        self.mocked_trainer.updater = self

        super().__init__(*args, **kwargs)
コード例 #3
0
ファイル: test_step_shift.py プロジェクト: asi1024/chainer
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.extension = extensions.StepShift(
            'x', self.gamma, self.step, self.init, self.target, self.optimizer)

        self.interval = 1
        self.expect = [e for e in self.expect for _ in range(self.interval)]
        self.trigger = training.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
コード例 #4
0
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.extension = extensions.LinearShift('x', self.value_range,
                                                self.time_range,
                                                self.optimizer)

        self.interval = 2
        self.trigger = training.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
コード例 #5
0
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.extension = extensions.StepShift('x', self.gamma, self.step,
                                              self.init, self.target,
                                              self.optimizer)

        self.interval = 1
        self.expect = [e for e in self.expect for _ in range(self.interval)]
        self.trigger = training.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
コード例 #6
0
    def __init__(self, *args, **kwargs):
        self.localizer, self.recognizer = kwargs.pop('models')
        self.tensorboard_handle = kwargs.pop('tensorboard_handle', None)
        tensorboard_log_interval = kwargs.pop('tensorboard_log_interval',
                                              (1, 'iteration'))
        self.recognizer_update_interval = kwargs.pop(
            'recognizer_update_interval', 1)
        self.tensorboard_trigger = get_trigger(tensorboard_log_interval)

        self.mocked_trainer = Mock()
        self.mocked_trainer.updater = self

        super().__init__(*args, **kwargs)

        self.regularizers = [
            DirectionLossCalculator(self.localizer.xp),
            OutOfImageLossCalculator(self.localizer.xp),
        ]
コード例 #7
0
 def setUp(self):
     self.trainer = _get_mocked_trainer()
     self.optimizer = self.trainer.updater.get_optimizer('main')
     self.interval = 2
     self.trigger = training.get_trigger((self.interval, 'iteration'))
コード例 #8
0
 def setUp(self):
     self.trainer = _get_mocked_trainer()
     self.optimizer = self.trainer.updater.get_optimizer('main')
     self.interval = 2
     self.trigger = training.get_trigger((self.interval, 'iteration'))