Exemple #1
0
    def test_view_manifest_only_trainable(self):
        json = get_har_json()
        har = har_from_json(json)
        res_list = har_entries_to_resources(har)
        push_groups = resource_list_to_push_groups(
            res_list, train_domain_globs=["*reddit.com"])
        config = EnvironmentConfig(replay_dir="",
                                   request_url="https://www.reddit.com/",
                                   push_groups=push_groups,
                                   har_resources=res_list)
        with mock.patch("builtins.print") as mock_print:
            with tempfile.NamedTemporaryFile() as config_file:
                config.save_file(config_file.name)
                view_manifest(["--trainable", config_file.name])
        assert mock_print.call_count > 5

        printed_text = "\n".join(call[0][0]
                                 for call in mock_print.call_args_list
                                 if call[0])
        assert config.replay_dir in printed_text
        assert config.request_url in printed_text
        assert all(group.name in printed_text for group in config.push_groups
                   if group.trainable)

        pre_graph_text = printed_text.split("Execution Graph")[0]
        assert not any(group.name in pre_graph_text
                       for group in config.push_groups if not group.trainable)
Exemple #2
0
    def test_runs_successfully(self, mock_capture_har_in_mahimahi):
        hars = [generate_har() for _ in range(STABLE_SET_NUM_RUNS + 1)]
        har_resources = har_entries_to_resources(hars[0])
        mock_capture_har_in_mahimahi.return_value = hars[0]
        with tempfile.NamedTemporaryFile() as output_file:
            with tempfile.TemporaryDirectory() as output_dir:
                with mock.patch(
                        "blaze.preprocess.record.capture_har_in_replay_server",
                        new=HarReturner(hars)):
                    preprocess([
                        "https://cs.ucla.edu", "--output", output_file.name,
                        "--record_dir", output_dir
                    ])

                config = EnvironmentConfig.load_file(output_file.name)
                assert config.replay_dir == output_dir
                assert config.request_url == "https://cs.ucla.edu"
                assert config.push_groups
                # since we passed cs.ucla.edu as URL, nothing should be trainable
                assert all(not group.trainable for group in config.push_groups)
                assert config.har_resources == har_resources

        client_env = get_default_client_environment()
        config = get_config(
            EnvironmentConfig(replay_dir=output_dir,
                              request_url="https://cs.ucla.edu"))

        assert mock_capture_har_in_mahimahi.call_count == 1
        mock_capture_har_in_mahimahi.assert_called_with(
            "https://cs.ucla.edu", config, client_env)
Exemple #3
0
def update_manifest(args):
    """
    Update the manifest file with the specified changes. This will replace the manifest file with the
    update version, unless --save_as is specified, in which case it will create a new manifest file in
    that location and leave the original one unmodified.
    """
    save_as = args.save_as or args.manifest_file
    log.info(
        "Updating manifest",
        manifest_file=args.manifest_file,
        replay_dir_path_prefix=args.replay_dir_path_prefix,
        replay_dir_folder_name=args.replay_dir_folder_name,
        save_as=save_as,
    )
    env_config = EnvironmentConfig.load_file(args.manifest_file)

    new_replay_dir = env_config.replay_dir
    if args.replay_dir_path_prefix:
        new_replay_dir = os.path.join(args.replay_dir_path_prefix,
                                      os.path.basename(new_replay_dir))

    if args.replay_dir_folder_name:
        new_replay_dir = os.path.join(os.path.dirname(new_replay_dir),
                                      args.replay_dir_folder_name)

    new_env_config = env_config._replace(replay_dir=new_replay_dir)
    new_env_config.save_file(save_as)
Exemple #4
0
def view_manifest(args):
    """ View the prepared manifest from `blaze preprocess` """
    log.info("loading manifest", manifest_file=args.manifest_file)
    env_config = EnvironmentConfig.load_file(args.manifest_file)

    print("[[ Request URL ]]\n{}\n".format(env_config.request_url))
    print("[[ Replay Dir ]]\n{}\n".format(env_config.replay_dir))
    print("[[ Trainable Groups ]]\n{}\n".format("\n".join(
        group.name for group in env_config.trainable_push_groups)))
    print("[[ {}Push Groups ]]".format("Trainable " if args.trainable else ""))

    for group in env_config.push_groups:
        if args.trainable and not group.trainable:
            continue
        print("  [{id}: {name} ({num} resources)]".format(
            id=group.id, name=group.name, num=len(group.resources)))
        for res in group.resources:
            url = Url.parse(res.url).resource
            if len(url) > 64:
                url = url[:61] + "..."
            print(
                "    {order:<3}  {url:<64}  {type:<6} {size:>8} B  cache: {cache}s  {crit}"
                .format(
                    order=res.order,
                    url=url,
                    type=res.type.name,
                    size=res.size,
                    cache=res.cache_time,
                    crit="critical" if res.critical else "",
                ))
        print()

    print("[[ Execution Graph ]]")
    sim = Simulator(env_config)
    sim.print_execution_map()
