def testBranchingSingleBeamEntry(self): sequence, state, score = beam_search.beam_search( initial_sequence=[], initial_state=1, generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=1, branch_factor=32, steps_per_iteration=1) # Here the beam search should greedily choose ones. self.assertEqual(sequence, [1, 1, 1, 1, 1]) self.assertEqual(state, 1) self.assertEqual(score, 5)
def testNoBranchingMultipleBeamEntries(self): sequence, state, score = beam_search.beam_search( initial_sequence=[], initial_state=1, generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=32, branch_factor=1, steps_per_iteration=1) # Here the beam has enough capacity to find the optimal solution without # branching. self.assertEqual(sequence, [0, 0, 0, 0, 1]) self.assertEqual(state, 1) self.assertEqual(score, 16)
def testNoBranchingMultipleStepsPerIteration(self): sequence, state, score = beam_search.beam_search( initial_sequence=[], initial_state=1, generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=1, branch_factor=1, steps_per_iteration=2) # Like the above case, the counter should never reach one as only a single # sequence is ever considered. self.assertEqual(sequence, [0, 0, 0, 0, 0]) self.assertEqual(state, 32) self.assertEqual(score, 0)
def testNoBranchingSingleStepPerIteration(self): sequence, state, score = beam_search.beam_search( initial_sequence=[], initial_state=1, generate_step_fn=self._generate_step_fn, num_steps=5, beam_size=1, branch_factor=1, steps_per_iteration=1) # The generator should emit all zeros, as only a single sequence is ever # considered so the counter doesn't reach one. self.assertEqual(sequence, [0, 0, 0, 0, 0]) self.assertEqual(state, 32) self.assertEqual(score, 0)