コード例 #1
0
def test_every_iteration_model_updater_with_cost():
    """
    Tests that the model updater can use a different attribute from loop_state as the training targets
    """

    class MockModel(IModel):
        def optimize(self):
            pass

        def set_data(self, X: np.ndarray, Y: np.ndarray):
            self._X = X
            self._Y = Y

        @property
        def X(self):
            return self._X

        @property
        def Y(self):
            return self._Y

    mock_model = MockModel()
    updater = FixedIntervalUpdater(mock_model, 1, lambda loop_state: loop_state.cost)

    loop_state_mock = mock.create_autospec(LoopState)
    loop_state_mock.iteration = 1
    loop_state_mock.X.return_value(np.random.rand(5, 1))
    loop_state_mock.cost.return_value(np.random.rand(5, 1))

    cost = np.random.rand(5, 1)
    loop_state_mock.cost.return_value(cost)
    updater.update(loop_state_mock)
    assert np.array_equiv(mock_model.X, cost)
コード例 #2
0
def test_every_iteration_model_updater():
    mock_model = mock.create_autospec(IModel)
    mock_model.optimize.return_value(None)
    updater = FixedIntervalUpdater(mock_model, 1)

    loop_state_mock = mock.create_autospec(LoopState)
    loop_state_mock.iteration = 1
    loop_state_mock.X.return_value(np.random.rand(5, 1))
    loop_state_mock.Y.return_value(np.random.rand(5, 1))
    updater.update(loop_state_mock)
    mock_model.optimize.assert_called_once()