def run_single_run(iw_parameters,iw_variants, run_num, Domain, sim_dt, sim_budget, horizon, numberRollouts, runNum, seed) :
    import time
    import numpy as np
    import random
    import json
    import matplotlib.pyplot as plt
    import gc
    import wizluk
    import pandas as pd

    import sys
    import gym
    import wizluk.envs
    import wizluk.policies
    from wizluk.policies import RandomPolicy
    import cv2
    import warnings
    def warn_with_traceback(message, category, filename, lineno, file=None, line=None):

        log = file if hasattr(file,'write') else sys.stderr
        traceback.print_stack(file=log)
        log.write(warnings.formatwarning(message, category, filename, lineno, line))

    warnings.showwarning = warn_with_traceback
    wizluk.setup_logger("iw_gridworld_v1.log")
    env = gym.make(Domain)
    np.random.seed(seed)
    env.seed(seed)
    random.seed(seed)

    IW_Rollout = wizluk.policies.UCT(**(iw_parameters[run_num]))
    IW_Rollout_agent = wizluk.agents.LookaheadAgent(env, IW_Rollout, name='IW_Rollout', domain='GridWorld-16x16-v1')
    IW_Rollout_df = {}
    IW_Rollout_agent.init_evaluation_statistics(IW_Rollout_df)
    S =  np.prod(env.observation_space.shape)

    x = env.reset()
    x0 = copy.deepcopy(x)
    IW_Rollout_agent.start_episode()
    score = 0.0
    for s in range(horizon):
        wizluk.logger.debug("action number: {}".format(s))
        x_flat = np.reshape(x, [1, S])
        u = IW_Rollout_agent.get_action(x)
        x_next, reward, done, info = env.step(u)
        x_next_flat = np.reshape(x_next, [1, S])
        IW_Rollout_agent.observe_transition(x_flat,u, reward, x_next_flat, done, False)
        x = x_next
        score += reward
        if done:
            break
    IW_Rollout_agent.stop_episode()
    IW_Rollout_agent.collect_evaluation_statistics( IW_Rollout_df, x0 )
    with open('../results/{}_{}_simdt_{}_simBud_{}_Horizon_{}_numRoll_{}_runNum_{}.dat'.format(iw_variants['Name'][run_num],Domain, sim_dt, sim_budget, horizon, numberRollouts, runNum), 'wb') as output:
        pickle.dump(score, output, pickle.HIGHEST_PROTOCOL)
Beispiel #2
0
def run_experiment(Domain, sim_dt, sim_budget, horizon, numberRollouts, N, seeds, seed, includeHeur, includeKuth) :
    import time
    import numpy as np
    import random
    import json
    import matplotlib.pyplot as plt
    import gc
    import wizluk
    import pandas as pd

    import sys
    import gym
    import wizluk.envs
    import wizluk.policies
    from wizluk.policies import RandomPolicy
    import cv2
    import warnings

    def warn_with_traceback(message, category, filename, lineno, file=None, line=None):

        log = file if hasattr(file,'write') else sys.stderr
        traceback.print_stack(file=log)
        log.write(warnings.formatwarning(message, category, filename, lineno, line))

    warnings.showwarning = warn_with_traceback
    wizluk.setup_logger("iw_gridworld_v1.log")

    iw_depth_random_rollout_parameters = {
    "budget" : sim_dt,
    "sim_budget" : sim_budget,
    "horizon": horizon,
    "cost_to_go_est": "random_rollout",
    "num_rollouts": numberRollouts,
    "atari": "True",
    "caching": "Partial"
    }

    iw_depth_knuth_parameters = {
    "budget" : sim_dt,
    "sim_budget" : sim_budget,
    "horizon": horizon,
    "cost_to_go_est": "knuth",
    "num_rollouts": numberRollouts,
    "atari": "True",
    "caching": "Partial"
    }

    iw_depth_heur_parameters = {
    "budget" : sim_dt,
    "sim_budget" : sim_budget,
    "horizon": horizon,
    "cost_to_go_est": "heuristic",
    "atari": "True",
    "caching": "Partial"
    }

    #iw_depth_stochastic_enum_parameters = {
    #"budget" : sim_dt,
    #"sim_budget" : sim_budget,
    #"novelty_definition" : "depth",
    #"include_root_in_novelty_table": "True",
    #"horizon": horizon,
    #"pruned_state_strategy": "stochastic_enum",
    #"number_of_paths_to_consider_for_stoch_enum": 2,
    #"num_rollouts": numberRollouts
    #}

    iw_variants = {'Name': []}#, "iw_depth_knuth"]}#, "iw_depth_heur"]}#, "iw_depth_stochastic_enum"]}
    iw_parameters = [] #, iw_depth_knuth_parameters]#, iw_depth_heur_parameters]#, iw_depth_stochastic_enum_parameters]

    if includeKuth:
        iw_variants['Name'].append("OneStep_knuth")
        iw_parameters.append(iw_depth_knuth_parameters)

    if includeHeur:
        iw_variants['Name'].append("OneStep_heur")
        iw_parameters.append(iw_depth_heur_parameters)

    listOfRuns = []
    random.seed(seeds[seed])
    for run_num in range(len(iw_parameters)) :
        seedStart = random.randint(0,50000)
        seedSkip = random.randint(1,201)
        for k in range(N) :
            listOfRuns.append(run_single_run.remote(iw_parameters, iw_variants, run_num, Domain, sim_dt, sim_budget, horizon, numberRollouts, k, seeds[seedStart + k * seedSkip]))
    ray.get(listOfRuns)
    gc.collect()
    return 0