Example #1
0
    def test_make_create_env(self):
        """Tests that the make_create_env methods generates an environment with
        the expected flow parameters."""
        # use a flow_params dict derived from flow/benchmarks/figureeight0.py
        vehicles = Vehicles()
        vehicles.add(veh_id="human",
                     acceleration_controller=(IDMController, {
                         "noise": 0.2
                     }),
                     routing_controller=(ContinuousRouter, {}),
                     speed_mode="no_collide",
                     num_vehicles=13)
        vehicles.add(veh_id="rl",
                     acceleration_controller=(RLController, {}),
                     routing_controller=(ContinuousRouter, {}),
                     speed_mode="no_collide",
                     num_vehicles=1)

        flow_params = dict(
            exp_tag="figure_eight_0",
            env_name="AccelEnv",
            scenario="Figure8Scenario",
            generator="Figure8Generator",
            sumo=SumoParams(
                sim_step=0.1,
                sumo_binary="sumo",
            ),
            env=EnvParams(
                horizon=1500,
                additional_params={
                    "target_velocity": 20,
                    "max_accel": 3,
                    "max_decel": 3,
                },
            ),
            net=NetParams(
                no_internal_links=False,
                additional_params={
                    "radius_ring": 30,
                    "lanes": 1,
                    "speed_limit": 30,
                    "resolution": 40,
                },
            ),
            veh=vehicles,
            initial=InitialConfig(),
            tls=TrafficLights(),
        )

        # some random version number for testing
        v = 23434

        # call make_create_env
        create_env, env_name = make_create_env(params=flow_params, version=v)

        # check that the name is correct
        self.assertEqual(env_name, '{}-v{}'.format(flow_params["env_name"], v))

        # create the gym environment
        env = create_env()

        # Note that we expect the port number in sumo_params to change, and
        # that this feature is in fact needed to avoid race conditions
        flow_params["sumo"].port = env.env.sumo_params.port

        # check that each of the parameter match
        self.assertEqual(env.env.env_params.__dict__,
                         flow_params["env"].__dict__)
        self.assertEqual(env.env.sumo_params.__dict__,
                         flow_params["sumo"].__dict__)
        self.assertEqual(env.env.traffic_lights.__dict__,
                         flow_params["tls"].__dict__)
        self.assertEqual(env.env.scenario.net_params.__dict__,
                         flow_params["net"].__dict__)
        self.assertEqual(env.env.scenario.net_params.__dict__,
                         flow_params["net"].__dict__)
        self.assertEqual(env.env.scenario.initial_config.__dict__,
                         flow_params["initial"].__dict__)
        self.assertEqual(env.env.__class__.__name__, flow_params["env_name"])
        self.assertEqual(env.env.scenario.__class__.__name__,
                         flow_params["scenario"])
        self.assertEqual(env.env.scenario.generator_class.__name__,
                         flow_params["generator"])