Exemple #5
0
def test_push(args):
    """
    Runs a pre-defined test on the given webpage
    """
    if args.policy_type == "all":
        policy_generator = _push_preload_all_policy_generator()
    else:
        weight = 0 if args.policy_type == "preload" else 1 if args.policy_type == "push" else None
        cached_urls = set()
        if args.user_data_dir:
            filestore = FileStore(
                EnvironmentConfig.load_file(args.from_manifest).replay_dir)
            for f in filestore.cacheable_files:
                cached_urls.add(f"http://{f.host}{f.uri}")
                cached_urls.add(f"https://{f.host}{f.uri}")

        policy_generator = _random_push_preload_policy_generator(
            weight, cached_urls)

    _test_push(
        manifest=args.from_manifest,
        iterations=args.iterations,
        max_retries=args.max_retries,
        policy_generator=policy_generator,
        bandwidth=args.bandwidth,
        latency=args.latency,
        cpu_slowdown=args.cpu_slowdown,
        only_simulator=args.only_simulator,
        speed_index=args.speed_index,
        cache_time=args.cache_time,
        user_data_dir=args.user_data_dir,
    )
    return 0
Exemple #6
0
def train(args):
    """
    Trains a model to generate push policies for the given website. This command takes as input the
    manifest file generated by `blaze preprocess` and outputs a model that can be served.
    """
    # check for ambiguous options
    if args.resume and args.no_resume:
        log.error(
            "invalid options: cannot specify both --resume and --no-resume")
        sys.exit(1)

    log.info("starting train", name=args.name, model=args.model)

    # import specified model
    if args.model == "A3C":
        from blaze.model import a3c as model
    if args.model == "APEX":
        from blaze.model import apex as model
    if args.model == "PPO":
        from blaze.model import ppo as model

    # compute resume flag and initialize training
    resume = False if args.no_resume else True if args.resume else "prompt"
    train_config = TrainConfig(experiment_name=args.name,
                               num_workers=args.workers,
                               resume=resume)
    env_config = EnvironmentConfig.load_file(args.manifest_file)
    config = get_config(env_config,
                        reward_func=args.reward_func,
                        use_aft=args.use_aft)
    model.train(train_config, config)
Exemple #7
0
def preprocess(args):
    """
    Preprocesses a website for training. Automatically discovers linked pages up to a certain depth
    and finds the stable set of page dependencies. The page load is recorded and stored and a
    training manifest is outputted.
    """
    domain = Url.parse(args.website).domain
    train_domain_globs = args.train_domain_globs or ["*{}*".format(domain)]
    log.info("preprocessing website",
             website=args.website,
             record_dir=args.record_dir,
             train_domain_globs=train_domain_globs)

    config = get_config(env_config=EnvironmentConfig(
        replay_dir=args.record_dir, request_url=args.website))
    client_env = get_default_client_environment()
    log.debug("using configuration", **config._asdict())

    log.info("capturing execution")
    har_resources = har_entries_to_resources(
        capture_har_in_replay_server(args.website, config, client_env))

    log.info("finding dependency stable set...")
    res_list = find_url_stable_set(args.website, config)

    log.info("found total dependencies", total=len(res_list))
    push_groups = resource_list_to_push_groups(
        res_list, train_domain_globs=train_domain_globs)

    if args.extract_critical_requests:
        log.info("extracting critical requests")
        push_groups = annotate_critical_requests(args.website, config,
                                                 client_env, push_groups)
        critical_resources = set(res.url for group in push_groups
                                 for res in group.resources if res.critical)
        log.debug("critical resources", resources=critical_resources)

    log.info("finding cacheable objects")
    push_groups = annotate_cacheable_objects(args.record_dir, push_groups)

    log.info("generating configuration...")
    env_config = EnvironmentConfig(replay_dir=args.record_dir,
                                   request_url=args.website,
                                   push_groups=push_groups,
                                   har_resources=har_resources)
    env_config.save_file(args.output)
    log.info("successfully prepared website for training", output=args.output)
