Ejemplo n.º 1
0
    def test_remove(self) -> None:
        """Test removal of a state from beam."""
        beam = Beam(width=2)

        state1 = State(score=0.0)
        beam.add_state(state1)
        beam.remove(state1)

        state2 = State(score=1.0)
        beam.add_state(state2)
        beam.remove(state2)

        state3 = State(score=2.0)
        beam.add_state(state3)
        beam.remove(state3)

        assert beam.size == 0

        beam.add_state(state1)
        beam.add_state(state2)
        beam.add_state(state3)

        assert state1 not in beam.iter_states()
        assert state2 in beam.iter_states()
        assert state3 in beam.iter_states()

        # TODO: uncomment once fixed https://github.com/thoth-station/adviser/issues/1541
        # with pytest.raises(ValueError):
        #     beam.remove(state1)

        assert beam.max() is state3

        beam.remove(state2)

        assert beam.max() is state3
        assert state2 not in beam.iter_states()
        assert state3 in beam.iter_states()
        assert beam.size == 1

        beam.remove(state3)

        assert beam.size == 0
        assert state3 not in beam.iter_states()
Ejemplo n.º 2
0
    def test_remove(self) -> None:
        """Test removal of a state from beam."""
        beam = Beam(width=2)

        state1 = State(score=0.0)
        beam.add_state(state1)
        beam.remove(state1)

        state2 = State(score=1.0)
        beam.add_state(state2)
        beam.remove(state2)

        state3 = State(score=2.0)
        beam.add_state(state3)
        beam.remove(state3)

        assert beam.size == 0

        beam.add_state(state1)
        beam.add_state(state2)
        beam.add_state(state3)

        assert state1 not in beam.iter_states()
        assert state2 in beam.iter_states()
        assert state3 in beam.iter_states()

        with pytest.raises(KeyError):
            beam.remove(state1)

        beam.remove(state2)

        assert state2 not in beam.iter_states()
        assert state3 in beam.iter_states()
        assert beam.size == 1

        beam.remove(state3)

        assert beam.size == 0
        assert state3 not in beam.iter_states()