示例#1
0
def setup_eval(config, stats_path, seed=0):
    env_fun = my_utils.import_env(config["env_name"])
    config_tmp = deepcopy(config)
    config_tmp["seed"] = seed
    env = DummyVecEnv(
        [lambda: env_fun(config) for _ in range(config["n_envs"])])
    return env
def main():
    args = parse_args()
    algo_config = my_utils.read_config(args["algo_config"])
    env_config = my_utils.read_config(args["env_config"])
    config = {**args, **algo_config, **env_config}

    print(config)

    for s in ["agents", "agents_cp", "tb"]:
        if not os.path.exists(s):
            os.makedirs(s)

    # Random ID of this session
    if config["default_session_ID"] is None:
        config["session_ID"] = ''.join(
            random.choices(string.ascii_uppercase + string.digits, k=3))
    else:
        config["session_ID"] = "TST"

    # Import correct env by name
    env_fun = my_utils.import_env(config["env_name"])
    env = env_fun(config)

    policy = my_utils.make_policy(env, config)

    if config["log_tb_all"]:
        tb_writer = SummaryWriter(f'tb/{config["session_ID"]}')
        config["tb_writer"] = tb_writer
    else:
        config["tb_writer"] = None

    maml_rl_trainer = MAMLRLTrainer(env, policy, config)

    if config["train"] or socket.gethostname() == "goedel":
        t1 = time.time()
        maml_rl_trainer.meta_train()
        t2 = time.time()

        print("Training time: {}".format(t2 - t1))
        print(config)

    if config["test"] and socket.gethostname() != "goedel":
        if not args["train"]:
            policy.load_state_dict(T.load(config["test_agent_path"]))
        maml_rl_trainer.test(env, policy)
示例#3
0
def setup_train_partial(config, setup_dirs=True):
    T.set_num_threads(1)
    if setup_dirs:
        for s in ["agents", "agents_cp", "tb"]:
            if not os.path.exists(s):
                os.makedirs(s)

    # Random ID of this session
    if config["default_session_ID"] is None:
        config["session_ID"] = ''.join(
            random.choices('ABCDEFGHJKLMNPQRSTUVWXYZ', k=3))
    else:
        config["session_ID"] = config["default_session_ID"]

    # Import correct env by name
    env_fun = my_utils.import_env(config["env_name"])
    env = env_fun(config)

    return env
示例#4
0
def setup_train(config, setup_dirs=True):
    T.set_num_threads(1)
    if setup_dirs:
        for s in ["agents", "agents_cp", "tb"]:
            if not os.path.exists(s):
                os.makedirs(s)

    # Random ID of this session
    if config["default_session_ID"] is None:
        config["session_ID"] = ''.join(
            random.choices('ABCDEFGHJKLMNPQRSTUVWXYZ', k=3))
    else:
        config["session_ID"] = config["default_session_ID"]

    stats_path = "agents/{}_vecnorm.pkl".format(config["session_ID"])

    # Import correct env by name
    env_fun = my_utils.import_env(config["env_name"])
    env = env_fun(config)
    model = make_model(config, env)

    checkpoint_callback = CheckpointCallback(save_freq=100000,
                                             save_path='agents_cp/',
                                             name_prefix=config["session_ID"],
                                             verbose=1)

    # Separate evaluation env
    config_eval = deepcopy(config)
    config_eval["animate"] = False
    eval_env = env_fun(config_eval)
    # Use deterministic actions for evaluation
    eval_callback = EvalCallback(eval_env,
                                 eval_freq=10000,
                                 deterministic=True,
                                 render=False)
    callback_list = CallbackList([checkpoint_callback, eval_callback])

    return env, model, callback_list, stats_path
示例#5
0
def objective(trial, config):
    # Hexapod
    config["std"] = trial.suggest_uniform('std', 0.1, 1.6)

    for s in ["agents", "agents_cp", "tb"]:
        if not os.path.exists(s):
            os.makedirs(s)

        # Random ID of this session
        if config["default_session_ID"] is None:
            config["session_ID"] = ''.join(
                random.choices('ABCDEFGHJKLMNPQRSTUVWXYZ', k=3))
        else:
            config["session_ID"] = "TST"

    # Import correct env by name
    env_fun = my_utils.import_env(config["env_name"])
    env = env_fun(config)

    policy = my_utils.make_policy(env, config)

    if config["algo"] == "cma":
        train(env, policy, config)
    elif config["algo"] == "optuna":
        train_optuna(env, policy, config)
    else:
        print("Algorithm not implemented")
        exit()

    avg_episode_rew = test_agent(env,
                                 policy,
                                 config["N_test"],
                                 print_rew=False)

    env.close()
    del env

    return avg_episode_rew
