예제 #1
0
 def setup(self):
     self.client_environment = get_random_client_environment()
     self.page = get_page("http://example.com", self.client_environment)
     self.push_groups = get_push_groups()
     self.trainable_push_groups = [group for group in self.push_groups if group.trainable]
     self.action_space = ActionSpace(self.trainable_push_groups)
     self.saved_model = SavedModel(mock_agent_with_action_space(self.action_space), Environment, "", {})
예제 #2
0
 def setup(self):
     self.config = get_config()
     self.action_space = ActionSpace(get_push_groups())
     self.client_environment = get_random_client_environment()
     self.policy = Policy(self.action_space)
     applied = True
     while applied:
         applied = self.policy.apply_action(self.action_space.sample())
예제 #3
0
 def test_get_random_env(self):
     env = client.get_random_client_environment()
     assert isinstance(env, client.ClientEnvironment)
     assert isinstance(env.network_type, client.NetworkType)
     assert isinstance(env.network_speed, client.NetworkSpeed)
     assert isinstance(env.device_speed, client.DeviceSpeed)
     assert isinstance(env.bandwidth, int) and env.bandwidth > 0
     assert isinstance(env.latency, int) and env.latency > 0
예제 #4
0
 def test_generates_correct_range_for_network(self):
     for _ in range(100):
         env = client.get_random_client_environment()
         bw_low, bw_high = client.network_to_bandwidth_range(
             env.network_type, env.network_speed)
         latency_low, latency_high = client.network_to_latency_range(
             env.network_type)
         assert bw_low <= env.bandwidth <= bw_high
         assert latency_low <= env.latency <= latency_high
예제 #5
0
def get_page(
    url: str, client_environment=get_random_client_environment()
) -> policy_service_pb2.Page:
    return policy_service_pb2.Page(
        url=url,
        bandwidth_kbps=client_environment.bandwidth,
        latency_ms=client_environment.latency,
        cpu_slowdown=client_environment.cpu_slowdown,
        manifest=get_env_config().serialize(),
    )
예제 #6
0
    def test_instantiate_creates_model_with_given_environment(self):
        env_config = get_env_config()
        client_env = get_random_client_environment()
        config = get_config(env_config, client_env)

        saved_model = SavedModel(MockAgent, Environment, "/tmp/model_location",
                                 {})
        model_instance = saved_model.instantiate(config)
        assert isinstance(model_instance, ModelInstance)
        assert isinstance(model_instance.agent, MockAgent)
        assert model_instance.agent.kwargs["env"] == Environment
        assert model_instance.agent.kwargs["config"] == {"env_config": config}
        assert model_instance.agent.file_path == saved_model.location
        assert model_instance.config == config
예제 #7
0
    def test_get_policy(self):
        server = Server(self.serve_config)
        policy_service = PolicyService(self.saved_model)
        server.set_policy_service(policy_service)
        try:
            server.start()
            time.sleep(0.5)
            # create the client
            address = "{}:{}".format(self.serve_config.host,
                                     self.serve_config.port)
            channel = grpc.insecure_channel(address)
            client_stub = Client(channel)
            policy = client_stub.get_policy(
                url="https://www.example.com",
                client_env=client.get_random_client_environment(),
                manifest=self.env_config,
            )

            assert policy is not None
            assert len(list(policy.push)) + len(list(policy.preload)) > 0
        finally:
            server.stop()
예제 #8
0
 def setup(self):
     self.config = get_config()
     self.policy = Policy(ActionSpace(self.config.env_config.push_groups))
     self.client_environment = get_random_client_environment()
예제 #9
0
 def setup(self):
     self.client_environment = get_random_client_environment()
     self.env_config = get_env_config()
     self.config = get_config(self.env_config, self.client_environment)
     self.trainable_push_groups = self.env_config.trainable_push_groups
예제 #10
0
 def setup(self):
     self.push_groups = get_push_groups()
     self.observation_space = get_observation_space()
     self.client_environment = get_random_client_environment()
예제 #11
0
 def test_latency_rand_is_even(self):
     for _ in range(10):
         env = client.get_random_client_environment()
         assert env.latency % 2 == 0
예제 #12
0
 def test_bandwidth_rand_is_multiple_of_100(self):
     for _ in range(10):
         env = client.get_random_client_environment()
         assert env.bandwidth % 100 == 0
예제 #13
0
def get_mahimahi_config() -> MahiMahiConfig:
    return MahiMahiConfig(
        config=get_config(),
        policy=Policy(ActionSpace(get_push_groups())),
        client_environment=get_random_client_environment(),
    )