def test_solver_cpp(solver_cpp, parallel, shared_memory): noexcept = True if solver_cpp["entry"] == "UCT": pytest.skip("There is a heap corruption in MCTS solver") try: dom = GridDomain() solver_type = load_registered_solver(solver_cpp["entry"]) solver_args = deepcopy(solver_cpp["config"]) if "parallel" in inspect.signature(solver_type.__init__).parameters: solver_args["parallel"] = parallel if ("shared_memory_proxy" in inspect.signature( solver_type.__init__).parameters and shared_memory): solver_args["shared_memory_proxy"] = GridShmProxy() solver_args["domain_factory"] = lambda: GridDomain() with solver_type(**solver_args) as slv: GridDomain.solve_with(slv) plan, cost = get_plan(dom, slv) except Exception as e: print(e) noexcept = False assert (solver_type.check_domain(dom) and noexcept and ((not solver_cpp["optimal"]) or parallel or (cost == 18 and len(plan) == 18)))
def do_test_python(solver_python, result): noexcept = True try: dom = GridDomain() solver_type = load_registered_solver(solver_python['entry']) solver_args = solver_python['config'] with solver_type(**solver_args) as slv: GridDomain.solve_with(slv) plan, cost = get_plan(dom, slv) except Exception as e: print(e) noexcept = False result.send(solver_type.check_domain(dom) and noexcept and \ ((not solver_python['optimal']) or (cost == 18 and len(plan) == 18))) result.close()
def test_solve_python(solver_python): noexcept = True try: dom = GridDomain() solver_type = load_registered_solver(solver_python["entry"]) solver_args = deepcopy(solver_python["config"]) if solver_python["entry"] == "StableBaseline": solver_args["algo_class"] = PPO with solver_type(**solver_args) as slv: GridDomain.solve_with(slv) plan, cost = get_plan(dom, slv) except Exception as e: print(e) noexcept = False assert ( solver_type.check_domain(dom) and noexcept and ((not solver_python["optimal"]) or (cost == 18 and len(plan) == 18)) )
def do_test_cpp(solver_cpp, parallel, shared_memory, result): noexcept = True try: dom = GridDomain() solver_type = load_registered_solver(solver_cpp['entry']) solver_args = solver_cpp['config'] if 'parallel' in inspect.signature(solver_type.__init__).parameters: solver_args['parallel'] = parallel if 'shared_memory_proxy' in inspect.signature(solver_type.__init__).parameters and shared_memory: solver_args['shared_memory_proxy'] = GridShmProxy() solver_args['domain_factory'] = lambda: GridDomain() with solver_type(**solver_args) as slv: GridDomain.solve_with(slv) plan, cost = get_plan(dom, slv) except Exception as e: print(e) noexcept = False result.send(solver_type.check_domain(dom) and noexcept and \ ((not solver_cpp['optimal']) or parallel or (cost == 18 and len(plan) == 18))) result.close()
{ 'name': 'PPO (deep reinforcement learning)', 'entry': 'StableBaseline', 'config': { 'algo_class': PPO, 'baselines_policy': 'MlpPolicy', 'learn_config': { 'total_timesteps': 30000 }, 'verbose': 1 } } ] # Load solvers (filtering out badly installed ones) solvers = map(lambda s: dict(s, entry=load_registered_solver(s['entry'])), try_solvers) solvers = list(filter(lambda s: s['entry'] is not None, solvers)) solvers.insert(0, { 'name': 'Random Walk', 'entry': None }) # Add Random Walk as option # Run loop to ask user input domain = Maze() while True: # Ask user input to select solver choice = int( input('\nChoose a solver:\n{solvers}\n'.format(solvers='\n'.join( ['0. Quit'] + [f'{i + 1}. {s["name"]}' for i, s in enumerate(solvers)]))))
Value( cost=sqrt( (d.num_cols - 1 - s.x) ** 2 + (d.num_rows - 1 - s.y) ** 2 ) ), 10000, ), "parallel": True, "debug_logs": False, }, }, ] # Load solvers (filtering out badly installed ones) solvers = map( lambda s: dict(s, entry=load_registered_solver(s["entry"])), try_solvers ) solvers = list(filter(lambda s: s["entry"] is not None, solvers)) # Run loop to ask user input domain = MyDomain() # MyDomain(5,5) with tqdm(total=len(solvers) * 100) as pbar: for s in solvers: solver_type = s["entry"] for i in range(50): s["config"]["shared_memory_proxy"] = None with solver_type(**s["config"]) as solver: MyDomain.solve_with(solver) # ,lambda:MyDomain(5,5)) rollout( domain,
# Copyright (c) AIRBUS and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from pprint import pprint from skdecide.utils import ( get_registered_domains, get_registered_solvers, load_registered_domain, load_registered_solver, ) if __name__ == "__main__": print("\nAll registered domains:\n-----------------------") pprint({d: load_registered_domain(d) for d in get_registered_domains()}) print("\nAll registered solvers:\n-----------------------") pprint({s: load_registered_solver(s) for s in get_registered_solvers()})
# BFWS (classical planning) {'name': 'BFWS (planning) - (num_rows * num_cols) binary encoding (1 binary variable <=> 1 cell)', 'entry': 'BFWS', 'need_domain_factory': True, 'config': {'state_features': lambda d, s: d.state_features(s), 'heuristic': lambda d, s: d.heuristic(s), 'termination_checker': lambda d, s: d.is_goal(s), 'parallel': False, 'debug_logs': False}} ] # Load domains (filtering out badly installed ones) domains = map(lambda d: dict(d, entry=load_registered_domain(d['entry'])), try_domains) domains = list(filter(lambda d: d['entry'] is not None, domains)) # Load solvers (filtering out badly installed ones) solvers = map(lambda s: dict(s, entry=load_registered_solver(s['entry'])), try_solvers) solvers = list(filter(lambda s: s['entry'] is not None, solvers)) solvers.insert(0, {'name': 'Random Walk', 'entry': None}) # Add Random Walk as option # Run loop to ask user input solver_candidates = [s['entry'] for s in solvers if s['entry'] is not None] while True: # Ask user input to select domain domain_choice = int(input('\nChoose a domain:\n{domains}\n'.format( domains='\n'.join([f'{i + 1}. {d["name"]}' for i, d in enumerate(domains)])))) selected_domain = domains[domain_choice - 1] domain_type = selected_domain['entry'] domain = domain_type(**selected_domain['config']) while True: # Match solvers compatible with selected domain