def test_energy_threshold_termination(self):
        class ExactSolver(Runnable):
            def next(self, state):
                return state.updated(
                    samples=dimod.ExactSolver().sample(state.problem))

        bqm = dimod.BinaryQuadraticModel({'a': 1}, {}, 0, dimod.SPIN)
        state = State.from_sample({'a': 1}, bqm)

        w = LoopUntilNoImprovement(
            ExactSolver(),
            key=operator.attrgetter('samples.first.energy'),
            terminate=partial(operator.ge, -1))
        s = w.run(state).result()
        self.assertEqual(s.samples.first.energy, -1)

        w = LoopUntilNoImprovement(ExactSolver(),
                                   key='samples.first.energy',
                                   terminate=partial(operator.ge, -1))
        s = w.run(state).result()
        self.assertEqual(s.samples.first.energy, -1)

        w = LoopUntilNoImprovement(ExactSolver(),
                                   terminate=partial(operator.ge, -1))
        s = w.run(state).result()
        self.assertEqual(s.samples.first.energy, -1)
Ejemplo n.º 2
0
    def test_terminate_predicate(self):
        class Inc(Runnable):
            def next(self, state):
                return state.updated(cnt=state.cnt + 1)

        it = LoopUntilNoImprovement(Inc(),
                                    key=lambda state: state.cnt,
                                    terminate=lambda key: key >= 3)
        s = it.run(State(cnt=0)).result()

        self.assertEqual(s.cnt, 3)
Ejemplo n.º 3
0
    def test_convergence(self):
        class Inc(Runnable):
            def next(self, state):
                return state.updated(cnt=state.cnt + 1)

        it = LoopUntilNoImprovement(Inc(),
                                    max_iter=1000,
                                    convergence=100,
                                    key=lambda _: None)
        s = it.run(State(cnt=0)).result()

        self.assertEqual(s.cnt, 100)
Ejemplo n.º 4
0
    def test_validation(self):
        class simo(Runnable, traits.SIMO):
            def next(self, state):
                return States(state, state)

        with self.assertRaises(TypeError):
            LoopUntilNoImprovement(simo())
Ejemplo n.º 5
0
    def test_max_iter(self):
        class Inc(Runnable):
            def next(self, state):
                return state.updated(cnt=state.cnt + 1)

        # iterate for `max_iter`
        it = LoopUntilNoImprovement(Inc(), max_iter=100, convergence=1000, key=lambda _: None)
        s = it.run(State(cnt=0)).result()
        self.assertEqual(s.cnt, 100)

        # `key` function not needed if `convergence` undefined
        it = LoopUntilNoImprovement(Inc(), max_iter=100, convergence=None)
        s = it.run(State(cnt=0)).result()
        self.assertEqual(s.cnt, 100)

        # `convergence` not needed for simple finite loop
        it = LoopUntilNoImprovement(Inc(), max_iter=100)
        s = it.run(State(cnt=0)).result()
        self.assertEqual(s.cnt, 100)
Ejemplo n.º 6
0
    def test_timeout(self):
        class Inc(Runnable):
            def next(self, state):
                return state.updated(cnt=state.cnt + 1)

        # timeout after exactly two runs
        with mock.patch('time.time', side_effect=itertools.count(), create=True):
            loop = LoopUntilNoImprovement(Inc(), max_time=2)
            state = loop.run(State(cnt=0)).result()
            self.assertEqual(state.cnt, 2)

        # timeout after the second run
        with mock.patch('time.time', side_effect=itertools.count(), create=True):
            loop = LoopUntilNoImprovement(Inc(), max_time=2.5)
            state = loop.run(State(cnt=0)).result()
            self.assertEqual(state.cnt, 3)