Exemple #8
0
def _test_push(
    *,
    manifest: str,
    iterations: Optional[int] = 1,
    max_retries: Optional[int] = 0,
    policy_generator: Callable[[EnvironmentConfig], Policy],
    bandwidth: Optional[int],
    latency: Optional[int],
    cpu_slowdown: Optional[int],
    only_simulator: Optional[bool],
    speed_index: Optional[bool],
    cache_time: Optional[int],
    user_data_dir: Optional[str],
):
    env_config = EnvironmentConfig.load_file(manifest)
    default_client_env = get_default_client_environment()
    client_env = get_client_environment_from_parameters(
        bandwidth or default_client_env.bandwidth,
        latency or default_client_env.latency,
        cpu_slowdown or default_client_env.cpu_slowdown,
    )

    data = {
        "client_env": client_env._asdict(),
        "url": env_config.request_url,
        "cache": "warm" if user_data_dir else "cold",
        "metric": "speed_index" if speed_index else "plt",
        "cache_time": cache_time,
    }

    if not only_simulator:
        config = get_config(env_config)
        plt, push_plts, policies = _get_results_in_replay_server(
            config, client_env, iterations, max_retries, policy_generator,
            cache_time, user_data_dir, speed_index)
        data["replay_server"] = {
            "without_policy":
            plt,
            "with_policy": [{
                "plt": plt,
                "policy": policy.as_dict
            } for (plt, policy) in zip(push_plts, policies)],
        }

    else:
        policies = [policy_generator(env_config) for _ in range(iterations)]

    sim = Simulator(env_config)
    data["simulator"] = {
        "without_policy":
        sim.simulate_load_time(client_env),
        "with_policy": [{
            "plt": sim.simulate_load_time(client_env, policy),
            "policy": policy.as_dict
        } for policy in policies],
    }

    print(json.dumps(data, indent=4))
Exemple #9
0
def get_env_config() -> EnvironmentConfig:
    return EnvironmentConfig(
        replay_dir="/tmp/replay_dir",
        request_url="http://example.com/",
        push_groups=get_push_groups(),
        har_resources=sorted(
            [res for group in get_push_groups() for res in group.resources],
            key=lambda r: r.order),
    )
Exemple #10
0
def random_push_policy(args):
    """
    Outputs a random push policy for the given recorded website
    """
    log.info("generating a random policy", policy_type=args.policy_type)
    env_config = EnvironmentConfig.load_file(args.from_manifest)

    weight = 0 if args.policy_type == "preload" else 1 if args.policy_type == "push" else None
    policy = _random_push_preload_policy_generator(weight)(env_config)

    print(json.dumps(policy.as_dict, indent=4))
Exemple #11
0
    def test_raises_on_no_replay_dir(self):
        config = _get_config()
        with pytest.raises(ValueError):
            capture_har_in_replay_server("https://www.cs.ucla.edu", config,
                                         self.client_env)

        config = _get_config(
            EnvironmentConfig(request_url="https://www.cs.ucla.edu",
                              replay_dir=""))
        with pytest.raises(ValueError):
            capture_har_in_replay_server("https://www.cs.ucla.edu", config,
                                         self.client_env)
Exemple #12
0
    def get_policy(self, url: str, client_env: ClientEnvironment,
                   manifest: EnvironmentConfig) -> Policy:
        """ Queries the policy service for a push policy for the given configuration """
        page = policy_service_pb2.Page(
            url=url,
            bandwidth_kbps=client_env.bandwidth,
            latency_ms=client_env.latency,
            cpu_slowdown=client_env.cpu_slowdown,
            manifest=manifest.serialize(),
        )

        policy_res = self.stub.GetPolicy(page)
        return Policy.from_dict(json.loads(policy_res.policy))
Exemple #13
0
 def test_pickle(self):
     c = get_env_config()
     with tempfile.NamedTemporaryFile() as tmp_file:
         c.save_file(tmp_file.name)
         loaded_c = EnvironmentConfig.load_file(tmp_file.name)
         assert c.request_url == loaded_c.request_url
         assert c.replay_dir == loaded_c.replay_dir
         assert len(c.push_groups) == len(loaded_c.push_groups)
         for i, group in enumerate(c.push_groups):
             assert loaded_c.push_groups[i].name == group.name
             assert len(loaded_c.push_groups[i].resources) == len(group.resources)
             for j, res in enumerate(group.resources):
                 assert loaded_c.push_groups[i].resources[j] == res
Exemple #14
0
    def test_simulator(self):
        har_json = get_har_json()
        har = har_from_json(har_json)
        res_list = har_entries_to_resources(har)
        push_groups = resource_list_to_push_groups(res_list)
        env_config = EnvironmentConfig(replay_dir="",
                                       request_url="https://www.reddit.com/",
                                       push_groups=push_groups,
                                       har_resources=res_list)

        client_env = get_fast_mobile_client_environment()

        simulator = Simulator(env_config)
        time_ms = simulator.simulate_load_time(client_env)
        assert time_ms > 0
