def test_energy_threshold_termination(self): class ExactSolver(Runnable): def next(self, state): return state.updated( samples=dimod.ExactSolver().sample(state.problem)) bqm = dimod.BinaryQuadraticModel({'a': 1}, {}, 0, dimod.SPIN) state = State.from_sample({'a': 1}, bqm) w = LoopUntilNoImprovement( ExactSolver(), key=operator.attrgetter('samples.first.energy'), terminate=partial(operator.ge, -1)) s = w.run(state).result() self.assertEqual(s.samples.first.energy, -1) w = LoopUntilNoImprovement(ExactSolver(), key='samples.first.energy', terminate=partial(operator.ge, -1)) s = w.run(state).result() self.assertEqual(s.samples.first.energy, -1) w = LoopUntilNoImprovement(ExactSolver(), terminate=partial(operator.ge, -1)) s = w.run(state).result() self.assertEqual(s.samples.first.energy, -1)
def test_terminate_predicate(self): class Inc(Runnable): def next(self, state): return state.updated(cnt=state.cnt + 1) it = LoopUntilNoImprovement(Inc(), key=lambda state: state.cnt, terminate=lambda key: key >= 3) s = it.run(State(cnt=0)).result() self.assertEqual(s.cnt, 3)
def test_convergence(self): class Inc(Runnable): def next(self, state): return state.updated(cnt=state.cnt + 1) it = LoopUntilNoImprovement(Inc(), max_iter=1000, convergence=100, key=lambda _: None) s = it.run(State(cnt=0)).result() self.assertEqual(s.cnt, 100)
def test_validation(self): class simo(Runnable, traits.SIMO): def next(self, state): return States(state, state) with self.assertRaises(TypeError): LoopUntilNoImprovement(simo())
def test_max_iter(self): class Inc(Runnable): def next(self, state): return state.updated(cnt=state.cnt + 1) # iterate for `max_iter` it = LoopUntilNoImprovement(Inc(), max_iter=100, convergence=1000, key=lambda _: None) s = it.run(State(cnt=0)).result() self.assertEqual(s.cnt, 100) # `key` function not needed if `convergence` undefined it = LoopUntilNoImprovement(Inc(), max_iter=100, convergence=None) s = it.run(State(cnt=0)).result() self.assertEqual(s.cnt, 100) # `convergence` not needed for simple finite loop it = LoopUntilNoImprovement(Inc(), max_iter=100) s = it.run(State(cnt=0)).result() self.assertEqual(s.cnt, 100)
def test_timeout(self): class Inc(Runnable): def next(self, state): return state.updated(cnt=state.cnt + 1) # timeout after exactly two runs with mock.patch('time.time', side_effect=itertools.count(), create=True): loop = LoopUntilNoImprovement(Inc(), max_time=2) state = loop.run(State(cnt=0)).result() self.assertEqual(state.cnt, 2) # timeout after the second run with mock.patch('time.time', side_effect=itertools.count(), create=True): loop = LoopUntilNoImprovement(Inc(), max_time=2.5) state = loop.run(State(cnt=0)).result() self.assertEqual(state.cnt, 3)