Exemplo n.º 1
0
 def setup(self):
     self.config = get_config()
     self.action_space = ActionSpace(get_push_groups())
     self.client_environment = get_random_client_environment()
     self.policy = Policy(self.action_space)
     applied = True
     while applied:
         applied = self.policy.apply_action(self.action_space.sample())
Exemplo n.º 2
0
 def test_apply_action_noop_as_first_action(self):
     action_space = get_action_space()
     policy = Policy(action_space)
     applied = policy.apply_action(Action())
     assert not applied  # check that action was not applied
     assert not list(
         policy.push)  # check that no URLs were added to the policy
     assert not list(
         policy.preload)  # check that no URLs were added to the policy
     assert len(policy) == 0  # check that the policy length == 0
Exemplo n.º 3
0
    def test_apply_push_action(self):
        action_space = get_action_space()
        policy = Policy(action_space)

        action = action_space.decode_action((1, (0, 0, 1), (0, 0)))
        applied = policy.apply_action((1, (0, 0, 1), (0, 0)))
        output_policy = list(policy.push)
        assert applied
        assert len(policy) == 1
        assert len(output_policy) == 1
        assert len(output_policy[0][1]) == 1
        assert output_policy[0][0] == action.source
        assert output_policy[0][1] == {action.push}
Exemplo n.º 4
0
    def test_apply_action_noop_as_second_action(self):
        action_space = get_action_space()
        policy = Policy(action_space)

        applied = policy.apply_action((1, (0, 0, 1), (0, 0)))
        output_policy = list(policy.push)
        assert applied
        assert output_policy
        assert len(policy) == 1

        applied = policy.apply_action(Action())
        assert not applied
        assert output_policy == list(policy.push)
        assert len(policy) == 1
Exemplo n.º 5
0
def replay(args):
    """
    Starts a replay environment for the given replay directory, including setting up interfaces, running
    a DNS server, and configuring and running an nginx server to serve the requests
    """
    policy = None
    cert_path = os.path.abspath(args.cert_path) if args.cert_path else None
    key_path = os.path.abspath(args.key_path) if args.key_path else None
    per_resource_latency = os.path.abspath(
        args.per_resource_latency) if args.per_resource_latency else None

    if args.policy:
        log.debug("reading policy", push_policy=args.policy)
        with open(args.policy, "r") as policy_file:
            policy_dict = json.load(policy_file)
        policy = Policy.from_dict(policy_dict)

    # handle sigterm gracefully
    signal.signal(signal.SIGTERM, sigterm_handler)
    with start_server(args.replay_dir,
                      cert_path,
                      key_path,
                      policy,
                      per_resource_latency,
                      cache_time=args.cache_time,
                      extract_critical_requests=args.extract_critical_requests,
                      enable_http2=args.enable_http2):
        while True:
            time.sleep(86400)
Exemplo n.º 6
0
def replay(args):
    """
    Starts a replay environment for the given replay directory, including setting up interfaces, running
    a DNS server, and configuring and running an nginx server to serve the requests
    """
    policy = None
    cert_path = os.path.abspath(args.cert_path) if args.cert_path else None
    key_path = os.path.abspath(args.key_path) if args.key_path else None

    if args.policy:
        log.debug("reading policy", push_policy=args.policy)
        with open(args.policy, "r") as policy_file:
            policy_dict = json.load(policy_file)
        policy = Policy.from_dict(policy_dict)

    with start_server(
            args.replay_dir,
            cert_path,
            key_path,
            policy,
            cache_time=args.cache_time,
            extract_critical_requests=args.extract_critical_requests,
    ):
        while True:
            time.sleep(86400)
Exemplo n.º 7
0
    def test_apply_push_action_same_source_resource(self):
        action_space = get_action_space()
        policy = Policy(action_space)

        action_1 = action_space.decode_action((1, (0, 0, 1), (0, 0)))
        action_2 = action_space.decode_action((1, (0, 0, 2), (0, 0)))
        assert policy.apply_action(action_1)
        assert policy.apply_action(action_2)

        output_policy = list(policy.push)
        assert len(policy) == 2
        assert len(output_policy) == 1
        assert len(output_policy[0][1]) == 2
        assert output_policy[0][0] == action_1.source
        assert output_policy[0][0] == action_2.source
        assert output_policy[0][1] == {action_1.push, action_2.push}