Exemple #15
0
    def test_view_manifest(self):
        har = har_from_json(get_har_json())
        res_list = har_entries_to_resources(har)
        push_groups = resource_list_to_push_groups(res_list)
        config = EnvironmentConfig(replay_dir="",
                                   request_url="https://www.reddit.com/",
                                   push_groups=push_groups,
                                   har_resources=res_list)

        with mock.patch("builtins.print") as mock_print:
            with tempfile.NamedTemporaryFile() as config_file:
                config.save_file(config_file.name)
                view_manifest([config_file.name])
        assert mock_print.call_count > 5

        printed_text = "\n".join(call[0][0]
                                 for call in mock_print.call_args_list
                                 if call[0])
        assert config.replay_dir in printed_text
        assert config.request_url in printed_text
        assert all(group.name in printed_text for group in config.push_groups)
        assert all(
            Url.parse(res.url).resource[:61] in printed_text
            for group in config.push_groups for res in group.resources)
Exemple #16
0
    def test_calls_capture_har_with_correct_arguments(self, mock_run,
                                                      mock_open):
        mock_run.return_value = subprocess.CompletedProcess(args=[],
                                                            returncode=0)

        config = _get_config(
            EnvironmentConfig(request_url="https://www.cs.ucla.edu",
                              replay_dir="/tmp/dir"))
        har = capture_har_in_replay_server("https://www.cs.ucla.edu", config,
                                           self.client_env)

        run_args = mock_run.call_args_list[0][0][0]
        assert run_args[0] == "docker"
        assert run_args[-1] == "https://www.cs.ucla.edu"
        assert har == self.har
Exemple #17
0
    def test_writes_mahimahi_files_correctly(self, mock_run, mock_open,
                                             mock_tmpdir):
        tmp_dir = "/tmp/blaze_test_123"
        mock_run.return_value = subprocess.CompletedProcess(args=[],
                                                            returncode=0)
        mock_tmpdir.return_value.__enter__.return_value = tmp_dir
        config = _get_config(
            EnvironmentConfig(request_url="https://www.cs.ucla.edu",
                              replay_dir=tmp_dir))

        capture_har_in_replay_server("https://www.cs.ucla.edu", config,
                                     self.client_env)

        assert mock_open.call_args_list[0][0][0].startswith(tmp_dir)
        assert mock_open.call_args_list[1][0][0].startswith(tmp_dir)
        assert mock_open.call_args_list[0][0][1] == "w"
        assert mock_open.call_args_list[1][0][1] == "w"
Exemple #18
0
def query(args):
    """
    Queries a trained model that is served on a gRPC server.
    """
    log.info("querying server...", host=args.host, port=args.port)

    channel = grpc.insecure_channel(f"{args.host}:{args.port}")
    client = Client(channel)

    manifest = EnvironmentConfig.load_file(args.manifest)
    client_env = get_client_environment_from_parameters(
        args.bandwidth, args.latency, args.cpu_slowdown)
    policy = client.get_policy(url=manifest.request_url,
                               client_env=client_env,
                               manifest=manifest)

    print(json.dumps(policy.as_dict, indent=4))
Exemple #19
0
 def test_compiles(self):
     c = EnvironmentConfig(replay_dir="/replay/dir", request_url="http://example.com", push_groups=[])
     assert isinstance(c, EnvironmentConfig)
Exemple #20
0
 def read_file(fpath):
     log.debug("reading file...", file=fpath)
     return EnvironmentConfig.load_file(fpath)
