Ejemplo n.º 1
0
    def test_runs_successfully(self, mock_capture_har_in_mahimahi):
        hars = [generate_har() for _ in range(STABLE_SET_NUM_RUNS + 1)]
        har_resources = har_entries_to_resources(hars[0])
        mock_capture_har_in_mahimahi.return_value = hars[0]
        with tempfile.NamedTemporaryFile() as output_file:
            with tempfile.TemporaryDirectory() as output_dir:
                with mock.patch(
                        "blaze.preprocess.record.capture_har_in_replay_server",
                        new=HarReturner(hars)):
                    preprocess([
                        "https://cs.ucla.edu", "--output", output_file.name,
                        "--record_dir", output_dir
                    ])

                config = EnvironmentConfig.load_file(output_file.name)
                assert config.replay_dir == output_dir
                assert config.request_url == "https://cs.ucla.edu"
                assert config.push_groups
                # since we passed cs.ucla.edu as URL, nothing should be trainable
                assert all(not group.trainable for group in config.push_groups)
                assert config.har_resources == har_resources

        client_env = get_default_client_environment()
        config = get_config(
            EnvironmentConfig(replay_dir=output_dir,
                              request_url="https://cs.ucla.edu"))

        assert mock_capture_har_in_mahimahi.call_count == 1
        mock_capture_har_in_mahimahi.assert_called_with(
            "https://cs.ucla.edu", config, client_env)
Ejemplo n.º 2
0
def train(args):
    """
    Trains a model to generate push policies for the given website. This command takes as input the
    manifest file generated by `blaze preprocess` and outputs a model that can be served.
    """
    # check for ambiguous options
    if args.resume and args.no_resume:
        log.error(
            "invalid options: cannot specify both --resume and --no-resume")
        sys.exit(1)

    log.info("starting train", name=args.name, model=args.model)

    # import specified model
    if args.model == "A3C":
        from blaze.model import a3c as model
    if args.model == "APEX":
        from blaze.model import apex as model
    if args.model == "PPO":
        from blaze.model import ppo as model

    # compute resume flag and initialize training
    resume = False if args.no_resume else True if args.resume else "prompt"
    train_config = TrainConfig(experiment_name=args.name,
                               num_workers=args.workers,
                               resume=resume)
    env_config = EnvironmentConfig.load_file(args.manifest_file)
    config = get_config(env_config,
                        reward_func=args.reward_func,
                        use_aft=args.use_aft)
    model.train(train_config, config)
Ejemplo n.º 3
0
 def test_get_default_config(self):
     conf = config.get_config()
     assert isinstance(conf, config.Config)
     assert conf.http2push_image == config.DEFAULT_HTTP2PUSH_IMAGE
     assert conf.chrome_bin == config.DEFAULT_CHROME_BIN
     assert conf.env_config is None
     assert conf.client_env is None
     assert conf.reward_func is None
Ejemplo n.º 4
0
def _test_push(
    *,
    manifest: str,
    iterations: Optional[int] = 1,
    max_retries: Optional[int] = 0,
    policy_generator: Callable[[EnvironmentConfig], Policy],
    bandwidth: Optional[int],
    latency: Optional[int],
    cpu_slowdown: Optional[int],
    only_simulator: Optional[bool],
    speed_index: Optional[bool],
    cache_time: Optional[int],
    user_data_dir: Optional[str],
):
    env_config = EnvironmentConfig.load_file(manifest)
    default_client_env = get_default_client_environment()
    client_env = get_client_environment_from_parameters(
        bandwidth or default_client_env.bandwidth,
        latency or default_client_env.latency,
        cpu_slowdown or default_client_env.cpu_slowdown,
    )

    data = {
        "client_env": client_env._asdict(),
        "url": env_config.request_url,
        "cache": "warm" if user_data_dir else "cold",
        "metric": "speed_index" if speed_index else "plt",
        "cache_time": cache_time,
    }

    if not only_simulator:
        config = get_config(env_config)
        plt, push_plts, policies = _get_results_in_replay_server(
            config, client_env, iterations, max_retries, policy_generator,
            cache_time, user_data_dir, speed_index)
        data["replay_server"] = {
            "without_policy":
            plt,
            "with_policy": [{
                "plt": plt,
                "policy": policy.as_dict
            } for (plt, policy) in zip(push_plts, policies)],
        }

    else:
        policies = [policy_generator(env_config) for _ in range(iterations)]

    sim = Simulator(env_config)
    data["simulator"] = {
        "without_policy":
        sim.simulate_load_time(client_env),
        "with_policy": [{
            "plt": sim.simulate_load_time(client_env, policy),
            "policy": policy.as_dict
        } for policy in policies],
    }

    print(json.dumps(data, indent=4))
