def test_branch(self): # explicit branch construction self.assertEqual( self.children(Branch(components=(self.RunnableA(), ))), ['RunnableA']) # implicit + order self.assertEqual(self.children(self.RunnableA() | self.RunnableB()), ['RunnableA', 'RunnableB'])
def test_branch_with_single_component(self): """Traits requirements from inner runnable must be reflected in branch.""" class ValidSISO(Runnable, traits.SISO): def next(self, state): return state with self.assertRaises(traits.StateDimensionalityError): Branch(components=(ValidSISO(),)).run(States()).result() self.assertEqual(Branch(components=(ValidSISO(),)).run(State(x=1)).result().x, 1) class InvalidSISO(Runnable, traits.SISO): def next(self, state): return States(state, state) with self.assertRaises(traits.StateDimensionalityError): Branch(components=(InvalidSISO(),)).run(State()).result() Branch(components=(InvalidSISO(),)).run(States()).result() # input: list of states with subproblem # output: list of states with subsamples class SubproblemSamplerMIMO(Runnable, traits.MIMO, traits.SubproblemSampler): def next(self, states): return States(State(subsamples=1), State(subsamples=2)) with self.assertRaises(traits.StateDimensionalityError): Branch(components=(SubproblemSamplerMIMO(),)).run(State()).result() with self.assertRaises(traits.StateTraitMissingError): Branch(components=(SubproblemSamplerMIMO(),)).run(States(State())).result() r = Branch(components=(SubproblemSamplerMIMO(),)).run(States(State(subproblem=True))).result() self.assertEqual(r[0].subsamples, 1) self.assertEqual(r[1].subsamples, 2)
def test_composition(self): class A(Runnable): def next(self, state): return state.updated(x=state.x + 1) class B(Runnable): def next(self, state): return state.updated(x=state.x * 7) a, b = A(), B() s = State(x=1) b1 = Branch(components=(a, b)) self.assertEqual(b1.components, (a, b)) self.assertEqual(b1.run(s).result().x, (s.x + 1) * 7) b2 = b1 | b | a self.assertEqual(b2.components, (a, b, b, a)) self.assertEqual(b2.run(s).result().x, (s.x + 1) * 7 * 7 + 1) with self.assertRaises(TypeError): b1 | 1
def test_stop(self): class Stoppable(Runnable): def init(self, state): self.stopped = False def next(self, state): return state def halt(self): self.stopped = True branch = Branch([Stoppable()]) branch.run(State()) branch.stop() self.assertTrue(next(iter(branch)).stopped)
def __or__(self, other): """Composition of runnable components (L-to-R) returns a new runnable Branch.""" return Branch(components=(self, other))
def test_empty(self): with self.assertRaises(ValueError): Branch()