def setup(self): self.config = get_config() self.action_space = ActionSpace(get_push_groups()) self.client_environment = get_random_client_environment() self.policy = Policy(self.action_space) applied = True while applied: applied = self.policy.apply_action(self.action_space.sample())
def test_apply_action_noop_as_first_action(self): action_space = get_action_space() policy = Policy(action_space) applied = policy.apply_action(Action()) assert not applied # check that action was not applied assert not list( policy.push) # check that no URLs were added to the policy assert not list( policy.preload) # check that no URLs were added to the policy assert len(policy) == 0 # check that the policy length == 0
def test_apply_push_action(self): action_space = get_action_space() policy = Policy(action_space) action = action_space.decode_action((1, (0, 0, 1), (0, 0))) applied = policy.apply_action((1, (0, 0, 1), (0, 0))) output_policy = list(policy.push) assert applied assert len(policy) == 1 assert len(output_policy) == 1 assert len(output_policy[0][1]) == 1 assert output_policy[0][0] == action.source assert output_policy[0][1] == {action.push}
def test_apply_action_noop_as_second_action(self): action_space = get_action_space() policy = Policy(action_space) applied = policy.apply_action((1, (0, 0, 1), (0, 0))) output_policy = list(policy.push) assert applied assert output_policy assert len(policy) == 1 applied = policy.apply_action(Action()) assert not applied assert output_policy == list(policy.push) assert len(policy) == 1
def replay(args): """ Starts a replay environment for the given replay directory, including setting up interfaces, running a DNS server, and configuring and running an nginx server to serve the requests """ policy = None cert_path = os.path.abspath(args.cert_path) if args.cert_path else None key_path = os.path.abspath(args.key_path) if args.key_path else None per_resource_latency = os.path.abspath( args.per_resource_latency) if args.per_resource_latency else None if args.policy: log.debug("reading policy", push_policy=args.policy) with open(args.policy, "r") as policy_file: policy_dict = json.load(policy_file) policy = Policy.from_dict(policy_dict) # handle sigterm gracefully signal.signal(signal.SIGTERM, sigterm_handler) with start_server(args.replay_dir, cert_path, key_path, policy, per_resource_latency, cache_time=args.cache_time, extract_critical_requests=args.extract_critical_requests, enable_http2=args.enable_http2): while True: time.sleep(86400)
def replay(args): """ Starts a replay environment for the given replay directory, including setting up interfaces, running a DNS server, and configuring and running an nginx server to serve the requests """ policy = None cert_path = os.path.abspath(args.cert_path) if args.cert_path else None key_path = os.path.abspath(args.key_path) if args.key_path else None if args.policy: log.debug("reading policy", push_policy=args.policy) with open(args.policy, "r") as policy_file: policy_dict = json.load(policy_file) policy = Policy.from_dict(policy_dict) with start_server( args.replay_dir, cert_path, key_path, policy, cache_time=args.cache_time, extract_critical_requests=args.extract_critical_requests, ): while True: time.sleep(86400)
def test_apply_push_action_same_source_resource(self): action_space = get_action_space() policy = Policy(action_space) action_1 = action_space.decode_action((1, (0, 0, 1), (0, 0))) action_2 = action_space.decode_action((1, (0, 0, 2), (0, 0))) assert policy.apply_action(action_1) assert policy.apply_action(action_2) output_policy = list(policy.push) assert len(policy) == 2 assert len(output_policy) == 1 assert len(output_policy[0][1]) == 2 assert output_policy[0][0] == action_1.source assert output_policy[0][0] == action_2.source assert output_policy[0][1] == {action_1.push, action_2.push}
def test_as_dict(self): action_space = get_action_space() policy = Policy(action_space) for _ in range(10): action = action_space.decode_action(action_space.sample()) policy.apply_action(action) action_space.use_action(action) policy_dict = policy.as_dict for (source, push) in policy.push: assert all( p.url in [pp["url"] for pp in policy_dict["push"][source.url]] for p in push) for (source, preload) in policy.preload: assert all( p.url in [pp["url"] for pp in policy_dict["preload"][source.url]] for p in preload)
def test_create_push_policy(self): ps = PolicyService(self.saved_model) policy = ps.create_policy(self.page) assert policy assert isinstance(policy, policy_service_pb2.Policy) push_pairs = convert_push_groups_to_push_pairs(self.push_groups) push_pairs = [(s.url, p.url) for (s, p) in push_pairs] p = Policy.from_dict(json.loads(policy.policy)) for (source, push_res) in p.push: for push in push_res: assert (source.url, push.url) in push_pairs
def test_apply_multiple_actions(self): action_space = get_action_space() policy = Policy(action_space) actions = [] while len(policy) < 10: action_id = action_space.sample() action = action_space.decode_action(action_id) action_applied = policy.apply_action(action) action_space.use_action(action) if action_applied: actions.append(action) action_space.use_action(action) for action in actions: print(action) assert any(action.source == source and action.push == push for source, res in policy.push for push in res) or any( action.source == source and action.push == push for source, res in policy.preload for push in res)
def test_from_dict(self): policy_dict = { "push": { "A": [{ "url": "B", "type": "SCRIPT" }, { "url": "C", "type": "IMAGE" }] }, "preload": { "B": [{ "url": "D", "type": "IMAGE" }], "G": [{ "url": "E", "type": "CSS" }, { "url": "F", "type": "FONT" }], }, } policy = Policy.from_dict(policy_dict) assert policy.total_actions == 0 assert not policy.action_space for (source, deps) in policy.push: assert isinstance(source, Resource) assert all(isinstance(push, Resource) for push in deps) assert [p["url"] for p in policy_dict["push"][source.url] ] == sorted([push.url for push in deps]) for push in deps: assert policy.push_to_source[push] == source assert push.url in [ p["url"] for p in policy_dict["push"][source.url] ] for (source, deps) in policy.preload: assert isinstance(source, Resource) assert all(isinstance(push, Resource) for push in deps) assert [p["url"] for p in policy_dict["preload"][source.url] ] == sorted([push.url for push in deps]) for push in deps: assert policy.preload_to_source[push] == source assert push.url in [ p["url"] for p in policy_dict["preload"][source.url] ] assert len(policy.source_to_push) == len(policy_dict["push"]) assert len(policy.source_to_preload) == len(policy_dict["preload"])
def test_observation_with_nonempty_policy(self): action_space = ActionSpace(self.push_groups) policy = Policy(action_space) # do some actions and check the observation space over time for _ in range(len(action_space) - 1): # get an action and apply it in the policy action_id = action_space.sample() policy.apply_action(action_id) # get the observation obs = get_observation(self.client_environment, self.push_groups, policy, set()) assert self.observation_space.contains(obs) # make sure the push sources are recorded correctly for (source, push) in policy.push: for push_res in push: # +1 since we have defined it that way assert obs["resources"][str( push_res.order)][-2] == source.source_id + 1 # make sure the push sources are recorded correctly for (source, preload) in policy.preload: for push_res in preload: # +1 since we have defined it that way assert obs["resources"][str( push_res.order)][-1] == source.order + 1 # check that all other resources are not pushed pushed_res = set(push_res.order for (source, push) in policy.push for push_res in push) preloaded_res = set(push_res.order for (source, push) in policy.preload for push_res in push) assert all(res[-2] == 0 for order, res in obs["resources"].items() if int(order) not in pushed_res) assert all(res[-1] == 0 for order, res in obs["resources"].items() if int(order) not in preloaded_res)
def get_policy(self, url: str, client_env: ClientEnvironment, manifest: EnvironmentConfig) -> Policy: """ Queries the policy service for a push policy for the given configuration """ page = policy_service_pb2.Page( url=url, bandwidth_kbps=client_env.bandwidth, latency_ms=client_env.latency, cpu_slowdown=client_env.cpu_slowdown, manifest=manifest.serialize(), ) policy_res = self.stub.GetPolicy(page) return Policy.from_dict(json.loads(policy_res.policy))
def test_resource_push_from(self): action_space = get_action_space() policy = Policy(action_space) action = Action() while action.is_noop or not action.is_push: action = action_space.decode_action(action_space.sample()) assert policy.resource_pushed_from(action.push) is None assert policy.apply_action(action) assert policy.resource_pushed_from(action.push) is action.source while action.is_noop or action.is_push: action = action_space.decode_action(action_space.sample()) assert policy.resource_preloaded_from(action.push) is None assert policy.apply_action(action) assert policy.resource_preloaded_from(action.push) is action.source
def initialize_environment(self, client_environment: ClientEnvironment, cached_urls: Optional[Set[str]] = None): """ Initialize the environment """ log.info( "initialized environment", network_type=client.NetworkType(client_environment.network_type), network_speed=client.NetworkSpeed( client_environment.network_speed), device_speed=client.DeviceSpeed(client_environment.device_speed), bandwidth=client_environment.bandwidth, latency=client_environment.latency, cpu_slowdown=client_environment.cpu_slowdown, loss=client_environment.loss, reward_func=self.analyzer.reward_func_num, cached_urls=cached_urls, ) # Cache scenarios in hours scenarios = [0, 0, 0, 0, 0, 1, 2, 4, 12, 24] cache_time = self.np_random.choice(scenarios) self.cached_urls = (cached_urls if cached_urls is not None else set() if cache_time == 0 else set( res.url for group in self.env_config.push_groups for res in group.resources if res.cache_time >= (cache_time * 60 * 60))) self.client_environment = client_environment self.analyzer.reset(self.client_environment, self.cached_urls) num_domains_deployed = math.ceil(PROPORTION_DEPLOYED * len(self.env_config.push_groups)) push_groups = sorted(self.env_config.push_groups, key=lambda g: len(g.resources), reverse=True)[:num_domains_deployed] self.action_space = ActionSpace(push_groups) self.policy = Policy(self.action_space)
def test_push_preload_list_for_source(self): action_space = get_action_space() policy = Policy(action_space) push_map = collections.defaultdict(set) preload_map = collections.defaultdict(set) while len(policy) < 10: action_id = action_space.sample() action = action_space.decode_action(action_id) action_applied = policy.apply_action(action) if action_applied: if action.is_push: push_map[action.source].add(action.push) if action.is_preload: preload_map[action.source].add(action.push) assert push_map assert preload_map for (source, push_set) in push_map.items(): assert policy.push_set_for_resource(source) == push_set for (source, preload_set) in preload_map.items(): assert policy.preload_set_for_resource(source) == preload_set
def test_observation_with_cached_urls(self): action_space = ActionSpace(self.push_groups) policy = Policy(action_space) resources = [ res for group in self.push_groups for res in group.resources ] mask = [random.randint(0, 2) for _ in range(len(resources))] cached = [res for (res, include) in zip(resources, mask) if include] cached_urls = set(res.url for res in cached) obs = get_observation(self.client_environment, self.push_groups, policy, cached_urls) for res in cached: assert obs["resources"][str(res.order)][1] == 1
def test_observation_with_nonempty_policy_with_default_actions(self): # use all push groups except the chosen default group candidate_push_groups = [ i for i, group in enumerate(self.push_groups) if len(group.resources) > 2 and not group.trainable ] default_group_idx = random.choice(candidate_push_groups) default_group = self.push_groups[default_group_idx] remaining_groups = [ group for i, group in enumerate(self.push_groups) if i != default_group_idx ] action_space = ActionSpace(remaining_groups) policy = Policy(action_space) # apply some default action for push in default_group.resources[1:]: policy.add_default_push_action(default_group.resources[0], push) # do some actions and check the observation space over time for _ in range(len(action_space) - 1): # get an action and apply it in the policy action_id = action_space.sample() policy.apply_action(action_id) # get the observation obs = get_observation(self.client_environment, self.push_groups, policy, set()) assert self.observation_space.contains(obs) # make sure the push sources are recorded correctly for (source, push) in policy.observable_push: for push_res in push: # +1 since we have defined it that way assert obs["resources"][str( push_res.order)][-2] == source.source_id + 1 # make sure the push sources are recorded correctly for (source, preload) in policy.observable_preload: for push_res in preload: # +1 since we have defined it that way assert obs["resources"][str( push_res.order)][-1] == source.order + 1 # check that all other resources are not pushed pushed_res = set(push_res.order for (source, push) in policy.observable_push for push_res in push) preloaded_res = set(push_res.order for (source, push) in policy.observable_preload for push_res in push) assert all(res[-2] == 0 for order, res in obs["resources"].items() if int(order) not in pushed_res) assert all(res[-1] == 0 for order, res in obs["resources"].items() if int(order) not in preloaded_res)
class TestMahiMahiConfig: def setup(self): self.config = get_config() self.action_space = ActionSpace(get_push_groups()) self.client_environment = get_random_client_environment() self.policy = Policy(self.action_space) applied = True while applied: applied = self.policy.apply_action(self.action_space.sample()) def test_init_without_policy(self): mm_config = MahiMahiConfig(self.config) assert isinstance(mm_config, MahiMahiConfig) assert mm_config.policy is None assert mm_config.client_environment is None def test_init_without_client_environment(self): mm_config = MahiMahiConfig(self.config, policy=self.policy) assert isinstance(mm_config, MahiMahiConfig) assert mm_config.policy is self.policy assert mm_config.client_environment is None def test_init_with_client_environment(self): mm_config = MahiMahiConfig(self.config, policy=self.policy, client_environment=self.client_environment) assert isinstance(mm_config, MahiMahiConfig) assert mm_config.policy is self.policy assert mm_config.client_environment is self.client_environment def test_record_shell_with_cmd(self): save_dir = "/tmp/save_dir" mm_config = MahiMahiConfig(self.config, policy=self.policy) cmd = mm_config.record_shell_with_cmd(save_dir, ["a", "command"]) assert cmd == (mm_config.record_cmd(save_dir) + ["a", "command"]) def test_record_cmd(self): save_dir = "/tmp/save_dir" mm_config = MahiMahiConfig(self.config) record_cmd = mm_config.record_cmd(save_dir) assert record_cmd[0] == "mm-webrecord" assert record_cmd[1] == save_dir def test_formatted_trace_file(self): mm_config = MahiMahiConfig(self.config, policy=self.policy, client_environment=self.client_environment) trace_lines = trace_for_kbps(self.client_environment.bandwidth) formatted = format_trace_lines(trace_lines) assert mm_config.formatted_trace_file == formatted
def test_get_default_observation(self): action_space = ActionSpace(self.push_groups) policy = Policy(action_space) obs = get_observation(self.client_environment, self.push_groups, policy, set()) assert isinstance(obs, dict) assert self.observation_space.contains(obs) # assert that the client environment is correctly captured assert obs["client"][ "network_type"] == self.client_environment.network_type.value assert obs["client"][ "device_speed"] == self.client_environment.device_speed.value # assert that all resources are not pushed initially assert all(res[-2] == 0 for res in obs["resources"].values()) # assert that all resources are not preloaded initially assert all(res[-1] == 0 for res in obs["resources"].values()) # assert that the push_groups are encoded correctly for group in self.push_groups: for res in group.resources: assert np.array_equal( obs["resources"][str(res.order)], np.array(( 1, # resource is enabled 0, # resource is not cached group.id, # the resource's domain id res. source_id, # the resource's relative offset from its domain top res.order + 1, # the resource's absolute offset from the start of the page load res.initiator + 1, # the resource's initiator res.type.value, # resource type res.size // 1000, # resource size in KB 0, # not pushed 0, # not preloaded )), ) max_order = max(r.order for group in self.push_groups for r in group.resources) for i in range(max_order + 1, MAX_RESOURCES): assert np.array_equal(obs["resources"][str(i)], np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
def _generator(env_config: EnvironmentConfig) -> Policy: push_groups = env_config.push_groups # Collect all resources and group them by type all_resources = sorted( [res for group in push_groups for res in group.resources], key=lambda res: res.order) res_by_type = collections.defaultdict(list) for res in all_resources: # Only consider non-cached objects in the push resource type distribution if res.type in dist and res.url not in cached_urls: res_by_type[res.type].append(res) # choose the number of resources to push/preload total = sum(map(len, res_by_type.values())) if total <= 1: return Policy() n = random.randint(1, total) # choose the weight factor between push and preload weight = push_weight if push_weight is not None else random.random() # Choose n resources based on the resource type distribution without replacement log.debug("generating push-preload policy", num_resources=len(all_resources), total_size=n, push_weight=weight) res = [] for _ in range(n): g, r, s = _choose_with_dist(res_by_type, dist) res_by_type[g].pop(r) res.append(s) policy = Policy() for r in res: if r.source_id == 0 or r.order == 0: continue push = random.random() < weight policy.steps_taken += 1 if push: source = random.randint(0, r.source_id - 1) policy.add_default_push_action( push_groups[r.group_id].resources[source], r) else: source = random.randint(0, r.order - 1) policy.add_default_preload_action(all_resources[source], r) return policy
def _generator(env_config: EnvironmentConfig) -> Policy: push_groups = env_config.push_groups # Collect all resources and group them by type all_resources = sorted( [res for group in push_groups for res in group.resources], key=lambda res: res.order) # choose the weight factor between push and preload main_domain = urllib.parse.urlparse(env_config.request_url) policy = Policy() for r in all_resources: if r.source_id == 0 or r.order == 0: continue request_domain = urllib.parse.urlparse(r.url) push = request_domain.netloc == main_domain.netloc policy.steps_taken += 1 if push: source = random.randint(0, r.source_id - 1) policy.add_default_push_action( push_groups[r.group_id].resources[source], r) else: source = random.randint(0, r.order - 1) policy.add_default_preload_action(all_resources[source], r) return policy
def page_load_time(args): """ Captures a webpage and calculates the median page load time for a given website in a fast, no-latency Mahimahi shell. Then simulates the load based on profiling the page in the same Mahimahi shell. """ # Validate the arguments if args.latency is not None and args.latency < 0: log.critical("provided latency must be greater or equal to 0") sys.exit(1) if args.bandwidth is not None and args.bandwidth <= 0: log.critical("provided bandwidth must be greater than 0") sys.exit(1) if args.cpu_slowdown is not None and args.cpu_slowdown not in {1, 2, 4}: log.critical("provided cpu slodown must be 1, 2, or 4") sys.exit(1) # Setup the client environment default_client_env = get_default_client_environment() client_env = get_client_environment_from_parameters( args.bandwidth or default_client_env.bandwidth, args.latency or default_client_env.latency, args.cpu_slowdown or default_client_env.cpu_slowdown, ) # If a push/preload policy was specified, read it policy = None if args.policy: log.debug("reading policy", push_policy=args.policy) with open(args.policy, "r") as policy_file: policy_dict = json.load(policy_file) policy = Policy.from_dict(policy_dict) env_config = EnvironmentConfig.load_file(args.from_manifest) config = get_config(env_config) log.info("calculating page load time", manifest=args.from_manifest, url=env_config.request_url) plt, orig_plt = 0, 0 if not args.only_simulator: if not args.speed_index: orig_plt, *_ = get_page_load_time_in_replay_server( request_url=config.env_config.request_url, client_env=client_env, config=config, cache_time=args.cache_time, user_data_dir=args.user_data_dir, ) if policy: plt, *_ = get_page_load_time_in_replay_server( request_url=config.env_config.request_url, client_env=client_env, config=config, policy=policy, cache_time=args.cache_time, user_data_dir=args.user_data_dir, ) else: orig_plt = get_speed_index_in_replay_server( request_url=config.env_config.request_url, client_env=client_env, config=config, cache_time=args.cache_time, user_data_dir=args.user_data_dir, ) if policy: plt = get_speed_index_in_replay_server( request_url=config.env_config.request_url, client_env=client_env, config=config, policy=policy, cache_time=args.cache_time, user_data_dir=args.user_data_dir, ) log.debug("running simulator...") sim = Simulator(env_config) orig_sim_plt = sim.simulate_load_time(client_env) sim_plt = sim.simulate_load_time(client_env, policy) print( json.dumps( { "client_env": client_env._asdict(), "metric": "speed_index" if args.speed_index else "plt", "cache": "warm" if args.user_data_dir else "cold", "cache_time": args.cache_time, "replay_server": { "with_policy": plt, "without_policy": orig_plt }, "simulator": { "with_policy": sim_plt, "without_policy": orig_sim_plt }, }, indent=4, ))
def setup(self): self.config = get_config() self.policy = Policy(ActionSpace(self.config.env_config.push_groups)) self.client_environment = get_random_client_environment()
def test_init(self): action_space = get_action_space() policy = Policy(action_space) assert isinstance(policy, Policy) assert policy.action_space == action_space
def get_mahimahi_config() -> MahiMahiConfig: return MahiMahiConfig( config=get_config(), policy=Policy(ActionSpace(get_push_groups())), client_environment=get_random_client_environment(), )
class Environment(gym.Env): """ Environment virtualizes a randomly chosen network and browser environment and facilitates the training for a given web page. This includes action selection, policy generation, and evaluation of the policy/action in the simulated environment. """ def __init__(self, config: Union[Config, dict]): # make sure config is an instance of Config or a dict assert isinstance(config, (Config, dict)) config = config if isinstance(config, Config) else Config(**config) self.config = config self.env_config = config.env_config self.np_random = np.random.RandomState() log.info("initialized trainable push groups", groups=[ group.name for group in self.env_config.trainable_push_groups ]) self.observation_space = get_observation_space() self.cached_urls = config.cached_urls or set() self.analyzer = Analyzer(self.config, config.reward_func or 0, config.use_aft or False) self.client_environment: Optional[ClientEnvironment] = None self.action_space: Optional[ActionSpace] = None self.policy: Optional[Policy] = None self.initialize_environment( self.config.client_env or client.get_random_fast_lte_client_environment(), self.config.cached_urls) def seed(self, seed=None): self.np_random.seed(seed) def reset(self): self.initialize_environment( client.get_random_fast_lte_client_environment(), self.config.cached_urls) return self.observation def initialize_environment(self, client_environment: ClientEnvironment, cached_urls: Optional[Set[str]] = None): """ Initialize the environment """ log.info( "initialized environment", network_type=client.NetworkType(client_environment.network_type), network_speed=client.NetworkSpeed( client_environment.network_speed), device_speed=client.DeviceSpeed(client_environment.device_speed), bandwidth=client_environment.bandwidth, latency=client_environment.latency, cpu_slowdown=client_environment.cpu_slowdown, loss=client_environment.loss, reward_func=self.analyzer.reward_func_num, cached_urls=cached_urls, ) # Cache scenarios in hours scenarios = [0, 0, 0, 0, 0, 1, 2, 4, 12, 24] cache_time = self.np_random.choice(scenarios) self.cached_urls = (cached_urls if cached_urls is not None else set() if cache_time == 0 else set( res.url for group in self.env_config.push_groups for res in group.resources if res.cache_time >= (cache_time * 60 * 60))) self.client_environment = client_environment self.analyzer.reset(self.client_environment, self.cached_urls) num_domains_deployed = math.ceil(PROPORTION_DEPLOYED * len(self.env_config.push_groups)) push_groups = sorted(self.env_config.push_groups, key=lambda g: len(g.resources), reverse=True)[:num_domains_deployed] self.action_space = ActionSpace(push_groups) self.policy = Policy(self.action_space) def step(self, action: ActionIDType): # decode the action and apply it to the policy decoded_action = self.action_space.decode_action(action) action_applied = self.policy.apply_action(decoded_action) # make sure the action isn't used again log.info("trying action", action_id=action, action=repr(decoded_action), steps_taken=self.policy.steps_taken) self.action_space.use_action(decoded_action) reward = NOOP_ACTION_REWARD if action_applied: reward = self.analyzer.get_reward(self.policy) log.info("got reward", action=repr(decoded_action), reward=reward) info = {"action": decoded_action, "policy": self.policy.as_dict} return self.observation, reward, not action_applied, info def render(self, mode="human"): return super(Environment, self).render(mode=mode) @property def observation(self): """ Returns an observation for the current state of the environment """ return get_observation(self.client_environment, self.env_config.push_groups, self.policy, self.cached_urls)