Exemplo n.º 8
0
    def test_as_dict(self):
        action_space = get_action_space()
        policy = Policy(action_space)
        for _ in range(10):
            action = action_space.decode_action(action_space.sample())
            policy.apply_action(action)
            action_space.use_action(action)

        policy_dict = policy.as_dict
        for (source, push) in policy.push:
            assert all(
                p.url in [pp["url"] for pp in policy_dict["push"][source.url]]
                for p in push)
        for (source, preload) in policy.preload:
            assert all(
                p.url in
                [pp["url"] for pp in policy_dict["preload"][source.url]]
                for p in preload)
Exemplo n.º 9
0
 def test_create_push_policy(self):
     ps = PolicyService(self.saved_model)
     policy = ps.create_policy(self.page)
     assert policy
     assert isinstance(policy, policy_service_pb2.Policy)
     push_pairs = convert_push_groups_to_push_pairs(self.push_groups)
     push_pairs = [(s.url, p.url) for (s, p) in push_pairs]
     p = Policy.from_dict(json.loads(policy.policy))
     for (source, push_res) in p.push:
         for push in push_res:
             assert (source.url, push.url) in push_pairs
Exemplo n.º 10
0
    def test_apply_multiple_actions(self):
        action_space = get_action_space()
        policy = Policy(action_space)

        actions = []
        while len(policy) < 10:
            action_id = action_space.sample()
            action = action_space.decode_action(action_id)
            action_applied = policy.apply_action(action)
            action_space.use_action(action)

            if action_applied:
                actions.append(action)
                action_space.use_action(action)

        for action in actions:
            print(action)
            assert any(action.source == source and action.push == push
                       for source, res in policy.push for push in res) or any(
                           action.source == source and action.push == push
                           for source, res in policy.preload for push in res)
Exemplo n.º 11
0
    def test_from_dict(self):
        policy_dict = {
            "push": {
                "A": [{
                    "url": "B",
                    "type": "SCRIPT"
                }, {
                    "url": "C",
                    "type": "IMAGE"
                }]
            },
            "preload": {
                "B": [{
                    "url": "D",
                    "type": "IMAGE"
                }],
                "G": [{
                    "url": "E",
                    "type": "CSS"
                }, {
                    "url": "F",
                    "type": "FONT"
                }],
            },
        }
        policy = Policy.from_dict(policy_dict)
        assert policy.total_actions == 0
        assert not policy.action_space
        for (source, deps) in policy.push:
            assert isinstance(source, Resource)
            assert all(isinstance(push, Resource) for push in deps)
            assert [p["url"] for p in policy_dict["push"][source.url]
                    ] == sorted([push.url for push in deps])
            for push in deps:
                assert policy.push_to_source[push] == source
                assert push.url in [
                    p["url"] for p in policy_dict["push"][source.url]
                ]

        for (source, deps) in policy.preload:
            assert isinstance(source, Resource)
            assert all(isinstance(push, Resource) for push in deps)
            assert [p["url"] for p in policy_dict["preload"][source.url]
                    ] == sorted([push.url for push in deps])
            for push in deps:
                assert policy.preload_to_source[push] == source
                assert push.url in [
                    p["url"] for p in policy_dict["preload"][source.url]
                ]

        assert len(policy.source_to_push) == len(policy_dict["push"])
        assert len(policy.source_to_preload) == len(policy_dict["preload"])