示例#6
0
def setup_train(config, setup_dirs=True):
    if setup_dirs:
        for s in ["agents", "agents_cp", "tb"]:
            if not os.path.exists(s):
                os.makedirs(s)

    # Random ID of this session
    if config["default_session_ID"] is None:
        config["session_ID"] = ''.join(
            random.choices('ABCDEFGHJKLMNPQRSTUVWXYZ', k=3))
    else:
        config["session_ID"] = config["default_session_ID"]

    stats_path = "agents/{}_vecnorm.pkl".format(config["session_ID"])

    # Import correct env by name
    env_fun = my_utils.import_env(config["env_name"])
    if config["dummy_vec_env"]:
        vec_env = DummyVecEnv(
            [lambda: env_fun(config) for _ in range(config["n_envs"])])
    else:
        vec_env = SubprocVecEnv(
            [lambda: env_fun(config) for _ in range(config["n_envs"])],
            start_method='fork')
    env = VecNormalize(vec_env,
                       gamma=config["gamma"],
                       norm_obs=config["norm_obs"],
                       norm_reward=config["norm_reward"])
    model = make_model(config, env, None)

    checkpoint_callback = CheckpointCallback(save_freq=300000,
                                             save_path='agents_cp/',
                                             name_prefix=config["session_ID"],
                                             verbose=1)

    return env, model, checkpoint_callback, stats_path
示例#7
0
def setup_eval(config, stats_path, seed=0):
    T.set_num_threads(1)
    env_fun = my_utils.import_env(config["env_name"])
    config_tmp = deepcopy(config)
    config_tmp["seed"] = seed
    return env_fun(config_tmp)
示例#8
0
                    log_interval=1)
        t2 = time.time()

        # Make tb run script inside tb dir
        if os.path.exists(os.path.join("tb", config["session_ID"])):
            copyfile("tb_runner.py",
                     os.path.join("tb", config["session_ID"], "tb_runner.py"))

        print("Training time: {}".format(t2 - t1))
        pprint(config)

        model.save("agents/{}_SB_policy".format(config["session_ID"]))
        env.close()

    if args["test"] and socket.gethostname() != "goedel":
        env_fun = my_utils.import_env(config["env_name"])
        config["seed"] = 1337

        env = env_fun(config)
        env.training = False
        env.norm_reward = False

        model = TD3.load("agents/{}".format(args["test_agent_path"]))
        # Normal testing
        N_test = 20
        total_rew = test_agent(env, model, deterministic=False, N=N_test)
        #total_rew = test_agent_mirrored(env, model, deterministic=False, N=N_test, perm=[-1, 0, 3, 2])
        print(f"Total test rew: {total_rew / N_test}")

        # Testing for permutation
        # N_test = 10
 def setup_test(self):
     env_fun = my_utils.import_env(env_config["env_name"])
     env = DummyVecEnv([lambda: env_fun(config)])
     policy = my_utils.make_par_policy(env, config)
     policy.load_state_dict(T.load(config["test_agent_path"]))
     return env, policy
    def setup_train(self, setup_dirs=True):
        if setup_dirs:
            for s in ["agents", "agents_cp", "tb"]:
                if not os.path.exists(s):
                    os.makedirs(s)

        # Random ID of this session
        if self.config["default_session_ID"] is None:
            self.config["session_ID"] = ''.join(
                random.choices('ABCDEFGHJKLMNPQRSTUVWXYZ', k=3))
        else:
            self.config["session_ID"] = "TST"

        # Import correct env by name
        self.env_fun = my_utils.import_env(self.config["env_name"])
        self.env = VecNormalize(SubprocVecEnv([
            lambda: self.env_fun(self.config)
            for _ in range(self.config["n_envs"])
        ],
                                              start_method='fork'),
                                gamma=self.config["gamma"],
                                norm_obs=self.config["norm_obs"],
                                norm_reward=self.config["norm_reward"])

        self.policy = my_utils.make_par_policy(self.env, self.config)

        self.config["tb_writer"] = None
        if self.config["log_tb"] and setup_dirs:
            tb_writer = SummaryWriter(f'tb/{self.config["session_ID"]}')
            self.config["tb_writer"] = tb_writer

        self.config["sdir"] = os.path.join(
            os.path.dirname(os.path.realpath(__file__)),
            f'agents/{self.config["session_ID"]}_AC_policy.p')

        self.policy_optim = None

        if self.config["policy_optim"] == "rmsprop":
            self.policy_optim = T.optim.RMSprop(
                self.policy.parameters(),
                lr=self.config["policy_learning_rate"],
                weight_decay=self.config["weight_decay"],
                eps=1e-8,
                momentum=self.config["momentum"])

        if self.config["policy_optim"] == "sgd":
            self.policy_optim = T.optim.SGD(
                self.policy.parameters(),
                lr=self.config["policy_learning_rate"],
                weight_decay=self.config["weight_decay"],
                momentum=self.config["momentum"])
        if self.config["policy_optim"] == "adam":
            self.policy_optim = T.optim.Adam(
                self.policy.parameters(),
                lr=self.config["policy_learning_rate"],
                weight_decay=self.config["weight_decay"])
        assert self.policy_optim is not None

        self.replay_buffer = ReplayBuffer()

        self.global_step_ctr = 0

        return self.env, self.policy, self.policy_optim, self.replay_buffer