Ejemplo n.º 5
0
 def test_init(self):
     env = self.environment
     assert isinstance(env, Environment)
     assert isinstance(env.client_environment, ClientEnvironment)
     assert isinstance(env.action_space, ActionSpace)
     assert isinstance(env.analyzer, Analyzer)
     assert isinstance(env.policy, Policy)
     assert env.config.env_config.push_groups == get_config(
     ).env_config.push_groups
     assert env.action_space.push_groups == self.push_groups
     assert env.policy.action_space == env.action_space
Ejemplo n.º 6
0
def record(args):
    """
    Record a website using Mahimahi. Stores the recorded files in the specified directory. In order
    to use it with blaze, you must preprocess it using `blaze preprocess` to generate a training
    manifest.
    """
    log.info("recording website",
             website=args.website,
             record_dir=args.record_dir)

    config = get_config()
    log.debug("using configuration", **config._asdict())
    record_webpage(args.website, args.record_dir, config)
Ejemplo n.º 7
0
    def test_instantiate_creates_model_with_given_environment(self):
        env_config = get_env_config()
        client_env = get_random_client_environment()
        config = get_config(env_config, client_env)

        saved_model = SavedModel(MockAgent, Environment, "/tmp/model_location",
                                 {})
        model_instance = saved_model.instantiate(config)
        assert isinstance(model_instance, ModelInstance)
        assert isinstance(model_instance.agent, MockAgent)
        assert model_instance.agent.kwargs["env"] == Environment
        assert model_instance.agent.kwargs["config"] == {"env_config": config}
        assert model_instance.agent.file_path == saved_model.location
        assert model_instance.config == config
Ejemplo n.º 8
0
    def test_environment_with_cached_urls(self):
        config = get_config()
        resources = [
            res for group in config.env_config.push_groups
            for res in group.resources
        ]
        mask = [random.randint(0, 2) for _ in range(len(resources))]
        cached = [res for (res, include) in zip(resources, mask) if include]
        cached_urls = set(res.url for res in cached)

        config = config.with_mutations(cached_urls=cached_urls)
        env = Environment(config)

        obs = env.observation
        for res in cached:
            assert obs["resources"][str(res.order)][1] == 1
Ejemplo n.º 9
0
def preprocess(args):
    """
    Preprocesses a website for training. Automatically discovers linked pages up to a certain depth
    and finds the stable set of page dependencies. The page load is recorded and stored and a
    training manifest is outputted.
    """
    domain = Url.parse(args.website).domain
    train_domain_globs = args.train_domain_globs or ["*{}*".format(domain)]
    log.info("preprocessing website",
             website=args.website,
             record_dir=args.record_dir,
             train_domain_globs=train_domain_globs)

    config = get_config(env_config=EnvironmentConfig(
        replay_dir=args.record_dir, request_url=args.website))
    client_env = get_default_client_environment()
    log.debug("using configuration", **config._asdict())

    log.info("capturing execution")
    har_resources = har_entries_to_resources(
        capture_har_in_replay_server(args.website, config, client_env))

    log.info("finding dependency stable set...")
    res_list = find_url_stable_set(args.website, config)

    log.info("found total dependencies", total=len(res_list))
    push_groups = resource_list_to_push_groups(
        res_list, train_domain_globs=train_domain_globs)

    if args.extract_critical_requests:
        log.info("extracting critical requests")
        push_groups = annotate_critical_requests(args.website, config,
                                                 client_env, push_groups)
        critical_resources = set(res.url for group in push_groups
                                 for res in group.resources if res.critical)
        log.debug("critical resources", resources=critical_resources)

    log.info("finding cacheable objects")
    push_groups = annotate_cacheable_objects(args.record_dir, push_groups)

    log.info("generating configuration...")
    env_config = EnvironmentConfig(replay_dir=args.record_dir,
                                   request_url=args.website,
                                   push_groups=push_groups,
                                   har_resources=har_resources)
    env_config.save_file(args.output)
    log.info("successfully prepared website for training", output=args.output)