Exemplo n.º 12
0
    def test_observation_with_nonempty_policy(self):
        action_space = ActionSpace(self.push_groups)
        policy = Policy(action_space)

        # do some actions and check the observation space over time
        for _ in range(len(action_space) - 1):
            # get an action and apply it in the policy
            action_id = action_space.sample()
            policy.apply_action(action_id)

            # get the observation
            obs = get_observation(self.client_environment, self.push_groups,
                                  policy, set())
            assert self.observation_space.contains(obs)

            # make sure the push sources are recorded correctly
            for (source, push) in policy.push:
                for push_res in push:
                    # +1 since we have defined it that way
                    assert obs["resources"][str(
                        push_res.order)][-2] == source.source_id + 1

            # make sure the push sources are recorded correctly
            for (source, preload) in policy.preload:
                for push_res in preload:
                    # +1 since we have defined it that way
                    assert obs["resources"][str(
                        push_res.order)][-1] == source.order + 1

            # check that all other resources are not pushed
            pushed_res = set(push_res.order for (source, push) in policy.push
                             for push_res in push)
            preloaded_res = set(push_res.order
                                for (source, push) in policy.preload
                                for push_res in push)
            assert all(res[-2] == 0 for order, res in obs["resources"].items()
                       if int(order) not in pushed_res)
            assert all(res[-1] == 0 for order, res in obs["resources"].items()
                       if int(order) not in preloaded_res)
Exemplo n.º 13
0
    def get_policy(self, url: str, client_env: ClientEnvironment,
                   manifest: EnvironmentConfig) -> Policy:
        """ Queries the policy service for a push policy for the given configuration """
        page = policy_service_pb2.Page(
            url=url,
            bandwidth_kbps=client_env.bandwidth,
            latency_ms=client_env.latency,
            cpu_slowdown=client_env.cpu_slowdown,
            manifest=manifest.serialize(),
        )

        policy_res = self.stub.GetPolicy(page)
        return Policy.from_dict(json.loads(policy_res.policy))
Exemplo n.º 14
0
    def test_resource_push_from(self):
        action_space = get_action_space()
        policy = Policy(action_space)
        action = Action()
        while action.is_noop or not action.is_push:
            action = action_space.decode_action(action_space.sample())
        assert policy.resource_pushed_from(action.push) is None
        assert policy.apply_action(action)
        assert policy.resource_pushed_from(action.push) is action.source

        while action.is_noop or action.is_push:
            action = action_space.decode_action(action_space.sample())
        assert policy.resource_preloaded_from(action.push) is None
        assert policy.apply_action(action)
        assert policy.resource_preloaded_from(action.push) is action.source
Exemplo n.º 15
0
    def initialize_environment(self,
                               client_environment: ClientEnvironment,
                               cached_urls: Optional[Set[str]] = None):
        """ Initialize the environment """
        log.info(
            "initialized environment",
            network_type=client.NetworkType(client_environment.network_type),
            network_speed=client.NetworkSpeed(
                client_environment.network_speed),
            device_speed=client.DeviceSpeed(client_environment.device_speed),
            bandwidth=client_environment.bandwidth,
            latency=client_environment.latency,
            cpu_slowdown=client_environment.cpu_slowdown,
            loss=client_environment.loss,
            reward_func=self.analyzer.reward_func_num,
            cached_urls=cached_urls,
        )
        # Cache scenarios in hours
        scenarios = [0, 0, 0, 0, 0, 1, 2, 4, 12, 24]
        cache_time = self.np_random.choice(scenarios)
        self.cached_urls = (cached_urls if cached_urls is not None else
                            set() if cache_time == 0 else set(
                                res.url
                                for group in self.env_config.push_groups
                                for res in group.resources
                                if res.cache_time >= (cache_time * 60 * 60)))

        self.client_environment = client_environment
        self.analyzer.reset(self.client_environment, self.cached_urls)

        num_domains_deployed = math.ceil(PROPORTION_DEPLOYED *
                                         len(self.env_config.push_groups))
        push_groups = sorted(self.env_config.push_groups,
                             key=lambda g: len(g.resources),
                             reverse=True)[:num_domains_deployed]

        self.action_space = ActionSpace(push_groups)
        self.policy = Policy(self.action_space)
Exemplo n.º 16
0
    def test_push_preload_list_for_source(self):
        action_space = get_action_space()
        policy = Policy(action_space)

        push_map = collections.defaultdict(set)
        preload_map = collections.defaultdict(set)

        while len(policy) < 10:
            action_id = action_space.sample()
            action = action_space.decode_action(action_id)
            action_applied = policy.apply_action(action)

            if action_applied:
                if action.is_push:
                    push_map[action.source].add(action.push)
                if action.is_preload:
                    preload_map[action.source].add(action.push)

        assert push_map
        assert preload_map
        for (source, push_set) in push_map.items():
            assert policy.push_set_for_resource(source) == push_set
        for (source, preload_set) in preload_map.items():
            assert policy.preload_set_for_resource(source) == preload_set
