def test_client(self, capsys): server = Server(self.serve_config) server.set_policy_service(PolicyService(self.saved_model)) try: server.start() with tempfile.NamedTemporaryFile() as manifest_file: config = get_env_config() config.save_file(manifest_file.name) query([ "--manifest", manifest_file.name, "-l", "10", "-b", "24000", "--host", str(self.serve_config.host), "--port", str(self.serve_config.port), ]) policy = json.loads(capsys.readouterr().out) assert policy assert len(policy) > 0 finally: server.stop()
def setup(self): self.env_config = get_env_config() self.trainable_push_groups = self.env_config.trainable_push_groups self.serve_config = get_serve_config() self.action_space = ActionSpace(self.trainable_push_groups) self.action_space.seed(2048) self.mock_agent = mock_agent_with_action_space(self.action_space) self.saved_model = SavedModel(self.mock_agent, Environment, "/tmp/model_location", {})
def get_page( url: str, client_environment=get_random_client_environment() ) -> policy_service_pb2.Page: return policy_service_pb2.Page( url=url, bandwidth_kbps=client_environment.bandwidth, latency_ms=client_environment.latency, cpu_slowdown=client_environment.cpu_slowdown, manifest=get_env_config().serialize(), )
def test_pickle(self): c = get_env_config() with tempfile.NamedTemporaryFile() as tmp_file: c.save_file(tmp_file.name) loaded_c = EnvironmentConfig.load_file(tmp_file.name) assert c.request_url == loaded_c.request_url assert c.replay_dir == loaded_c.replay_dir assert len(c.push_groups) == len(loaded_c.push_groups) for i, group in enumerate(c.push_groups): assert loaded_c.push_groups[i].name == group.name assert len(loaded_c.push_groups[i].resources) == len(group.resources) for j, res in enumerate(group.resources): assert loaded_c.push_groups[i].resources[j] == res
def test_instantiate_creates_model_with_given_environment(self): env_config = get_env_config() client_env = get_random_client_environment() config = get_config(env_config, client_env) saved_model = SavedModel(MockAgent, Environment, "/tmp/model_location", {}) model_instance = saved_model.instantiate(config) assert isinstance(model_instance, ModelInstance) assert isinstance(model_instance.agent, MockAgent) assert model_instance.agent.kwargs["env"] == Environment assert model_instance.agent.kwargs["config"] == {"env_config": config} assert model_instance.agent.file_path == saved_model.location assert model_instance.config == config
def test_cluster(self, capsys): port = 24451 distances = [0] * 100 env_config = get_env_config() with tempfile.TemporaryDirectory() as tmp_dir: for i in range(3): env_config._replace( request_url=env_config.request_url + str(i)).save_file(f"{tmp_dir}/{i}.manifest") with apted_server(port, distances): cluster(["--apted_port", str(port), tmp_dir]) resp = json.loads(capsys.readouterr().out) assert resp assert len(resp) > 0 for url, mapping in resp.items(): assert mapping == 0
def test_train_ppo(self, mock_train): env_config = get_env_config() train_config = TrainConfig(experiment_name="experiment_name", num_workers=4) config = get_config(env_config, reward_func=1, use_aft=False) with tempfile.NamedTemporaryFile() as env_file: env_config.save_file(env_file.name) train([ train_config.experiment_name, "--workers", str(train_config.num_workers), "--model", "PPO", "--manifest_file", env_file.name, ]) mock_train.assert_called_once() mock_train.assert_called_with(train_config, config)
def test_train_compiles(self, mock_run_experiments, _): ppo.train(get_train_config(), get_config(get_env_config())) mock_run_experiments.assert_called_once()
def test(self): distances = [0] with apted_server(PORT, distances): distance_func = create_apted_distance_function(PORT) assert distance_func(get_env_config(), get_env_config()) == distances[0]
def setup(self): self.client_environment = get_random_client_environment() self.env_config = get_env_config() self.config = get_config(self.env_config, self.client_environment) self.trainable_push_groups = self.env_config.trainable_push_groups
def test_trainable_push_groups(self): c = get_env_config() assert all(group.trainable for group in c.trainable_push_groups) assert all(group in c.push_groups for group in c.trainable_push_groups) assert all(group not in c.trainable_push_groups for group in c.push_groups if not group.trainable)
def test_get_config_with_other_properties(self): client_env = get_random_client_environment() conf = config.get_config(get_env_config(), client_env, 0) assert conf.env_config == get_env_config() assert conf.client_env == client_env assert conf.reward_func == 0
def test_get_config_with_env_config(self): conf = config.get_config(get_env_config()) assert conf.env_config == get_env_config()