Exemple #1
0
    def test_client(self, capsys):
        server = Server(self.serve_config)
        server.set_policy_service(PolicyService(self.saved_model))

        try:
            server.start()

            with tempfile.NamedTemporaryFile() as manifest_file:
                config = get_env_config()
                config.save_file(manifest_file.name)
                query([
                    "--manifest",
                    manifest_file.name,
                    "-l",
                    "10",
                    "-b",
                    "24000",
                    "--host",
                    str(self.serve_config.host),
                    "--port",
                    str(self.serve_config.port),
                ])

            policy = json.loads(capsys.readouterr().out)
            assert policy
            assert len(policy) > 0
        finally:
            server.stop()
Exemple #2
0
 def setup(self):
     self.env_config = get_env_config()
     self.trainable_push_groups = self.env_config.trainable_push_groups
     self.serve_config = get_serve_config()
     self.action_space = ActionSpace(self.trainable_push_groups)
     self.action_space.seed(2048)
     self.mock_agent = mock_agent_with_action_space(self.action_space)
     self.saved_model = SavedModel(self.mock_agent, Environment,
                                   "/tmp/model_location", {})
Exemple #3
0
def get_page(
    url: str, client_environment=get_random_client_environment()
) -> policy_service_pb2.Page:
    return policy_service_pb2.Page(
        url=url,
        bandwidth_kbps=client_environment.bandwidth,
        latency_ms=client_environment.latency,
        cpu_slowdown=client_environment.cpu_slowdown,
        manifest=get_env_config().serialize(),
    )
Exemple #4
0
 def test_pickle(self):
     c = get_env_config()
     with tempfile.NamedTemporaryFile() as tmp_file:
         c.save_file(tmp_file.name)
         loaded_c = EnvironmentConfig.load_file(tmp_file.name)
         assert c.request_url == loaded_c.request_url
         assert c.replay_dir == loaded_c.replay_dir
         assert len(c.push_groups) == len(loaded_c.push_groups)
         for i, group in enumerate(c.push_groups):
             assert loaded_c.push_groups[i].name == group.name
             assert len(loaded_c.push_groups[i].resources) == len(group.resources)
             for j, res in enumerate(group.resources):
                 assert loaded_c.push_groups[i].resources[j] == res
Exemple #5
0
    def test_instantiate_creates_model_with_given_environment(self):
        env_config = get_env_config()
        client_env = get_random_client_environment()
        config = get_config(env_config, client_env)

        saved_model = SavedModel(MockAgent, Environment, "/tmp/model_location",
                                 {})
        model_instance = saved_model.instantiate(config)
        assert isinstance(model_instance, ModelInstance)
        assert isinstance(model_instance.agent, MockAgent)
        assert model_instance.agent.kwargs["env"] == Environment
        assert model_instance.agent.kwargs["config"] == {"env_config": config}
        assert model_instance.agent.file_path == saved_model.location
        assert model_instance.config == config
Exemple #6
0
    def test_cluster(self, capsys):
        port = 24451
        distances = [0] * 100
        env_config = get_env_config()
        with tempfile.TemporaryDirectory() as tmp_dir:
            for i in range(3):
                env_config._replace(
                    request_url=env_config.request_url +
                    str(i)).save_file(f"{tmp_dir}/{i}.manifest")

            with apted_server(port, distances):
                cluster(["--apted_port", str(port), tmp_dir])

            resp = json.loads(capsys.readouterr().out)
            assert resp
            assert len(resp) > 0
            for url, mapping in resp.items():
                assert mapping == 0
Exemple #7
0
    def test_train_ppo(self, mock_train):
        env_config = get_env_config()
        train_config = TrainConfig(experiment_name="experiment_name",
                                   num_workers=4)
        config = get_config(env_config, reward_func=1, use_aft=False)
        with tempfile.NamedTemporaryFile() as env_file:
            env_config.save_file(env_file.name)
            train([
                train_config.experiment_name,
                "--workers",
                str(train_config.num_workers),
                "--model",
                "PPO",
                "--manifest_file",
                env_file.name,
            ])

        mock_train.assert_called_once()
        mock_train.assert_called_with(train_config, config)
Exemple #8
0
 def test_train_compiles(self, mock_run_experiments, _):
     ppo.train(get_train_config(), get_config(get_env_config()))
     mock_run_experiments.assert_called_once()
Exemple #9
0
 def test(self):
     distances = [0]
     with apted_server(PORT, distances):
         distance_func = create_apted_distance_function(PORT)
         assert distance_func(get_env_config(),
                              get_env_config()) == distances[0]
Exemple #10
0
 def setup(self):
     self.client_environment = get_random_client_environment()
     self.env_config = get_env_config()
     self.config = get_config(self.env_config, self.client_environment)
     self.trainable_push_groups = self.env_config.trainable_push_groups
Exemple #11
0
 def test_trainable_push_groups(self):
     c = get_env_config()
     assert all(group.trainable for group in c.trainable_push_groups)
     assert all(group in c.push_groups for group in c.trainable_push_groups)
     assert all(group not in c.trainable_push_groups for group in c.push_groups if not group.trainable)
Exemple #12
0
 def test_get_config_with_other_properties(self):
     client_env = get_random_client_environment()
     conf = config.get_config(get_env_config(), client_env, 0)
     assert conf.env_config == get_env_config()
     assert conf.client_env == client_env
     assert conf.reward_func == 0
Exemple #13
0
 def test_get_config_with_env_config(self):
     conf = config.get_config(get_env_config())
     assert conf.env_config == get_env_config()