Exemplo n.º 17
0
    def test_observation_with_cached_urls(self):
        action_space = ActionSpace(self.push_groups)
        policy = Policy(action_space)

        resources = [
            res for group in self.push_groups for res in group.resources
        ]
        mask = [random.randint(0, 2) for _ in range(len(resources))]
        cached = [res for (res, include) in zip(resources, mask) if include]
        cached_urls = set(res.url for res in cached)

        obs = get_observation(self.client_environment, self.push_groups,
                              policy, cached_urls)
        for res in cached:
            assert obs["resources"][str(res.order)][1] == 1
Exemplo n.º 18
0
    def test_observation_with_nonempty_policy_with_default_actions(self):
        # use all push groups except the chosen default group
        candidate_push_groups = [
            i for i, group in enumerate(self.push_groups)
            if len(group.resources) > 2 and not group.trainable
        ]
        default_group_idx = random.choice(candidate_push_groups)
        default_group = self.push_groups[default_group_idx]
        remaining_groups = [
            group for i, group in enumerate(self.push_groups)
            if i != default_group_idx
        ]
        action_space = ActionSpace(remaining_groups)
        policy = Policy(action_space)

        # apply some default action
        for push in default_group.resources[1:]:
            policy.add_default_push_action(default_group.resources[0], push)

        # do some actions and check the observation space over time
        for _ in range(len(action_space) - 1):
            # get an action and apply it in the policy
            action_id = action_space.sample()
            policy.apply_action(action_id)

            # get the observation
            obs = get_observation(self.client_environment, self.push_groups,
                                  policy, set())
            assert self.observation_space.contains(obs)

            # make sure the push sources are recorded correctly
            for (source, push) in policy.observable_push:
                for push_res in push:
                    # +1 since we have defined it that way
                    assert obs["resources"][str(
                        push_res.order)][-2] == source.source_id + 1

            # make sure the push sources are recorded correctly
            for (source, preload) in policy.observable_preload:
                for push_res in preload:
                    # +1 since we have defined it that way
                    assert obs["resources"][str(
                        push_res.order)][-1] == source.order + 1

            # check that all other resources are not pushed
            pushed_res = set(push_res.order
                             for (source, push) in policy.observable_push
                             for push_res in push)
            preloaded_res = set(push_res.order
                                for (source, push) in policy.observable_preload
                                for push_res in push)
            assert all(res[-2] == 0 for order, res in obs["resources"].items()
                       if int(order) not in pushed_res)
            assert all(res[-1] == 0 for order, res in obs["resources"].items()
                       if int(order) not in preloaded_res)
Exemplo n.º 19
0
class TestMahiMahiConfig:
    def setup(self):
        self.config = get_config()
        self.action_space = ActionSpace(get_push_groups())
        self.client_environment = get_random_client_environment()
        self.policy = Policy(self.action_space)
        applied = True
        while applied:
            applied = self.policy.apply_action(self.action_space.sample())

    def test_init_without_policy(self):
        mm_config = MahiMahiConfig(self.config)
        assert isinstance(mm_config, MahiMahiConfig)
        assert mm_config.policy is None
        assert mm_config.client_environment is None

    def test_init_without_client_environment(self):
        mm_config = MahiMahiConfig(self.config, policy=self.policy)
        assert isinstance(mm_config, MahiMahiConfig)
        assert mm_config.policy is self.policy
        assert mm_config.client_environment is None

    def test_init_with_client_environment(self):
        mm_config = MahiMahiConfig(self.config, policy=self.policy, client_environment=self.client_environment)
        assert isinstance(mm_config, MahiMahiConfig)
        assert mm_config.policy is self.policy
        assert mm_config.client_environment is self.client_environment

    def test_record_shell_with_cmd(self):
        save_dir = "/tmp/save_dir"
        mm_config = MahiMahiConfig(self.config, policy=self.policy)
        cmd = mm_config.record_shell_with_cmd(save_dir, ["a", "command"])
        assert cmd == (mm_config.record_cmd(save_dir) + ["a", "command"])

    def test_record_cmd(self):
        save_dir = "/tmp/save_dir"
        mm_config = MahiMahiConfig(self.config)
        record_cmd = mm_config.record_cmd(save_dir)
        assert record_cmd[0] == "mm-webrecord"
        assert record_cmd[1] == save_dir

    def test_formatted_trace_file(self):
        mm_config = MahiMahiConfig(self.config, policy=self.policy, client_environment=self.client_environment)
        trace_lines = trace_for_kbps(self.client_environment.bandwidth)
        formatted = format_trace_lines(trace_lines)
        assert mm_config.formatted_trace_file == formatted
