def test_every_iteration_model_updater_with_cost(): """ Tests that the model updater can use a different attribute from loop_state as the training targets """ class MockModel(IModel): def optimize(self): pass def set_data(self, X: np.ndarray, Y: np.ndarray): self._X = X self._Y = Y @property def X(self): return self._X @property def Y(self): return self._Y mock_model = MockModel() updater = FixedIntervalUpdater(mock_model, 1, lambda loop_state: loop_state.cost) loop_state_mock = mock.create_autospec(LoopState) loop_state_mock.iteration = 1 loop_state_mock.X.return_value(np.random.rand(5, 1)) loop_state_mock.cost.return_value(np.random.rand(5, 1)) cost = np.random.rand(5, 1) loop_state_mock.cost.return_value(cost) updater.update(loop_state_mock) assert np.array_equiv(mock_model.X, cost)
def test_every_iteration_model_updater(): mock_model = mock.create_autospec(IModel) mock_model.optimize.return_value(None) updater = FixedIntervalUpdater(mock_model, 1) loop_state_mock = mock.create_autospec(LoopState) loop_state_mock.iteration = 1 loop_state_mock.X.return_value(np.random.rand(5, 1)) loop_state_mock.Y.return_value(np.random.rand(5, 1)) updater.update(loop_state_mock) mock_model.optimize.assert_called_once()