示例#1
0
    def test_set_reward_signal_nan(self) -> None:
        """Test the reward signal if no new state was generated."""
        predictor = MCTS()

        state = flexmock()

        assert predictor._policy == {}
        predictor._next_state = flexmock()
        predictor.set_reward_signal(state, ("numpy", "2.0.0", "https://pypi.org/simple"), math.nan)
        assert predictor._next_state is None
        assert predictor._policy == {}
示例#2
0
    def test_pre_run(self) -> None:
        """Test calling pre-run for the initialization part."""
        flexmock(TemporalDifference)
        TemporalDifference.should_receive("pre_run").once()

        state = flexmock()
        predictor = MCTS()
        assert predictor._next_state is None
        predictor._next_state = state
        predictor.pre_run()
        assert predictor._next_state is None
示例#3
0
    def test_run_next_state(self, context: Context) -> None:
        """Test running the predictor when the next state is scheduled."""
        state = flexmock()
        unresolved_dependency = ("tensorflow", "2.0.0", "https://pypi.org/simple")
        state.should_receive("get_random_unresolved_dependency").with_args(prefer_recent=True).and_return(
            unresolved_dependency
        ).once()

        predictor = MCTS()
        predictor._next_state = state
        context.beam.should_receive("get_last").and_return(state).once()
        context.iteration = 1000000  # Some big number not to hit the heat-up part.
        with predictor.assigned_context(context):
            assert predictor.run() == (state, unresolved_dependency)
示例#4
0
    def test_run_heat_up(self, context: Context, next_state) -> None:
        """Test running the predictor in the "heat-up" phase regardless next state being set."""
        state = flexmock()
        unresolved_dependency = ("tensorflow", "2.0.0", "https://pypi.org/simple")

        predictor = MCTS()
        predictor._next_state = None

        flexmock(TemporalDifference)
        TemporalDifference.should_receive("run").with_args().and_return(state, unresolved_dependency).once()

        context.iteration = 1  # Some small number to hit the heat-up part.
        with predictor.assigned_context(context):
            assert predictor.run() == (state, unresolved_dependency)
示例#5
0
    def test_run_next_state_no_last(self, context: Context) -> None:
        """Test running the predictor when the next state is not last state added to beam."""
        state = flexmock()
        unresolved_dependency = ("tensorflow", "2.0.0", "https://pypi.org/simple")

        predictor = MCTS()
        predictor._next_state = flexmock()
        context.beam.should_receive("get_last").and_return(flexmock()).once()

        flexmock(TemporalDifference)
        TemporalDifference.should_receive("run").with_args().and_return(state, unresolved_dependency).once()

        context.iteration = 1000000  # Some big number not to hit the heat-up part.
        with predictor.assigned_context(context):
            assert predictor.run() == (state, unresolved_dependency)
示例#6
0
    def test_set_reward_signal_inf(self) -> None:
        """Test the reward signal if a final state was generated."""
        predictor = MCTS()

        state = flexmock(score=3.1)
        state.should_receive("iter_resolved_dependencies").and_return(
            [
                ("numpy", "2.0.0", "https://pypi.org/simple"),
                ("tensorflow", "2.0.0", "https://thoth-station.ninja/simple"),
            ]
        )
        # numpy was already seen, tensorflow was not seen yet
        predictor._policy = {
            ("numpy", "2.0.0", "https://pypi.org/simple"): [2.3, 100],
        }
        predictor._next_state = flexmock()
        predictor.set_reward_signal(state, ("numpy", "2.0.0", "https://pypi.org/simple"), math.inf)
        assert predictor._next_state is None
        assert predictor._policy == {
            ("numpy", "2.0.0", "https://pypi.org/simple"): [2.3 + 3.1, 101],
            ("tensorflow", "2.0.0", "https://thoth-station.ninja/simple"): [3.1, 1],
        }