Exemple #21
0
def page_load_time(args):
    """
    Captures a webpage and calculates the median page load time for a given website
    in a fast, no-latency Mahimahi shell. Then simulates the load based on profiling
    the page in the same Mahimahi shell.
    """
    # Validate the arguments
    if args.latency is not None and args.latency < 0:
        log.critical("provided latency must be greater or equal to 0")
        sys.exit(1)
    if args.bandwidth is not None and args.bandwidth <= 0:
        log.critical("provided bandwidth must be greater than 0")
        sys.exit(1)
    if args.cpu_slowdown is not None and args.cpu_slowdown not in {1, 2, 4}:
        log.critical("provided cpu slodown must be 1, 2, or 4")
        sys.exit(1)

    # Setup the client environment
    default_client_env = get_default_client_environment()
    client_env = get_client_environment_from_parameters(
        args.bandwidth or default_client_env.bandwidth,
        args.latency or default_client_env.latency,
        args.cpu_slowdown or default_client_env.cpu_slowdown,
    )

    # If a push/preload policy was specified, read it
    policy = None
    if args.policy:
        log.debug("reading policy", push_policy=args.policy)
        with open(args.policy, "r") as policy_file:
            policy_dict = json.load(policy_file)
        policy = Policy.from_dict(policy_dict)

    env_config = EnvironmentConfig.load_file(args.from_manifest)
    config = get_config(env_config)

    log.info("calculating page load time",
             manifest=args.from_manifest,
             url=env_config.request_url)
    plt, orig_plt = 0, 0
    if not args.only_simulator:
        if not args.speed_index:
            orig_plt, *_ = get_page_load_time_in_replay_server(
                request_url=config.env_config.request_url,
                client_env=client_env,
                config=config,
                cache_time=args.cache_time,
                user_data_dir=args.user_data_dir,
            )
            if policy:
                plt, *_ = get_page_load_time_in_replay_server(
                    request_url=config.env_config.request_url,
                    client_env=client_env,
                    config=config,
                    policy=policy,
                    cache_time=args.cache_time,
                    user_data_dir=args.user_data_dir,
                )
        else:
            orig_plt = get_speed_index_in_replay_server(
                request_url=config.env_config.request_url,
                client_env=client_env,
                config=config,
                cache_time=args.cache_time,
                user_data_dir=args.user_data_dir,
            )
            if policy:
                plt = get_speed_index_in_replay_server(
                    request_url=config.env_config.request_url,
                    client_env=client_env,
                    config=config,
                    policy=policy,
                    cache_time=args.cache_time,
                    user_data_dir=args.user_data_dir,
                )

    log.debug("running simulator...")
    sim = Simulator(env_config)
    orig_sim_plt = sim.simulate_load_time(client_env)
    sim_plt = sim.simulate_load_time(client_env, policy)

    print(
        json.dumps(
            {
                "client_env": client_env._asdict(),
                "metric": "speed_index" if args.speed_index else "plt",
                "cache": "warm" if args.user_data_dir else "cold",
                "cache_time": args.cache_time,
                "replay_server": {
                    "with_policy": plt,
                    "without_policy": orig_plt
                },
                "simulator": {
                    "with_policy": sim_plt,
                    "without_policy": orig_sim_plt
                },
            },
            indent=4,
        ))
Exemple #22
0
def evaluate(args):
    """
    Instantiate the given model and checkpoint and query it for the policy corresponding to the given
    client and network conditions. Also allows running the generated policy through the simulator and
    replay server to get the PLTs and compare them under different conditions.
    """
    log.info("evaluating model...",
             model=args.model,
             location=args.location,
             manifest=args.manifest)
    client_env = get_client_environment_from_parameters(
        args.bandwidth, args.latency, args.cpu_slowdown)
    manifest = EnvironmentConfig.load_file(args.manifest)

    cached_urls = set(
        res.url for group in manifest.push_groups for res in group.resources
        if args.cache_time is not None and res.cache_time > args.cache_time)

    log.debug("using cached resources", cached_urls=cached_urls)
    config = get_config(manifest, client_env, args.reward_func).with_mutations(
        cached_urls=cached_urls, use_aft=args.use_aft)

    if args.model == "A3C":
        from blaze.model import a3c as model
    if args.model == "APEX":
        from blaze.model import apex as model
    if args.model == "PPO":
        from blaze.model import ppo as model

    import ray

    ray.init(num_cpus=2, log_to_driver=False)

    saved_model = model.get_model(args.location)
    instance = saved_model.instantiate(config)
    policy = instance.policy
    data = policy.as_dict

    if args.verbose or args.run_simulator or args.run_replay_server:
        data = {
            "manifest": args.manifest,
            "location": args.location,
            "client_env": client_env._asdict(),
            "policy": policy.as_dict,
        }

    if args.run_simulator:
        sim = Simulator(manifest)
        sim_plt = sim.simulate_load_time(client_env)
        push_plt = sim.simulate_load_time(client_env, policy)
        data["simulator"] = {
            "without_policy": sim_plt,
            "with_policy": push_plt
        }

    if args.run_replay_server:
        *_, plts = get_page_load_time_in_replay_server(
            config.env_config.request_url, client_env, config)
        *_, push_plts = get_page_load_time_in_replay_server(
            config.env_config.request_url, client_env, config, policy=policy)
        data["replay_server"] = {
            "without_policy": plts,
            "with_policy": push_plts
        }

    print(json.dumps(data, indent=4))