示例#1
0
    def test_observation_with_nonempty_policy_with_default_actions(self):
        # use all push groups except the chosen default group
        candidate_push_groups = [
            i for i, group in enumerate(self.push_groups)
            if len(group.resources) > 2 and not group.trainable
        ]
        default_group_idx = random.choice(candidate_push_groups)
        default_group = self.push_groups[default_group_idx]
        remaining_groups = [
            group for i, group in enumerate(self.push_groups)
            if i != default_group_idx
        ]
        action_space = ActionSpace(remaining_groups)
        policy = Policy(action_space)

        # apply some default action
        for push in default_group.resources[1:]:
            policy.add_default_push_action(default_group.resources[0], push)

        # do some actions and check the observation space over time
        for _ in range(len(action_space) - 1):
            # get an action and apply it in the policy
            action_id = action_space.sample()
            policy.apply_action(action_id)

            # get the observation
            obs = get_observation(self.client_environment, self.push_groups,
                                  policy, set())
            assert self.observation_space.contains(obs)

            # make sure the push sources are recorded correctly
            for (source, push) in policy.observable_push:
                for push_res in push:
                    # +1 since we have defined it that way
                    assert obs["resources"][str(
                        push_res.order)][-2] == source.source_id + 1

            # make sure the push sources are recorded correctly
            for (source, preload) in policy.observable_preload:
                for push_res in preload:
                    # +1 since we have defined it that way
                    assert obs["resources"][str(
                        push_res.order)][-1] == source.order + 1

            # check that all other resources are not pushed
            pushed_res = set(push_res.order
                             for (source, push) in policy.observable_push
                             for push_res in push)
            preloaded_res = set(push_res.order
                                for (source, push) in policy.observable_preload
                                for push_res in push)
            assert all(res[-2] == 0 for order, res in obs["resources"].items()
                       if int(order) not in pushed_res)
            assert all(res[-1] == 0 for order, res in obs["resources"].items()
                       if int(order) not in preloaded_res)
示例#2
0
class TestMahiMahiConfig:
    def setup(self):
        self.config = get_config()
        self.action_space = ActionSpace(get_push_groups())
        self.client_environment = get_random_client_environment()
        self.policy = Policy(self.action_space)
        applied = True
        while applied:
            applied = self.policy.apply_action(self.action_space.sample())

    def test_init_without_policy(self):
        mm_config = MahiMahiConfig(self.config)
        assert isinstance(mm_config, MahiMahiConfig)
        assert mm_config.policy is None
        assert mm_config.client_environment is None

    def test_init_without_client_environment(self):
        mm_config = MahiMahiConfig(self.config, policy=self.policy)
        assert isinstance(mm_config, MahiMahiConfig)
        assert mm_config.policy is self.policy
        assert mm_config.client_environment is None

    def test_init_with_client_environment(self):
        mm_config = MahiMahiConfig(self.config, policy=self.policy, client_environment=self.client_environment)
        assert isinstance(mm_config, MahiMahiConfig)
        assert mm_config.policy is self.policy
        assert mm_config.client_environment is self.client_environment

    def test_record_shell_with_cmd(self):
        save_dir = "/tmp/save_dir"
        mm_config = MahiMahiConfig(self.config, policy=self.policy)
        cmd = mm_config.record_shell_with_cmd(save_dir, ["a", "command"])
        assert cmd == (mm_config.record_cmd(save_dir) + ["a", "command"])

    def test_record_cmd(self):
        save_dir = "/tmp/save_dir"
        mm_config = MahiMahiConfig(self.config)
        record_cmd = mm_config.record_cmd(save_dir)
        assert record_cmd[0] == "mm-webrecord"
        assert record_cmd[1] == save_dir

    def test_formatted_trace_file(self):
        mm_config = MahiMahiConfig(self.config, policy=self.policy, client_environment=self.client_environment)
        trace_lines = trace_for_kbps(self.client_environment.bandwidth)
        formatted = format_trace_lines(trace_lines)
        assert mm_config.formatted_trace_file == formatted
示例#3
0
    def test_observation_with_nonempty_policy(self):
        action_space = ActionSpace(self.push_groups)
        policy = Policy(action_space)

        # do some actions and check the observation space over time
        for _ in range(len(action_space) - 1):
            # get an action and apply it in the policy
            action_id = action_space.sample()
            policy.apply_action(action_id)

            # get the observation
            obs = get_observation(self.client_environment, self.push_groups,
                                  policy, set())
            assert self.observation_space.contains(obs)

            # make sure the push sources are recorded correctly
            for (source, push) in policy.push:
                for push_res in push:
                    # +1 since we have defined it that way
                    assert obs["resources"][str(
                        push_res.order)][-2] == source.source_id + 1

            # make sure the push sources are recorded correctly
            for (source, preload) in policy.preload:
                for push_res in preload:
                    # +1 since we have defined it that way
                    assert obs["resources"][str(
                        push_res.order)][-1] == source.order + 1

            # check that all other resources are not pushed
            pushed_res = set(push_res.order for (source, push) in policy.push
                             for push_res in push)
            preloaded_res = set(push_res.order
                                for (source, push) in policy.preload
                                for push_res in push)
            assert all(res[-2] == 0 for order, res in obs["resources"].items()
                       if int(order) not in pushed_res)
            assert all(res[-1] == 0 for order, res in obs["resources"].items()
                       if int(order) not in preloaded_res)
示例#4
0
def get_action(action_space: ActionSpace) -> ActionIDType:
    # pick a non-noop action
    action = NOOP_ACTION_ID
    while action_space.decode_action(action).is_noop:
        action = action_space.sample()
    return action