示例#1
0
def test_model_manipulation(model_class, goal_selection_strategy):
    env = BitFlippingEnv(N_BITS,
                         continuous=model_class in [DDPG, SAC],
                         max_steps=N_BITS)
    env = DummyVecEnv([lambda: env])

    model = HER('MlpPolicy',
                env,
                model_class,
                n_sampled_goal=3,
                goal_selection_strategy=goal_selection_strategy,
                verbose=0)
    model.learn(1000)

    model_predict(model, env, n_steps=100, additional_check=None)

    model.save('./test_her')
    del model

    # NOTE: HER does not support VecEnvWrapper yet
    with pytest.raises(AssertionError):
        model = HER.load('./test_her', env=VecNormalize(env))

    model = HER.load('./test_her')

    # Check that the model raises an error when the env
    # is not wrapped (or no env passed to the model)
    with pytest.raises(ValueError):
        model.predict(env.reset())

    env_ = BitFlippingEnv(N_BITS,
                          continuous=model_class in [DDPG, SAC],
                          max_steps=N_BITS)
    env_ = HERGoalEnvWrapper(env_)

    model_predict(model, env_, n_steps=100, additional_check=None)

    model.set_env(env)
    model.learn(1000)

    model_predict(model, env_, n_steps=100, additional_check=None)

    assert model.n_sampled_goal == 3

    del model

    env = BitFlippingEnv(N_BITS,
                         continuous=model_class in [DDPG, SAC],
                         max_steps=N_BITS)
    model = HER.load('./test_her', env=env)
    model.learn(1000)

    model_predict(model, env_, n_steps=100, additional_check=None)

    assert model.n_sampled_goal == 3

    if os.path.isfile('./test_her.pkl'):
        os.remove('./test_her.pkl')
def launchAgent():
    import Reinforcement_AI.env.d_image_env as image_env
    from stable_baselines import DQN, HER, DDPG, PPO2
    from stable_baselines.common import make_vec_env

    model_name = "PPO2"

    if model_name == "HER":
        model = HER(
            "CnnPolicy",
            env=image_env.DetailedMiniMapEnv(),
            model_class=DQN
        )
    if model_name == "DDPG":
        model = DDPG(
            policy="CnnPolicy",
            env=image_env.DDPGImageEnv(),
            normalize_observations=True
        )
    if model_name == "PPO2":
        # env = image_env.DetailedMiniMapEnv()
        env = make_vec_env(image_env.DetailedMiniMapEnv, n_envs=1)
        model = PPO2(
            policy="CnnPolicy",
            env=env,
            verbose=1
        )
    else:
        model = DQN(
            "CnnPolicy",  # policy
            env=image_env.DetailedMiniMapEnv(),  # environment
            double_q=True,  # Double Q enable
            prioritized_replay=True,  # Replay buffer enabled
            verbose=0  # log print
        )

    for i in range(1000):
        if i != 0:
            if model_name == "HER":
                model = HER.load("detailedmap_HER_" + str(i))
                model.set_env(image_env.DetailedMiniMapEnv())
            if model_name == "DDPG":
                model = DDPG.load("detailedmap_DDPG_" + str(i))
                model.set_env(image_env.DDPGImageEnv())
            if model_name == "PPO2":
                # print('set env')
                # ppo2_env = make_vec_env(image_env.DetailedMiniMapEnv, n_envs=1)
                # print('get model')
                model = PPO2.load("detailedmap_PPO2_" + str(i), env)
                # print('set model env')
                # model.set_env(ppo2_env)
            else:
                model = DQN.load("detailedmap_DQN_" + str(i))
                model.set_env(image_env.DetailedMiniMapEnv())

        # print('model learn start')
        model.learn(total_timesteps=3900)
        # print('model learn finished')

        # print('model save start')
        model.save("detailedmap_" + model_name + "_" + str(i+1))
        del model
def train(params, model=None, path=None):
    if model: # indicate in filename that this is a finetune
        if params['name']:
            params['name'] += '_Finetune'
        else:
            params['name'] = 'Finetune'
    
    data_dir, tb_path = get_paths(params, path=path)
    print("Training Parameters: ", params)
    os.makedirs(data_dir, exist_ok=True)
    # Save parameters immediatly
    params.save(data_dir)

    rank = mpi_rank_or_zero()
    if rank != 0:
        logger.set_level(logger.DISABLED)
    
    def make_env(i):
        env = get_env(params)
        env = Monitor(env, data_dir + '/' + str(i), allow_early_resets=params['early_reset'])
        return env

    use_her = params['env_args']['use_her'] if 'use_her' in params['env_args'] else False

    if use_her:
        env = make_env(0)
        goal_selection_strategy = 'future'
    else:
        env = DummyVecEnv([(lambda n: lambda: make_env(n))(i) for i in range(params['num_proc'])])

    if model: # indicate in filename that this is a finetune
        print("Model action space", model.action_space, model.action_space.low)
        print("Env action space", env.action_space, env.action_space.low)
    if params['normalize']:
        env = VecNormalize(env)
    if params['seed']:
        seed = params['seed'] + 100000 * rank
        set_global_seeds(seed)
        params['alg_args']['seed'] = seed
    if 'noise' in params and params['noise']:
        from stable_baselines.ddpg import OrnsteinUhlenbeckActionNoise
        n_actions = env.action_space.shape[-1]
        params['alg_args']['action_noise'] = OrnsteinUhlenbeckActionNoise(mean=np.zeros(n_actions), sigma=float(params['noise'])*np.ones(n_actions))
    
    if model is None:
        alg = get_alg(params)
        policy = get_policy(params)
        if use_her:
            from stable_baselines import HER
            model = HER(policy, env, alg, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy, verbose=1, 
                            tensorboard_log=tb_path, policy_kwargs=params['policy_args'], **params['alg_args'])
        else:
            model = alg(policy,  env, verbose=1, tensorboard_log=tb_path, policy_kwargs=params['policy_args'], **params['alg_args'])
    else:
        model.set_env(env)

    model.learn(total_timesteps=params['timesteps'], log_interval=params['log_interval'], callback=create_training_callback(data_dir, 
                                                    freq=params['eval_freq'], checkpoint_freq=params['checkpoint_freq']))
    print("######## SAVING MODEL TO", data_dir)
    model.save(data_dir +'/final_model')
    if params['normalize']:
        env.save(data_dir + '/normalized_environment.env')
    env.close()