Ejemplo n.º 10
0
    def test_train_ppo(self, mock_train):
        env_config = get_env_config()
        train_config = TrainConfig(experiment_name="experiment_name",
                                   num_workers=4)
        config = get_config(env_config, reward_func=1, use_aft=False)
        with tempfile.NamedTemporaryFile() as env_file:
            env_config.save_file(env_file.name)
            train([
                train_config.experiment_name,
                "--workers",
                str(train_config.num_workers),
                "--model",
                "PPO",
                "--manifest_file",
                env_file.name,
            ])

        mock_train.assert_called_once()
        mock_train.assert_called_with(train_config, config)
Ejemplo n.º 11
0
 def test_obseration_when_environment_is_created_with_dict(self):
     env = Environment(get_config()._asdict())
     obs = env.observation
     assert obs and isinstance(obs, dict)
     assert self.environment.observation_space.contains(obs)
Ejemplo n.º 12
0
 def test_get_config_with_env_config(self):
     conf = config.get_config(get_env_config())
     assert conf.env_config == get_env_config()
Ejemplo n.º 13
0
 def test_get_config_with_other_properties(self):
     client_env = get_random_client_environment()
     conf = config.get_config(get_env_config(), client_env, 0)
     assert conf.env_config == get_env_config()
     assert conf.client_env == client_env
     assert conf.reward_func == 0
Ejemplo n.º 14
0
 def test_get_config_with_override(self):
     conf = config.get_config()
     assert isinstance(conf, config.Config)
     assert conf.http2push_image == "test_image"
     assert conf.chrome_bin == "test_chrome"
     assert conf.env_config is None
Ejemplo n.º 15
0
 def setup(self):
     self.client_environment = get_random_client_environment()
     self.env_config = get_env_config()
     self.config = get_config(self.env_config, self.client_environment)
     self.trainable_push_groups = self.env_config.trainable_push_groups
Ejemplo n.º 16
0
 def test_init_with_dict_env(self):
     env = Environment(get_config()._asdict())
     assert isinstance(env, Environment)
     assert env.config == get_config()
Ejemplo n.º 17
0
 def setup(self):
     self.environment = Environment(get_config())
     self.environment.action_space.seed(2048)
     self.push_groups = self.environment.env_config.push_groups
     self.trainable_push_groups = self.environment.env_config.trainable_push_groups
Ejemplo n.º 18
0
 def test_items(self):
     conf = config.get_config()
     items = conf.items()
     assert all(len(v) == 2 for v in items)
     assert len(items) == 7
Ejemplo n.º 19
0
 def test_train_compiles(self, mock_run_experiments, _):
     ppo.train(get_train_config(), get_config(get_env_config()))
     mock_run_experiments.assert_called_once()
Ejemplo n.º 20
0
 def __init__(self, saved_model: SavedModel, config=get_config()):
     self.config = config
     self.saved_model = saved_model
     self.policies: Dict[str, policy_service_pb2.Policy] = {}