Exemplo n.º 20
0
    def test_get_default_observation(self):
        action_space = ActionSpace(self.push_groups)
        policy = Policy(action_space)

        obs = get_observation(self.client_environment, self.push_groups,
                              policy, set())
        assert isinstance(obs, dict)
        assert self.observation_space.contains(obs)

        # assert that the client environment is correctly captured
        assert obs["client"][
            "network_type"] == self.client_environment.network_type.value
        assert obs["client"][
            "device_speed"] == self.client_environment.device_speed.value

        # assert that all resources are not pushed initially
        assert all(res[-2] == 0 for res in obs["resources"].values())
        # assert that all resources are not preloaded initially
        assert all(res[-1] == 0 for res in obs["resources"].values())

        # assert that the push_groups are encoded correctly
        for group in self.push_groups:
            for res in group.resources:
                assert np.array_equal(
                    obs["resources"][str(res.order)],
                    np.array((
                        1,  # resource is enabled
                        0,  # resource is not cached
                        group.id,  # the resource's domain id
                        res.
                        source_id,  # the resource's relative offset from its domain top
                        res.order +
                        1,  # the resource's absolute offset from the start of the page load
                        res.initiator + 1,  # the resource's initiator
                        res.type.value,  # resource type
                        res.size // 1000,  # resource size in KB
                        0,  # not pushed
                        0,  # not preloaded
                    )),
                )

        max_order = max(r.order for group in self.push_groups
                        for r in group.resources)
        for i in range(max_order + 1, MAX_RESOURCES):
            assert np.array_equal(obs["resources"][str(i)],
                                  np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
Exemplo n.º 21
0
    def _generator(env_config: EnvironmentConfig) -> Policy:
        push_groups = env_config.push_groups
        # Collect all resources and group them by type
        all_resources = sorted(
            [res for group in push_groups for res in group.resources],
            key=lambda res: res.order)
        res_by_type = collections.defaultdict(list)
        for res in all_resources:
            # Only consider non-cached objects in the push resource type distribution
            if res.type in dist and res.url not in cached_urls:
                res_by_type[res.type].append(res)

        # choose the number of resources to push/preload
        total = sum(map(len, res_by_type.values()))
        if total <= 1:
            return Policy()

        n = random.randint(1, total)
        # choose the weight factor between push and preload
        weight = push_weight if push_weight is not None else random.random()

        # Choose n resources based on the resource type distribution without replacement
        log.debug("generating push-preload policy",
                  num_resources=len(all_resources),
                  total_size=n,
                  push_weight=weight)
        res = []
        for _ in range(n):
            g, r, s = _choose_with_dist(res_by_type, dist)
            res_by_type[g].pop(r)
            res.append(s)

        policy = Policy()

        for r in res:
            if r.source_id == 0 or r.order == 0:
                continue
            push = random.random() < weight
            policy.steps_taken += 1
            if push:
                source = random.randint(0, r.source_id - 1)
                policy.add_default_push_action(
                    push_groups[r.group_id].resources[source], r)
            else:
                source = random.randint(0, r.order - 1)
                policy.add_default_preload_action(all_resources[source], r)

        return policy
Exemplo n.º 22
0
 def _generator(env_config: EnvironmentConfig) -> Policy:
     push_groups = env_config.push_groups
     # Collect all resources and group them by type
     all_resources = sorted(
         [res for group in push_groups for res in group.resources],
         key=lambda res: res.order)
     # choose the weight factor between push and preload
     main_domain = urllib.parse.urlparse(env_config.request_url)
     policy = Policy()
     for r in all_resources:
         if r.source_id == 0 or r.order == 0:
             continue
         request_domain = urllib.parse.urlparse(r.url)
         push = request_domain.netloc == main_domain.netloc
         policy.steps_taken += 1
         if push:
             source = random.randint(0, r.source_id - 1)
             policy.add_default_push_action(
                 push_groups[r.group_id].resources[source], r)
         else:
             source = random.randint(0, r.order - 1)
             policy.add_default_preload_action(all_resources[source], r)
     return policy
Exemplo n.º 23
0
def page_load_time(args):
    """
    Captures a webpage and calculates the median page load time for a given website
    in a fast, no-latency Mahimahi shell. Then simulates the load based on profiling
    the page in the same Mahimahi shell.
    """
    # Validate the arguments
    if args.latency is not None and args.latency < 0:
        log.critical("provided latency must be greater or equal to 0")
        sys.exit(1)
    if args.bandwidth is not None and args.bandwidth <= 0:
        log.critical("provided bandwidth must be greater than 0")
        sys.exit(1)
    if args.cpu_slowdown is not None and args.cpu_slowdown not in {1, 2, 4}:
        log.critical("provided cpu slodown must be 1, 2, or 4")
        sys.exit(1)

    # Setup the client environment
    default_client_env = get_default_client_environment()
    client_env = get_client_environment_from_parameters(
        args.bandwidth or default_client_env.bandwidth,
        args.latency or default_client_env.latency,
        args.cpu_slowdown or default_client_env.cpu_slowdown,
    )

    # If a push/preload policy was specified, read it
    policy = None
    if args.policy:
        log.debug("reading policy", push_policy=args.policy)
        with open(args.policy, "r") as policy_file:
            policy_dict = json.load(policy_file)
        policy = Policy.from_dict(policy_dict)

    env_config = EnvironmentConfig.load_file(args.from_manifest)
    config = get_config(env_config)

    log.info("calculating page load time",
             manifest=args.from_manifest,
             url=env_config.request_url)
    plt, orig_plt = 0, 0
    if not args.only_simulator:
        if not args.speed_index:
            orig_plt, *_ = get_page_load_time_in_replay_server(
                request_url=config.env_config.request_url,
                client_env=client_env,
                config=config,
                cache_time=args.cache_time,
                user_data_dir=args.user_data_dir,
            )
            if policy:
                plt, *_ = get_page_load_time_in_replay_server(
                    request_url=config.env_config.request_url,
                    client_env=client_env,
                    config=config,
                    policy=policy,
                    cache_time=args.cache_time,
                    user_data_dir=args.user_data_dir,
                )
        else:
            orig_plt = get_speed_index_in_replay_server(
                request_url=config.env_config.request_url,
                client_env=client_env,
                config=config,
                cache_time=args.cache_time,
                user_data_dir=args.user_data_dir,
            )
            if policy:
                plt = get_speed_index_in_replay_server(
                    request_url=config.env_config.request_url,
                    client_env=client_env,
                    config=config,
                    policy=policy,
                    cache_time=args.cache_time,
                    user_data_dir=args.user_data_dir,
                )

    log.debug("running simulator...")
    sim = Simulator(env_config)
    orig_sim_plt = sim.simulate_load_time(client_env)
    sim_plt = sim.simulate_load_time(client_env, policy)

    print(
        json.dumps(
            {
                "client_env": client_env._asdict(),
                "metric": "speed_index" if args.speed_index else "plt",
                "cache": "warm" if args.user_data_dir else "cold",
                "cache_time": args.cache_time,
                "replay_server": {
                    "with_policy": plt,
                    "without_policy": orig_plt
                },
                "simulator": {
                    "with_policy": sim_plt,
                    "without_policy": orig_sim_plt
                },
            },
            indent=4,
        ))
Exemplo n.º 24
0
 def setup(self):
     self.config = get_config()
     self.policy = Policy(ActionSpace(self.config.env_config.push_groups))
     self.client_environment = get_random_client_environment()
Exemplo n.º 25
0
 def test_init(self):
     action_space = get_action_space()
     policy = Policy(action_space)
     assert isinstance(policy, Policy)
     assert policy.action_space == action_space
Exemplo n.º 26
0
def get_mahimahi_config() -> MahiMahiConfig:
    return MahiMahiConfig(
        config=get_config(),
        policy=Policy(ActionSpace(get_push_groups())),
        client_environment=get_random_client_environment(),
    )
Exemplo n.º 27
0
class Environment(gym.Env):
    """
    Environment virtualizes a randomly chosen network and browser environment and
    facilitates the training for a given web page. This includes action selection, policy
    generation, and evaluation of the policy/action in the simulated environment.
    """
    def __init__(self, config: Union[Config, dict]):
        # make sure config is an instance of Config or a dict
        assert isinstance(config, (Config, dict))
        config = config if isinstance(config, Config) else Config(**config)

        self.config = config
        self.env_config = config.env_config
        self.np_random = np.random.RandomState()

        log.info("initialized trainable push groups",
                 groups=[
                     group.name
                     for group in self.env_config.trainable_push_groups
                 ])

        self.observation_space = get_observation_space()
        self.cached_urls = config.cached_urls or set()
        self.analyzer = Analyzer(self.config, config.reward_func or 0,
                                 config.use_aft or False)

        self.client_environment: Optional[ClientEnvironment] = None
        self.action_space: Optional[ActionSpace] = None
        self.policy: Optional[Policy] = None
        self.initialize_environment(
            self.config.client_env
            or client.get_random_fast_lte_client_environment(),
            self.config.cached_urls)

    def seed(self, seed=None):
        self.np_random.seed(seed)

    def reset(self):
        self.initialize_environment(
            client.get_random_fast_lte_client_environment(),
            self.config.cached_urls)
        return self.observation

    def initialize_environment(self,
                               client_environment: ClientEnvironment,
                               cached_urls: Optional[Set[str]] = None):
        """ Initialize the environment """
        log.info(
            "initialized environment",
            network_type=client.NetworkType(client_environment.network_type),
            network_speed=client.NetworkSpeed(
                client_environment.network_speed),
            device_speed=client.DeviceSpeed(client_environment.device_speed),
            bandwidth=client_environment.bandwidth,
            latency=client_environment.latency,
            cpu_slowdown=client_environment.cpu_slowdown,
            loss=client_environment.loss,
            reward_func=self.analyzer.reward_func_num,
            cached_urls=cached_urls,
        )
        # Cache scenarios in hours
        scenarios = [0, 0, 0, 0, 0, 1, 2, 4, 12, 24]
        cache_time = self.np_random.choice(scenarios)
        self.cached_urls = (cached_urls if cached_urls is not None else
                            set() if cache_time == 0 else set(
                                res.url
                                for group in self.env_config.push_groups
                                for res in group.resources
                                if res.cache_time >= (cache_time * 60 * 60)))

        self.client_environment = client_environment
        self.analyzer.reset(self.client_environment, self.cached_urls)

        num_domains_deployed = math.ceil(PROPORTION_DEPLOYED *
                                         len(self.env_config.push_groups))
        push_groups = sorted(self.env_config.push_groups,
                             key=lambda g: len(g.resources),
                             reverse=True)[:num_domains_deployed]

        self.action_space = ActionSpace(push_groups)
        self.policy = Policy(self.action_space)

    def step(self, action: ActionIDType):
        # decode the action and apply it to the policy
        decoded_action = self.action_space.decode_action(action)
        action_applied = self.policy.apply_action(decoded_action)

        # make sure the action isn't used again
        log.info("trying action",
                 action_id=action,
                 action=repr(decoded_action),
                 steps_taken=self.policy.steps_taken)
        self.action_space.use_action(decoded_action)

        reward = NOOP_ACTION_REWARD
        if action_applied:
            reward = self.analyzer.get_reward(self.policy)
            log.info("got reward", action=repr(decoded_action), reward=reward)

        info = {"action": decoded_action, "policy": self.policy.as_dict}
        return self.observation, reward, not action_applied, info

    def render(self, mode="human"):
        return super(Environment, self).render(mode=mode)

    @property
    def observation(self):
        """ Returns an observation for the current state of the environment """
        return get_observation(self.client_environment,
                               self.env_config.push_groups, self.policy,
                               self.cached_urls)