def _test_new_iteration(self) -> None: """Test marking a new iteration in a resolution round.""" beam = Beam(width=2) assert list(beam.iter_states()) == [] assert beam.get_last() is None state01 = State(score=0.0) beam.add_state(state01.score, state01) assert list(beam.iter_states()) == [state01] assert beam.get_last() is state01 beam.new_iteration() # New iterations do not interleave. assert list(beam.iter_states()) == [state01] assert beam.get_last() is state01 state02 = State(score=0.1) beam.add_state(state02.score, state02) assert state01 in beam.iter_states() assert state02 in beam.iter_states() assert beam.get_last() is state02
def test_new_iteration(self) -> None: """Test marking a new iteration in a resolution round.""" beam = Beam(width=2) assert list(beam.iter_states()) == [] assert list(beam.iter_new_added_states()) == [] state01 = State(score=0.0) beam.add_state(state01) assert list(beam.iter_states()) == [state01] assert list(beam.iter_new_added_states()) == [state01] beam.new_iteration() assert list(beam.iter_states()) == [state01] assert list(beam.iter_new_added_states()) == [] state02 = State(score=0.1) beam.add_state(state02) assert state01 in beam.iter_states() assert state02 in beam.iter_states() assert list(beam.iter_new_added_states()) == [state02] beam.new_iteration() state03 = State(score=0.2) beam.add_state(state03) state04 = State(score=0.3) beam.add_state(state04) state05 = State(score=0.4) beam.add_state(state05) assert beam.size == 2 assert state04 in beam.iter_states() assert state05 in beam.iter_states() new_added = list(beam.iter_new_added_states()) assert len(new_added) == 2 assert state04 in new_added assert state05 in new_added assert list(beam.iter_new_added_states_sorted()) == [state05, state04] assert list(beam.iter_new_added_states_sorted(reverse=True)) == [state05, state04] assert list(beam.iter_new_added_states_sorted(reverse=False)) == [state04, state05] beam.new_iteration() state06 = State(score=1.0) beam.add_state(state06) state07 = State(score=1.0) beam.add_state(state07) assert list(beam.iter_new_added_states_sorted()) == [state06, state07] assert list(beam.iter_new_added_states_sorted(reverse=True)) == [state06, state07] assert list(beam.iter_new_added_states_sorted(reverse=False)) == [state06, state07]