Example #2
0
    def test_encoder_and_get_flow_params(self):
        """Tests both FlowParamsEncoder and get_flow_params.

        FlowParamsEncoder is used to serialize the data from a flow_params dict
        for replay by the visualizer later. Then, the get_flow_params method is
        used to try and read the parameters from the config file, and is
        checked to match expected results.
        """
        # use a flow_params dict derived from flow/benchmarks/merge0.py
        vehicles = Vehicles()
        vehicles.add(veh_id="human",
                     acceleration_controller=(IDMController, {}),
                     speed_mode="no_collide",
                     num_vehicles=5)
        vehicles.add(veh_id="rl",
                     acceleration_controller=(RLController, {}),
                     speed_mode="no_collide",
                     num_vehicles=0)

        inflow = InFlows()
        inflow.add(veh_type="human",
                   edge="inflow_highway",
                   vehs_per_hour=1800,
                   departLane="free",
                   departSpeed=10)
        inflow.add(veh_type="rl",
                   edge="inflow_highway",
                   vehs_per_hour=200,
                   departLane="free",
                   departSpeed=10)
        inflow.add(veh_type="human",
                   edge="inflow_merge",
                   vehs_per_hour=100,
                   departLane="free",
                   departSpeed=7.5)

        flow_params = dict(
            exp_tag="merge_0",
            env_name="WaveAttenuationMergePOEnv",
            scenario="MergeScenario",
            generator="MergeGenerator",
            sumo=SumoParams(
                restart_instance=True,
                sim_step=0.5,
                sumo_binary="sumo",
            ),
            env=EnvParams(
                horizon=750,
                sims_per_step=2,
                warmup_steps=0,
                additional_params={
                    "max_accel": 1.5,
                    "max_decel": 1.5,
                    "target_velocity": 20,
                    "num_rl": 5,
                },
            ),
            net=NetParams(
                in_flows=inflow,
                no_internal_links=False,
                additional_params={
                    "merge_length": 100,
                    "pre_merge_length": 500,
                    "post_merge_length": 100,
                    "merge_lanes": 1,
                    "highway_lanes": 1,
                    "speed_limit": 30,
                },
            ),
            veh=vehicles,
            initial=InitialConfig(),
            tls=TrafficLights(),
        )

        # create an config dict with space for the flow_params dict
        config = {"env_config": {}}

        # save the flow params for replay
        flow_json = json.dumps(flow_params,
                               cls=FlowParamsEncoder,
                               sort_keys=True,
                               indent=4)
        config['env_config']['flow_params'] = flow_json

        # dump the config so we can fetch it
        json_out_file = 'params.json'
        with open(os.path.expanduser(json_out_file), 'w+') as outfile:
            json.dump(config,
                      outfile,
                      cls=FlowParamsEncoder,
                      sort_keys=True,
                      indent=4)

        # fetch values using utility function `get_flow_params`
        imported_flow_params = get_flow_params(config)

        # delete the created file
        os.remove(os.path.expanduser('params.json'))

        # test that this inflows are correct
        self.assertTrue(imported_flow_params["net"].in_flows.__dict__ ==
                        flow_params["net"].in_flows.__dict__)

        imported_flow_params["net"].in_flows = None
        flow_params["net"].in_flows = None

        # make sure the rest of the imported flow_params match the originals
        self.assertTrue(imported_flow_params["env"].__dict__ ==
                        flow_params["env"].__dict__)
        self.assertTrue(imported_flow_params["initial"].__dict__ ==
                        flow_params["initial"].__dict__)
        self.assertTrue(imported_flow_params["tls"].__dict__ ==
                        flow_params["tls"].__dict__)
        self.assertTrue(imported_flow_params["sumo"].__dict__ ==
                        flow_params["sumo"].__dict__)
        self.assertTrue(imported_flow_params["net"].__dict__ ==
                        flow_params["net"].__dict__)

        self.assertTrue(
            imported_flow_params["exp_tag"] == flow_params["exp_tag"])
        self.assertTrue(
            imported_flow_params["env_name"] == flow_params["env_name"])
        self.assertTrue(
            imported_flow_params["scenario"] == flow_params["scenario"])
        self.assertTrue(
            imported_flow_params["generator"] == flow_params["generator"])

        def search_dicts(obj1, obj2):
            """Searches through dictionaries as well as lists of dictionaries
            recursively to determine if any two components are mismatched."""
            for key in obj1.keys():
                # if an next element is a list, either compare the two lists,
                # or if the lists contain dictionaries themselves, look at each
                # dictionary component recursively to check for mismatches
                if isinstance(obj1[key], list):
                    if len(obj1[key]) > 0:
                        if isinstance(obj1[key][0], dict):
                            for i in range(len(obj1[key])):
                                if not search_dicts(obj1[key][i],
                                                    obj2[key][i]):
                                    return False
                        elif obj1[key] != obj2[key]:
                            return False
                # if the next element is a dict, run through it recursively to
                # determine if the separate elements of the dict match
                if isinstance(obj1[key], (dict, collections.OrderedDict)):
                    if not search_dicts(obj1[key], obj2[key]):
                        return False
                # if it is neither a list or a dictionary, compare to determine
                # if the two elements match
                elif obj1[key] != obj2[key]:
                    # if the two elements that are being compared are objects,
                    # make sure that they are the same type
                    if not isinstance(obj1[key], type(obj2[key])):
                        return False
            return True

        # make sure that the Vehicles class that was imported matches the
        # original one
        if not search_dicts(imported_flow_params["veh"].__dict__,
                            flow_params["veh"].__dict__):
            raise AssertionError