def test_observation_with_nonempty_policy_with_default_actions(self): # use all push groups except the chosen default group candidate_push_groups = [ i for i, group in enumerate(self.push_groups) if len(group.resources) > 2 and not group.trainable ] default_group_idx = random.choice(candidate_push_groups) default_group = self.push_groups[default_group_idx] remaining_groups = [ group for i, group in enumerate(self.push_groups) if i != default_group_idx ] action_space = ActionSpace(remaining_groups) policy = Policy(action_space) # apply some default action for push in default_group.resources[1:]: policy.add_default_push_action(default_group.resources[0], push) # do some actions and check the observation space over time for _ in range(len(action_space) - 1): # get an action and apply it in the policy action_id = action_space.sample() policy.apply_action(action_id) # get the observation obs = get_observation(self.client_environment, self.push_groups, policy, set()) assert self.observation_space.contains(obs) # make sure the push sources are recorded correctly for (source, push) in policy.observable_push: for push_res in push: # +1 since we have defined it that way assert obs["resources"][str( push_res.order)][-2] == source.source_id + 1 # make sure the push sources are recorded correctly for (source, preload) in policy.observable_preload: for push_res in preload: # +1 since we have defined it that way assert obs["resources"][str( push_res.order)][-1] == source.order + 1 # check that all other resources are not pushed pushed_res = set(push_res.order for (source, push) in policy.observable_push for push_res in push) preloaded_res = set(push_res.order for (source, push) in policy.observable_preload for push_res in push) assert all(res[-2] == 0 for order, res in obs["resources"].items() if int(order) not in pushed_res) assert all(res[-1] == 0 for order, res in obs["resources"].items() if int(order) not in preloaded_res)
class TestMahiMahiConfig: def setup(self): self.config = get_config() self.action_space = ActionSpace(get_push_groups()) self.client_environment = get_random_client_environment() self.policy = Policy(self.action_space) applied = True while applied: applied = self.policy.apply_action(self.action_space.sample()) def test_init_without_policy(self): mm_config = MahiMahiConfig(self.config) assert isinstance(mm_config, MahiMahiConfig) assert mm_config.policy is None assert mm_config.client_environment is None def test_init_without_client_environment(self): mm_config = MahiMahiConfig(self.config, policy=self.policy) assert isinstance(mm_config, MahiMahiConfig) assert mm_config.policy is self.policy assert mm_config.client_environment is None def test_init_with_client_environment(self): mm_config = MahiMahiConfig(self.config, policy=self.policy, client_environment=self.client_environment) assert isinstance(mm_config, MahiMahiConfig) assert mm_config.policy is self.policy assert mm_config.client_environment is self.client_environment def test_record_shell_with_cmd(self): save_dir = "/tmp/save_dir" mm_config = MahiMahiConfig(self.config, policy=self.policy) cmd = mm_config.record_shell_with_cmd(save_dir, ["a", "command"]) assert cmd == (mm_config.record_cmd(save_dir) + ["a", "command"]) def test_record_cmd(self): save_dir = "/tmp/save_dir" mm_config = MahiMahiConfig(self.config) record_cmd = mm_config.record_cmd(save_dir) assert record_cmd[0] == "mm-webrecord" assert record_cmd[1] == save_dir def test_formatted_trace_file(self): mm_config = MahiMahiConfig(self.config, policy=self.policy, client_environment=self.client_environment) trace_lines = trace_for_kbps(self.client_environment.bandwidth) formatted = format_trace_lines(trace_lines) assert mm_config.formatted_trace_file == formatted
def test_observation_with_nonempty_policy(self): action_space = ActionSpace(self.push_groups) policy = Policy(action_space) # do some actions and check the observation space over time for _ in range(len(action_space) - 1): # get an action and apply it in the policy action_id = action_space.sample() policy.apply_action(action_id) # get the observation obs = get_observation(self.client_environment, self.push_groups, policy, set()) assert self.observation_space.contains(obs) # make sure the push sources are recorded correctly for (source, push) in policy.push: for push_res in push: # +1 since we have defined it that way assert obs["resources"][str( push_res.order)][-2] == source.source_id + 1 # make sure the push sources are recorded correctly for (source, preload) in policy.preload: for push_res in preload: # +1 since we have defined it that way assert obs["resources"][str( push_res.order)][-1] == source.order + 1 # check that all other resources are not pushed pushed_res = set(push_res.order for (source, push) in policy.push for push_res in push) preloaded_res = set(push_res.order for (source, push) in policy.preload for push_res in push) assert all(res[-2] == 0 for order, res in obs["resources"].items() if int(order) not in pushed_res) assert all(res[-1] == 0 for order, res in obs["resources"].items() if int(order) not in preloaded_res)
def get_action(action_space: ActionSpace) -> ActionIDType: # pick a non-noop action action = NOOP_ACTION_ID while action_space.decode_action(action).is_noop: action = action_space.sample() return action