Beispiel #1
    def test_observation_with_nonempty_policy_with_default_actions(self):
        # use all push groups except the chosen default group
        candidate_push_groups = [
            i for i, group in enumerate(self.push_groups)
            if len(group.resources) > 2 and not group.trainable
        default_group_idx = random.choice(candidate_push_groups)
        default_group = self.push_groups[default_group_idx]
        remaining_groups = [
            group for i, group in enumerate(self.push_groups)
            if i != default_group_idx
        action_space = ActionSpace(remaining_groups)
        policy = Policy(action_space)

        # apply some default action
        for push in default_group.resources[1:]:
            policy.add_default_push_action(default_group.resources[0], push)

        # do some actions and check the observation space over time
        for _ in range(len(action_space) - 1):
            # get an action and apply it in the policy
            action_id = action_space.sample()

            # get the observation
            obs = get_observation(self.client_environment, self.push_groups,
                                  policy, set())
            assert self.observation_space.contains(obs)

            # make sure the push sources are recorded correctly
            for (source, push) in policy.observable_push:
                for push_res in push:
                    # +1 since we have defined it that way
                    assert obs["resources"][str(
                        push_res.order)][-2] == source.source_id + 1

            # make sure the push sources are recorded correctly
            for (source, preload) in policy.observable_preload:
                for push_res in preload:
                    # +1 since we have defined it that way
                    assert obs["resources"][str(
                        push_res.order)][-1] == source.order + 1

            # check that all other resources are not pushed
            pushed_res = set(push_res.order
                             for (source, push) in policy.observable_push
                             for push_res in push)
            preloaded_res = set(push_res.order
                                for (source, push) in policy.observable_preload
                                for push_res in push)
            assert all(res[-2] == 0 for order, res in obs["resources"].items()
                       if int(order) not in pushed_res)
            assert all(res[-1] == 0 for order, res in obs["resources"].items()
                       if int(order) not in preloaded_res)
Beispiel #2
    def test_apply_action_noop_as_second_action(self):
        action_space = get_action_space()
        policy = Policy(action_space)

        applied = policy.apply_action((1, (0, 0, 1), (0, 0)))
        output_policy = list(policy.push)
        assert applied
        assert output_policy
        assert len(policy) == 1

        applied = policy.apply_action(Action())
        assert not applied
        assert output_policy == list(policy.push)
        assert len(policy) == 1
Beispiel #3
    def test_resource_push_from(self):
        action_space = get_action_space()
        policy = Policy(action_space)
        action = Action()
        while action.is_noop or not action.is_push:
            action = action_space.decode_action(action_space.sample())
        assert policy.resource_pushed_from(action.push) is None
        assert policy.apply_action(action)
        assert policy.resource_pushed_from(action.push) is action.source

        while action.is_noop or action.is_push:
            action = action_space.decode_action(action_space.sample())
        assert policy.resource_preloaded_from(action.push) is None
        assert policy.apply_action(action)
        assert policy.resource_preloaded_from(action.push) is action.source
Beispiel #4
    def test_as_dict(self):
        action_space = get_action_space()
        policy = Policy(action_space)
        for _ in range(10):
            action = action_space.decode_action(action_space.sample())

        policy_dict = policy.as_dict
        for (source, push) in policy.push:
            assert all(
                p.url in [pp["url"] for pp in policy_dict["push"][source.url]]
                for p in push)
        for (source, preload) in policy.preload:
            assert all(
                p.url in
                [pp["url"] for pp in policy_dict["preload"][source.url]]
                for p in preload)
Beispiel #5
 def test_apply_action_noop_as_first_action(self):
     action_space = get_action_space()
     policy = Policy(action_space)
     applied = policy.apply_action(Action())
     assert not applied  # check that action was not applied
     assert not list(
         policy.push)  # check that no URLs were added to the policy
     assert not list(
         policy.preload)  # check that no URLs were added to the policy
     assert len(policy) == 0  # check that the policy length == 0
Beispiel #6
    def test_observation_with_nonempty_policy(self):
        action_space = ActionSpace(self.push_groups)
        policy = Policy(action_space)

        # do some actions and check the observation space over time
        for _ in range(len(action_space) - 1):
            # get an action and apply it in the policy
            action_id = action_space.sample()

            # get the observation
            obs = get_observation(self.client_environment, self.push_groups,
                                  policy, set())
            assert self.observation_space.contains(obs)

            # make sure the push sources are recorded correctly
            for (source, push) in policy.push:
                for push_res in push:
                    # +1 since we have defined it that way
                    assert obs["resources"][str(
                        push_res.order)][-2] == source.source_id + 1

            # make sure the push sources are recorded correctly
            for (source, preload) in policy.preload:
                for push_res in preload:
                    # +1 since we have defined it that way
                    assert obs["resources"][str(
                        push_res.order)][-1] == source.order + 1

            # check that all other resources are not pushed
            pushed_res = set(push_res.order for (source, push) in policy.push
                             for push_res in push)
            preloaded_res = set(push_res.order
                                for (source, push) in policy.preload
                                for push_res in push)
            assert all(res[-2] == 0 for order, res in obs["resources"].items()
                       if int(order) not in pushed_res)
            assert all(res[-1] == 0 for order, res in obs["resources"].items()
                       if int(order) not in preloaded_res)
Beispiel #7
    def test_apply_push_action(self):
        action_space = get_action_space()
        policy = Policy(action_space)

        action = action_space.decode_action((1, (0, 0, 1), (0, 0)))
        applied = policy.apply_action((1, (0, 0, 1), (0, 0)))
        output_policy = list(policy.push)
        assert applied
        assert len(policy) == 1
        assert len(output_policy) == 1
        assert len(output_policy[0][1]) == 1
        assert output_policy[0][0] == action.source
        assert output_policy[0][1] == {action.push}
Beispiel #8
    def test_apply_push_action_same_source_resource(self):
        action_space = get_action_space()
        policy = Policy(action_space)

        action_1 = action_space.decode_action((1, (0, 0, 1), (0, 0)))
        action_2 = action_space.decode_action((1, (0, 0, 2), (0, 0)))
        assert policy.apply_action(action_1)
        assert policy.apply_action(action_2)

        output_policy = list(policy.push)
        assert len(policy) == 2
        assert len(output_policy) == 1
        assert len(output_policy[0][1]) == 2
        assert output_policy[0][0] == action_1.source
        assert output_policy[0][0] == action_2.source
        assert output_policy[0][1] == {action_1.push, action_2.push}
Beispiel #9
class TestMahiMahiConfig:
    def setup(self):
        self.config = get_config()
        self.action_space = ActionSpace(get_push_groups())
        self.client_environment = get_random_client_environment()
        self.policy = Policy(self.action_space)
        applied = True
        while applied:
            applied = self.policy.apply_action(self.action_space.sample())

    def test_init_without_policy(self):
        mm_config = MahiMahiConfig(self.config)
        assert isinstance(mm_config, MahiMahiConfig)
        assert mm_config.policy is None
        assert mm_config.client_environment is None

    def test_init_without_client_environment(self):
        mm_config = MahiMahiConfig(self.config, policy=self.policy)
        assert isinstance(mm_config, MahiMahiConfig)
        assert mm_config.policy is self.policy
        assert mm_config.client_environment is None

    def test_init_with_client_environment(self):
        mm_config = MahiMahiConfig(self.config, policy=self.policy, client_environment=self.client_environment)
        assert isinstance(mm_config, MahiMahiConfig)
        assert mm_config.policy is self.policy
        assert mm_config.client_environment is self.client_environment

    def test_record_shell_with_cmd(self):
        save_dir = "/tmp/save_dir"
        mm_config = MahiMahiConfig(self.config, policy=self.policy)
        cmd = mm_config.record_shell_with_cmd(save_dir, ["a", "command"])
        assert cmd == (mm_config.record_cmd(save_dir) + ["a", "command"])

    def test_record_cmd(self):
        save_dir = "/tmp/save_dir"
        mm_config = MahiMahiConfig(self.config)
        record_cmd = mm_config.record_cmd(save_dir)
        assert record_cmd[0] == "mm-webrecord"
        assert record_cmd[1] == save_dir

    def test_formatted_trace_file(self):
        mm_config = MahiMahiConfig(self.config, policy=self.policy, client_environment=self.client_environment)
        trace_lines = trace_for_kbps(self.client_environment.bandwidth)
        formatted = format_trace_lines(trace_lines)
        assert mm_config.formatted_trace_file == formatted
Beispiel #10
    def test_apply_multiple_actions(self):
        action_space = get_action_space()
        policy = Policy(action_space)

        actions = []
        while len(policy) < 10:
            action_id = action_space.sample()
            action = action_space.decode_action(action_id)
            action_applied = policy.apply_action(action)

            if action_applied:

        for action in actions:
            assert any(action.source == source and action.push == push
                       for source, res in policy.push for push in res) or any(
                           action.source == source and action.push == push
                           for source, res in policy.preload for push in res)
Beispiel #11
    def test_push_preload_list_for_source(self):
        action_space = get_action_space()
        policy = Policy(action_space)

        push_map = collections.defaultdict(set)
        preload_map = collections.defaultdict(set)

        while len(policy) < 10:
            action_id = action_space.sample()
            action = action_space.decode_action(action_id)
            action_applied = policy.apply_action(action)

            if action_applied:
                if action.is_push:
                if action.is_preload:

        assert push_map
        assert preload_map
        for (source, push_set) in push_map.items():
            assert policy.push_set_for_resource(source) == push_set
        for (source, preload_set) in preload_map.items():
            assert policy.preload_set_for_resource(source) == preload_set
Beispiel #12
class Environment(gym.Env):
    Environment virtualizes a randomly chosen network and browser environment and
    facilitates the training for a given web page. This includes action selection, policy
    generation, and evaluation of the policy/action in the simulated environment.
    def __init__(self, config: Union[Config, dict]):
        # make sure config is an instance of Config or a dict
        assert isinstance(config, (Config, dict))
        config = config if isinstance(config, Config) else Config(**config)

        self.config = config
        self.env_config = config.env_config
        self.np_random = np.random.RandomState()"initialized trainable push groups",
                     for group in self.env_config.trainable_push_groups

        self.observation_space = get_observation_space()
        self.cached_urls = config.cached_urls or set()
        self.analyzer = Analyzer(self.config, config.reward_func or 0,
                                 config.use_aft or False)

        self.client_environment: Optional[ClientEnvironment] = None
        self.action_space: Optional[ActionSpace] = None
        self.policy: Optional[Policy] = None
            or client.get_random_fast_lte_client_environment(),

    def seed(self, seed=None):

    def reset(self):
        return self.observation

    def initialize_environment(self,
                               client_environment: ClientEnvironment,
                               cached_urls: Optional[Set[str]] = None):
        """ Initialize the environment """
            "initialized environment",
        # Cache scenarios in hours
        scenarios = [0, 0, 0, 0, 0, 1, 2, 4, 12, 24]
        cache_time = self.np_random.choice(scenarios)
        self.cached_urls = (cached_urls if cached_urls is not None else
                            set() if cache_time == 0 else set(
                                for group in self.env_config.push_groups
                                for res in group.resources
                                if res.cache_time >= (cache_time * 60 * 60)))

        self.client_environment = client_environment
        self.analyzer.reset(self.client_environment, self.cached_urls)

        num_domains_deployed = math.ceil(PROPORTION_DEPLOYED *
        push_groups = sorted(self.env_config.push_groups,
                             key=lambda g: len(g.resources),

        self.action_space = ActionSpace(push_groups)
        self.policy = Policy(self.action_space)

    def step(self, action: ActionIDType):
        # decode the action and apply it to the policy
        decoded_action = self.action_space.decode_action(action)
        action_applied = self.policy.apply_action(decoded_action)

        # make sure the action isn't used again"trying action",

        reward = NOOP_ACTION_REWARD
        if action_applied:
            reward = self.analyzer.get_reward(self.policy)
  "got reward", action=repr(decoded_action), reward=reward)

        info = {"action": decoded_action, "policy": self.policy.as_dict}
        return self.observation, reward, not action_applied, info

    def render(self, mode="human"):
        return super(Environment, self).render(mode=mode)

    def observation(self):
        """ Returns an observation for the current state of the environment """
        return get_observation(self.client_environment,
                               self.env_config.push_groups, self.policy,