def GetFeatureSpace(env): """This method extracts the observation feature space from the given environment. Args: env (gym.environment): The environment from which the feature space is extracted. Returns: FeatureSpace: The feature space with the low and high bounds set """ observation = env.reset() observer = NearestAgentsObserver() observer._max_num_vehicles = int(len(observation) / 4) - 1 observer.Reset(env._world) space = list( np.array([ observer._world_x_range, observer._world_y_range, observer._ThetaRange, observer._VelocityRange ] * (observer._max_num_vehicles + 1)).transpose()) class FeatureSpace: """The feature space wrapper. """ def __init__(self, low, high): """Constructor """ self.high = high self.low = low return FeatureSpace(space[0], space[1])
def main(): args = configure_args() if is_local: dir_prefix = "" else: dir_prefix="hy-iqn-lfd-full-exp.runfiles/hythe/" logging.info(f"Executing job: {args.jobname}") logging.info(f"Experiment server at: {os.getcwd()}") params = ParameterServer(filename=os.path.join(dir_prefix, params_file), log_if_default=True) params = configure_params(params, seed=args.jobname) experiment_id = params["Experiment"]["random_seed"] params_filename = os.path.join(params["Experiment"]["dir"], "params_{}.json".format(experiment_id)) behavior = BehaviorDiscreteMacroActionsML(params) evaluator = GoalReached(params) observer = NearestAgentsObserver(params) viewer = MPViewer(params=params, x_range=[-35, 35], y_range=[-35, 35], follow_agent_id=True) # extract params and save experiment parameters params["ML"]["BaseAgent"]["SummaryPath"] = os.path.join(params["Experiment"]["dir"], "agent/summaries") params["ML"]["BaseAgent"]["CheckpointPath"] = os.path.join(params["Experiment"]["dir"], "agent/checkpoints") params.Save(filename=params_filename) logging.info('-' * 60) logging.info("Writing params to :{}".format(params_filename)) logging.info('-' * 60) # database creation dbs1 = DatabaseSerializer(test_scenarios=1, test_world_steps=2, num_serialize_scenarios=num_demo_scenarios) dbs1.process(os.path.join(dir_prefix, "configuration/database"), filter_sets="**/**/interaction_merging_light_dense_1D.json") local_release_filename = dbs1.release(version="lfd_offline") db1 = BenchmarkDatabase(database_root=local_release_filename) scenario_generator1, _, _ = db1.get_scenario_generator(0) env = HyDiscreteHighway(params=params, scenario_generation=scenario_generator1, behavior=behavior, evaluator=evaluator, observer=observer, viewer=viewer, render=False) scenario, _ = scenario_generator1.get_next_scenario() world = scenario.GetWorldState() observer.Reset(world) assert env.action_space._n == 8, "Action Space is incorrect!" run(args, params, env, db=db1) params.Save(params_filename) logging.info('-' * 60) logging.info("Writing params to :{}".format(params_filename)) logging.info('-' * 60)