def test_view_manifest_only_trainable(self): json = get_har_json() har = har_from_json(json) res_list = har_entries_to_resources(har) push_groups = resource_list_to_push_groups( res_list, train_domain_globs=["*reddit.com"]) config = EnvironmentConfig(replay_dir="", request_url="https://www.reddit.com/", push_groups=push_groups, har_resources=res_list) with mock.patch("builtins.print") as mock_print: with tempfile.NamedTemporaryFile() as config_file: config.save_file(config_file.name) view_manifest(["--trainable", config_file.name]) assert mock_print.call_count > 5 printed_text = "\n".join(call[0][0] for call in mock_print.call_args_list if call[0]) assert config.replay_dir in printed_text assert config.request_url in printed_text assert all(group.name in printed_text for group in config.push_groups if group.trainable) pre_graph_text = printed_text.split("Execution Graph")[0] assert not any(group.name in pre_graph_text for group in config.push_groups if not group.trainable)
def test_runs_successfully(self, mock_capture_har_in_mahimahi): hars = [generate_har() for _ in range(STABLE_SET_NUM_RUNS + 1)] har_resources = har_entries_to_resources(hars[0]) mock_capture_har_in_mahimahi.return_value = hars[0] with tempfile.NamedTemporaryFile() as output_file: with tempfile.TemporaryDirectory() as output_dir: with mock.patch( "blaze.preprocess.record.capture_har_in_replay_server", new=HarReturner(hars)): preprocess([ "https://cs.ucla.edu", "--output", output_file.name, "--record_dir", output_dir ]) config = EnvironmentConfig.load_file(output_file.name) assert config.replay_dir == output_dir assert config.request_url == "https://cs.ucla.edu" assert config.push_groups # since we passed cs.ucla.edu as URL, nothing should be trainable assert all(not group.trainable for group in config.push_groups) assert config.har_resources == har_resources client_env = get_default_client_environment() config = get_config( EnvironmentConfig(replay_dir=output_dir, request_url="https://cs.ucla.edu")) assert mock_capture_har_in_mahimahi.call_count == 1 mock_capture_har_in_mahimahi.assert_called_with( "https://cs.ucla.edu", config, client_env)
def update_manifest(args): """ Update the manifest file with the specified changes. This will replace the manifest file with the update version, unless --save_as is specified, in which case it will create a new manifest file in that location and leave the original one unmodified. """ save_as = args.save_as or args.manifest_file log.info( "Updating manifest", manifest_file=args.manifest_file, replay_dir_path_prefix=args.replay_dir_path_prefix, replay_dir_folder_name=args.replay_dir_folder_name, save_as=save_as, ) env_config = EnvironmentConfig.load_file(args.manifest_file) new_replay_dir = env_config.replay_dir if args.replay_dir_path_prefix: new_replay_dir = os.path.join(args.replay_dir_path_prefix, os.path.basename(new_replay_dir)) if args.replay_dir_folder_name: new_replay_dir = os.path.join(os.path.dirname(new_replay_dir), args.replay_dir_folder_name) new_env_config = env_config._replace(replay_dir=new_replay_dir) new_env_config.save_file(save_as)
def view_manifest(args): """ View the prepared manifest from `blaze preprocess` """ log.info("loading manifest", manifest_file=args.manifest_file) env_config = EnvironmentConfig.load_file(args.manifest_file) print("[[ Request URL ]]\n{}\n".format(env_config.request_url)) print("[[ Replay Dir ]]\n{}\n".format(env_config.replay_dir)) print("[[ Trainable Groups ]]\n{}\n".format("\n".join( group.name for group in env_config.trainable_push_groups))) print("[[ {}Push Groups ]]".format("Trainable " if args.trainable else "")) for group in env_config.push_groups: if args.trainable and not group.trainable: continue print(" [{id}: {name} ({num} resources)]".format( id=group.id, name=group.name, num=len(group.resources))) for res in group.resources: url = Url.parse(res.url).resource if len(url) > 64: url = url[:61] + "..." print( " {order:<3} {url:<64} {type:<6} {size:>8} B cache: {cache}s {crit}" .format( order=res.order, url=url, type=res.type.name, size=res.size, cache=res.cache_time, crit="critical" if res.critical else "", )) print() print("[[ Execution Graph ]]") sim = Simulator(env_config) sim.print_execution_map()
def test_push(args): """ Runs a pre-defined test on the given webpage """ if args.policy_type == "all": policy_generator = _push_preload_all_policy_generator() else: weight = 0 if args.policy_type == "preload" else 1 if args.policy_type == "push" else None cached_urls = set() if args.user_data_dir: filestore = FileStore( EnvironmentConfig.load_file(args.from_manifest).replay_dir) for f in filestore.cacheable_files: cached_urls.add(f"http://{f.host}{f.uri}") cached_urls.add(f"https://{f.host}{f.uri}") policy_generator = _random_push_preload_policy_generator( weight, cached_urls) _test_push( manifest=args.from_manifest, iterations=args.iterations, max_retries=args.max_retries, policy_generator=policy_generator, bandwidth=args.bandwidth, latency=args.latency, cpu_slowdown=args.cpu_slowdown, only_simulator=args.only_simulator, speed_index=args.speed_index, cache_time=args.cache_time, user_data_dir=args.user_data_dir, ) return 0
def train(args): """ Trains a model to generate push policies for the given website. This command takes as input the manifest file generated by `blaze preprocess` and outputs a model that can be served. """ # check for ambiguous options if args.resume and args.no_resume: log.error( "invalid options: cannot specify both --resume and --no-resume") sys.exit(1) log.info("starting train", name=args.name, model=args.model) # import specified model if args.model == "A3C": from blaze.model import a3c as model if args.model == "APEX": from blaze.model import apex as model if args.model == "PPO": from blaze.model import ppo as model # compute resume flag and initialize training resume = False if args.no_resume else True if args.resume else "prompt" train_config = TrainConfig(experiment_name=args.name, num_workers=args.workers, resume=resume) env_config = EnvironmentConfig.load_file(args.manifest_file) config = get_config(env_config, reward_func=args.reward_func, use_aft=args.use_aft) model.train(train_config, config)
def preprocess(args): """ Preprocesses a website for training. Automatically discovers linked pages up to a certain depth and finds the stable set of page dependencies. The page load is recorded and stored and a training manifest is outputted. """ domain = Url.parse(args.website).domain train_domain_globs = args.train_domain_globs or ["*{}*".format(domain)] log.info("preprocessing website", website=args.website, record_dir=args.record_dir, train_domain_globs=train_domain_globs) config = get_config(env_config=EnvironmentConfig( replay_dir=args.record_dir, request_url=args.website)) client_env = get_default_client_environment() log.debug("using configuration", **config._asdict()) log.info("capturing execution") har_resources = har_entries_to_resources( capture_har_in_replay_server(args.website, config, client_env)) log.info("finding dependency stable set...") res_list = find_url_stable_set(args.website, config) log.info("found total dependencies", total=len(res_list)) push_groups = resource_list_to_push_groups( res_list, train_domain_globs=train_domain_globs) if args.extract_critical_requests: log.info("extracting critical requests") push_groups = annotate_critical_requests(args.website, config, client_env, push_groups) critical_resources = set(res.url for group in push_groups for res in group.resources if res.critical) log.debug("critical resources", resources=critical_resources) log.info("finding cacheable objects") push_groups = annotate_cacheable_objects(args.record_dir, push_groups) log.info("generating configuration...") env_config = EnvironmentConfig(replay_dir=args.record_dir, request_url=args.website, push_groups=push_groups, har_resources=har_resources) env_config.save_file(args.output) log.info("successfully prepared website for training", output=args.output)
def _test_push( *, manifest: str, iterations: Optional[int] = 1, max_retries: Optional[int] = 0, policy_generator: Callable[[EnvironmentConfig], Policy], bandwidth: Optional[int], latency: Optional[int], cpu_slowdown: Optional[int], only_simulator: Optional[bool], speed_index: Optional[bool], cache_time: Optional[int], user_data_dir: Optional[str], ): env_config = EnvironmentConfig.load_file(manifest) default_client_env = get_default_client_environment() client_env = get_client_environment_from_parameters( bandwidth or default_client_env.bandwidth, latency or default_client_env.latency, cpu_slowdown or default_client_env.cpu_slowdown, ) data = { "client_env": client_env._asdict(), "url": env_config.request_url, "cache": "warm" if user_data_dir else "cold", "metric": "speed_index" if speed_index else "plt", "cache_time": cache_time, } if not only_simulator: config = get_config(env_config) plt, push_plts, policies = _get_results_in_replay_server( config, client_env, iterations, max_retries, policy_generator, cache_time, user_data_dir, speed_index) data["replay_server"] = { "without_policy": plt, "with_policy": [{ "plt": plt, "policy": policy.as_dict } for (plt, policy) in zip(push_plts, policies)], } else: policies = [policy_generator(env_config) for _ in range(iterations)] sim = Simulator(env_config) data["simulator"] = { "without_policy": sim.simulate_load_time(client_env), "with_policy": [{ "plt": sim.simulate_load_time(client_env, policy), "policy": policy.as_dict } for policy in policies], } print(json.dumps(data, indent=4))
def get_env_config() -> EnvironmentConfig: return EnvironmentConfig( replay_dir="/tmp/replay_dir", request_url="http://example.com/", push_groups=get_push_groups(), har_resources=sorted( [res for group in get_push_groups() for res in group.resources], key=lambda r: r.order), )
def random_push_policy(args): """ Outputs a random push policy for the given recorded website """ log.info("generating a random policy", policy_type=args.policy_type) env_config = EnvironmentConfig.load_file(args.from_manifest) weight = 0 if args.policy_type == "preload" else 1 if args.policy_type == "push" else None policy = _random_push_preload_policy_generator(weight)(env_config) print(json.dumps(policy.as_dict, indent=4))
def test_raises_on_no_replay_dir(self): config = _get_config() with pytest.raises(ValueError): capture_har_in_replay_server("https://www.cs.ucla.edu", config, self.client_env) config = _get_config( EnvironmentConfig(request_url="https://www.cs.ucla.edu", replay_dir="")) with pytest.raises(ValueError): capture_har_in_replay_server("https://www.cs.ucla.edu", config, self.client_env)
def get_policy(self, url: str, client_env: ClientEnvironment, manifest: EnvironmentConfig) -> Policy: """ Queries the policy service for a push policy for the given configuration """ page = policy_service_pb2.Page( url=url, bandwidth_kbps=client_env.bandwidth, latency_ms=client_env.latency, cpu_slowdown=client_env.cpu_slowdown, manifest=manifest.serialize(), ) policy_res = self.stub.GetPolicy(page) return Policy.from_dict(json.loads(policy_res.policy))
def test_pickle(self): c = get_env_config() with tempfile.NamedTemporaryFile() as tmp_file: c.save_file(tmp_file.name) loaded_c = EnvironmentConfig.load_file(tmp_file.name) assert c.request_url == loaded_c.request_url assert c.replay_dir == loaded_c.replay_dir assert len(c.push_groups) == len(loaded_c.push_groups) for i, group in enumerate(c.push_groups): assert loaded_c.push_groups[i].name == group.name assert len(loaded_c.push_groups[i].resources) == len(group.resources) for j, res in enumerate(group.resources): assert loaded_c.push_groups[i].resources[j] == res
def test_simulator(self): har_json = get_har_json() har = har_from_json(har_json) res_list = har_entries_to_resources(har) push_groups = resource_list_to_push_groups(res_list) env_config = EnvironmentConfig(replay_dir="", request_url="https://www.reddit.com/", push_groups=push_groups, har_resources=res_list) client_env = get_fast_mobile_client_environment() simulator = Simulator(env_config) time_ms = simulator.simulate_load_time(client_env) assert time_ms > 0
def test_view_manifest(self): har = har_from_json(get_har_json()) res_list = har_entries_to_resources(har) push_groups = resource_list_to_push_groups(res_list) config = EnvironmentConfig(replay_dir="", request_url="https://www.reddit.com/", push_groups=push_groups, har_resources=res_list) with mock.patch("builtins.print") as mock_print: with tempfile.NamedTemporaryFile() as config_file: config.save_file(config_file.name) view_manifest([config_file.name]) assert mock_print.call_count > 5 printed_text = "\n".join(call[0][0] for call in mock_print.call_args_list if call[0]) assert config.replay_dir in printed_text assert config.request_url in printed_text assert all(group.name in printed_text for group in config.push_groups) assert all( Url.parse(res.url).resource[:61] in printed_text for group in config.push_groups for res in group.resources)
def test_calls_capture_har_with_correct_arguments(self, mock_run, mock_open): mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0) config = _get_config( EnvironmentConfig(request_url="https://www.cs.ucla.edu", replay_dir="/tmp/dir")) har = capture_har_in_replay_server("https://www.cs.ucla.edu", config, self.client_env) run_args = mock_run.call_args_list[0][0][0] assert run_args[0] == "docker" assert run_args[-1] == "https://www.cs.ucla.edu" assert har == self.har
def test_writes_mahimahi_files_correctly(self, mock_run, mock_open, mock_tmpdir): tmp_dir = "/tmp/blaze_test_123" mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0) mock_tmpdir.return_value.__enter__.return_value = tmp_dir config = _get_config( EnvironmentConfig(request_url="https://www.cs.ucla.edu", replay_dir=tmp_dir)) capture_har_in_replay_server("https://www.cs.ucla.edu", config, self.client_env) assert mock_open.call_args_list[0][0][0].startswith(tmp_dir) assert mock_open.call_args_list[1][0][0].startswith(tmp_dir) assert mock_open.call_args_list[0][0][1] == "w" assert mock_open.call_args_list[1][0][1] == "w"
def query(args): """ Queries a trained model that is served on a gRPC server. """ log.info("querying server...", host=args.host, port=args.port) channel = grpc.insecure_channel(f"{args.host}:{args.port}") client = Client(channel) manifest = EnvironmentConfig.load_file(args.manifest) client_env = get_client_environment_from_parameters( args.bandwidth, args.latency, args.cpu_slowdown) policy = client.get_policy(url=manifest.request_url, client_env=client_env, manifest=manifest) print(json.dumps(policy.as_dict, indent=4))
def test_compiles(self): c = EnvironmentConfig(replay_dir="/replay/dir", request_url="http://example.com", push_groups=[]) assert isinstance(c, EnvironmentConfig)
def read_file(fpath): log.debug("reading file...", file=fpath) return EnvironmentConfig.load_file(fpath)
def page_load_time(args): """ Captures a webpage and calculates the median page load time for a given website in a fast, no-latency Mahimahi shell. Then simulates the load based on profiling the page in the same Mahimahi shell. """ # Validate the arguments if args.latency is not None and args.latency < 0: log.critical("provided latency must be greater or equal to 0") sys.exit(1) if args.bandwidth is not None and args.bandwidth <= 0: log.critical("provided bandwidth must be greater than 0") sys.exit(1) if args.cpu_slowdown is not None and args.cpu_slowdown not in {1, 2, 4}: log.critical("provided cpu slodown must be 1, 2, or 4") sys.exit(1) # Setup the client environment default_client_env = get_default_client_environment() client_env = get_client_environment_from_parameters( args.bandwidth or default_client_env.bandwidth, args.latency or default_client_env.latency, args.cpu_slowdown or default_client_env.cpu_slowdown, ) # If a push/preload policy was specified, read it policy = None if args.policy: log.debug("reading policy", push_policy=args.policy) with open(args.policy, "r") as policy_file: policy_dict = json.load(policy_file) policy = Policy.from_dict(policy_dict) env_config = EnvironmentConfig.load_file(args.from_manifest) config = get_config(env_config) log.info("calculating page load time", manifest=args.from_manifest, url=env_config.request_url) plt, orig_plt = 0, 0 if not args.only_simulator: if not args.speed_index: orig_plt, *_ = get_page_load_time_in_replay_server( request_url=config.env_config.request_url, client_env=client_env, config=config, cache_time=args.cache_time, user_data_dir=args.user_data_dir, ) if policy: plt, *_ = get_page_load_time_in_replay_server( request_url=config.env_config.request_url, client_env=client_env, config=config, policy=policy, cache_time=args.cache_time, user_data_dir=args.user_data_dir, ) else: orig_plt = get_speed_index_in_replay_server( request_url=config.env_config.request_url, client_env=client_env, config=config, cache_time=args.cache_time, user_data_dir=args.user_data_dir, ) if policy: plt = get_speed_index_in_replay_server( request_url=config.env_config.request_url, client_env=client_env, config=config, policy=policy, cache_time=args.cache_time, user_data_dir=args.user_data_dir, ) log.debug("running simulator...") sim = Simulator(env_config) orig_sim_plt = sim.simulate_load_time(client_env) sim_plt = sim.simulate_load_time(client_env, policy) print( json.dumps( { "client_env": client_env._asdict(), "metric": "speed_index" if args.speed_index else "plt", "cache": "warm" if args.user_data_dir else "cold", "cache_time": args.cache_time, "replay_server": { "with_policy": plt, "without_policy": orig_plt }, "simulator": { "with_policy": sim_plt, "without_policy": orig_sim_plt }, }, indent=4, ))
def evaluate(args): """ Instantiate the given model and checkpoint and query it for the policy corresponding to the given client and network conditions. Also allows running the generated policy through the simulator and replay server to get the PLTs and compare them under different conditions. """ log.info("evaluating model...", model=args.model, location=args.location, manifest=args.manifest) client_env = get_client_environment_from_parameters( args.bandwidth, args.latency, args.cpu_slowdown) manifest = EnvironmentConfig.load_file(args.manifest) cached_urls = set( res.url for group in manifest.push_groups for res in group.resources if args.cache_time is not None and res.cache_time > args.cache_time) log.debug("using cached resources", cached_urls=cached_urls) config = get_config(manifest, client_env, args.reward_func).with_mutations( cached_urls=cached_urls, use_aft=args.use_aft) if args.model == "A3C": from blaze.model import a3c as model if args.model == "APEX": from blaze.model import apex as model if args.model == "PPO": from blaze.model import ppo as model import ray ray.init(num_cpus=2, log_to_driver=False) saved_model = model.get_model(args.location) instance = saved_model.instantiate(config) policy = instance.policy data = policy.as_dict if args.verbose or args.run_simulator or args.run_replay_server: data = { "manifest": args.manifest, "location": args.location, "client_env": client_env._asdict(), "policy": policy.as_dict, } if args.run_simulator: sim = Simulator(manifest) sim_plt = sim.simulate_load_time(client_env) push_plt = sim.simulate_load_time(client_env, policy) data["simulator"] = { "without_policy": sim_plt, "with_policy": push_plt } if args.run_replay_server: *_, plts = get_page_load_time_in_replay_server( config.env_config.request_url, client_env, config) *_, push_plts = get_page_load_time_in_replay_server( config.env_config.request_url, client_env, config, policy=policy) data["replay_server"] = { "without_policy": plts, "with_policy": push_plts } print(json.dumps(data, indent=4))