Ejemplo n.º 21
0
def page_load_time(args):
    """
    Captures a webpage and calculates the median page load time for a given website
    in a fast, no-latency Mahimahi shell. Then simulates the load based on profiling
    the page in the same Mahimahi shell.
    """
    # Validate the arguments
    if args.latency is not None and args.latency < 0:
        log.critical("provided latency must be greater or equal to 0")
        sys.exit(1)
    if args.bandwidth is not None and args.bandwidth <= 0:
        log.critical("provided bandwidth must be greater than 0")
        sys.exit(1)
    if args.cpu_slowdown is not None and args.cpu_slowdown not in {1, 2, 4}:
        log.critical("provided cpu slodown must be 1, 2, or 4")
        sys.exit(1)

    # Setup the client environment
    default_client_env = get_default_client_environment()
    client_env = get_client_environment_from_parameters(
        args.bandwidth or default_client_env.bandwidth,
        args.latency or default_client_env.latency,
        args.cpu_slowdown or default_client_env.cpu_slowdown,
    )

    # If a push/preload policy was specified, read it
    policy = None
    if args.policy:
        log.debug("reading policy", push_policy=args.policy)
        with open(args.policy, "r") as policy_file:
            policy_dict = json.load(policy_file)
        policy = Policy.from_dict(policy_dict)

    env_config = EnvironmentConfig.load_file(args.from_manifest)
    config = get_config(env_config)

    log.info("calculating page load time",
             manifest=args.from_manifest,
             url=env_config.request_url)
    plt, orig_plt = 0, 0
    if not args.only_simulator:
        if not args.speed_index:
            orig_plt, *_ = get_page_load_time_in_replay_server(
                request_url=config.env_config.request_url,
                client_env=client_env,
                config=config,
                cache_time=args.cache_time,
                user_data_dir=args.user_data_dir,
            )
            if policy:
                plt, *_ = get_page_load_time_in_replay_server(
                    request_url=config.env_config.request_url,
                    client_env=client_env,
                    config=config,
                    policy=policy,
                    cache_time=args.cache_time,
                    user_data_dir=args.user_data_dir,
                )
        else:
            orig_plt = get_speed_index_in_replay_server(
                request_url=config.env_config.request_url,
                client_env=client_env,
                config=config,
                cache_time=args.cache_time,
                user_data_dir=args.user_data_dir,
            )
            if policy:
                plt = get_speed_index_in_replay_server(
                    request_url=config.env_config.request_url,
                    client_env=client_env,
                    config=config,
                    policy=policy,
                    cache_time=args.cache_time,
                    user_data_dir=args.user_data_dir,
                )

    log.debug("running simulator...")
    sim = Simulator(env_config)
    orig_sim_plt = sim.simulate_load_time(client_env)
    sim_plt = sim.simulate_load_time(client_env, policy)

    print(
        json.dumps(
            {
                "client_env": client_env._asdict(),
                "metric": "speed_index" if args.speed_index else "plt",
                "cache": "warm" if args.user_data_dir else "cold",
                "cache_time": args.cache_time,
                "replay_server": {
                    "with_policy": plt,
                    "without_policy": orig_plt
                },
                "simulator": {
                    "with_policy": sim_plt,
                    "without_policy": orig_sim_plt
                },
            },
            indent=4,
        ))
Ejemplo n.º 22
0
def evaluate(args):
    """
    Instantiate the given model and checkpoint and query it for the policy corresponding to the given
    client and network conditions. Also allows running the generated policy through the simulator and
    replay server to get the PLTs and compare them under different conditions.
    """
    log.info("evaluating model...",
             model=args.model,
             location=args.location,
             manifest=args.manifest)
    client_env = get_client_environment_from_parameters(
        args.bandwidth, args.latency, args.cpu_slowdown)
    manifest = EnvironmentConfig.load_file(args.manifest)

    cached_urls = set(
        res.url for group in manifest.push_groups for res in group.resources
        if args.cache_time is not None and res.cache_time > args.cache_time)

    log.debug("using cached resources", cached_urls=cached_urls)
    config = get_config(manifest, client_env, args.reward_func).with_mutations(
        cached_urls=cached_urls, use_aft=args.use_aft)

    if args.model == "A3C":
        from blaze.model import a3c as model
    if args.model == "APEX":
        from blaze.model import apex as model
    if args.model == "PPO":
        from blaze.model import ppo as model

    import ray

    ray.init(num_cpus=2, log_to_driver=False)

    saved_model = model.get_model(args.location)
    instance = saved_model.instantiate(config)
    policy = instance.policy
    data = policy.as_dict

    if args.verbose or args.run_simulator or args.run_replay_server:
        data = {
            "manifest": args.manifest,
            "location": args.location,
            "client_env": client_env._asdict(),
            "policy": policy.as_dict,
        }

    if args.run_simulator:
        sim = Simulator(manifest)
        sim_plt = sim.simulate_load_time(client_env)
        push_plt = sim.simulate_load_time(client_env, policy)
        data["simulator"] = {
            "without_policy": sim_plt,
            "with_policy": push_plt
        }

    if args.run_replay_server:
        *_, plts = get_page_load_time_in_replay_server(
            config.env_config.request_url, client_env, config)
        *_, push_plts = get_page_load_time_in_replay_server(
            config.env_config.request_url, client_env, config, policy=policy)
        data["replay_server"] = {
            "without_policy": plts,
            "with_policy": push_plts
        }

    print(json.dumps(data, indent=4))